kumahiyo commited on
Commit
79da350
1 Parent(s): d012c6a

add feature_extractor

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -2
  2. main.py +5 -3
Dockerfile CHANGED
@@ -37,8 +37,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
37
  numpy \
38
  scipy \
39
  tensorboard \
40
- transformers \
41
- streamlit
42
 
43
  # Set the working directory to /code
44
  WORKDIR /code
 
37
  numpy \
38
  scipy \
39
  tensorboard \
40
+ transformers
 
41
 
42
  # Set the working directory to /code
43
  WORKDIR /code
main.py CHANGED
@@ -4,17 +4,17 @@ import sys
4
  import re
5
  import random
6
  import torch
7
- import streamlit as st
8
  from PIL import Image
9
  from fastapi import FastAPI
10
  from fastapi.staticfiles import StaticFiles
11
  from pydantic import BaseModel
12
  from pydantic import Field
13
- from diffusers import DiffusionPipeline
14
  from transformers import (
15
  pipeline,
16
  MBart50TokenizerFast,
17
  MBartForConditionalGeneration,
 
18
  )
19
 
20
  app = FastAPI()
@@ -69,16 +69,18 @@ def draw(data: Data):
69
  # Add model for language translation
70
  trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
71
  trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
 
72
 
73
  model_id = 'stabilityai/stable-diffusion-2'
74
 
75
  #pipe = StableDiffusionPipeline.from_pretrained(model_id)
76
- pipe = DiffusionPipeline.from_pretrained(
77
  model_id,
78
  custom_pipeline="multilingual_stable_diffusion",
79
  detection_pipeline=language_detection_pipeline,
80
  translation_model=trans_model,
81
  translation_tokenizer=trans_tokenizer,
 
82
  revision='fp16',
83
  torch_dtype=torch.float16
84
  )
 
4
  import re
5
  import random
6
  import torch
 
7
  from PIL import Image
8
  from fastapi import FastAPI
9
  from fastapi.staticfiles import StaticFiles
10
  from pydantic import BaseModel
11
  from pydantic import Field
12
+ from diffusers import StableDiffusionPipeline
13
  from transformers import (
14
  pipeline,
15
  MBart50TokenizerFast,
16
  MBartForConditionalGeneration,
17
+ CLIPFeatureExtractor
18
  )
19
 
20
  app = FastAPI()
 
69
  # Add model for language translation
70
  trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
71
  trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
72
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
73
 
74
  model_id = 'stabilityai/stable-diffusion-2'
75
 
76
  #pipe = StableDiffusionPipeline.from_pretrained(model_id)
77
+ pipe = StableDiffusionPipeline.from_pretrained(
78
  model_id,
79
  custom_pipeline="multilingual_stable_diffusion",
80
  detection_pipeline=language_detection_pipeline,
81
  translation_model=trans_model,
82
  translation_tokenizer=trans_tokenizer,
83
+ feature_extractor=feature_extractor,
84
  revision='fp16',
85
  torch_dtype=torch.float16
86
  )