blanchon commited on
Commit
67237f1
·
1 Parent(s): 3cc9869

Enable batch size

Browse files
Files changed (2) hide show
  1. app.py +8 -7
  2. pyproject.toml +1 -1
app.py CHANGED
@@ -20,8 +20,8 @@ SYSTEM_PROMPT = r"""This two-panel split-frame image showcases a furniture in as
20
 
21
  if not torch.cuda.is_available():
22
 
23
- def _dummy_pipe(image: Image.Image, *args, **kwargs): # noqa: ARG001
24
- return {"images": [image]}
25
 
26
  pipe = _dummy_pipe
27
  else:
@@ -59,7 +59,7 @@ def calculate_optimal_dimensions(image: Image.Image) -> tuple[int, int]:
59
  def infer(
60
  furniture_image: Image.Image,
61
  room_image: EditorValue,
62
- prompt: str,
63
  seed: int = 42,
64
  randomize_seed: bool = False,
65
  guidance_scale: float = 3.5,
@@ -132,15 +132,16 @@ def infer(
132
  if randomize_seed:
133
  seed = secrets.randbelow(MAX_SEED)
134
 
 
 
135
  results_images = pipe(
136
- prompt=prompt + ".\n" + SYSTEM_PROMPT,
137
- image=image,
138
- mask_image=mask,
139
  height=height,
140
  width=width,
141
  guidance_scale=guidance_scale,
142
  num_inference_steps=num_inference_steps,
143
- batch_size=4,
144
  generator=torch.Generator("cpu").manual_seed(seed),
145
  )["images"]
146
 
 
20
 
21
  if not torch.cuda.is_available():
22
 
23
+ def _dummy_pipe(image: list[Image.Image], *args, **kwargs): # noqa: ARG001
24
+ return {"images": image}
25
 
26
  pipe = _dummy_pipe
27
  else:
 
59
  def infer(
60
  furniture_image: Image.Image,
61
  room_image: EditorValue,
62
+ prompt: str = "",
63
  seed: int = 42,
64
  randomize_seed: bool = False,
65
  guidance_scale: float = 3.5,
 
132
  if randomize_seed:
133
  seed = secrets.randbelow(MAX_SEED)
134
 
135
+ prompt = prompt + ".\n" + SYSTEM_PROMPT if prompt else SYSTEM_PROMPT
136
+ batch_size = 4
137
  results_images = pipe(
138
+ prompt=[prompt] * batch_size,
139
+ image=[image] * batch_size,
140
+ mask_image=[mask] * batch_size,
141
  height=height,
142
  width=width,
143
  guidance_scale=guidance_scale,
144
  num_inference_steps=num_inference_steps,
 
145
  generator=torch.Generator("cpu").manual_seed(seed),
146
  )["images"]
147
 
pyproject.toml CHANGED
@@ -18,4 +18,4 @@ dependencies = [
18
  ]
19
 
20
  [tool.uv.sources]
21
- diffusers = { git = "https://github.com/huggingface/diffusers.git" }
 
18
  ]
19
 
20
  [tool.uv.sources]
21
+ diffusers = { git = "https://github.com/huggingface/diffusers.git" }