RickyMartin-dev commited on
Commit
ac5a8f8
·
1 Parent(s): 680f307

first push

Browse files
Files changed (3) hide show
  1. app.py +19 -0
  2. requiremnets.txt +76 -0
  3. text_to_image.py +49 -0
app.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Imports
2
+ from text_to_image import TextToImageTool
3
+ import gradio as gr
4
+
5
+ # Define Text to Image Tool
6
+ tool = TextToImageTool()
7
+
8
+ # Helper Function, necessary for Gradio
9
+ def fn(*args, **kwargs):
10
+ return tool(*args, **kwargs)
11
+
12
+ # Gradio Interface
13
+ gr.Interface(
14
+ fn=fn,
15
+ inputs=tool.inputs,
16
+ outputs=tool.outputs,
17
+ title="TextToImageTool",
18
+ article=tool.description,
19
+ ).queue(concurrency_count=5).launch()
requiremnets.txt ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.20.3
2
+ aiofiles==23.1.0
3
+ aiohttp==3.8.4
4
+ aiosignal==1.3.1
5
+ altair==5.0.1
6
+ anyio==3.7.0
7
+ async-timeout==4.0.2
8
+ attrs==23.1.0
9
+ certifi==2023.5.7
10
+ charset-normalizer==3.1.0
11
+ click==8.1.3
12
+ contourpy==1.1.0
13
+ cycler==0.11.0
14
+ diffusers @ git+https://github.com/huggingface/diffusers@666743302ff5bd1e02c204b81a80e566648d60de
15
+ fastapi==0.97.0
16
+ ffmpy==0.3.0
17
+ filelock==3.12.2
18
+ fonttools==4.40.0
19
+ frozenlist==1.3.3
20
+ fsspec==2023.6.0
21
+ gradio==3.35.2
22
+ gradio_client==0.2.7
23
+ h11==0.14.0
24
+ httpcore==0.17.2
25
+ httpx==0.24.1
26
+ huggingface-hub==0.15.1
27
+ idna==3.4
28
+ importlib-metadata==6.7.0
29
+ Jinja2==3.1.2
30
+ jsonschema==4.17.3
31
+ kiwisolver==1.4.4
32
+ linkify-it-py==2.0.2
33
+ markdown-it-py==2.2.0
34
+ MarkupSafe==2.1.3
35
+ matplotlib==3.7.1
36
+ mdit-py-plugins==0.3.3
37
+ mdurl==0.1.2
38
+ mpmath==1.3.0
39
+ multidict==6.0.4
40
+ networkx==3.1
41
+ numpy==1.25.0
42
+ orjson==3.9.1
43
+ packaging==23.1
44
+ pandas==2.0.2
45
+ Pillow==9.5.0
46
+ psutil==5.9.5
47
+ pydantic==1.10.9
48
+ pydub==0.25.1
49
+ Pygments==2.15.1
50
+ pyparsing==3.1.0
51
+ pyrsistent==0.19.3
52
+ python-dateutil==2.8.2
53
+ python-multipart==0.0.6
54
+ pytz==2023.3
55
+ PyYAML==6.0
56
+ regex==2023.6.3
57
+ requests==2.31.0
58
+ safetensors==0.3.1
59
+ semantic-version==2.10.0
60
+ six==1.16.0
61
+ sniffio==1.3.0
62
+ starlette==0.27.0
63
+ sympy==1.12
64
+ tokenizers==0.13.3
65
+ toolz==0.12.0
66
+ torch==2.0.1
67
+ tqdm==4.65.0
68
+ transformers==4.30.2
69
+ typing_extensions==4.6.3
70
+ tzdata==2023.3
71
+ uc-micro-py==1.0.2
72
+ urllib3==2.0.3
73
+ uvicorn==0.22.0
74
+ websockets==11.0.3
75
+ yarl==1.9.2
76
+ zipp==3.15.0
text_to_image.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.tools.base import Tool, get_default_device
2
+ from transformers.utils import is_accelerate_available
3
+ import torch
4
+
5
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
6
+
7
+ # Descrition of Image Processing
8
+ TEXT_TO_IMAGE_DESCRIPTION = (
9
+ "This is a tool that creates an image according to a prompt"
10
+ )
11
+
12
+ # Defining a stable diffusion tool
13
+ class TextToImageTool(Tool):
14
+ default_checkpoint = "runwayml/stable-diffusion-v1-5"
15
+ description = TEXT_TO_IMAGE_DESCRIPTION
16
+ inputs = ['text']
17
+ outputs = ['image']
18
+
19
+ def __init__(self, device=None, **hub_kwargs) -> None:
20
+ if not is_accelerate_available():
21
+ raise ImportError("Accelerate should be installed in order to use tools.")
22
+
23
+ super().__init__()
24
+
25
+ self.device = device
26
+ self.pipeline = None
27
+ self.hub_kwargs = hub_kwargs
28
+
29
+ def setup(self):
30
+ if self.device is None:
31
+ self.device = get_default_device()
32
+
33
+ self.pipeline = DiffusionPipeline.from_pretrained(self.default_checkpoint)
34
+ self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
35
+ self.pipeline.to(self.device)
36
+
37
+ if self.device.type == "cuda":
38
+ self.pipeline.to(torch_dtype=torch.float16)
39
+
40
+ self.is_initialized = True
41
+
42
+ def __call__(self, prompt):
43
+ if not self.is_initialized:
44
+ self.setup()
45
+
46
+ negative_prompt = "low quality, bad quality, deformed, low resolution, janky"
47
+ added_prompt = " , highest quality, highly realistic, very high resolution"
48
+
49
+ return self.pipeline(prompt + added_prompt, negative_prompt=negative_prompt, num_inference_steps=25).images[0]