kumahiyo commited on
Commit
a79e202
1 Parent(s): d4119c7

add multi lang

Browse files
Files changed (1) hide show
  1. main.py +29 -3
main.py CHANGED
@@ -10,6 +10,11 @@ from fastapi.staticfiles import StaticFiles
10
  from pydantic import BaseModel
11
  from pydantic import Field
12
  from diffusers import StableDiffusionPipeline
 
 
 
 
 
13
 
14
  app = FastAPI()
15
 
@@ -28,7 +33,9 @@ def index():
28
  @app.post("/draw", response_model=ItemOut)
29
  def draw(data: Data):
30
  if data.member_secret != "" and data.member_secret == os.environ.get("MEMBER_SECRET"):
31
- print(f"Is CUDA available: {torch.cuda.is_available()}")
 
 
32
 
33
  seedno = 0
34
  if '_seed' in data.string:
@@ -50,13 +57,32 @@ def draw(data: Data):
50
  # https://huggingface.co/docs/hub/spaces-sdks-docker-first-demo
51
  # how to validation: https://qiita.com/bee2/items/75d9c0d7ba20e7a4a0e9
52
  # https://github.com/huggingface/diffusers
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  model_id = 'stabilityai/stable-diffusion-2'
55
 
56
  #pipe = StableDiffusionPipeline.from_pretrained(model_id)
57
- pipe = StableDiffusionPipeline.from_pretrained(model_id, revision='fp16', torch_dtype=torch.float16)
 
 
 
 
 
 
 
 
58
  pipe.enable_attention_slicing() # reduce gpu usage
59
- pipe = pipe.to('cuda')
60
 
61
  generator = torch.Generator("cuda").manual_seed(seed)
62
  images = pipe(prompt, negative_prompt=n_prompt, guidance_scale=7.5, generator=generator, num_images_per_prompt=1).images
 
10
  from pydantic import BaseModel
11
  from pydantic import Field
12
  from diffusers import StableDiffusionPipeline
13
+ from transformers import (
14
+ pipeline,
15
+ MBart50TokenizerFast,
16
+ MBartForConditionalGeneration,
17
+ )
18
 
19
  app = FastAPI()
20
 
 
33
  @app.post("/draw", response_model=ItemOut)
34
  def draw(data: Data):
35
  if data.member_secret != "" and data.member_secret == os.environ.get("MEMBER_SECRET"):
36
+
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ device_dict = {"cuda": 0, "cpu": -1}
39
 
40
  seedno = 0
41
  if '_seed' in data.string:
 
57
  # https://huggingface.co/docs/hub/spaces-sdks-docker-first-demo
58
  # how to validation: https://qiita.com/bee2/items/75d9c0d7ba20e7a4a0e9
59
  # https://github.com/huggingface/diffusers
60
+ # https://github.com/huggingface/diffusers/pull/1142
61
+
62
+ # Add language detection pipeline
63
+ language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
64
+ language_detection_pipeline = pipeline("text-classification",
65
+ model=language_detection_model_ckpt,
66
+ device=device_dict[device])
67
+
68
+ # Add model for language translation
69
+ trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
70
+ trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
71
 
72
  model_id = 'stabilityai/stable-diffusion-2'
73
 
74
  #pipe = StableDiffusionPipeline.from_pretrained(model_id)
75
+ pipe = StableDiffusionPipeline.from_pretrained(
76
+ model_id,
77
+ custom_pipeline="multilingual_stable_diffusion",
78
+ detection_pipeline=language_detection_pipeline,
79
+ translation_model=trans_model,
80
+ translation_tokenizer=trans_tokenizer,
81
+ revision='fp16',
82
+ torch_dtype=torch.float16
83
+ )
84
  pipe.enable_attention_slicing() # reduce gpu usage
85
+ pipe = pipe.to(device)
86
 
87
  generator = torch.Generator("cuda").manual_seed(seed)
88
  images = pipe(prompt, negative_prompt=n_prompt, guidance_scale=7.5, generator=generator, num_images_per_prompt=1).images