Roman Bachmann commited on
Commit
57876e1
1 Parent(s): f742b9c

Initial commit

Browse files
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: 4M
3
- emoji: 👀
4
- colorFrom: blue
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
@@ -10,4 +10,4 @@ pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: 4M Demo
3
+ emoji:
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
 
10
  license: apache-2.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ try:
3
+ # Try to install detectron2 from source. Needed for semseg plotting functionality.
4
+ os.system("python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")
5
+ except Exception as e:
6
+ print('detectron2 cannot be installed. Falling back to simple semseg visualization.')
7
+ print(e)
8
+
9
+ import torch
10
+ # We recommend running this demo on an A100 GPU
11
+ if torch.cuda.is_available():
12
+ device = "cuda"
13
+ gpu_type = torch.cuda.get_device_name(torch.cuda.current_device())
14
+ power_device = f"{gpu_type} GPU"
15
+ torch.cuda.max_memory_allocated(device=device)
16
+ else:
17
+ device = "cpu"
18
+ power_device = "CPU"
19
+ os.system("pip uninstall -y xformers") # Only use xformers on GPU
20
+
21
+ import spaces
22
+ import gradio as gr
23
+ import random
24
+ import numpy as np
25
+ from torchvision.transforms.functional import center_crop
26
+ from fourm.demo_4M_sampler import Demo4MSampler
27
+ from fourm.data.modality_transforms import RGBTransform
28
+
29
+
30
+ # The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
31
+ torch.backends.cuda.matmul.allow_tf32 = True
32
+ # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
33
+ torch.backends.cudnn.allow_tf32 = True
34
+
35
+ MAX_SEED = np.iinfo(np.int32).max
36
+
37
+ FM_MODEL_ID = 'EPFL-VILAB/4M-21_B'
38
+ MODEL_NAME = FM_MODEL_ID.split('/')[1].replace('_', ' ')
39
+
40
+ # Human poses visualization is disabled, since it needs SMPL weights. To enable human pose prediction and rendering:
41
+ # 1) Install via `pip install timm yacs smplx pyrender pyopengl==3.1.4`
42
+ # You may need to follow the pyrender install instructions: https://pyrender.readthedocs.io/en/latest/install/index.html
43
+ # 2) Download SMPL data from https://smpl.is.tue.mpg.de/. See https://github.com/shubham-goel/4D-Humans/ for an example
44
+ # 3) Copy the required SMPL files (smpl_mean_params.npz, SMPL_to_J19.pkl, smpl/SMPL_NEUTRAL.pkl) to fourm/utils/hmr2_utils/data .
45
+
46
+ sampler = Demo4MSampler(
47
+ fm=FM_MODEL_ID,
48
+ fm_sr=None,
49
+ tok_human_poses=None,
50
+ tok_text='./text_tokenizer_4m_wordpiece_30k.json',
51
+ ).to(device)
52
+
53
+
54
+ def img_from_path(img_path: str):
55
+ rgb_transform = RGBTransform(imagenet_default_mean_and_std=True)
56
+ img_pil = rgb_transform.load(img_path)
57
+ img_pil = rgb_transform.preprocess(img_pil)
58
+ img_pil = center_crop(img_pil, (min(img_pil.size), min(img_pil.size))).resize((224,224))
59
+ img = rgb_transform.postprocess(img_pil).unsqueeze(0)
60
+ return img
61
+
62
+ @spaces.GPU
63
+ def infer(img_path, seed=0, randomize_seed=False, target_modalities=None, top_p=0.8, top_k=0.0):
64
+ if randomize_seed:
65
+ seed = None
66
+ img = img_from_path(img_path).to(device)
67
+ preds = sampler({'rgb@224': img}, seed=seed, target_modalities=target_modalities, top_p=top_p, top_k=top_k)
68
+ sampler.plot_modalities(preds, ncols_max=4, use_fixed_plotting_order=True, save_path='./output.png')
69
+ return './output.png'
70
+
71
+
72
+ examples = [
73
+ 'examples/example_0.png', 'examples/example_1.png', 'examples/example_2.png',
74
+ 'examples/example_3.png', 'examples/example_4.png', 'examples/example_5.png',
75
+ ]
76
+
77
+ css="""
78
+ #col-container {
79
+ margin: 0 auto;
80
+ max-width: 1500px;
81
+ }
82
+ #col-input-container {
83
+ margin: 0 auto;
84
+ max-width: 400px;
85
+ }
86
+ #run-button {
87
+ margin: 0 auto;
88
+ }
89
+ """
90
+
91
+ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
92
+
93
+ with gr.Column(elem_id="col-container"):
94
+ gr.Markdown(f"""
95
+ # 4M: Massively Multimodal Masked Modeling
96
+ """)
97
+
98
+ with gr.Row():
99
+ with gr.Column(elem_id="col-input-container"):
100
+ gr.Markdown(f"""
101
+ *A framework for training any-to-any multimodal foundation models. Scalable. Open-sourced. Across tens of modalities and tasks.*
102
+
103
+ [`Website`](https://4m.epfl.ch) | [`GitHub`](https://github.com/apple/ml-4m) <br>[`4M Paper (NeurIPS'23)`](https://arxiv.org/abs/2312.06647) | [`4M-21 Paper (arXiv'24)`](https://arxiv.org/abs/2406.09406)
104
+
105
+ This demo predicts all modalities from a given RGB input, using [{FM_MODEL_ID}](https://huggingface.co/{FM_MODEL_ID}), running on *{power_device}*.
106
+ For more generative examples, and to enable human pose visualizations, please see our [GitHub repo](https://github.com/apple/ml-4m).
107
+
108
+ (Disclaimer: The demo is a work in progress. We will switch it to using 4M-21 XL when running on GPU. Until then, this space runs on CPU and takes several minutes for inference.)
109
+ """)
110
+
111
+ img_path = gr.Image(label='RGB input image', type='filepath')
112
+ run_button = gr.Button(f"Predict with {MODEL_NAME}", scale=0, elem_id="run-button")
113
+
114
+ with gr.Accordion("Advanced Settings", open=False):
115
+ target_modalities = gr.CheckboxGroup(
116
+ choices=[
117
+ ('CLIP-B/16', 'tok_clip@224'), ('DINOv2-B/14', 'tok_dinov2@224'), ('ImageBind-H/14', 'tok_imagebind@224'),
118
+ ('Depth', 'tok_depth@224'), ('Surface normals', 'tok_normal@224'), ('Semantic segmentation', 'tok_semseg@224'),
119
+ ('Canny edges', 'tok_canny_edge@224'), ('SAM edges', 'tok_sam_edge@224'), ('Caption', 'caption'),
120
+ ('Bounding boxes', 'det'), ('SAM instances', 'sam_instance'), ('Color palette', 'color_palette'),
121
+ ('Metadata', 'metadata'),
122
+ ],
123
+ value=[
124
+ 'tok_clip@224', 'tok_dinov2@224', 'tok_imagebind@224',
125
+ 'tok_depth@224', 'tok_normal@224', 'tok_semseg@224',
126
+ 'tok_canny_edge@224', 'tok_sam_edge@224', 'caption',
127
+ 'det', 'sam_instance', 'color_palette', 'metadata'
128
+ ],
129
+ label="Target modalities",
130
+ info='Choose which modalities are predicted (in this order).'
131
+ )
132
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
133
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
134
+ top_p = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, step=0.01, value=0.8)
135
+ top_k = gr.Slider(label="Top-k", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
136
+
137
+ result = gr.Image(label="Predictions", show_label=False)
138
+
139
+ gr.Examples(
140
+ examples = examples,
141
+ fn = infer,
142
+ inputs = [img_path],
143
+ outputs = [result],
144
+ cache_examples='lazy',
145
+ )
146
+
147
+ run_button.click(
148
+ fn = infer,
149
+ inputs = [img_path, seed, randomize_seed, target_modalities, top_p, top_k],
150
+ outputs = [result]
151
+ )
152
+
153
+ demo.queue(max_size=10).launch()
examples/example_0.png ADDED
examples/example_1.png ADDED
examples/example_2.png ADDED
examples/example_3.png ADDED
examples/example_4.png ADDED
examples/example_5.png ADDED
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ fourm @ git+https://github.com/apple/ml-4m@4573d6e
2
+ xformers>=0.0.24
text_tokenizer_4m_wordpiece_30k.json ADDED
The diff for this file is too large to render. See raw diff