DJStomp commited on
Commit
2683ee0
·
verified ·
1 Parent(s): f5756b3

ZeroGPU fix

Browse files
Files changed (1) hide show
  1. app.py +25 -19
app.py CHANGED
@@ -7,11 +7,13 @@ from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipe
7
  from diffusers.models.controlnet_flux import FluxControlNetModel
8
  import numpy as np
9
  from huggingface_hub import login, snapshot_download
 
 
10
 
11
  # Configuration
12
- base_model = 'black-forest-labs/FLUX.1-dev'
13
- controlnet_model = 'promeai/FLUX.1-controlnet-lineart-promeai'
14
- css = """
15
  #col-container {
16
  margin: 0 auto;
17
  max-width: 640px;
@@ -19,28 +21,28 @@ css = """
19
  """
20
 
21
  # Setup
22
- auth_token = os.getenv("HF_AUTH_TOKEN")
23
- if not auth_token:
 
 
24
  raise ValueError("Hugging Face auth token not found. Please set HF_AUTH_TOKEN in the environment.")
25
 
26
- login(auth_token)
27
-
28
- model_dir = snapshot_download(
29
- repo_id=base_model,
30
  revision="main",
31
- use_auth_token=auth_token
32
  )
33
 
34
- device = "cuda" if torch.cuda.is_available() else "cpu"
35
- torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
36
- print(f"Using device: {device} (torch_dtype={torch_dtype})")
37
 
38
- controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch_dtype)
39
- pipe = FluxControlNetPipeline.from_pretrained(model_dir, controlnet=controlnet, torch_dtype=torch_dtype)
40
- pipe = pipe.to(device)
41
 
42
  MAX_SEED = np.iinfo(np.int32).max
43
 
 
44
  def infer(
45
  prompt,
46
  control_image_path,
@@ -50,6 +52,9 @@ def infer(
50
  seed,
51
  randomize_seed,
52
  ):
 
 
 
53
  if randomize_seed:
54
  seed = random.randint(0, MAX_SEED)
55
 
@@ -57,7 +62,7 @@ def infer(
57
  control_image = load_image(control_image_path) if control_image_path else None
58
 
59
  # Generate image
60
- result = pipe(
61
  prompt=prompt,
62
  control_image=control_image,
63
  controlnet_conditioning_scale=controlnet_conditioning_scale,
@@ -68,7 +73,7 @@ def infer(
68
 
69
  return result, seed
70
 
71
- with gr.Blocks(css=css) as demo:
72
  with gr.Column(elem_id="col-container"):
73
  gr.Markdown("Flux.1[dev] LineArt")
74
  gr.Markdown("### Zero-shot Partial Style Transfer for Line Art Images, Powered by FLUX.1")
@@ -118,7 +123,8 @@ with gr.Blocks(css=css) as demo:
118
  gr.Examples(
119
  examples=[
120
  "Shiba Inu wearing dinosaur costume riding skateboard",
121
- "Victorian style mansion interior with candlelight"
 
122
  ],
123
  inputs=[prompt]
124
  )
 
7
  from diffusers.models.controlnet_flux import FluxControlNetModel
8
  import numpy as np
9
  from huggingface_hub import login, snapshot_download
10
+ import spaces
11
+
12
 
13
  # Configuration
14
+ BASE_MODEL = 'black-forest-labs/FLUX.1-dev'
15
+ CONTROLNET_MODEL = 'promeai/FLUX.1-controlnet-lineart-promeai'
16
+ CSS = """
17
  #col-container {
18
  margin: 0 auto;
19
  max-width: 640px;
 
21
  """
22
 
23
  # Setup
24
+ AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
25
+ if AUTH_TOKEN:
26
+ login(AUTH_TOKEN)
27
+ else:
28
  raise ValueError("Hugging Face auth token not found. Please set HF_AUTH_TOKEN in the environment.")
29
 
30
+ MODEL_DIR = snapshot_download(
31
+ repo_id=BASE_MODEL,
 
 
32
  revision="main",
33
+ use_auth_token=AUTH_TOKEN
34
  )
35
 
36
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
37
+ TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
38
 
39
+ CONTROLNET = FluxControlNetModel.from_pretrained(CONTROLNET_MODEL, torch_dtype=TORCH_DTYPE)
40
+ PIPE = FluxControlNetPipeline.from_pretrained(MODEL_DIR, controlnet=CONTROLNET, torch_dtype=TORCH_DTYPE)
41
+ PIPE = PIPE.to(DEVICE)
42
 
43
  MAX_SEED = np.iinfo(np.int32).max
44
 
45
+ @spaces.GPU
46
  def infer(
47
  prompt,
48
  control_image_path,
 
52
  seed,
53
  randomize_seed,
54
  ):
55
+ device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
57
+ print(f"Inference: using device: {device} (torch_dtype={torch_dtype})")
58
  if randomize_seed:
59
  seed = random.randint(0, MAX_SEED)
60
 
 
62
  control_image = load_image(control_image_path) if control_image_path else None
63
 
64
  # Generate image
65
+ result = PIPE(
66
  prompt=prompt,
67
  control_image=control_image,
68
  controlnet_conditioning_scale=controlnet_conditioning_scale,
 
73
 
74
  return result, seed
75
 
76
+ with gr.Blocks(css=CSS) as demo:
77
  with gr.Column(elem_id="col-container"):
78
  gr.Markdown("Flux.1[dev] LineArt")
79
  gr.Markdown("### Zero-shot Partial Style Transfer for Line Art Images, Powered by FLUX.1")
 
123
  gr.Examples(
124
  examples=[
125
  "Shiba Inu wearing dinosaur costume riding skateboard",
126
+ "Victorian style mansion interior with candlelight",
127
+ "Loading screen for Grand Theft Otter: Clam Andreas"
128
  ],
129
  inputs=[prompt]
130
  )