multimodalart HF staff commited on
Commit
ad0cea8
·
verified ·
1 Parent(s): 52b0321

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -65
app.py CHANGED
@@ -8,6 +8,7 @@ from pyramid_dit import PyramidDiTForVideoGeneration
8
  from diffusers.utils import export_to_video
9
 
10
  import spaces
 
11
 
12
  import subprocess
13
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
@@ -62,49 +63,62 @@ model = load_model()
62
 
63
  # Text-to-video generation function
64
  @spaces.GPU(duration=240)
65
- def generate_video(prompt, duration, guidance_scale, video_guidance_scale):
66
  temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
67
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
68
-
69
- with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
70
- frames = model.generate(
71
- prompt=prompt,
72
- num_inference_steps=[20, 20, 20],
73
- video_num_inference_steps=[10, 10, 10],
74
- height=768,
75
- width=1280,
76
- temp=temp,
77
- guidance_scale=guidance_scale,
78
- video_guidance_scale=video_guidance_scale,
79
- output_type="pil",
80
- save_memory=True,
81
- )
82
-
83
- output_path = "output_video.mp4"
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  export_to_video(frames, output_path, fps=24)
85
  return output_path
86
 
87
  # Image-to-video generation function
88
- @spaces.GPU(duration=240)
89
- def generate_video_from_image(image, prompt, duration, video_guidance_scale):
90
- temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
91
- torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
92
-
93
- target_size = (1280, 720)
94
- cropped_image = center_crop(image, 1280, 720)
95
- resized_image = cropped_image.resize((1280, 720))
96
-
97
- with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
98
- frames = model.generate_i2v(
99
- prompt=prompt,
100
- input_image=resized_image,
101
- num_inference_steps=[10, 10, 10],
102
- temp=temp,
103
- guidance_scale=7.0,
104
- video_guidance_scale=video_guidance_scale,
105
- output_type="pil",
106
- save_memory=True,
107
- )
108
 
109
  output_path = "output_video_i2v.mp4"
110
  export_to_video(frames, output_path, fps=24)
@@ -114,38 +128,41 @@ def generate_video_from_image(image, prompt, duration, video_guidance_scale):
114
  with gr.Blocks() as demo:
115
  gr.Markdown("# Pyramid Flow Video Generation Demo")
116
 
117
- with gr.Tab("Text-to-Video"):
118
- with gr.Row():
119
- with gr.Column():
120
- t2v_prompt = gr.Textbox(label="Prompt")
 
 
 
121
  t2v_duration = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Duration (seconds)")
122
  t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
123
  t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
124
- t2v_generate_btn = gr.Button("Generate Video")
125
- with gr.Column():
126
- t2v_output = gr.Video(label="Generated Video")
127
-
128
- t2v_generate_btn.click(
129
- generate_video,
130
- inputs=[t2v_prompt, t2v_duration, t2v_guidance_scale, t2v_video_guidance_scale],
131
- outputs=t2v_output
132
- )
133
 
134
- with gr.Tab("Image-to-Video"):
135
- with gr.Row():
136
- with gr.Column():
137
- i2v_image = gr.Image(type="pil", label="Input Image")
138
- i2v_prompt = gr.Textbox(label="Prompt")
139
- i2v_duration = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Duration (seconds)")
140
- i2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=4, step=0.1, label="Video Guidance Scale")
141
- i2v_generate_btn = gr.Button("Generate Video")
142
- with gr.Column():
143
- i2v_output = gr.Video(label="Generated Video")
 
 
 
 
 
 
144
 
145
- i2v_generate_btn.click(
146
- generate_video_from_image,
147
- inputs=[i2v_image, i2v_prompt, i2v_duration, i2v_video_guidance_scale],
148
- outputs=i2v_output
149
- )
150
 
151
  demo.launch()
 
8
  from diffusers.utils import export_to_video
9
 
10
  import spaces
11
+ import uuid
12
 
13
  import subprocess
14
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
63
 
64
  # Text-to-video generation function
65
  @spaces.GPU(duration=240)
66
+ def generate_video(image, prompt, duration, guidance_scale, video_guidance_scale):
67
  temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
68
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
69
+ if(image):
70
+ cropped_image = center_crop(image, 1280, 720)
71
+ resized_image = cropped_image.resize((1280, 720))
72
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
73
+ frames = model.generate_i2v(
74
+ prompt=prompt,
75
+ input_image=resized_image,
76
+ num_inference_steps=[10, 10, 10],
77
+ temp=temp,
78
+ guidance_scale=7.0,
79
+ video_guidance_scale=video_guidance_scale,
80
+ output_type="pil",
81
+ save_memory=True,
82
+ )
83
+ else:
84
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
85
+ frames = model.generate(
86
+ prompt=prompt,
87
+ num_inference_steps=[20, 20, 20],
88
+ video_num_inference_steps=[10, 10, 10],
89
+ height=768,
90
+ width=1280,
91
+ temp=temp,
92
+ guidance_scale=guidance_scale,
93
+ video_guidance_scale=video_guidance_scale,
94
+ output_type="pil",
95
+ save_memory=True,
96
+ )
97
+ output_path = f"{str(uuid.uuid4())}_output_video.mp4"
98
  export_to_video(frames, output_path, fps=24)
99
  return output_path
100
 
101
  # Image-to-video generation function
102
+ #@spaces.GPU(duration=240)
103
+ #def generate_video_from_image(image, prompt, duration, video_guidance_scale):
104
+ # temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
105
+ # torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
106
+ #
107
+ # target_size = (1280, 720)
108
+ # cropped_image = center_crop(image, 1280, 720)
109
+ # resized_image = cropped_image.resize((1280, 720))
110
+ #
111
+ # with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
112
+ # frames = model.generate_i2v(
113
+ # prompt=prompt,
114
+ # input_image=resized_image,
115
+ # num_inference_steps=[10, 10, 10],
116
+ # temp=temp,
117
+ # guidance_scale=7.0,
118
+ # video_guidance_scale=video_guidance_scale,
119
+ # output_type="pil",
120
+ # save_memory=True,
121
+ # )
122
 
123
  output_path = "output_video_i2v.mp4"
124
  export_to_video(frames, output_path, fps=24)
 
128
  with gr.Blocks() as demo:
129
  gr.Markdown("# Pyramid Flow Video Generation Demo")
130
 
131
+ #with gr.Tab("Text-to-Video"):
132
+ with gr.Row():
133
+ with gr.Column():
134
+ with gr.Accordion("Image to Video (optional)", open=False):
135
+ i2v_image = gr.Image(type="pil", label="Input Image")
136
+ t2v_prompt = gr.Textbox(label="Prompt")
137
+ with gr.Accordion("Advanced settings", open=False):
138
  t2v_duration = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Duration (seconds)")
139
  t2v_guidance_scale = gr.Slider(minimum=1, maximum=15, value=9, step=0.1, label="Guidance Scale")
140
  t2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=5, step=0.1, label="Video Guidance Scale")
141
+ t2v_generate_btn = gr.Button("Generate Video")
142
+ with gr.Column():
143
+ t2v_output = gr.Video(label="Generated Video")
 
 
 
 
 
 
144
 
145
+ t2v_generate_btn.click(
146
+ generate_video,
147
+ inputs=[i2v_image, t2v_prompt, t2v_duration, t2v_guidance_scale, t2v_video_guidance_scale],
148
+ outputs=t2v_output
149
+ )
150
+
151
+ #with gr.Tab("Image-to-Video"):
152
+ # with gr.Row():
153
+ # with gr.Column():
154
+
155
+ # i2v_prompt = gr.Textbox(label="Prompt")
156
+ # i2v_duration = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Duration (seconds)")
157
+ # i2v_video_guidance_scale = gr.Slider(minimum=1, maximum=15, value=4, step=0.1, label="Video Guidance Scale")
158
+ # i2v_generate_btn = gr.Button("Generate Video")
159
+ # with gr.Column():
160
+ # i2v_output = gr.Video(label="Generated Video")
161
 
162
+ #i2v_generate_btn.click(
163
+ # generate_video_from_image,
164
+ # inputs=[i2v_image, i2v_prompt, i2v_duration, i2v_video_guidance_scale],
165
+ # outputs=i2v_output
166
+ #)
167
 
168
  demo.launch()