Francke commited on
Commit
24c345c
·
1 Parent(s): 5d63776
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. latentsync/data/syncnet_dataset.py +153 -0
  2. latentsync/data/unet_dataset.py +164 -0
  3. latentsync/models/attention.py +492 -0
  4. latentsync/models/motion_module.py +332 -0
  5. latentsync/models/resnet.py +234 -0
  6. latentsync/models/syncnet.py +233 -0
  7. latentsync/models/syncnet_wav2lip.py +90 -0
  8. latentsync/models/unet.py +528 -0
  9. latentsync/models/unet_blocks.py +903 -0
  10. latentsync/models/utils.py +19 -0
  11. latentsync/pipelines/lipsync_pipeline.py +470 -0
  12. latentsync/trepa/__init__.py +64 -0
  13. latentsync/trepa/third_party/VideoMAEv2/__init__.py +0 -0
  14. latentsync/trepa/third_party/VideoMAEv2/utils.py +81 -0
  15. latentsync/trepa/third_party/VideoMAEv2/videomaev2_finetune.py +539 -0
  16. latentsync/trepa/third_party/VideoMAEv2/videomaev2_pretrain.py +469 -0
  17. latentsync/trepa/third_party/__init__.py +0 -0
  18. latentsync/trepa/utils/__init__.py +0 -0
  19. latentsync/trepa/utils/data_utils.py +321 -0
  20. latentsync/trepa/utils/metric_utils.py +161 -0
  21. latentsync/utils/affine_transform.py +138 -0
  22. latentsync/utils/audio.py +194 -0
  23. latentsync/utils/av_reader.py +157 -0
  24. latentsync/utils/image_processor.py +342 -0
  25. latentsync/utils/mask.png +0 -0
  26. latentsync/utils/util.py +365 -0
  27. latentsync/whisper/audio2feature.py +166 -0
  28. latentsync/whisper/whisper/__init__.py +119 -0
  29. latentsync/whisper/whisper/__main__.py +4 -0
  30. latentsync/whisper/whisper/assets/gpt2/merges.txt +0 -0
  31. latentsync/whisper/whisper/assets/gpt2/special_tokens_map.json +1 -0
  32. latentsync/whisper/whisper/assets/gpt2/tokenizer_config.json +1 -0
  33. latentsync/whisper/whisper/assets/gpt2/vocab.json +0 -0
  34. latentsync/whisper/whisper/assets/mel_filters.npz +3 -0
  35. latentsync/whisper/whisper/assets/multilingual/added_tokens.json +1 -0
  36. latentsync/whisper/whisper/assets/multilingual/merges.txt +0 -0
  37. latentsync/whisper/whisper/assets/multilingual/special_tokens_map.json +1 -0
  38. latentsync/whisper/whisper/assets/multilingual/tokenizer_config.json +1 -0
  39. latentsync/whisper/whisper/assets/multilingual/vocab.json +0 -0
  40. latentsync/whisper/whisper/audio.py +125 -0
  41. latentsync/whisper/whisper/decoding.py +729 -0
  42. latentsync/whisper/whisper/model.py +290 -0
  43. latentsync/whisper/whisper/normalizers/__init__.py +2 -0
  44. latentsync/whisper/whisper/normalizers/basic.py +71 -0
  45. latentsync/whisper/whisper/normalizers/english.json +1742 -0
  46. latentsync/whisper/whisper/normalizers/english.py +543 -0
  47. latentsync/whisper/whisper/tokenizer.py +331 -0
  48. latentsync/whisper/whisper/transcribe.py +207 -0
  49. latentsync/whisper/whisper/utils.py +87 -0
  50. preprocess/affine_transform.py +137 -0
latentsync/data/syncnet_dataset.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import numpy as np
17
+ from torch.utils.data import Dataset
18
+ import torch
19
+ import random
20
+ from ..utils.util import gather_video_paths_recursively
21
+ from ..utils.image_processor import ImageProcessor
22
+ from ..utils.audio import melspectrogram
23
+ import math
24
+
25
+ from decord import AudioReader, VideoReader, cpu
26
+
27
+
28
+ class SyncNetDataset(Dataset):
29
+ def __init__(self, data_dir: str, fileslist: str, config):
30
+ if fileslist != "":
31
+ with open(fileslist) as file:
32
+ self.video_paths = [line.rstrip() for line in file]
33
+ elif data_dir != "":
34
+ self.video_paths = gather_video_paths_recursively(data_dir)
35
+ else:
36
+ raise ValueError("data_dir and fileslist cannot be both empty")
37
+
38
+ self.resolution = config.data.resolution
39
+ self.num_frames = config.data.num_frames
40
+
41
+ self.mel_window_length = math.ceil(self.num_frames / 5 * 16)
42
+
43
+ self.audio_sample_rate = config.data.audio_sample_rate
44
+ self.video_fps = config.data.video_fps
45
+ self.audio_samples_length = int(
46
+ config.data.audio_sample_rate // config.data.video_fps * config.data.num_frames
47
+ )
48
+ self.image_processor = ImageProcessor(resolution=config.data.resolution, mask="half")
49
+ self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
50
+ os.makedirs(self.audio_mel_cache_dir, exist_ok=True)
51
+
52
+ def __len__(self):
53
+ return len(self.video_paths)
54
+
55
+ def read_audio(self, video_path: str):
56
+ ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
57
+ original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
58
+ return torch.from_numpy(original_mel)
59
+
60
+ def crop_audio_window(self, original_mel, start_index):
61
+ start_idx = int(80.0 * (start_index / float(self.video_fps)))
62
+ end_idx = start_idx + self.mel_window_length
63
+ return original_mel[:, start_idx:end_idx].unsqueeze(0)
64
+
65
+ def get_frames(self, video_reader: VideoReader):
66
+ total_num_frames = len(video_reader)
67
+
68
+ start_idx = random.randint(0, total_num_frames - self.num_frames)
69
+ frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
70
+
71
+ while True:
72
+ wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
73
+ # wrong_start_idx = random.randint(
74
+ # max(0, start_idx - 25), min(total_num_frames - self.num_frames, start_idx + 25)
75
+ # )
76
+ if wrong_start_idx == start_idx:
77
+ continue
78
+ # if wrong_start_idx >= start_idx - self.num_frames and wrong_start_idx <= start_idx + self.num_frames:
79
+ # continue
80
+ wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
81
+ break
82
+
83
+ frames = video_reader.get_batch(frames_index).asnumpy()
84
+ wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
85
+
86
+ return frames, wrong_frames, start_idx
87
+
88
+ def worker_init_fn(self, worker_id):
89
+ # Initialize the face mesh object in each worker process,
90
+ # because the face mesh object cannot be called in subprocesses
91
+ self.worker_id = worker_id
92
+ # setattr(self, f"image_processor_{worker_id}", ImageProcessor(self.resolution, self.mask))
93
+
94
+ def __getitem__(self, idx):
95
+ # image_processor = getattr(self, f"image_processor_{self.worker_id}")
96
+ while True:
97
+ try:
98
+ idx = random.randint(0, len(self) - 1)
99
+
100
+ # Get video file path
101
+ video_path = self.video_paths[idx]
102
+
103
+ vr = VideoReader(video_path, ctx=cpu(self.worker_id))
104
+
105
+ if len(vr) < 2 * self.num_frames:
106
+ continue
107
+
108
+ frames, wrong_frames, start_idx = self.get_frames(vr)
109
+
110
+ mel_cache_path = os.path.join(
111
+ self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
112
+ )
113
+
114
+ if os.path.isfile(mel_cache_path):
115
+ try:
116
+ original_mel = torch.load(mel_cache_path)
117
+ except Exception as e:
118
+ print(f"{type(e).__name__} - {e} - {mel_cache_path}")
119
+ os.remove(mel_cache_path)
120
+ original_mel = self.read_audio(video_path)
121
+ torch.save(original_mel, mel_cache_path)
122
+ else:
123
+ original_mel = self.read_audio(video_path)
124
+ torch.save(original_mel, mel_cache_path)
125
+
126
+ mel = self.crop_audio_window(original_mel, start_idx)
127
+
128
+ if mel.shape[-1] != self.mel_window_length:
129
+ continue
130
+
131
+ if random.choice([True, False]):
132
+ y = torch.ones(1).float()
133
+ chosen_frames = frames
134
+ else:
135
+ y = torch.zeros(1).float()
136
+ chosen_frames = wrong_frames
137
+
138
+ chosen_frames = self.image_processor.process_images(chosen_frames)
139
+ # chosen_frames, _, _ = image_processor.prepare_masks_and_masked_images(
140
+ # chosen_frames, affine_transform=True
141
+ # )
142
+
143
+ vr.seek(0) # avoid memory leak
144
+ break
145
+
146
+ except Exception as e: # Handle the exception of face not detcted
147
+ print(f"{type(e).__name__} - {e} - {video_path}")
148
+ if "vr" in locals():
149
+ vr.seek(0) # avoid memory leak
150
+
151
+ sample = dict(frames=chosen_frames, audio_samples=mel, y=y)
152
+
153
+ return sample
latentsync/data/unet_dataset.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import numpy as np
17
+ from torch.utils.data import Dataset
18
+ import torch
19
+ import random
20
+ import cv2
21
+ from ..utils.image_processor import ImageProcessor, load_fixed_mask
22
+ from ..utils.audio import melspectrogram
23
+ from decord import AudioReader, VideoReader, cpu
24
+
25
+
26
+ class UNetDataset(Dataset):
27
+ def __init__(self, train_data_dir: str, config):
28
+ if config.data.train_fileslist != "":
29
+ with open(config.data.train_fileslist) as file:
30
+ self.video_paths = [line.rstrip() for line in file]
31
+ elif train_data_dir != "":
32
+ self.video_paths = []
33
+ for file in os.listdir(train_data_dir):
34
+ if file.endswith(".mp4"):
35
+ self.video_paths.append(os.path.join(train_data_dir, file))
36
+ else:
37
+ raise ValueError("data_dir and fileslist cannot be both empty")
38
+
39
+ self.resolution = config.data.resolution
40
+ self.num_frames = config.data.num_frames
41
+
42
+ if self.num_frames == 16:
43
+ self.mel_window_length = 52
44
+ elif self.num_frames == 5:
45
+ self.mel_window_length = 16
46
+ else:
47
+ raise NotImplementedError("Only support 16 and 5 frames now")
48
+
49
+ self.audio_sample_rate = config.data.audio_sample_rate
50
+ self.video_fps = config.data.video_fps
51
+ self.mask = config.data.mask
52
+ self.mask_image = load_fixed_mask(self.resolution)
53
+ self.load_audio_data = config.model.add_audio_layer and config.run.use_syncnet
54
+ self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
55
+ os.makedirs(self.audio_mel_cache_dir, exist_ok=True)
56
+
57
+ def __len__(self):
58
+ return len(self.video_paths)
59
+
60
+ def read_audio(self, video_path: str):
61
+ ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
62
+ original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
63
+ return torch.from_numpy(original_mel)
64
+
65
+ def crop_audio_window(self, original_mel, start_index):
66
+ start_idx = int(80.0 * (start_index / float(self.video_fps)))
67
+ end_idx = start_idx + self.mel_window_length
68
+ return original_mel[:, start_idx:end_idx].unsqueeze(0)
69
+
70
+ def get_frames(self, video_reader: VideoReader):
71
+ total_num_frames = len(video_reader)
72
+
73
+ start_idx = random.randint(self.num_frames // 2, total_num_frames - self.num_frames - self.num_frames // 2)
74
+ frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
75
+
76
+ while True:
77
+ wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
78
+ if wrong_start_idx > start_idx - self.num_frames and wrong_start_idx < start_idx + self.num_frames:
79
+ continue
80
+ wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
81
+ break
82
+
83
+ frames = video_reader.get_batch(frames_index).asnumpy()
84
+ wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
85
+
86
+ return frames, wrong_frames, start_idx
87
+
88
+ def worker_init_fn(self, worker_id):
89
+ # Initialize the face mesh object in each worker process,
90
+ # because the face mesh object cannot be called in subprocesses
91
+ self.worker_id = worker_id
92
+ setattr(
93
+ self,
94
+ f"image_processor_{worker_id}",
95
+ ImageProcessor(self.resolution, self.mask, mask_image=self.mask_image),
96
+ )
97
+
98
+ def __getitem__(self, idx):
99
+ image_processor = getattr(self, f"image_processor_{self.worker_id}")
100
+ while True:
101
+ try:
102
+ idx = random.randint(0, len(self) - 1)
103
+
104
+ # Get video file path
105
+ video_path = self.video_paths[idx]
106
+
107
+ vr = VideoReader(video_path, ctx=cpu(self.worker_id))
108
+
109
+ if len(vr) < 3 * self.num_frames:
110
+ continue
111
+
112
+ continuous_frames, ref_frames, start_idx = self.get_frames(vr)
113
+
114
+ if self.load_audio_data:
115
+ mel_cache_path = os.path.join(
116
+ self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
117
+ )
118
+
119
+ if os.path.isfile(mel_cache_path):
120
+ try:
121
+ original_mel = torch.load(mel_cache_path)
122
+ except Exception as e:
123
+ print(f"{type(e).__name__} - {e} - {mel_cache_path}")
124
+ os.remove(mel_cache_path)
125
+ original_mel = self.read_audio(video_path)
126
+ torch.save(original_mel, mel_cache_path)
127
+ else:
128
+ original_mel = self.read_audio(video_path)
129
+ torch.save(original_mel, mel_cache_path)
130
+
131
+ mel = self.crop_audio_window(original_mel, start_idx)
132
+
133
+ if mel.shape[-1] != self.mel_window_length:
134
+ continue
135
+ else:
136
+ mel = []
137
+
138
+ gt, masked_gt, mask = image_processor.prepare_masks_and_masked_images(
139
+ continuous_frames, affine_transform=False
140
+ )
141
+
142
+ if self.mask == "fix_mask":
143
+ ref, _, _ = image_processor.prepare_masks_and_masked_images(ref_frames, affine_transform=False)
144
+ else:
145
+ ref = image_processor.process_images(ref_frames)
146
+ vr.seek(0) # avoid memory leak
147
+ break
148
+
149
+ except Exception as e: # Handle the exception of face not detcted
150
+ print(f"{type(e).__name__} - {e} - {video_path}")
151
+ if "vr" in locals():
152
+ vr.seek(0) # avoid memory leak
153
+
154
+ sample = dict(
155
+ gt=gt,
156
+ masked_gt=masked_gt,
157
+ ref=ref,
158
+ mel=mel,
159
+ mask=mask,
160
+ video_path=video_path,
161
+ start_idx=start_idx,
162
+ )
163
+
164
+ return sample
latentsync/models/attention.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from turtle import forward
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.modeling_utils import ModelMixin
13
+ from diffusers.utils import BaseOutput
14
+ from diffusers.utils.import_utils import is_xformers_available
15
+ from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
16
+
17
+ from einops import rearrange, repeat
18
+ from .utils import zero_module
19
+
20
+
21
+ @dataclass
22
+ class Transformer3DModelOutput(BaseOutput):
23
+ sample: torch.FloatTensor
24
+
25
+
26
+ if is_xformers_available():
27
+ import xformers
28
+ import xformers.ops
29
+ else:
30
+ xformers = None
31
+
32
+
33
+ class Transformer3DModel(ModelMixin, ConfigMixin):
34
+ @register_to_config
35
+ def __init__(
36
+ self,
37
+ num_attention_heads: int = 16,
38
+ attention_head_dim: int = 88,
39
+ in_channels: Optional[int] = None,
40
+ num_layers: int = 1,
41
+ dropout: float = 0.0,
42
+ norm_num_groups: int = 32,
43
+ cross_attention_dim: Optional[int] = None,
44
+ attention_bias: bool = False,
45
+ activation_fn: str = "geglu",
46
+ num_embeds_ada_norm: Optional[int] = None,
47
+ use_linear_projection: bool = False,
48
+ only_cross_attention: bool = False,
49
+ upcast_attention: bool = False,
50
+ use_motion_module: bool = False,
51
+ unet_use_cross_frame_attention=None,
52
+ unet_use_temporal_attention=None,
53
+ add_audio_layer=False,
54
+ audio_condition_method="cross_attn",
55
+ custom_audio_layer: bool = False,
56
+ ):
57
+ super().__init__()
58
+ self.use_linear_projection = use_linear_projection
59
+ self.num_attention_heads = num_attention_heads
60
+ self.attention_head_dim = attention_head_dim
61
+ inner_dim = num_attention_heads * attention_head_dim
62
+
63
+ # Define input layers
64
+ self.in_channels = in_channels
65
+
66
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
67
+ if use_linear_projection:
68
+ self.proj_in = nn.Linear(in_channels, inner_dim)
69
+ else:
70
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
71
+
72
+ if not custom_audio_layer:
73
+ # Define transformers blocks
74
+ self.transformer_blocks = nn.ModuleList(
75
+ [
76
+ BasicTransformerBlock(
77
+ inner_dim,
78
+ num_attention_heads,
79
+ attention_head_dim,
80
+ dropout=dropout,
81
+ cross_attention_dim=cross_attention_dim,
82
+ activation_fn=activation_fn,
83
+ num_embeds_ada_norm=num_embeds_ada_norm,
84
+ attention_bias=attention_bias,
85
+ only_cross_attention=only_cross_attention,
86
+ upcast_attention=upcast_attention,
87
+ use_motion_module=use_motion_module,
88
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
89
+ unet_use_temporal_attention=unet_use_temporal_attention,
90
+ add_audio_layer=add_audio_layer,
91
+ custom_audio_layer=custom_audio_layer,
92
+ audio_condition_method=audio_condition_method,
93
+ )
94
+ for d in range(num_layers)
95
+ ]
96
+ )
97
+ else:
98
+ self.transformer_blocks = nn.ModuleList(
99
+ [
100
+ AudioTransformerBlock(
101
+ inner_dim,
102
+ num_attention_heads,
103
+ attention_head_dim,
104
+ dropout=dropout,
105
+ cross_attention_dim=cross_attention_dim,
106
+ activation_fn=activation_fn,
107
+ num_embeds_ada_norm=num_embeds_ada_norm,
108
+ attention_bias=attention_bias,
109
+ only_cross_attention=only_cross_attention,
110
+ upcast_attention=upcast_attention,
111
+ use_motion_module=use_motion_module,
112
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
113
+ unet_use_temporal_attention=unet_use_temporal_attention,
114
+ add_audio_layer=add_audio_layer,
115
+ )
116
+ for d in range(num_layers)
117
+ ]
118
+ )
119
+
120
+ # 4. Define output layers
121
+ if use_linear_projection:
122
+ self.proj_out = nn.Linear(in_channels, inner_dim)
123
+ else:
124
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
125
+
126
+ if custom_audio_layer:
127
+ self.proj_out = zero_module(self.proj_out)
128
+
129
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
130
+ # Input
131
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
132
+ video_length = hidden_states.shape[2]
133
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
134
+
135
+ # No need to do this for audio input, because different audio samples are independent
136
+ # encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
137
+
138
+ batch, channel, height, weight = hidden_states.shape
139
+ residual = hidden_states
140
+
141
+ hidden_states = self.norm(hidden_states)
142
+ if not self.use_linear_projection:
143
+ hidden_states = self.proj_in(hidden_states)
144
+ inner_dim = hidden_states.shape[1]
145
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
146
+ else:
147
+ inner_dim = hidden_states.shape[1]
148
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
149
+ hidden_states = self.proj_in(hidden_states)
150
+
151
+ # Blocks
152
+ for block in self.transformer_blocks:
153
+ hidden_states = block(
154
+ hidden_states,
155
+ encoder_hidden_states=encoder_hidden_states,
156
+ timestep=timestep,
157
+ video_length=video_length,
158
+ )
159
+
160
+ # Output
161
+ if not self.use_linear_projection:
162
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
163
+ hidden_states = self.proj_out(hidden_states)
164
+ else:
165
+ hidden_states = self.proj_out(hidden_states)
166
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
167
+
168
+ output = hidden_states + residual
169
+
170
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
171
+ if not return_dict:
172
+ return (output,)
173
+
174
+ return Transformer3DModelOutput(sample=output)
175
+
176
+
177
+ class BasicTransformerBlock(nn.Module):
178
+ def __init__(
179
+ self,
180
+ dim: int,
181
+ num_attention_heads: int,
182
+ attention_head_dim: int,
183
+ dropout=0.0,
184
+ cross_attention_dim: Optional[int] = None,
185
+ activation_fn: str = "geglu",
186
+ num_embeds_ada_norm: Optional[int] = None,
187
+ attention_bias: bool = False,
188
+ only_cross_attention: bool = False,
189
+ upcast_attention: bool = False,
190
+ use_motion_module: bool = False,
191
+ unet_use_cross_frame_attention=None,
192
+ unet_use_temporal_attention=None,
193
+ add_audio_layer=False,
194
+ custom_audio_layer=False,
195
+ audio_condition_method="cross_attn",
196
+ ):
197
+ super().__init__()
198
+ self.only_cross_attention = only_cross_attention
199
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
200
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
201
+ self.unet_use_temporal_attention = unet_use_temporal_attention
202
+ self.use_motion_module = use_motion_module
203
+ self.add_audio_layer = add_audio_layer
204
+
205
+ # SC-Attn
206
+ assert unet_use_cross_frame_attention is not None
207
+ if unet_use_cross_frame_attention:
208
+ raise NotImplementedError("SparseCausalAttention2D not implemented yet.")
209
+ else:
210
+ self.attn1 = CrossAttention(
211
+ query_dim=dim,
212
+ heads=num_attention_heads,
213
+ dim_head=attention_head_dim,
214
+ dropout=dropout,
215
+ bias=attention_bias,
216
+ upcast_attention=upcast_attention,
217
+ )
218
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
219
+
220
+ # Cross-Attn
221
+ if add_audio_layer and audio_condition_method == "cross_attn" and not custom_audio_layer:
222
+ self.audio_cross_attn = AudioCrossAttn(
223
+ dim=dim,
224
+ cross_attention_dim=cross_attention_dim,
225
+ num_attention_heads=num_attention_heads,
226
+ attention_head_dim=attention_head_dim,
227
+ dropout=dropout,
228
+ attention_bias=attention_bias,
229
+ upcast_attention=upcast_attention,
230
+ num_embeds_ada_norm=num_embeds_ada_norm,
231
+ use_ada_layer_norm=self.use_ada_layer_norm,
232
+ zero_proj_out=False,
233
+ )
234
+ else:
235
+ self.audio_cross_attn = None
236
+
237
+ # Feed-forward
238
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
239
+ self.norm3 = nn.LayerNorm(dim)
240
+
241
+ # Temp-Attn
242
+ assert unet_use_temporal_attention is not None
243
+ if unet_use_temporal_attention:
244
+ self.attn_temp = CrossAttention(
245
+ query_dim=dim,
246
+ heads=num_attention_heads,
247
+ dim_head=attention_head_dim,
248
+ dropout=dropout,
249
+ bias=attention_bias,
250
+ upcast_attention=upcast_attention,
251
+ )
252
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
253
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
254
+
255
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
256
+ if not is_xformers_available():
257
+ print("Here is how to install it")
258
+ raise ModuleNotFoundError(
259
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
260
+ " xformers",
261
+ name="xformers",
262
+ )
263
+ elif not torch.cuda.is_available():
264
+ raise ValueError(
265
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
266
+ " available for GPU "
267
+ )
268
+ else:
269
+ try:
270
+ # Make sure we can run the memory efficient attention
271
+ _ = xformers.ops.memory_efficient_attention(
272
+ torch.randn((1, 2, 40), device="cuda"),
273
+ torch.randn((1, 2, 40), device="cuda"),
274
+ torch.randn((1, 2, 40), device="cuda"),
275
+ )
276
+ except Exception as e:
277
+ raise e
278
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
279
+ if self.audio_cross_attn is not None:
280
+ self.audio_cross_attn.attn._use_memory_efficient_attention_xformers = (
281
+ use_memory_efficient_attention_xformers
282
+ )
283
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
284
+
285
+ def forward(
286
+ self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
287
+ ):
288
+ # SparseCausal-Attention
289
+ norm_hidden_states = (
290
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
291
+ )
292
+
293
+ # if self.only_cross_attention:
294
+ # hidden_states = (
295
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
296
+ # )
297
+ # else:
298
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
299
+
300
+ # pdb.set_trace()
301
+ if self.unet_use_cross_frame_attention:
302
+ hidden_states = (
303
+ self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
304
+ + hidden_states
305
+ )
306
+ else:
307
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
308
+
309
+ if self.audio_cross_attn is not None and encoder_hidden_states is not None:
310
+ hidden_states = self.audio_cross_attn(
311
+ hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
312
+ )
313
+
314
+ # Feed-forward
315
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
316
+
317
+ # Temporal-Attention
318
+ if self.unet_use_temporal_attention:
319
+ d = hidden_states.shape[1]
320
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
321
+ norm_hidden_states = (
322
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
323
+ )
324
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
325
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
326
+
327
+ return hidden_states
328
+
329
+
330
+ class AudioTransformerBlock(nn.Module):
331
+ def __init__(
332
+ self,
333
+ dim: int,
334
+ num_attention_heads: int,
335
+ attention_head_dim: int,
336
+ dropout=0.0,
337
+ cross_attention_dim: Optional[int] = None,
338
+ activation_fn: str = "geglu",
339
+ num_embeds_ada_norm: Optional[int] = None,
340
+ attention_bias: bool = False,
341
+ only_cross_attention: bool = False,
342
+ upcast_attention: bool = False,
343
+ use_motion_module: bool = False,
344
+ unet_use_cross_frame_attention=None,
345
+ unet_use_temporal_attention=None,
346
+ add_audio_layer=False,
347
+ ):
348
+ super().__init__()
349
+ self.only_cross_attention = only_cross_attention
350
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
351
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
352
+ self.unet_use_temporal_attention = unet_use_temporal_attention
353
+ self.use_motion_module = use_motion_module
354
+ self.add_audio_layer = add_audio_layer
355
+
356
+ # SC-Attn
357
+ assert unet_use_cross_frame_attention is not None
358
+ if unet_use_cross_frame_attention:
359
+ raise NotImplementedError("SparseCausalAttention2D not implemented yet.")
360
+ else:
361
+ self.attn1 = CrossAttention(
362
+ query_dim=dim,
363
+ heads=num_attention_heads,
364
+ dim_head=attention_head_dim,
365
+ dropout=dropout,
366
+ bias=attention_bias,
367
+ upcast_attention=upcast_attention,
368
+ )
369
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
370
+
371
+ self.audio_cross_attn = AudioCrossAttn(
372
+ dim=dim,
373
+ cross_attention_dim=cross_attention_dim,
374
+ num_attention_heads=num_attention_heads,
375
+ attention_head_dim=attention_head_dim,
376
+ dropout=dropout,
377
+ attention_bias=attention_bias,
378
+ upcast_attention=upcast_attention,
379
+ num_embeds_ada_norm=num_embeds_ada_norm,
380
+ use_ada_layer_norm=self.use_ada_layer_norm,
381
+ zero_proj_out=False,
382
+ )
383
+
384
+ # Feed-forward
385
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
386
+ self.norm3 = nn.LayerNorm(dim)
387
+
388
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
389
+ if not is_xformers_available():
390
+ print("Here is how to install it")
391
+ raise ModuleNotFoundError(
392
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
393
+ " xformers",
394
+ name="xformers",
395
+ )
396
+ elif not torch.cuda.is_available():
397
+ raise ValueError(
398
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
399
+ " available for GPU "
400
+ )
401
+ else:
402
+ try:
403
+ # Make sure we can run the memory efficient attention
404
+ _ = xformers.ops.memory_efficient_attention(
405
+ torch.randn((1, 2, 40), device="cuda"),
406
+ torch.randn((1, 2, 40), device="cuda"),
407
+ torch.randn((1, 2, 40), device="cuda"),
408
+ )
409
+ except Exception as e:
410
+ raise e
411
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
412
+ if self.audio_cross_attn is not None:
413
+ self.audio_cross_attn.attn._use_memory_efficient_attention_xformers = (
414
+ use_memory_efficient_attention_xformers
415
+ )
416
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
417
+
418
+ def forward(
419
+ self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
420
+ ):
421
+ # SparseCausal-Attention
422
+ norm_hidden_states = (
423
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
424
+ )
425
+
426
+ # pdb.set_trace()
427
+ if self.unet_use_cross_frame_attention:
428
+ hidden_states = (
429
+ self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
430
+ + hidden_states
431
+ )
432
+ else:
433
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
434
+
435
+ if self.audio_cross_attn is not None and encoder_hidden_states is not None:
436
+ hidden_states = self.audio_cross_attn(
437
+ hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
438
+ )
439
+
440
+ # Feed-forward
441
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
442
+
443
+ return hidden_states
444
+
445
+
446
+ class AudioCrossAttn(nn.Module):
447
+ def __init__(
448
+ self,
449
+ dim,
450
+ cross_attention_dim,
451
+ num_attention_heads,
452
+ attention_head_dim,
453
+ dropout,
454
+ attention_bias,
455
+ upcast_attention,
456
+ num_embeds_ada_norm,
457
+ use_ada_layer_norm,
458
+ zero_proj_out=False,
459
+ ):
460
+ super().__init__()
461
+
462
+ self.norm = AdaLayerNorm(dim, num_embeds_ada_norm) if use_ada_layer_norm else nn.LayerNorm(dim)
463
+ self.attn = CrossAttention(
464
+ query_dim=dim,
465
+ cross_attention_dim=cross_attention_dim,
466
+ heads=num_attention_heads,
467
+ dim_head=attention_head_dim,
468
+ dropout=dropout,
469
+ bias=attention_bias,
470
+ upcast_attention=upcast_attention,
471
+ )
472
+
473
+ if zero_proj_out:
474
+ self.proj_out = zero_module(nn.Linear(dim, dim))
475
+
476
+ self.zero_proj_out = zero_proj_out
477
+ self.use_ada_layer_norm = use_ada_layer_norm
478
+
479
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
480
+ previous_hidden_states = hidden_states
481
+ hidden_states = self.norm(hidden_states, timestep) if self.use_ada_layer_norm else self.norm(hidden_states)
482
+
483
+ if encoder_hidden_states.dim() == 4:
484
+ encoder_hidden_states = rearrange(encoder_hidden_states, "b f n d -> (b f) n d")
485
+
486
+ hidden_states = self.attn(
487
+ hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
488
+ )
489
+
490
+ if self.zero_proj_out:
491
+ hidden_states = self.proj_out(hidden_states)
492
+ return hidden_states + previous_hidden_states
latentsync/models/motion_module.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2
+
3
+ # Actually we don't use the motion module in the final version of LatentSync
4
+ # When we started the project, we used the codebase of AnimateDiff and tried motion module
5
+ # But the results are poor, and we decied to leave the code here for possible future usage
6
+
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.modeling_utils import ModelMixin
15
+ from diffusers.utils import BaseOutput
16
+ from diffusers.utils.import_utils import is_xformers_available
17
+ from diffusers.models.attention import CrossAttention, FeedForward
18
+
19
+ from einops import rearrange, repeat
20
+ import math
21
+ from .utils import zero_module
22
+
23
+
24
+ @dataclass
25
+ class TemporalTransformer3DModelOutput(BaseOutput):
26
+ sample: torch.FloatTensor
27
+
28
+
29
+ if is_xformers_available():
30
+ import xformers
31
+ import xformers.ops
32
+ else:
33
+ xformers = None
34
+
35
+
36
+ def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
37
+ if motion_module_type == "Vanilla":
38
+ return VanillaTemporalModule(
39
+ in_channels=in_channels,
40
+ **motion_module_kwargs,
41
+ )
42
+ else:
43
+ raise ValueError
44
+
45
+
46
+ class VanillaTemporalModule(nn.Module):
47
+ def __init__(
48
+ self,
49
+ in_channels,
50
+ num_attention_heads=8,
51
+ num_transformer_block=2,
52
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
53
+ cross_frame_attention_mode=None,
54
+ temporal_position_encoding=False,
55
+ temporal_position_encoding_max_len=24,
56
+ temporal_attention_dim_div=1,
57
+ zero_initialize=True,
58
+ ):
59
+ super().__init__()
60
+
61
+ self.temporal_transformer = TemporalTransformer3DModel(
62
+ in_channels=in_channels,
63
+ num_attention_heads=num_attention_heads,
64
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
65
+ num_layers=num_transformer_block,
66
+ attention_block_types=attention_block_types,
67
+ cross_frame_attention_mode=cross_frame_attention_mode,
68
+ temporal_position_encoding=temporal_position_encoding,
69
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70
+ )
71
+
72
+ if zero_initialize:
73
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
74
+
75
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
76
+ hidden_states = input_tensor
77
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
78
+
79
+ output = hidden_states
80
+ return output
81
+
82
+
83
+ class TemporalTransformer3DModel(nn.Module):
84
+ def __init__(
85
+ self,
86
+ in_channels,
87
+ num_attention_heads,
88
+ attention_head_dim,
89
+ num_layers,
90
+ attention_block_types=(
91
+ "Temporal_Self",
92
+ "Temporal_Self",
93
+ ),
94
+ dropout=0.0,
95
+ norm_num_groups=32,
96
+ cross_attention_dim=768,
97
+ activation_fn="geglu",
98
+ attention_bias=False,
99
+ upcast_attention=False,
100
+ cross_frame_attention_mode=None,
101
+ temporal_position_encoding=False,
102
+ temporal_position_encoding_max_len=24,
103
+ ):
104
+ super().__init__()
105
+
106
+ inner_dim = num_attention_heads * attention_head_dim
107
+
108
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
109
+ self.proj_in = nn.Linear(in_channels, inner_dim)
110
+
111
+ self.transformer_blocks = nn.ModuleList(
112
+ [
113
+ TemporalTransformerBlock(
114
+ dim=inner_dim,
115
+ num_attention_heads=num_attention_heads,
116
+ attention_head_dim=attention_head_dim,
117
+ attention_block_types=attention_block_types,
118
+ dropout=dropout,
119
+ norm_num_groups=norm_num_groups,
120
+ cross_attention_dim=cross_attention_dim,
121
+ activation_fn=activation_fn,
122
+ attention_bias=attention_bias,
123
+ upcast_attention=upcast_attention,
124
+ cross_frame_attention_mode=cross_frame_attention_mode,
125
+ temporal_position_encoding=temporal_position_encoding,
126
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
127
+ )
128
+ for d in range(num_layers)
129
+ ]
130
+ )
131
+ self.proj_out = nn.Linear(inner_dim, in_channels)
132
+
133
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
134
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
135
+ video_length = hidden_states.shape[2]
136
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
137
+
138
+ batch, channel, height, weight = hidden_states.shape
139
+ residual = hidden_states
140
+
141
+ hidden_states = self.norm(hidden_states)
142
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
143
+ hidden_states = self.proj_in(hidden_states)
144
+
145
+ # Transformer Blocks
146
+ for block in self.transformer_blocks:
147
+ hidden_states = block(
148
+ hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length
149
+ )
150
+
151
+ # output
152
+ hidden_states = self.proj_out(hidden_states)
153
+ hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2).contiguous()
154
+
155
+ output = hidden_states + residual
156
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
157
+
158
+ return output
159
+
160
+
161
+ class TemporalTransformerBlock(nn.Module):
162
+ def __init__(
163
+ self,
164
+ dim,
165
+ num_attention_heads,
166
+ attention_head_dim,
167
+ attention_block_types=(
168
+ "Temporal_Self",
169
+ "Temporal_Self",
170
+ ),
171
+ dropout=0.0,
172
+ norm_num_groups=32,
173
+ cross_attention_dim=768,
174
+ activation_fn="geglu",
175
+ attention_bias=False,
176
+ upcast_attention=False,
177
+ cross_frame_attention_mode=None,
178
+ temporal_position_encoding=False,
179
+ temporal_position_encoding_max_len=24,
180
+ ):
181
+ super().__init__()
182
+
183
+ attention_blocks = []
184
+ norms = []
185
+
186
+ for block_name in attention_block_types:
187
+ attention_blocks.append(
188
+ VersatileAttention(
189
+ attention_mode=block_name.split("_")[0],
190
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
191
+ query_dim=dim,
192
+ heads=num_attention_heads,
193
+ dim_head=attention_head_dim,
194
+ dropout=dropout,
195
+ bias=attention_bias,
196
+ upcast_attention=upcast_attention,
197
+ cross_frame_attention_mode=cross_frame_attention_mode,
198
+ temporal_position_encoding=temporal_position_encoding,
199
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
200
+ )
201
+ )
202
+ norms.append(nn.LayerNorm(dim))
203
+
204
+ self.attention_blocks = nn.ModuleList(attention_blocks)
205
+ self.norms = nn.ModuleList(norms)
206
+
207
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
208
+ self.ff_norm = nn.LayerNorm(dim)
209
+
210
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
211
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
212
+ norm_hidden_states = norm(hidden_states)
213
+ hidden_states = (
214
+ attention_block(
215
+ norm_hidden_states,
216
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
217
+ video_length=video_length,
218
+ )
219
+ + hidden_states
220
+ )
221
+
222
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
223
+
224
+ output = hidden_states
225
+ return output
226
+
227
+
228
+ class PositionalEncoding(nn.Module):
229
+ def __init__(self, d_model, dropout=0.0, max_len=24):
230
+ super().__init__()
231
+ self.dropout = nn.Dropout(p=dropout)
232
+ position = torch.arange(max_len).unsqueeze(1)
233
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
234
+ pe = torch.zeros(1, max_len, d_model)
235
+ pe[0, :, 0::2] = torch.sin(position * div_term)
236
+ pe[0, :, 1::2] = torch.cos(position * div_term)
237
+ self.register_buffer("pe", pe)
238
+
239
+ def forward(self, x):
240
+ x = x + self.pe[:, : x.size(1)]
241
+ return self.dropout(x)
242
+
243
+
244
+ class VersatileAttention(CrossAttention):
245
+ def __init__(
246
+ self,
247
+ attention_mode=None,
248
+ cross_frame_attention_mode=None,
249
+ temporal_position_encoding=False,
250
+ temporal_position_encoding_max_len=24,
251
+ *args,
252
+ **kwargs,
253
+ ):
254
+ super().__init__(*args, **kwargs)
255
+ assert attention_mode == "Temporal"
256
+
257
+ self.attention_mode = attention_mode
258
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
259
+
260
+ self.pos_encoder = (
261
+ PositionalEncoding(kwargs["query_dim"], dropout=0.0, max_len=temporal_position_encoding_max_len)
262
+ if (temporal_position_encoding and attention_mode == "Temporal")
263
+ else None
264
+ )
265
+
266
+ def extra_repr(self):
267
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
268
+
269
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
270
+ batch_size, sequence_length, _ = hidden_states.shape
271
+
272
+ if self.attention_mode == "Temporal":
273
+ d = hidden_states.shape[1]
274
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
275
+
276
+ if self.pos_encoder is not None:
277
+ hidden_states = self.pos_encoder(hidden_states)
278
+
279
+ encoder_hidden_states = (
280
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
281
+ if encoder_hidden_states is not None
282
+ else encoder_hidden_states
283
+ )
284
+ else:
285
+ raise NotImplementedError
286
+
287
+ # encoder_hidden_states = encoder_hidden_states
288
+
289
+ if self.group_norm is not None:
290
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
291
+
292
+ query = self.to_q(hidden_states)
293
+ dim = query.shape[-1]
294
+ query = self.reshape_heads_to_batch_dim(query)
295
+
296
+ if self.added_kv_proj_dim is not None:
297
+ raise NotImplementedError
298
+
299
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
300
+ key = self.to_k(encoder_hidden_states)
301
+ value = self.to_v(encoder_hidden_states)
302
+
303
+ key = self.reshape_heads_to_batch_dim(key)
304
+ value = self.reshape_heads_to_batch_dim(value)
305
+
306
+ if attention_mask is not None:
307
+ if attention_mask.shape[-1] != query.shape[1]:
308
+ target_length = query.shape[1]
309
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
310
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
311
+
312
+ # attention, what we cannot get enough of
313
+ if self._use_memory_efficient_attention_xformers:
314
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
315
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
316
+ hidden_states = hidden_states.to(query.dtype)
317
+ else:
318
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
319
+ hidden_states = self._attention(query, key, value, attention_mask)
320
+ else:
321
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
322
+
323
+ # linear proj
324
+ hidden_states = self.to_out[0](hidden_states)
325
+
326
+ # dropout
327
+ hidden_states = self.to_out[1](hidden_states)
328
+
329
+ if self.attention_mode == "Temporal":
330
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
331
+
332
+ return hidden_states
latentsync/models/resnet.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+
21
+ class InflatedGroupNorm(nn.GroupNorm):
22
+ def forward(self, x):
23
+ video_length = x.shape[2]
24
+
25
+ x = rearrange(x, "b c f h w -> (b f) c h w")
26
+ x = super().forward(x)
27
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
28
+
29
+ return x
30
+
31
+
32
+ class Upsample3D(nn.Module):
33
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
34
+ super().__init__()
35
+ self.channels = channels
36
+ self.out_channels = out_channels or channels
37
+ self.use_conv = use_conv
38
+ self.use_conv_transpose = use_conv_transpose
39
+ self.name = name
40
+
41
+ conv = None
42
+ if use_conv_transpose:
43
+ raise NotImplementedError
44
+ elif use_conv:
45
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
46
+
47
+ def forward(self, hidden_states, output_size=None):
48
+ assert hidden_states.shape[1] == self.channels
49
+
50
+ if self.use_conv_transpose:
51
+ raise NotImplementedError
52
+
53
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
54
+ dtype = hidden_states.dtype
55
+ if dtype == torch.bfloat16:
56
+ hidden_states = hidden_states.to(torch.float32)
57
+
58
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
59
+ if hidden_states.shape[0] >= 64:
60
+ hidden_states = hidden_states.contiguous()
61
+
62
+ # if `output_size` is passed we force the interpolation output
63
+ # size and do not make use of `scale_factor=2`
64
+ if output_size is None:
65
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
66
+ else:
67
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
68
+
69
+ # If the input is bfloat16, we cast back to bfloat16
70
+ if dtype == torch.bfloat16:
71
+ hidden_states = hidden_states.to(dtype)
72
+
73
+ # if self.use_conv:
74
+ # if self.name == "conv":
75
+ # hidden_states = self.conv(hidden_states)
76
+ # else:
77
+ # hidden_states = self.Conv2d_0(hidden_states)
78
+ hidden_states = self.conv(hidden_states)
79
+
80
+ return hidden_states
81
+
82
+
83
+ class Downsample3D(nn.Module):
84
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
85
+ super().__init__()
86
+ self.channels = channels
87
+ self.out_channels = out_channels or channels
88
+ self.use_conv = use_conv
89
+ self.padding = padding
90
+ stride = 2
91
+ self.name = name
92
+
93
+ if use_conv:
94
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ def forward(self, hidden_states):
99
+ assert hidden_states.shape[1] == self.channels
100
+ if self.use_conv and self.padding == 0:
101
+ raise NotImplementedError
102
+
103
+ assert hidden_states.shape[1] == self.channels
104
+ hidden_states = self.conv(hidden_states)
105
+
106
+ return hidden_states
107
+
108
+
109
+ class ResnetBlock3D(nn.Module):
110
+ def __init__(
111
+ self,
112
+ *,
113
+ in_channels,
114
+ out_channels=None,
115
+ conv_shortcut=False,
116
+ dropout=0.0,
117
+ temb_channels=512,
118
+ groups=32,
119
+ groups_out=None,
120
+ pre_norm=True,
121
+ eps=1e-6,
122
+ non_linearity="swish",
123
+ time_embedding_norm="default",
124
+ output_scale_factor=1.0,
125
+ use_in_shortcut=None,
126
+ use_inflated_groupnorm=False,
127
+ ):
128
+ super().__init__()
129
+ self.pre_norm = pre_norm
130
+ self.pre_norm = True
131
+ self.in_channels = in_channels
132
+ out_channels = in_channels if out_channels is None else out_channels
133
+ self.out_channels = out_channels
134
+ self.use_conv_shortcut = conv_shortcut
135
+ self.time_embedding_norm = time_embedding_norm
136
+ self.output_scale_factor = output_scale_factor
137
+
138
+ if groups_out is None:
139
+ groups_out = groups
140
+
141
+ assert use_inflated_groupnorm != None
142
+ if use_inflated_groupnorm:
143
+ self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
144
+ else:
145
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
146
+
147
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
148
+
149
+ if temb_channels is not None:
150
+ time_emb_proj_out_channels = out_channels
151
+ # if self.time_embedding_norm == "default":
152
+ # time_emb_proj_out_channels = out_channels
153
+ # elif self.time_embedding_norm == "scale_shift":
154
+ # time_emb_proj_out_channels = out_channels * 2
155
+ # else:
156
+ # raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
157
+
158
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
159
+ else:
160
+ self.time_emb_proj = None
161
+
162
+ if self.time_embedding_norm == "scale_shift":
163
+ self.double_len_linear = torch.nn.Linear(time_emb_proj_out_channels, 2 * time_emb_proj_out_channels)
164
+ else:
165
+ self.double_len_linear = None
166
+
167
+ if use_inflated_groupnorm:
168
+ self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
169
+ else:
170
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
171
+
172
+ self.dropout = torch.nn.Dropout(dropout)
173
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
174
+
175
+ if non_linearity == "swish":
176
+ self.nonlinearity = lambda x: F.silu(x)
177
+ elif non_linearity == "mish":
178
+ self.nonlinearity = Mish()
179
+ elif non_linearity == "silu":
180
+ self.nonlinearity = nn.SiLU()
181
+
182
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
183
+
184
+ self.conv_shortcut = None
185
+ if self.use_in_shortcut:
186
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
187
+
188
+ def forward(self, input_tensor, temb):
189
+ hidden_states = input_tensor
190
+
191
+ hidden_states = self.norm1(hidden_states)
192
+ hidden_states = self.nonlinearity(hidden_states)
193
+
194
+ hidden_states = self.conv1(hidden_states)
195
+
196
+ if temb is not None:
197
+ if temb.dim() == 2:
198
+ # input (1, 1280)
199
+ temb = self.time_emb_proj(self.nonlinearity(temb))
200
+ temb = temb[:, :, None, None, None] # unsqueeze
201
+ else:
202
+ # input (1, 1280, 16)
203
+ temb = temb.permute(0, 2, 1)
204
+ temb = self.time_emb_proj(self.nonlinearity(temb))
205
+ if self.double_len_linear is not None:
206
+ temb = self.double_len_linear(self.nonlinearity(temb))
207
+ temb = temb.permute(0, 2, 1)
208
+ temb = temb[:, :, :, None, None]
209
+
210
+ if temb is not None and self.time_embedding_norm == "default":
211
+ hidden_states = hidden_states + temb
212
+
213
+ hidden_states = self.norm2(hidden_states)
214
+
215
+ if temb is not None and self.time_embedding_norm == "scale_shift":
216
+ scale, shift = torch.chunk(temb, 2, dim=1)
217
+ hidden_states = hidden_states * (1 + scale) + shift
218
+
219
+ hidden_states = self.nonlinearity(hidden_states)
220
+
221
+ hidden_states = self.dropout(hidden_states)
222
+ hidden_states = self.conv2(hidden_states)
223
+
224
+ if self.conv_shortcut is not None:
225
+ input_tensor = self.conv_shortcut(input_tensor)
226
+
227
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
228
+
229
+ return output_tensor
230
+
231
+
232
+ class Mish(torch.nn.Module):
233
+ def forward(self, hidden_states):
234
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
latentsync/models/syncnet.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import nn
17
+ from einops import rearrange
18
+ from torch.nn import functional as F
19
+ from ..utils.util import cosine_loss
20
+
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ from diffusers.models.attention import CrossAttention, FeedForward
25
+ from diffusers.utils.import_utils import is_xformers_available
26
+ from einops import rearrange
27
+
28
+
29
+ class SyncNet(nn.Module):
30
+ def __init__(self, config):
31
+ super().__init__()
32
+ self.audio_encoder = DownEncoder2D(
33
+ in_channels=config["audio_encoder"]["in_channels"],
34
+ block_out_channels=config["audio_encoder"]["block_out_channels"],
35
+ downsample_factors=config["audio_encoder"]["downsample_factors"],
36
+ dropout=config["audio_encoder"]["dropout"],
37
+ attn_blocks=config["audio_encoder"]["attn_blocks"],
38
+ )
39
+
40
+ self.visual_encoder = DownEncoder2D(
41
+ in_channels=config["visual_encoder"]["in_channels"],
42
+ block_out_channels=config["visual_encoder"]["block_out_channels"],
43
+ downsample_factors=config["visual_encoder"]["downsample_factors"],
44
+ dropout=config["visual_encoder"]["dropout"],
45
+ attn_blocks=config["visual_encoder"]["attn_blocks"],
46
+ )
47
+
48
+ self.eval()
49
+
50
+ def forward(self, image_sequences, audio_sequences):
51
+ vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
52
+ audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
53
+
54
+ vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
55
+ audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
56
+
57
+ # Make them unit vectors
58
+ vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
59
+ audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
60
+
61
+ return vision_embeds, audio_embeds
62
+
63
+
64
+ class ResnetBlock2D(nn.Module):
65
+ def __init__(
66
+ self,
67
+ in_channels: int,
68
+ out_channels: int,
69
+ dropout: float = 0.0,
70
+ norm_num_groups: int = 32,
71
+ eps: float = 1e-6,
72
+ act_fn: str = "silu",
73
+ downsample_factor=2,
74
+ ):
75
+ super().__init__()
76
+
77
+ self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
78
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
79
+
80
+ self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
81
+ self.dropout = nn.Dropout(dropout)
82
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
83
+
84
+ if act_fn == "relu":
85
+ self.act_fn = nn.ReLU()
86
+ elif act_fn == "silu":
87
+ self.act_fn = nn.SiLU()
88
+
89
+ if in_channels != out_channels:
90
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
91
+ else:
92
+ self.conv_shortcut = None
93
+
94
+ if isinstance(downsample_factor, list):
95
+ downsample_factor = tuple(downsample_factor)
96
+
97
+ if downsample_factor == 1:
98
+ self.downsample_conv = None
99
+ else:
100
+ self.downsample_conv = nn.Conv2d(
101
+ out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
102
+ )
103
+ self.pad = (0, 1, 0, 1)
104
+ if isinstance(downsample_factor, tuple):
105
+ if downsample_factor[0] == 1:
106
+ self.pad = (0, 1, 1, 1) # The padding order is from back to front
107
+ elif downsample_factor[1] == 1:
108
+ self.pad = (1, 1, 0, 1)
109
+
110
+ def forward(self, input_tensor):
111
+ hidden_states = input_tensor
112
+
113
+ hidden_states = self.norm1(hidden_states)
114
+ hidden_states = self.act_fn(hidden_states)
115
+
116
+ hidden_states = self.conv1(hidden_states)
117
+ hidden_states = self.norm2(hidden_states)
118
+ hidden_states = self.act_fn(hidden_states)
119
+
120
+ hidden_states = self.dropout(hidden_states)
121
+ hidden_states = self.conv2(hidden_states)
122
+
123
+ if self.conv_shortcut is not None:
124
+ input_tensor = self.conv_shortcut(input_tensor)
125
+
126
+ hidden_states += input_tensor
127
+
128
+ if self.downsample_conv is not None:
129
+ hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
130
+ hidden_states = self.downsample_conv(hidden_states)
131
+
132
+ return hidden_states
133
+
134
+
135
+ class AttentionBlock2D(nn.Module):
136
+ def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
137
+ super().__init__()
138
+ if not is_xformers_available():
139
+ raise ModuleNotFoundError(
140
+ "You have to install xformers to enable memory efficient attetion", name="xformers"
141
+ )
142
+ # inner_dim = dim_head * heads
143
+ self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
144
+ self.norm2 = nn.LayerNorm(query_dim)
145
+ self.norm3 = nn.LayerNorm(query_dim)
146
+
147
+ self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
148
+
149
+ self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
150
+ self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
151
+
152
+ self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
153
+ self.attn._use_memory_efficient_attention_xformers = True
154
+
155
+ def forward(self, hidden_states):
156
+ assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
157
+
158
+ batch, channel, height, width = hidden_states.shape
159
+ residual = hidden_states
160
+
161
+ hidden_states = self.norm1(hidden_states)
162
+ hidden_states = self.conv_in(hidden_states)
163
+ hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
164
+
165
+ norm_hidden_states = self.norm2(hidden_states)
166
+ hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
167
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
168
+
169
+ hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
170
+ hidden_states = self.conv_out(hidden_states)
171
+
172
+ hidden_states = hidden_states + residual
173
+ return hidden_states
174
+
175
+
176
+ class DownEncoder2D(nn.Module):
177
+ def __init__(
178
+ self,
179
+ in_channels=4 * 16,
180
+ block_out_channels=[64, 128, 256, 256],
181
+ downsample_factors=[2, 2, 2, 2],
182
+ layers_per_block=2,
183
+ norm_num_groups=32,
184
+ attn_blocks=[1, 1, 1, 1],
185
+ dropout: float = 0.0,
186
+ act_fn="silu",
187
+ ):
188
+ super().__init__()
189
+ self.layers_per_block = layers_per_block
190
+
191
+ # in
192
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
193
+
194
+ # down
195
+ self.down_blocks = nn.ModuleList([])
196
+
197
+ output_channels = block_out_channels[0]
198
+ for i, block_out_channel in enumerate(block_out_channels):
199
+ input_channels = output_channels
200
+ output_channels = block_out_channel
201
+ # is_final_block = i == len(block_out_channels) - 1
202
+
203
+ down_block = ResnetBlock2D(
204
+ in_channels=input_channels,
205
+ out_channels=output_channels,
206
+ downsample_factor=downsample_factors[i],
207
+ norm_num_groups=norm_num_groups,
208
+ dropout=dropout,
209
+ act_fn=act_fn,
210
+ )
211
+
212
+ self.down_blocks.append(down_block)
213
+
214
+ if attn_blocks[i] == 1:
215
+ attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
216
+ self.down_blocks.append(attention_block)
217
+
218
+ # out
219
+ self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
220
+ self.act_fn_out = nn.ReLU()
221
+
222
+ def forward(self, hidden_states):
223
+ hidden_states = self.conv_in(hidden_states)
224
+
225
+ # down
226
+ for down_block in self.down_blocks:
227
+ hidden_states = down_block(hidden_states)
228
+
229
+ # post-process
230
+ hidden_states = self.norm_out(hidden_states)
231
+ hidden_states = self.act_fn_out(hidden_states)
232
+
233
+ return hidden_states
latentsync/models/syncnet_wav2lip.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/primepake/wav2lip_288x288/blob/master/models/syncnetv2.py
2
+ # The code here is for ablation study.
3
+
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class SyncNetWav2Lip(nn.Module):
9
+ def __init__(self, act_fn="leaky"):
10
+ super().__init__()
11
+
12
+ # input image sequences: (15, 128, 256)
13
+ self.visual_encoder = nn.Sequential(
14
+ Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3, act_fn=act_fn), # (128, 256)
15
+ Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1, act_fn=act_fn), # (126, 127)
16
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
17
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
18
+ Conv2d(64, 128, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (63, 64)
19
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
20
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
21
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
22
+ Conv2d(128, 256, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (21, 22)
23
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
24
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
25
+ Conv2d(256, 512, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (11, 11)
26
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
27
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
28
+ Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (6, 6)
29
+ Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
30
+ Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
31
+ Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1, act_fn="relu"), # (3, 3)
32
+ Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
33
+ Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
34
+ )
35
+
36
+ # input audio sequences: (1, 80, 16)
37
+ self.audio_encoder = nn.Sequential(
38
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
39
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
40
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
41
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1, act_fn=act_fn), # (27, 16)
42
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
43
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
44
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (9, 6)
45
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
46
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
47
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1, act_fn=act_fn), # (3, 3)
48
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
49
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
50
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
51
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
52
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
53
+ Conv2d(512, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
54
+ Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
55
+ )
56
+
57
+ def forward(self, image_sequences, audio_sequences):
58
+ vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
59
+ audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
60
+
61
+ vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
62
+ audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
63
+
64
+ # Make them unit vectors
65
+ vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
66
+ audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
67
+
68
+ return vision_embeds, audio_embeds
69
+
70
+
71
+ class Conv2d(nn.Module):
72
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, act_fn="relu", *args, **kwargs):
73
+ super().__init__(*args, **kwargs)
74
+ self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
75
+ if act_fn == "relu":
76
+ self.act_fn = nn.ReLU()
77
+ elif act_fn == "tanh":
78
+ self.act_fn = nn.Tanh()
79
+ elif act_fn == "silu":
80
+ self.act_fn = nn.SiLU()
81
+ elif act_fn == "leaky":
82
+ self.act_fn = nn.LeakyReLU(0.2, inplace=True)
83
+
84
+ self.residual = residual
85
+
86
+ def forward(self, x):
87
+ out = self.conv_block(x)
88
+ if self.residual:
89
+ out += x
90
+ return self.act_fn(out)
latentsync/models/unet.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+ import copy
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.utils.checkpoint
10
+
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.modeling_utils import ModelMixin
13
+ from diffusers import UNet2DConditionModel
14
+ from diffusers.utils import BaseOutput, logging
15
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
16
+ from .unet_blocks import (
17
+ CrossAttnDownBlock3D,
18
+ CrossAttnUpBlock3D,
19
+ DownBlock3D,
20
+ UNetMidBlock3DCrossAttn,
21
+ UpBlock3D,
22
+ get_down_block,
23
+ get_up_block,
24
+ )
25
+ from .resnet import InflatedConv3d, InflatedGroupNorm
26
+
27
+ from ..utils.util import zero_rank_log
28
+ from einops import rearrange
29
+ from .utils import zero_module
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ @dataclass
36
+ class UNet3DConditionOutput(BaseOutput):
37
+ sample: torch.FloatTensor
38
+
39
+
40
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
41
+ _supports_gradient_checkpointing = True
42
+
43
+ @register_to_config
44
+ def __init__(
45
+ self,
46
+ sample_size: Optional[int] = None,
47
+ in_channels: int = 4,
48
+ out_channels: int = 4,
49
+ center_input_sample: bool = False,
50
+ flip_sin_to_cos: bool = True,
51
+ freq_shift: int = 0,
52
+ down_block_types: Tuple[str] = (
53
+ "CrossAttnDownBlock3D",
54
+ "CrossAttnDownBlock3D",
55
+ "CrossAttnDownBlock3D",
56
+ "DownBlock3D",
57
+ ),
58
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
59
+ up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
60
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
61
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
62
+ layers_per_block: int = 2,
63
+ downsample_padding: int = 1,
64
+ mid_block_scale_factor: float = 1,
65
+ act_fn: str = "silu",
66
+ norm_num_groups: int = 32,
67
+ norm_eps: float = 1e-5,
68
+ cross_attention_dim: int = 1280,
69
+ attention_head_dim: Union[int, Tuple[int]] = 8,
70
+ dual_cross_attention: bool = False,
71
+ use_linear_projection: bool = False,
72
+ class_embed_type: Optional[str] = None,
73
+ num_class_embeds: Optional[int] = None,
74
+ upcast_attention: bool = False,
75
+ resnet_time_scale_shift: str = "default",
76
+ use_inflated_groupnorm=False,
77
+ # Additional
78
+ use_motion_module=False,
79
+ motion_module_resolutions=(1, 2, 4, 8),
80
+ motion_module_mid_block=False,
81
+ motion_module_decoder_only=False,
82
+ motion_module_type=None,
83
+ motion_module_kwargs={},
84
+ unet_use_cross_frame_attention=False,
85
+ unet_use_temporal_attention=False,
86
+ add_audio_layer=False,
87
+ audio_condition_method: str = "cross_attn",
88
+ custom_audio_layer=False,
89
+ ):
90
+ super().__init__()
91
+
92
+ self.sample_size = sample_size
93
+ time_embed_dim = block_out_channels[0] * 4
94
+ self.use_motion_module = use_motion_module
95
+ self.add_audio_layer = add_audio_layer
96
+
97
+ self.conv_in = zero_module(InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)))
98
+
99
+ # time
100
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
101
+ timestep_input_dim = block_out_channels[0]
102
+
103
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
104
+
105
+ # class embedding
106
+ if class_embed_type is None and num_class_embeds is not None:
107
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
108
+ elif class_embed_type == "timestep":
109
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
110
+ elif class_embed_type == "identity":
111
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
112
+ else:
113
+ self.class_embedding = None
114
+
115
+ self.down_blocks = nn.ModuleList([])
116
+ self.mid_block = None
117
+ self.up_blocks = nn.ModuleList([])
118
+
119
+ if isinstance(only_cross_attention, bool):
120
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
121
+
122
+ if isinstance(attention_head_dim, int):
123
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
124
+
125
+ # down
126
+ output_channel = block_out_channels[0]
127
+ for i, down_block_type in enumerate(down_block_types):
128
+ res = 2**i
129
+ input_channel = output_channel
130
+ output_channel = block_out_channels[i]
131
+ is_final_block = i == len(block_out_channels) - 1
132
+
133
+ down_block = get_down_block(
134
+ down_block_type,
135
+ num_layers=layers_per_block,
136
+ in_channels=input_channel,
137
+ out_channels=output_channel,
138
+ temb_channels=time_embed_dim,
139
+ add_downsample=not is_final_block,
140
+ resnet_eps=norm_eps,
141
+ resnet_act_fn=act_fn,
142
+ resnet_groups=norm_num_groups,
143
+ cross_attention_dim=cross_attention_dim,
144
+ attn_num_head_channels=attention_head_dim[i],
145
+ downsample_padding=downsample_padding,
146
+ dual_cross_attention=dual_cross_attention,
147
+ use_linear_projection=use_linear_projection,
148
+ only_cross_attention=only_cross_attention[i],
149
+ upcast_attention=upcast_attention,
150
+ resnet_time_scale_shift=resnet_time_scale_shift,
151
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
152
+ unet_use_temporal_attention=unet_use_temporal_attention,
153
+ use_inflated_groupnorm=use_inflated_groupnorm,
154
+ use_motion_module=use_motion_module
155
+ and (res in motion_module_resolutions)
156
+ and (not motion_module_decoder_only),
157
+ motion_module_type=motion_module_type,
158
+ motion_module_kwargs=motion_module_kwargs,
159
+ add_audio_layer=add_audio_layer,
160
+ audio_condition_method=audio_condition_method,
161
+ custom_audio_layer=custom_audio_layer,
162
+ )
163
+ self.down_blocks.append(down_block)
164
+
165
+ # mid
166
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
167
+ self.mid_block = UNetMidBlock3DCrossAttn(
168
+ in_channels=block_out_channels[-1],
169
+ temb_channels=time_embed_dim,
170
+ resnet_eps=norm_eps,
171
+ resnet_act_fn=act_fn,
172
+ output_scale_factor=mid_block_scale_factor,
173
+ resnet_time_scale_shift=resnet_time_scale_shift,
174
+ cross_attention_dim=cross_attention_dim,
175
+ attn_num_head_channels=attention_head_dim[-1],
176
+ resnet_groups=norm_num_groups,
177
+ dual_cross_attention=dual_cross_attention,
178
+ use_linear_projection=use_linear_projection,
179
+ upcast_attention=upcast_attention,
180
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
181
+ unet_use_temporal_attention=unet_use_temporal_attention,
182
+ use_inflated_groupnorm=use_inflated_groupnorm,
183
+ use_motion_module=use_motion_module and motion_module_mid_block,
184
+ motion_module_type=motion_module_type,
185
+ motion_module_kwargs=motion_module_kwargs,
186
+ add_audio_layer=add_audio_layer,
187
+ audio_condition_method=audio_condition_method,
188
+ custom_audio_layer=custom_audio_layer,
189
+ )
190
+ else:
191
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
192
+
193
+ # count how many layers upsample the videos
194
+ self.num_upsamplers = 0
195
+
196
+ # up
197
+ reversed_block_out_channels = list(reversed(block_out_channels))
198
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
199
+ only_cross_attention = list(reversed(only_cross_attention))
200
+ output_channel = reversed_block_out_channels[0]
201
+ for i, up_block_type in enumerate(up_block_types):
202
+ res = 2 ** (3 - i)
203
+ is_final_block = i == len(block_out_channels) - 1
204
+
205
+ prev_output_channel = output_channel
206
+ output_channel = reversed_block_out_channels[i]
207
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
208
+
209
+ # add upsample block for all BUT final layer
210
+ if not is_final_block:
211
+ add_upsample = True
212
+ self.num_upsamplers += 1
213
+ else:
214
+ add_upsample = False
215
+
216
+ up_block = get_up_block(
217
+ up_block_type,
218
+ num_layers=layers_per_block + 1,
219
+ in_channels=input_channel,
220
+ out_channels=output_channel,
221
+ prev_output_channel=prev_output_channel,
222
+ temb_channels=time_embed_dim,
223
+ add_upsample=add_upsample,
224
+ resnet_eps=norm_eps,
225
+ resnet_act_fn=act_fn,
226
+ resnet_groups=norm_num_groups,
227
+ cross_attention_dim=cross_attention_dim,
228
+ attn_num_head_channels=reversed_attention_head_dim[i],
229
+ dual_cross_attention=dual_cross_attention,
230
+ use_linear_projection=use_linear_projection,
231
+ only_cross_attention=only_cross_attention[i],
232
+ upcast_attention=upcast_attention,
233
+ resnet_time_scale_shift=resnet_time_scale_shift,
234
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
235
+ unet_use_temporal_attention=unet_use_temporal_attention,
236
+ use_inflated_groupnorm=use_inflated_groupnorm,
237
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
238
+ motion_module_type=motion_module_type,
239
+ motion_module_kwargs=motion_module_kwargs,
240
+ add_audio_layer=add_audio_layer,
241
+ audio_condition_method=audio_condition_method,
242
+ custom_audio_layer=custom_audio_layer,
243
+ )
244
+ self.up_blocks.append(up_block)
245
+ prev_output_channel = output_channel
246
+
247
+ # out
248
+ if use_inflated_groupnorm:
249
+ self.conv_norm_out = InflatedGroupNorm(
250
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
251
+ )
252
+ else:
253
+ self.conv_norm_out = nn.GroupNorm(
254
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
255
+ )
256
+ self.conv_act = nn.SiLU()
257
+
258
+ self.conv_out = zero_module(InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1))
259
+
260
+ def set_attention_slice(self, slice_size):
261
+ r"""
262
+ Enable sliced attention computation.
263
+
264
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
265
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
266
+
267
+ Args:
268
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
269
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
270
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
271
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
272
+ must be a multiple of `slice_size`.
273
+ """
274
+ sliceable_head_dims = []
275
+
276
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
277
+ if hasattr(module, "set_attention_slice"):
278
+ sliceable_head_dims.append(module.sliceable_head_dim)
279
+
280
+ for child in module.children():
281
+ fn_recursive_retrieve_slicable_dims(child)
282
+
283
+ # retrieve number of attention layers
284
+ for module in self.children():
285
+ fn_recursive_retrieve_slicable_dims(module)
286
+
287
+ num_slicable_layers = len(sliceable_head_dims)
288
+
289
+ if slice_size == "auto":
290
+ # half the attention head size is usually a good trade-off between
291
+ # speed and memory
292
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
293
+ elif slice_size == "max":
294
+ # make smallest slice possible
295
+ slice_size = num_slicable_layers * [1]
296
+
297
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
298
+
299
+ if len(slice_size) != len(sliceable_head_dims):
300
+ raise ValueError(
301
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
302
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
303
+ )
304
+
305
+ for i in range(len(slice_size)):
306
+ size = slice_size[i]
307
+ dim = sliceable_head_dims[i]
308
+ if size is not None and size > dim:
309
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
310
+
311
+ # Recursively walk through all the children.
312
+ # Any children which exposes the set_attention_slice method
313
+ # gets the message
314
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
315
+ if hasattr(module, "set_attention_slice"):
316
+ module.set_attention_slice(slice_size.pop())
317
+
318
+ for child in module.children():
319
+ fn_recursive_set_attention_slice(child, slice_size)
320
+
321
+ reversed_slice_size = list(reversed(slice_size))
322
+ for module in self.children():
323
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
324
+
325
+ def _set_gradient_checkpointing(self, module, value=False):
326
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
327
+ module.gradient_checkpointing = value
328
+
329
+ def forward(
330
+ self,
331
+ sample: torch.FloatTensor,
332
+ timestep: Union[torch.Tensor, float, int],
333
+ encoder_hidden_states: torch.Tensor,
334
+ class_labels: Optional[torch.Tensor] = None,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ # support controlnet
337
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
338
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
339
+ return_dict: bool = True,
340
+ ) -> Union[UNet3DConditionOutput, Tuple]:
341
+ r"""
342
+ Args:
343
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
344
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
345
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
346
+ return_dict (`bool`, *optional*, defaults to `True`):
347
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
348
+
349
+ Returns:
350
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
351
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
352
+ returning a tuple, the first element is the sample tensor.
353
+ """
354
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
355
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
356
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
357
+ # on the fly if necessary.
358
+ default_overall_up_factor = 2**self.num_upsamplers
359
+
360
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
361
+ forward_upsample_size = False
362
+ upsample_size = None
363
+
364
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
365
+ logger.info("Forward upsample size to force interpolation output size.")
366
+ forward_upsample_size = True
367
+
368
+ # prepare attention_mask
369
+ if attention_mask is not None:
370
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
371
+ attention_mask = attention_mask.unsqueeze(1)
372
+
373
+ # center input if necessary
374
+ if self.config.center_input_sample:
375
+ sample = 2 * sample - 1.0
376
+
377
+ # time
378
+ timesteps = timestep
379
+ if not torch.is_tensor(timesteps):
380
+ # This would be a good case for the `match` statement (Python 3.10+)
381
+ is_mps = sample.device.type == "mps"
382
+ if isinstance(timestep, float):
383
+ dtype = torch.float32 if is_mps else torch.float64
384
+ else:
385
+ dtype = torch.int32 if is_mps else torch.int64
386
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
387
+ elif len(timesteps.shape) == 0:
388
+ timesteps = timesteps[None].to(sample.device)
389
+
390
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
391
+ timesteps = timesteps.expand(sample.shape[0])
392
+
393
+ t_emb = self.time_proj(timesteps)
394
+
395
+ # timesteps does not contain any weights and will always return f32 tensors
396
+ # but time_embedding might actually be running in fp16. so we need to cast here.
397
+ # there might be better ways to encapsulate this.
398
+ t_emb = t_emb.to(dtype=self.dtype)
399
+ emb = self.time_embedding(t_emb)
400
+
401
+ if self.class_embedding is not None:
402
+ if class_labels is None:
403
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
404
+
405
+ if self.config.class_embed_type == "timestep":
406
+ class_labels = self.time_proj(class_labels)
407
+
408
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
409
+ emb = emb + class_emb
410
+
411
+ # pre-process
412
+ sample = self.conv_in(sample)
413
+
414
+ # down
415
+ down_block_res_samples = (sample,)
416
+ for downsample_block in self.down_blocks:
417
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
418
+ sample, res_samples = downsample_block(
419
+ hidden_states=sample,
420
+ temb=emb,
421
+ encoder_hidden_states=encoder_hidden_states,
422
+ attention_mask=attention_mask,
423
+ )
424
+ else:
425
+ sample, res_samples = downsample_block(
426
+ hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
427
+ )
428
+
429
+ down_block_res_samples += res_samples
430
+
431
+ # support controlnet
432
+ down_block_res_samples = list(down_block_res_samples)
433
+ if down_block_additional_residuals is not None:
434
+ for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
435
+ if down_block_additional_residual.dim() == 4: # boardcast
436
+ down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
437
+ down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
438
+
439
+ # mid
440
+ sample = self.mid_block(
441
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
442
+ )
443
+
444
+ # support controlnet
445
+ if mid_block_additional_residual is not None:
446
+ if mid_block_additional_residual.dim() == 4: # boardcast
447
+ mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
448
+ sample = sample + mid_block_additional_residual
449
+
450
+ # up
451
+ for i, upsample_block in enumerate(self.up_blocks):
452
+ is_final_block = i == len(self.up_blocks) - 1
453
+
454
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
455
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
456
+
457
+ # if we have not reached the final block and need to forward the
458
+ # upsample size, we do it here
459
+ if not is_final_block and forward_upsample_size:
460
+ upsample_size = down_block_res_samples[-1].shape[2:]
461
+
462
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
463
+ sample = upsample_block(
464
+ hidden_states=sample,
465
+ temb=emb,
466
+ res_hidden_states_tuple=res_samples,
467
+ encoder_hidden_states=encoder_hidden_states,
468
+ upsample_size=upsample_size,
469
+ attention_mask=attention_mask,
470
+ )
471
+ else:
472
+ sample = upsample_block(
473
+ hidden_states=sample,
474
+ temb=emb,
475
+ res_hidden_states_tuple=res_samples,
476
+ upsample_size=upsample_size,
477
+ encoder_hidden_states=encoder_hidden_states,
478
+ )
479
+
480
+ # post-process
481
+ sample = self.conv_norm_out(sample)
482
+ sample = self.conv_act(sample)
483
+ sample = self.conv_out(sample)
484
+
485
+ if not return_dict:
486
+ return (sample,)
487
+
488
+ return UNet3DConditionOutput(sample=sample)
489
+
490
+ def load_state_dict(self, state_dict, strict=True):
491
+ # If the loaded checkpoint's in_channels or out_channels are different from config
492
+ temp_state_dict = copy.deepcopy(state_dict)
493
+ if temp_state_dict["conv_in.weight"].shape[1] != self.config.in_channels:
494
+ del temp_state_dict["conv_in.weight"]
495
+ del temp_state_dict["conv_in.bias"]
496
+ if temp_state_dict["conv_out.weight"].shape[0] != self.config.out_channels:
497
+ del temp_state_dict["conv_out.weight"]
498
+ del temp_state_dict["conv_out.bias"]
499
+
500
+ # If the loaded checkpoint's cross_attention_dim is different from config
501
+ keys_to_remove = []
502
+ for key in temp_state_dict:
503
+ if "audio_cross_attn.attn.to_k." in key or "audio_cross_attn.attn.to_v." in key:
504
+ if temp_state_dict[key].shape[1] != self.config.cross_attention_dim:
505
+ keys_to_remove.append(key)
506
+
507
+ for key in keys_to_remove:
508
+ del temp_state_dict[key]
509
+
510
+ return super().load_state_dict(state_dict=temp_state_dict, strict=strict)
511
+
512
+ @classmethod
513
+ def from_pretrained(cls, model_config: dict, ckpt_path: str, device="cpu"):
514
+ unet = cls.from_config(model_config).to(device)
515
+ if ckpt_path != "":
516
+ zero_rank_log(logger, f"Load from checkpoint: {ckpt_path}")
517
+ ckpt = torch.load(ckpt_path, map_location=device)
518
+ if "global_step" in ckpt:
519
+ zero_rank_log(logger, f"resume from global_step: {ckpt['global_step']}")
520
+ resume_global_step = ckpt["global_step"]
521
+ else:
522
+ resume_global_step = 0
523
+ state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
524
+ unet.load_state_dict(state_dict, strict=False)
525
+ else:
526
+ resume_global_step = 0
527
+
528
+ return unet, resume_global_step
latentsync/models/unet_blocks.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+ from .motion_module import get_motion_module
9
+
10
+
11
+ def get_down_block(
12
+ down_block_type,
13
+ num_layers,
14
+ in_channels,
15
+ out_channels,
16
+ temb_channels,
17
+ add_downsample,
18
+ resnet_eps,
19
+ resnet_act_fn,
20
+ attn_num_head_channels,
21
+ resnet_groups=None,
22
+ cross_attention_dim=None,
23
+ downsample_padding=None,
24
+ dual_cross_attention=False,
25
+ use_linear_projection=False,
26
+ only_cross_attention=False,
27
+ upcast_attention=False,
28
+ resnet_time_scale_shift="default",
29
+ unet_use_cross_frame_attention=False,
30
+ unet_use_temporal_attention=False,
31
+ use_inflated_groupnorm=False,
32
+ use_motion_module=None,
33
+ motion_module_type=None,
34
+ motion_module_kwargs=None,
35
+ add_audio_layer=False,
36
+ audio_condition_method="cross_attn",
37
+ custom_audio_layer=False,
38
+ ):
39
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
40
+ if down_block_type == "DownBlock3D":
41
+ return DownBlock3D(
42
+ num_layers=num_layers,
43
+ in_channels=in_channels,
44
+ out_channels=out_channels,
45
+ temb_channels=temb_channels,
46
+ add_downsample=add_downsample,
47
+ resnet_eps=resnet_eps,
48
+ resnet_act_fn=resnet_act_fn,
49
+ resnet_groups=resnet_groups,
50
+ downsample_padding=downsample_padding,
51
+ resnet_time_scale_shift=resnet_time_scale_shift,
52
+ use_inflated_groupnorm=use_inflated_groupnorm,
53
+ use_motion_module=use_motion_module,
54
+ motion_module_type=motion_module_type,
55
+ motion_module_kwargs=motion_module_kwargs,
56
+ )
57
+ elif down_block_type == "CrossAttnDownBlock3D":
58
+ if cross_attention_dim is None:
59
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
60
+ return CrossAttnDownBlock3D(
61
+ num_layers=num_layers,
62
+ in_channels=in_channels,
63
+ out_channels=out_channels,
64
+ temb_channels=temb_channels,
65
+ add_downsample=add_downsample,
66
+ resnet_eps=resnet_eps,
67
+ resnet_act_fn=resnet_act_fn,
68
+ resnet_groups=resnet_groups,
69
+ downsample_padding=downsample_padding,
70
+ cross_attention_dim=cross_attention_dim,
71
+ attn_num_head_channels=attn_num_head_channels,
72
+ dual_cross_attention=dual_cross_attention,
73
+ use_linear_projection=use_linear_projection,
74
+ only_cross_attention=only_cross_attention,
75
+ upcast_attention=upcast_attention,
76
+ resnet_time_scale_shift=resnet_time_scale_shift,
77
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
78
+ unet_use_temporal_attention=unet_use_temporal_attention,
79
+ use_inflated_groupnorm=use_inflated_groupnorm,
80
+ use_motion_module=use_motion_module,
81
+ motion_module_type=motion_module_type,
82
+ motion_module_kwargs=motion_module_kwargs,
83
+ add_audio_layer=add_audio_layer,
84
+ audio_condition_method=audio_condition_method,
85
+ custom_audio_layer=custom_audio_layer,
86
+ )
87
+ raise ValueError(f"{down_block_type} does not exist.")
88
+
89
+
90
+ def get_up_block(
91
+ up_block_type,
92
+ num_layers,
93
+ in_channels,
94
+ out_channels,
95
+ prev_output_channel,
96
+ temb_channels,
97
+ add_upsample,
98
+ resnet_eps,
99
+ resnet_act_fn,
100
+ attn_num_head_channels,
101
+ resnet_groups=None,
102
+ cross_attention_dim=None,
103
+ dual_cross_attention=False,
104
+ use_linear_projection=False,
105
+ only_cross_attention=False,
106
+ upcast_attention=False,
107
+ resnet_time_scale_shift="default",
108
+ unet_use_cross_frame_attention=False,
109
+ unet_use_temporal_attention=False,
110
+ use_inflated_groupnorm=False,
111
+ use_motion_module=None,
112
+ motion_module_type=None,
113
+ motion_module_kwargs=None,
114
+ add_audio_layer=False,
115
+ audio_condition_method="cross_attn",
116
+ custom_audio_layer=False,
117
+ ):
118
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
119
+ if up_block_type == "UpBlock3D":
120
+ return UpBlock3D(
121
+ num_layers=num_layers,
122
+ in_channels=in_channels,
123
+ out_channels=out_channels,
124
+ prev_output_channel=prev_output_channel,
125
+ temb_channels=temb_channels,
126
+ add_upsample=add_upsample,
127
+ resnet_eps=resnet_eps,
128
+ resnet_act_fn=resnet_act_fn,
129
+ resnet_groups=resnet_groups,
130
+ resnet_time_scale_shift=resnet_time_scale_shift,
131
+ use_inflated_groupnorm=use_inflated_groupnorm,
132
+ use_motion_module=use_motion_module,
133
+ motion_module_type=motion_module_type,
134
+ motion_module_kwargs=motion_module_kwargs,
135
+ )
136
+ elif up_block_type == "CrossAttnUpBlock3D":
137
+ if cross_attention_dim is None:
138
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
139
+ return CrossAttnUpBlock3D(
140
+ num_layers=num_layers,
141
+ in_channels=in_channels,
142
+ out_channels=out_channels,
143
+ prev_output_channel=prev_output_channel,
144
+ temb_channels=temb_channels,
145
+ add_upsample=add_upsample,
146
+ resnet_eps=resnet_eps,
147
+ resnet_act_fn=resnet_act_fn,
148
+ resnet_groups=resnet_groups,
149
+ cross_attention_dim=cross_attention_dim,
150
+ attn_num_head_channels=attn_num_head_channels,
151
+ dual_cross_attention=dual_cross_attention,
152
+ use_linear_projection=use_linear_projection,
153
+ only_cross_attention=only_cross_attention,
154
+ upcast_attention=upcast_attention,
155
+ resnet_time_scale_shift=resnet_time_scale_shift,
156
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
157
+ unet_use_temporal_attention=unet_use_temporal_attention,
158
+ use_inflated_groupnorm=use_inflated_groupnorm,
159
+ use_motion_module=use_motion_module,
160
+ motion_module_type=motion_module_type,
161
+ motion_module_kwargs=motion_module_kwargs,
162
+ add_audio_layer=add_audio_layer,
163
+ audio_condition_method=audio_condition_method,
164
+ custom_audio_layer=custom_audio_layer,
165
+ )
166
+ raise ValueError(f"{up_block_type} does not exist.")
167
+
168
+
169
+ class UNetMidBlock3DCrossAttn(nn.Module):
170
+ def __init__(
171
+ self,
172
+ in_channels: int,
173
+ temb_channels: int,
174
+ dropout: float = 0.0,
175
+ num_layers: int = 1,
176
+ resnet_eps: float = 1e-6,
177
+ resnet_time_scale_shift: str = "default",
178
+ resnet_act_fn: str = "swish",
179
+ resnet_groups: int = 32,
180
+ resnet_pre_norm: bool = True,
181
+ attn_num_head_channels=1,
182
+ output_scale_factor=1.0,
183
+ cross_attention_dim=1280,
184
+ dual_cross_attention=False,
185
+ use_linear_projection=False,
186
+ upcast_attention=False,
187
+ unet_use_cross_frame_attention=False,
188
+ unet_use_temporal_attention=False,
189
+ use_inflated_groupnorm=False,
190
+ use_motion_module=None,
191
+ motion_module_type=None,
192
+ motion_module_kwargs=None,
193
+ add_audio_layer=False,
194
+ audio_condition_method="cross_attn",
195
+ custom_audio_layer: bool = False,
196
+ ):
197
+ super().__init__()
198
+
199
+ self.has_cross_attention = True
200
+ self.attn_num_head_channels = attn_num_head_channels
201
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
202
+
203
+ # there is always at least one resnet
204
+ resnets = [
205
+ ResnetBlock3D(
206
+ in_channels=in_channels,
207
+ out_channels=in_channels,
208
+ temb_channels=temb_channels,
209
+ eps=resnet_eps,
210
+ groups=resnet_groups,
211
+ dropout=dropout,
212
+ time_embedding_norm=resnet_time_scale_shift,
213
+ non_linearity=resnet_act_fn,
214
+ output_scale_factor=output_scale_factor,
215
+ pre_norm=resnet_pre_norm,
216
+ use_inflated_groupnorm=use_inflated_groupnorm,
217
+ )
218
+ ]
219
+ attentions = []
220
+ audio_attentions = []
221
+ motion_modules = []
222
+
223
+ for _ in range(num_layers):
224
+ if dual_cross_attention:
225
+ raise NotImplementedError
226
+ attentions.append(
227
+ Transformer3DModel(
228
+ attn_num_head_channels,
229
+ in_channels // attn_num_head_channels,
230
+ in_channels=in_channels,
231
+ num_layers=1,
232
+ cross_attention_dim=cross_attention_dim,
233
+ norm_num_groups=resnet_groups,
234
+ use_linear_projection=use_linear_projection,
235
+ upcast_attention=upcast_attention,
236
+ use_motion_module=use_motion_module,
237
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
238
+ unet_use_temporal_attention=unet_use_temporal_attention,
239
+ add_audio_layer=add_audio_layer,
240
+ audio_condition_method=audio_condition_method,
241
+ )
242
+ )
243
+ audio_attentions.append(
244
+ Transformer3DModel(
245
+ attn_num_head_channels,
246
+ in_channels // attn_num_head_channels,
247
+ in_channels=in_channels,
248
+ num_layers=1,
249
+ cross_attention_dim=cross_attention_dim,
250
+ norm_num_groups=resnet_groups,
251
+ use_linear_projection=use_linear_projection,
252
+ upcast_attention=upcast_attention,
253
+ use_motion_module=use_motion_module,
254
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
255
+ unet_use_temporal_attention=unet_use_temporal_attention,
256
+ add_audio_layer=add_audio_layer,
257
+ audio_condition_method=audio_condition_method,
258
+ custom_audio_layer=True,
259
+ )
260
+ if custom_audio_layer
261
+ else None
262
+ )
263
+ motion_modules.append(
264
+ get_motion_module(
265
+ in_channels=in_channels,
266
+ motion_module_type=motion_module_type,
267
+ motion_module_kwargs=motion_module_kwargs,
268
+ )
269
+ if use_motion_module
270
+ else None
271
+ )
272
+ resnets.append(
273
+ ResnetBlock3D(
274
+ in_channels=in_channels,
275
+ out_channels=in_channels,
276
+ temb_channels=temb_channels,
277
+ eps=resnet_eps,
278
+ groups=resnet_groups,
279
+ dropout=dropout,
280
+ time_embedding_norm=resnet_time_scale_shift,
281
+ non_linearity=resnet_act_fn,
282
+ output_scale_factor=output_scale_factor,
283
+ pre_norm=resnet_pre_norm,
284
+ use_inflated_groupnorm=use_inflated_groupnorm,
285
+ )
286
+ )
287
+
288
+ self.attentions = nn.ModuleList(attentions)
289
+ self.audio_attentions = nn.ModuleList(audio_attentions)
290
+ self.resnets = nn.ModuleList(resnets)
291
+ self.motion_modules = nn.ModuleList(motion_modules)
292
+
293
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
294
+ hidden_states = self.resnets[0](hidden_states, temb)
295
+ for attn, audio_attn, resnet, motion_module in zip(
296
+ self.attentions, self.audio_attentions, self.resnets[1:], self.motion_modules
297
+ ):
298
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
299
+ hidden_states = (
300
+ audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
301
+ if audio_attn is not None
302
+ else hidden_states
303
+ )
304
+ hidden_states = (
305
+ motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
306
+ if motion_module is not None
307
+ else hidden_states
308
+ )
309
+ hidden_states = resnet(hidden_states, temb)
310
+
311
+ return hidden_states
312
+
313
+
314
+ class CrossAttnDownBlock3D(nn.Module):
315
+ def __init__(
316
+ self,
317
+ in_channels: int,
318
+ out_channels: int,
319
+ temb_channels: int,
320
+ dropout: float = 0.0,
321
+ num_layers: int = 1,
322
+ resnet_eps: float = 1e-6,
323
+ resnet_time_scale_shift: str = "default",
324
+ resnet_act_fn: str = "swish",
325
+ resnet_groups: int = 32,
326
+ resnet_pre_norm: bool = True,
327
+ attn_num_head_channels=1,
328
+ cross_attention_dim=1280,
329
+ output_scale_factor=1.0,
330
+ downsample_padding=1,
331
+ add_downsample=True,
332
+ dual_cross_attention=False,
333
+ use_linear_projection=False,
334
+ only_cross_attention=False,
335
+ upcast_attention=False,
336
+ unet_use_cross_frame_attention=False,
337
+ unet_use_temporal_attention=False,
338
+ use_inflated_groupnorm=False,
339
+ use_motion_module=None,
340
+ motion_module_type=None,
341
+ motion_module_kwargs=None,
342
+ add_audio_layer=False,
343
+ audio_condition_method="cross_attn",
344
+ custom_audio_layer: bool = False,
345
+ ):
346
+ super().__init__()
347
+ resnets = []
348
+ attentions = []
349
+ audio_attentions = []
350
+ motion_modules = []
351
+
352
+ self.has_cross_attention = True
353
+ self.attn_num_head_channels = attn_num_head_channels
354
+
355
+ for i in range(num_layers):
356
+ in_channels = in_channels if i == 0 else out_channels
357
+ resnets.append(
358
+ ResnetBlock3D(
359
+ in_channels=in_channels,
360
+ out_channels=out_channels,
361
+ temb_channels=temb_channels,
362
+ eps=resnet_eps,
363
+ groups=resnet_groups,
364
+ dropout=dropout,
365
+ time_embedding_norm=resnet_time_scale_shift,
366
+ non_linearity=resnet_act_fn,
367
+ output_scale_factor=output_scale_factor,
368
+ pre_norm=resnet_pre_norm,
369
+ use_inflated_groupnorm=use_inflated_groupnorm,
370
+ )
371
+ )
372
+ if dual_cross_attention:
373
+ raise NotImplementedError
374
+ attentions.append(
375
+ Transformer3DModel(
376
+ attn_num_head_channels,
377
+ out_channels // attn_num_head_channels,
378
+ in_channels=out_channels,
379
+ num_layers=1,
380
+ cross_attention_dim=cross_attention_dim,
381
+ norm_num_groups=resnet_groups,
382
+ use_linear_projection=use_linear_projection,
383
+ only_cross_attention=only_cross_attention,
384
+ upcast_attention=upcast_attention,
385
+ use_motion_module=use_motion_module,
386
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
387
+ unet_use_temporal_attention=unet_use_temporal_attention,
388
+ add_audio_layer=add_audio_layer,
389
+ audio_condition_method=audio_condition_method,
390
+ )
391
+ )
392
+ audio_attentions.append(
393
+ Transformer3DModel(
394
+ attn_num_head_channels,
395
+ out_channels // attn_num_head_channels,
396
+ in_channels=out_channels,
397
+ num_layers=1,
398
+ cross_attention_dim=cross_attention_dim,
399
+ norm_num_groups=resnet_groups,
400
+ use_linear_projection=use_linear_projection,
401
+ only_cross_attention=only_cross_attention,
402
+ upcast_attention=upcast_attention,
403
+ use_motion_module=use_motion_module,
404
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
405
+ unet_use_temporal_attention=unet_use_temporal_attention,
406
+ add_audio_layer=add_audio_layer,
407
+ audio_condition_method=audio_condition_method,
408
+ custom_audio_layer=True,
409
+ )
410
+ if custom_audio_layer
411
+ else None
412
+ )
413
+ motion_modules.append(
414
+ get_motion_module(
415
+ in_channels=out_channels,
416
+ motion_module_type=motion_module_type,
417
+ motion_module_kwargs=motion_module_kwargs,
418
+ )
419
+ if use_motion_module
420
+ else None
421
+ )
422
+
423
+ self.attentions = nn.ModuleList(attentions)
424
+ self.audio_attentions = nn.ModuleList(audio_attentions)
425
+ self.resnets = nn.ModuleList(resnets)
426
+ self.motion_modules = nn.ModuleList(motion_modules)
427
+
428
+ if add_downsample:
429
+ self.downsamplers = nn.ModuleList(
430
+ [
431
+ Downsample3D(
432
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
433
+ )
434
+ ]
435
+ )
436
+ else:
437
+ self.downsamplers = None
438
+
439
+ self.gradient_checkpointing = False
440
+
441
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
442
+ output_states = ()
443
+
444
+ for resnet, attn, audio_attn, motion_module in zip(
445
+ self.resnets, self.attentions, self.audio_attentions, self.motion_modules
446
+ ):
447
+ if self.training and self.gradient_checkpointing:
448
+
449
+ def create_custom_forward(module, return_dict=None):
450
+ def custom_forward(*inputs):
451
+ if return_dict is not None:
452
+ return module(*inputs, return_dict=return_dict)
453
+ else:
454
+ return module(*inputs)
455
+
456
+ return custom_forward
457
+
458
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
459
+ hidden_states = torch.utils.checkpoint.checkpoint(
460
+ create_custom_forward(attn, return_dict=False),
461
+ hidden_states,
462
+ encoder_hidden_states,
463
+ )[0]
464
+ if motion_module is not None:
465
+ hidden_states = torch.utils.checkpoint.checkpoint(
466
+ create_custom_forward(motion_module),
467
+ hidden_states.requires_grad_(),
468
+ temb,
469
+ encoder_hidden_states,
470
+ )
471
+
472
+ else:
473
+ hidden_states = resnet(hidden_states, temb)
474
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
475
+
476
+ hidden_states = (
477
+ audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
478
+ if audio_attn is not None
479
+ else hidden_states
480
+ )
481
+
482
+ # add motion module
483
+ hidden_states = (
484
+ motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
485
+ if motion_module is not None
486
+ else hidden_states
487
+ )
488
+
489
+ output_states += (hidden_states,)
490
+
491
+ if self.downsamplers is not None:
492
+ for downsampler in self.downsamplers:
493
+ hidden_states = downsampler(hidden_states)
494
+
495
+ output_states += (hidden_states,)
496
+
497
+ return hidden_states, output_states
498
+
499
+
500
+ class DownBlock3D(nn.Module):
501
+ def __init__(
502
+ self,
503
+ in_channels: int,
504
+ out_channels: int,
505
+ temb_channels: int,
506
+ dropout: float = 0.0,
507
+ num_layers: int = 1,
508
+ resnet_eps: float = 1e-6,
509
+ resnet_time_scale_shift: str = "default",
510
+ resnet_act_fn: str = "swish",
511
+ resnet_groups: int = 32,
512
+ resnet_pre_norm: bool = True,
513
+ output_scale_factor=1.0,
514
+ add_downsample=True,
515
+ downsample_padding=1,
516
+ use_inflated_groupnorm=False,
517
+ use_motion_module=None,
518
+ motion_module_type=None,
519
+ motion_module_kwargs=None,
520
+ ):
521
+ super().__init__()
522
+ resnets = []
523
+ motion_modules = []
524
+
525
+ for i in range(num_layers):
526
+ in_channels = in_channels if i == 0 else out_channels
527
+ resnets.append(
528
+ ResnetBlock3D(
529
+ in_channels=in_channels,
530
+ out_channels=out_channels,
531
+ temb_channels=temb_channels,
532
+ eps=resnet_eps,
533
+ groups=resnet_groups,
534
+ dropout=dropout,
535
+ time_embedding_norm=resnet_time_scale_shift,
536
+ non_linearity=resnet_act_fn,
537
+ output_scale_factor=output_scale_factor,
538
+ pre_norm=resnet_pre_norm,
539
+ use_inflated_groupnorm=use_inflated_groupnorm,
540
+ )
541
+ )
542
+ motion_modules.append(
543
+ get_motion_module(
544
+ in_channels=out_channels,
545
+ motion_module_type=motion_module_type,
546
+ motion_module_kwargs=motion_module_kwargs,
547
+ )
548
+ if use_motion_module
549
+ else None
550
+ )
551
+
552
+ self.resnets = nn.ModuleList(resnets)
553
+ self.motion_modules = nn.ModuleList(motion_modules)
554
+
555
+ if add_downsample:
556
+ self.downsamplers = nn.ModuleList(
557
+ [
558
+ Downsample3D(
559
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
560
+ )
561
+ ]
562
+ )
563
+ else:
564
+ self.downsamplers = None
565
+
566
+ self.gradient_checkpointing = False
567
+
568
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
569
+ output_states = ()
570
+
571
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
572
+ if self.training and self.gradient_checkpointing:
573
+
574
+ def create_custom_forward(module):
575
+ def custom_forward(*inputs):
576
+ return module(*inputs)
577
+
578
+ return custom_forward
579
+
580
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
581
+ if motion_module is not None:
582
+ hidden_states = torch.utils.checkpoint.checkpoint(
583
+ create_custom_forward(motion_module),
584
+ hidden_states.requires_grad_(),
585
+ temb,
586
+ encoder_hidden_states,
587
+ )
588
+ else:
589
+ hidden_states = resnet(hidden_states, temb)
590
+
591
+ # add motion module
592
+ hidden_states = (
593
+ motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
594
+ if motion_module is not None
595
+ else hidden_states
596
+ )
597
+
598
+ output_states += (hidden_states,)
599
+
600
+ if self.downsamplers is not None:
601
+ for downsampler in self.downsamplers:
602
+ hidden_states = downsampler(hidden_states)
603
+
604
+ output_states += (hidden_states,)
605
+
606
+ return hidden_states, output_states
607
+
608
+
609
+ class CrossAttnUpBlock3D(nn.Module):
610
+ def __init__(
611
+ self,
612
+ in_channels: int,
613
+ out_channels: int,
614
+ prev_output_channel: int,
615
+ temb_channels: int,
616
+ dropout: float = 0.0,
617
+ num_layers: int = 1,
618
+ resnet_eps: float = 1e-6,
619
+ resnet_time_scale_shift: str = "default",
620
+ resnet_act_fn: str = "swish",
621
+ resnet_groups: int = 32,
622
+ resnet_pre_norm: bool = True,
623
+ attn_num_head_channels=1,
624
+ cross_attention_dim=1280,
625
+ output_scale_factor=1.0,
626
+ add_upsample=True,
627
+ dual_cross_attention=False,
628
+ use_linear_projection=False,
629
+ only_cross_attention=False,
630
+ upcast_attention=False,
631
+ unet_use_cross_frame_attention=False,
632
+ unet_use_temporal_attention=False,
633
+ use_inflated_groupnorm=False,
634
+ use_motion_module=None,
635
+ motion_module_type=None,
636
+ motion_module_kwargs=None,
637
+ add_audio_layer=False,
638
+ audio_condition_method="cross_attn",
639
+ custom_audio_layer=False,
640
+ ):
641
+ super().__init__()
642
+ resnets = []
643
+ attentions = []
644
+ audio_attentions = []
645
+ motion_modules = []
646
+
647
+ self.has_cross_attention = True
648
+ self.attn_num_head_channels = attn_num_head_channels
649
+
650
+ for i in range(num_layers):
651
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
652
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
653
+
654
+ resnets.append(
655
+ ResnetBlock3D(
656
+ in_channels=resnet_in_channels + res_skip_channels,
657
+ out_channels=out_channels,
658
+ temb_channels=temb_channels,
659
+ eps=resnet_eps,
660
+ groups=resnet_groups,
661
+ dropout=dropout,
662
+ time_embedding_norm=resnet_time_scale_shift,
663
+ non_linearity=resnet_act_fn,
664
+ output_scale_factor=output_scale_factor,
665
+ pre_norm=resnet_pre_norm,
666
+ use_inflated_groupnorm=use_inflated_groupnorm,
667
+ )
668
+ )
669
+ if dual_cross_attention:
670
+ raise NotImplementedError
671
+ attentions.append(
672
+ Transformer3DModel(
673
+ attn_num_head_channels,
674
+ out_channels // attn_num_head_channels,
675
+ in_channels=out_channels,
676
+ num_layers=1,
677
+ cross_attention_dim=cross_attention_dim,
678
+ norm_num_groups=resnet_groups,
679
+ use_linear_projection=use_linear_projection,
680
+ only_cross_attention=only_cross_attention,
681
+ upcast_attention=upcast_attention,
682
+ use_motion_module=use_motion_module,
683
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
684
+ unet_use_temporal_attention=unet_use_temporal_attention,
685
+ add_audio_layer=add_audio_layer,
686
+ audio_condition_method=audio_condition_method,
687
+ )
688
+ )
689
+ audio_attentions.append(
690
+ Transformer3DModel(
691
+ attn_num_head_channels,
692
+ out_channels // attn_num_head_channels,
693
+ in_channels=out_channels,
694
+ num_layers=1,
695
+ cross_attention_dim=cross_attention_dim,
696
+ norm_num_groups=resnet_groups,
697
+ use_linear_projection=use_linear_projection,
698
+ only_cross_attention=only_cross_attention,
699
+ upcast_attention=upcast_attention,
700
+ use_motion_module=use_motion_module,
701
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
702
+ unet_use_temporal_attention=unet_use_temporal_attention,
703
+ add_audio_layer=add_audio_layer,
704
+ audio_condition_method=audio_condition_method,
705
+ custom_audio_layer=True,
706
+ )
707
+ if custom_audio_layer
708
+ else None
709
+ )
710
+ motion_modules.append(
711
+ get_motion_module(
712
+ in_channels=out_channels,
713
+ motion_module_type=motion_module_type,
714
+ motion_module_kwargs=motion_module_kwargs,
715
+ )
716
+ if use_motion_module
717
+ else None
718
+ )
719
+
720
+ self.attentions = nn.ModuleList(attentions)
721
+ self.audio_attentions = nn.ModuleList(audio_attentions)
722
+ self.resnets = nn.ModuleList(resnets)
723
+ self.motion_modules = nn.ModuleList(motion_modules)
724
+
725
+ if add_upsample:
726
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
727
+ else:
728
+ self.upsamplers = None
729
+
730
+ self.gradient_checkpointing = False
731
+
732
+ def forward(
733
+ self,
734
+ hidden_states,
735
+ res_hidden_states_tuple,
736
+ temb=None,
737
+ encoder_hidden_states=None,
738
+ upsample_size=None,
739
+ attention_mask=None,
740
+ ):
741
+ for resnet, attn, audio_attn, motion_module in zip(
742
+ self.resnets, self.attentions, self.audio_attentions, self.motion_modules
743
+ ):
744
+ # pop res hidden states
745
+ res_hidden_states = res_hidden_states_tuple[-1]
746
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
747
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
748
+
749
+ if self.training and self.gradient_checkpointing:
750
+
751
+ def create_custom_forward(module, return_dict=None):
752
+ def custom_forward(*inputs):
753
+ if return_dict is not None:
754
+ return module(*inputs, return_dict=return_dict)
755
+ else:
756
+ return module(*inputs)
757
+
758
+ return custom_forward
759
+
760
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
761
+ hidden_states = torch.utils.checkpoint.checkpoint(
762
+ create_custom_forward(attn, return_dict=False),
763
+ hidden_states,
764
+ encoder_hidden_states,
765
+ )[0]
766
+ if motion_module is not None:
767
+ hidden_states = torch.utils.checkpoint.checkpoint(
768
+ create_custom_forward(motion_module),
769
+ hidden_states.requires_grad_(),
770
+ temb,
771
+ encoder_hidden_states,
772
+ )
773
+
774
+ else:
775
+ hidden_states = resnet(hidden_states, temb)
776
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
777
+ hidden_states = (
778
+ audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
779
+ if audio_attn is not None
780
+ else hidden_states
781
+ )
782
+
783
+ # add motion module
784
+ hidden_states = (
785
+ motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
786
+ if motion_module is not None
787
+ else hidden_states
788
+ )
789
+
790
+ if self.upsamplers is not None:
791
+ for upsampler in self.upsamplers:
792
+ hidden_states = upsampler(hidden_states, upsample_size)
793
+
794
+ return hidden_states
795
+
796
+
797
+ class UpBlock3D(nn.Module):
798
+ def __init__(
799
+ self,
800
+ in_channels: int,
801
+ prev_output_channel: int,
802
+ out_channels: int,
803
+ temb_channels: int,
804
+ dropout: float = 0.0,
805
+ num_layers: int = 1,
806
+ resnet_eps: float = 1e-6,
807
+ resnet_time_scale_shift: str = "default",
808
+ resnet_act_fn: str = "swish",
809
+ resnet_groups: int = 32,
810
+ resnet_pre_norm: bool = True,
811
+ output_scale_factor=1.0,
812
+ add_upsample=True,
813
+ use_inflated_groupnorm=False,
814
+ use_motion_module=None,
815
+ motion_module_type=None,
816
+ motion_module_kwargs=None,
817
+ ):
818
+ super().__init__()
819
+ resnets = []
820
+ motion_modules = []
821
+
822
+ for i in range(num_layers):
823
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
824
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
825
+
826
+ resnets.append(
827
+ ResnetBlock3D(
828
+ in_channels=resnet_in_channels + res_skip_channels,
829
+ out_channels=out_channels,
830
+ temb_channels=temb_channels,
831
+ eps=resnet_eps,
832
+ groups=resnet_groups,
833
+ dropout=dropout,
834
+ time_embedding_norm=resnet_time_scale_shift,
835
+ non_linearity=resnet_act_fn,
836
+ output_scale_factor=output_scale_factor,
837
+ pre_norm=resnet_pre_norm,
838
+ use_inflated_groupnorm=use_inflated_groupnorm,
839
+ )
840
+ )
841
+ motion_modules.append(
842
+ get_motion_module(
843
+ in_channels=out_channels,
844
+ motion_module_type=motion_module_type,
845
+ motion_module_kwargs=motion_module_kwargs,
846
+ )
847
+ if use_motion_module
848
+ else None
849
+ )
850
+
851
+ self.resnets = nn.ModuleList(resnets)
852
+ self.motion_modules = nn.ModuleList(motion_modules)
853
+
854
+ if add_upsample:
855
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
856
+ else:
857
+ self.upsamplers = None
858
+
859
+ self.gradient_checkpointing = False
860
+
861
+ def forward(
862
+ self,
863
+ hidden_states,
864
+ res_hidden_states_tuple,
865
+ temb=None,
866
+ upsample_size=None,
867
+ encoder_hidden_states=None,
868
+ ):
869
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
870
+ # pop res hidden states
871
+ res_hidden_states = res_hidden_states_tuple[-1]
872
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
873
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
874
+
875
+ if self.training and self.gradient_checkpointing:
876
+
877
+ def create_custom_forward(module):
878
+ def custom_forward(*inputs):
879
+ return module(*inputs)
880
+
881
+ return custom_forward
882
+
883
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
884
+ if motion_module is not None:
885
+ hidden_states = torch.utils.checkpoint.checkpoint(
886
+ create_custom_forward(motion_module),
887
+ hidden_states.requires_grad_(),
888
+ temb,
889
+ encoder_hidden_states,
890
+ )
891
+ else:
892
+ hidden_states = resnet(hidden_states, temb)
893
+ hidden_states = (
894
+ motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
895
+ if motion_module is not None
896
+ else hidden_states
897
+ )
898
+
899
+ if self.upsamplers is not None:
900
+ for upsampler in self.upsamplers:
901
+ hidden_states = upsampler(hidden_states, upsample_size)
902
+
903
+ return hidden_states
latentsync/models/utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ def zero_module(module):
16
+ # Zero out the parameters of a module and return it.
17
+ for p in module.parameters():
18
+ p.detach().zero_()
19
+ return module
latentsync/pipelines/lipsync_pipeline.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/pipelines/pipeline_animation.py
2
+
3
+ import inspect
4
+ import os
5
+ import shutil
6
+ from typing import Callable, List, Optional, Union
7
+ import subprocess
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torchvision
12
+
13
+ from diffusers.utils import is_accelerate_available
14
+ from packaging import version
15
+
16
+ from diffusers.configuration_utils import FrozenDict
17
+ from diffusers.models import AutoencoderKL
18
+ from diffusers.pipeline_utils import DiffusionPipeline
19
+ from diffusers.schedulers import (
20
+ DDIMScheduler,
21
+ DPMSolverMultistepScheduler,
22
+ EulerAncestralDiscreteScheduler,
23
+ EulerDiscreteScheduler,
24
+ LMSDiscreteScheduler,
25
+ PNDMScheduler,
26
+ )
27
+ from diffusers.utils import deprecate, logging
28
+
29
+ from einops import rearrange
30
+
31
+ from ..models.unet import UNet3DConditionModel
32
+ from ..utils.image_processor import ImageProcessor
33
+ from ..utils.util import read_video, read_audio, write_video
34
+ from ..whisper.audio2feature import Audio2Feature
35
+ import tqdm
36
+ import soundfile as sf
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ class LipsyncPipeline(DiffusionPipeline):
42
+ _optional_components = []
43
+
44
+ def __init__(
45
+ self,
46
+ vae: AutoencoderKL,
47
+ audio_encoder: Audio2Feature,
48
+ unet: UNet3DConditionModel,
49
+ scheduler: Union[
50
+ DDIMScheduler,
51
+ PNDMScheduler,
52
+ LMSDiscreteScheduler,
53
+ EulerDiscreteScheduler,
54
+ EulerAncestralDiscreteScheduler,
55
+ DPMSolverMultistepScheduler,
56
+ ],
57
+ ):
58
+ super().__init__()
59
+
60
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
61
+ deprecation_message = (
62
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
63
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
64
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
65
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
66
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
67
+ " file"
68
+ )
69
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
70
+ new_config = dict(scheduler.config)
71
+ new_config["steps_offset"] = 1
72
+ scheduler._internal_dict = FrozenDict(new_config)
73
+
74
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
75
+ deprecation_message = (
76
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
77
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
78
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
79
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
80
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
81
+ )
82
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
83
+ new_config = dict(scheduler.config)
84
+ new_config["clip_sample"] = False
85
+ scheduler._internal_dict = FrozenDict(new_config)
86
+
87
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
88
+ version.parse(unet.config._diffusers_version).base_version
89
+ ) < version.parse("0.9.0.dev0")
90
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
91
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
92
+ deprecation_message = (
93
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
94
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
95
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
96
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
97
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
98
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
99
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
100
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
101
+ " the `unet/config.json` file"
102
+ )
103
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
104
+ new_config = dict(unet.config)
105
+ new_config["sample_size"] = 64
106
+ unet._internal_dict = FrozenDict(new_config)
107
+
108
+ self.register_modules(
109
+ vae=vae,
110
+ audio_encoder=audio_encoder,
111
+ unet=unet,
112
+ scheduler=scheduler,
113
+ )
114
+
115
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
116
+
117
+ self.set_progress_bar_config(desc="Steps")
118
+
119
+ def enable_vae_slicing(self):
120
+ self.vae.enable_slicing()
121
+
122
+ def disable_vae_slicing(self):
123
+ self.vae.disable_slicing()
124
+
125
+ def enable_sequential_cpu_offload(self, gpu_id=0):
126
+ if is_accelerate_available():
127
+ from accelerate import cpu_offload
128
+ else:
129
+ raise ImportError("Please install accelerate via `pip install accelerate`")
130
+
131
+ device = torch.device(f"cuda:{gpu_id}")
132
+
133
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
134
+ if cpu_offloaded_model is not None:
135
+ cpu_offload(cpu_offloaded_model, device)
136
+
137
+ @property
138
+ def _execution_device(self):
139
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
140
+ return self.device
141
+ for module in self.unet.modules():
142
+ if (
143
+ hasattr(module, "_hf_hook")
144
+ and hasattr(module._hf_hook, "execution_device")
145
+ and module._hf_hook.execution_device is not None
146
+ ):
147
+ return torch.device(module._hf_hook.execution_device)
148
+ return self.device
149
+
150
+ def decode_latents(self, latents):
151
+ latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
152
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
153
+ decoded_latents = self.vae.decode(latents).sample
154
+ return decoded_latents
155
+
156
+ def prepare_extra_step_kwargs(self, generator, eta):
157
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
158
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
159
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
160
+ # and should be between [0, 1]
161
+
162
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
163
+ extra_step_kwargs = {}
164
+ if accepts_eta:
165
+ extra_step_kwargs["eta"] = eta
166
+
167
+ # check if the scheduler accepts generator
168
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
169
+ if accepts_generator:
170
+ extra_step_kwargs["generator"] = generator
171
+ return extra_step_kwargs
172
+
173
+ def check_inputs(self, height, width, callback_steps):
174
+ assert height == width, "Height and width must be equal"
175
+
176
+ if height % 8 != 0 or width % 8 != 0:
177
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
178
+
179
+ if (callback_steps is None) or (
180
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
181
+ ):
182
+ raise ValueError(
183
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
184
+ f" {type(callback_steps)}."
185
+ )
186
+
187
+ def prepare_latents(self, batch_size, num_frames, num_channels_latents, height, width, dtype, device, generator):
188
+ shape = (
189
+ batch_size,
190
+ num_channels_latents,
191
+ 1,
192
+ height // self.vae_scale_factor,
193
+ width // self.vae_scale_factor,
194
+ )
195
+ rand_device = "cpu" if device.type == "mps" else device
196
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
197
+ latents = latents.repeat(1, 1, num_frames, 1, 1)
198
+
199
+ # scale the initial noise by the standard deviation required by the scheduler
200
+ latents = latents * self.scheduler.init_noise_sigma
201
+ return latents
202
+
203
+ def prepare_mask_latents(
204
+ self, mask, masked_image, height, width, dtype, device, generator, do_classifier_free_guidance
205
+ ):
206
+ # resize the mask to latents shape as we concatenate the mask to the latents
207
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
208
+ # and half precision
209
+ mask = torch.nn.functional.interpolate(
210
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
211
+ )
212
+ masked_image = masked_image.to(device=device, dtype=dtype)
213
+
214
+ # encode the mask image into latents space so we can concatenate it to the latents
215
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
216
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
217
+
218
+ # aligning device to prevent device errors when concating it with the latent model input
219
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
220
+ mask = mask.to(device=device, dtype=dtype)
221
+
222
+ # assume batch size = 1
223
+ mask = rearrange(mask, "f c h w -> 1 c f h w")
224
+ masked_image_latents = rearrange(masked_image_latents, "f c h w -> 1 c f h w")
225
+
226
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
227
+ masked_image_latents = (
228
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
229
+ )
230
+ return mask, masked_image_latents
231
+
232
+ def prepare_image_latents(self, images, device, dtype, generator, do_classifier_free_guidance):
233
+ images = images.to(device=device, dtype=dtype)
234
+ image_latents = self.vae.encode(images).latent_dist.sample(generator=generator)
235
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
236
+ image_latents = rearrange(image_latents, "f c h w -> 1 c f h w")
237
+ image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
238
+
239
+ return image_latents
240
+
241
+ def set_progress_bar_config(self, **kwargs):
242
+ if not hasattr(self, "_progress_bar_config"):
243
+ self._progress_bar_config = {}
244
+ self._progress_bar_config.update(kwargs)
245
+
246
+ @staticmethod
247
+ def paste_surrounding_pixels_back(decoded_latents, pixel_values, masks, device, weight_dtype):
248
+ # Paste the surrounding pixels back, because we only want to change the mouth region
249
+ pixel_values = pixel_values.to(device=device, dtype=weight_dtype)
250
+ masks = masks.to(device=device, dtype=weight_dtype)
251
+ combined_pixel_values = decoded_latents * masks + pixel_values * (1 - masks)
252
+ return combined_pixel_values
253
+
254
+ @staticmethod
255
+ def pixel_values_to_images(pixel_values: torch.Tensor):
256
+ pixel_values = rearrange(pixel_values, "f c h w -> f h w c")
257
+ pixel_values = (pixel_values / 2 + 0.5).clamp(0, 1)
258
+ images = (pixel_values * 255).to(torch.uint8)
259
+ images = images.cpu().numpy()
260
+ return images
261
+
262
+ def affine_transform_video(self, video_path):
263
+ video_frames = read_video(video_path, use_decord=False)
264
+ faces = []
265
+ boxes = []
266
+ affine_matrices = []
267
+ print(f"Affine transforming {len(video_frames)} faces...")
268
+ for frame in tqdm.tqdm(video_frames):
269
+ face, box, affine_matrix = self.image_processor.affine_transform(frame)
270
+ faces.append(face)
271
+ boxes.append(box)
272
+ affine_matrices.append(affine_matrix)
273
+
274
+ faces = torch.stack(faces)
275
+ return faces, video_frames, boxes, affine_matrices
276
+
277
+ def restore_video(self, faces, video_frames, boxes, affine_matrices):
278
+ video_frames = video_frames[: faces.shape[0]]
279
+ out_frames = []
280
+ for index, face in enumerate(faces):
281
+ x1, y1, x2, y2 = boxes[index]
282
+ height = int(y2 - y1)
283
+ width = int(x2 - x1)
284
+ face = torchvision.transforms.functional.resize(face, size=(height, width), antialias=True)
285
+ face = rearrange(face, "c h w -> h w c")
286
+ face = (face / 2 + 0.5).clamp(0, 1)
287
+ face = (face * 255).to(torch.uint8).cpu().numpy()
288
+ out_frame = self.image_processor.restorer.restore_img(video_frames[index], face, affine_matrices[index])
289
+ out_frames.append(out_frame)
290
+ return np.stack(out_frames, axis=0)
291
+
292
+ @torch.no_grad()
293
+ def __call__(
294
+ self,
295
+ video_path: str,
296
+ audio_path: str,
297
+ video_out_path: str,
298
+ video_mask_path: str = None,
299
+ num_frames: int = 16,
300
+ video_fps: int = 25,
301
+ audio_sample_rate: int = 16000,
302
+ height: Optional[int] = None,
303
+ width: Optional[int] = None,
304
+ num_inference_steps: int = 20,
305
+ guidance_scale: float = 1.5,
306
+ weight_dtype: Optional[torch.dtype] = torch.float16,
307
+ eta: float = 0.0,
308
+ mask: str = "fix_mask",
309
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
310
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
311
+ callback_steps: Optional[int] = 1,
312
+ **kwargs,
313
+ ):
314
+ is_train = self.unet.training
315
+ self.unet.eval()
316
+
317
+ # 0. Define call parameters
318
+ batch_size = 1
319
+ device = self._execution_device
320
+ self.image_processor = ImageProcessor(height, mask=mask, device="cuda")
321
+ self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
322
+
323
+ video_frames, original_video_frames, boxes, affine_matrices = self.affine_transform_video(video_path)
324
+ audio_samples = read_audio(audio_path)
325
+
326
+ # 1. Default height and width to unet
327
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
328
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
329
+
330
+ # 2. Check inputs
331
+ self.check_inputs(height, width, callback_steps)
332
+
333
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
334
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
335
+ # corresponds to doing no classifier free guidance.
336
+ do_classifier_free_guidance = guidance_scale > 1.0
337
+
338
+ # 3. set timesteps
339
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
340
+ timesteps = self.scheduler.timesteps
341
+
342
+ # 4. Prepare extra step kwargs.
343
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
344
+
345
+ self.video_fps = video_fps
346
+
347
+ if self.unet.add_audio_layer:
348
+ whisper_feature = self.audio_encoder.audio2feat(audio_path)
349
+ whisper_chunks = self.audio_encoder.feature2chunks(feature_array=whisper_feature, fps=video_fps)
350
+
351
+ num_inferences = min(len(video_frames), len(whisper_chunks)) // num_frames
352
+ else:
353
+ num_inferences = len(video_frames) // num_frames
354
+
355
+ synced_video_frames = []
356
+ masked_video_frames = []
357
+
358
+ num_channels_latents = self.vae.config.latent_channels
359
+
360
+ # Prepare latent variables
361
+ all_latents = self.prepare_latents(
362
+ batch_size,
363
+ num_frames * num_inferences,
364
+ num_channels_latents,
365
+ height,
366
+ width,
367
+ weight_dtype,
368
+ device,
369
+ generator,
370
+ )
371
+
372
+ for i in tqdm.tqdm(range(num_inferences), desc="Doing inference..."):
373
+ if self.unet.add_audio_layer:
374
+ audio_embeds = torch.stack(whisper_chunks[i * num_frames : (i + 1) * num_frames])
375
+ audio_embeds = audio_embeds.to(device, dtype=weight_dtype)
376
+ if do_classifier_free_guidance:
377
+ empty_audio_embeds = torch.zeros_like(audio_embeds)
378
+ audio_embeds = torch.cat([empty_audio_embeds, audio_embeds])
379
+ else:
380
+ audio_embeds = None
381
+ inference_video_frames = video_frames[i * num_frames : (i + 1) * num_frames]
382
+ latents = all_latents[:, :, i * num_frames : (i + 1) * num_frames]
383
+ pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
384
+ inference_video_frames, affine_transform=False
385
+ )
386
+
387
+ # 7. Prepare mask latent variables
388
+ mask_latents, masked_image_latents = self.prepare_mask_latents(
389
+ masks,
390
+ masked_pixel_values,
391
+ height,
392
+ width,
393
+ weight_dtype,
394
+ device,
395
+ generator,
396
+ do_classifier_free_guidance,
397
+ )
398
+
399
+ # 8. Prepare image latents
400
+ image_latents = self.prepare_image_latents(
401
+ pixel_values,
402
+ device,
403
+ weight_dtype,
404
+ generator,
405
+ do_classifier_free_guidance,
406
+ )
407
+
408
+ # 9. Denoising loop
409
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
410
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
411
+ for j, t in enumerate(timesteps):
412
+ # expand the latents if we are doing classifier free guidance
413
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
414
+
415
+ # concat latents, mask, masked_image_latents in the channel dimension
416
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
417
+ latent_model_input = torch.cat(
418
+ [latent_model_input, mask_latents, masked_image_latents, image_latents], dim=1
419
+ )
420
+
421
+ # predict the noise residual
422
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=audio_embeds).sample
423
+
424
+ # perform guidance
425
+ if do_classifier_free_guidance:
426
+ noise_pred_uncond, noise_pred_audio = noise_pred.chunk(2)
427
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_audio - noise_pred_uncond)
428
+
429
+ # compute the previous noisy sample x_t -> x_t-1
430
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
431
+
432
+ # call the callback, if provided
433
+ if j == len(timesteps) - 1 or ((j + 1) > num_warmup_steps and (j + 1) % self.scheduler.order == 0):
434
+ progress_bar.update()
435
+ if callback is not None and j % callback_steps == 0:
436
+ callback(j, t, latents)
437
+
438
+ # Recover the pixel values
439
+ decoded_latents = self.decode_latents(latents)
440
+ decoded_latents = self.paste_surrounding_pixels_back(
441
+ decoded_latents, pixel_values, 1 - masks, device, weight_dtype
442
+ )
443
+ synced_video_frames.append(decoded_latents)
444
+ masked_video_frames.append(masked_pixel_values)
445
+
446
+ synced_video_frames = self.restore_video(
447
+ torch.cat(synced_video_frames), original_video_frames, boxes, affine_matrices
448
+ )
449
+ masked_video_frames = self.restore_video(
450
+ torch.cat(masked_video_frames), original_video_frames, boxes, affine_matrices
451
+ )
452
+
453
+ audio_samples_remain_length = int(synced_video_frames.shape[0] / video_fps * audio_sample_rate)
454
+ audio_samples = audio_samples[:audio_samples_remain_length].cpu().numpy()
455
+
456
+ if is_train:
457
+ self.unet.train()
458
+
459
+ temp_dir = "temp"
460
+ if os.path.exists(temp_dir):
461
+ shutil.rmtree(temp_dir)
462
+ os.makedirs(temp_dir, exist_ok=True)
463
+
464
+ write_video(os.path.join(temp_dir, "video.mp4"), synced_video_frames, fps=25)
465
+ # write_video(video_mask_path, masked_video_frames, fps=25)
466
+
467
+ sf.write(os.path.join(temp_dir, "audio.wav"), audio_samples, audio_sample_rate)
468
+
469
+ command = f"ffmpeg -y -loglevel error -nostdin -i {os.path.join(temp_dir, 'video.mp4')} -i {os.path.join(temp_dir, 'audio.wav')} -c:v libx264 -c:a aac -q:v 0 -q:a 0 {video_out_path}"
470
+ subprocess.run(command, shell=True)
latentsync/trepa/__init__.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torch.nn as nn
18
+ from einops import rearrange
19
+ from .third_party.VideoMAEv2.utils import load_videomae_model
20
+
21
+
22
+ class TREPALoss:
23
+ def __init__(
24
+ self,
25
+ device="cuda",
26
+ ckpt_path="/mnt/bn/maliva-gen-ai-v2/chunyu.li/checkpoints/vit_g_hybrid_pt_1200e_ssv2_ft.pth",
27
+ ):
28
+ self.model = load_videomae_model(device, ckpt_path).eval().to(dtype=torch.float16)
29
+ self.model.requires_grad_(False)
30
+ self.bce_loss = nn.BCELoss()
31
+
32
+ def __call__(self, videos_fake, videos_real, loss_type="mse"):
33
+ batch_size = videos_fake.shape[0]
34
+ num_frames = videos_fake.shape[2]
35
+ videos_fake = rearrange(videos_fake.clone(), "b c f h w -> (b f) c h w")
36
+ videos_real = rearrange(videos_real.clone(), "b c f h w -> (b f) c h w")
37
+
38
+ videos_fake = F.interpolate(videos_fake, size=(224, 224), mode="bilinear")
39
+ videos_real = F.interpolate(videos_real, size=(224, 224), mode="bilinear")
40
+
41
+ videos_fake = rearrange(videos_fake, "(b f) c h w -> b c f h w", f=num_frames)
42
+ videos_real = rearrange(videos_real, "(b f) c h w -> b c f h w", f=num_frames)
43
+
44
+ # Because input pixel range is [-1, 1], and model expects pixel range to be [0, 1]
45
+ videos_fake = (videos_fake / 2 + 0.5).clamp(0, 1)
46
+ videos_real = (videos_real / 2 + 0.5).clamp(0, 1)
47
+
48
+ feats_fake = self.model.forward_features(videos_fake)
49
+ feats_real = self.model.forward_features(videos_real)
50
+
51
+ feats_fake = F.normalize(feats_fake, p=2, dim=1)
52
+ feats_real = F.normalize(feats_real, p=2, dim=1)
53
+
54
+ return F.mse_loss(feats_fake, feats_real)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ # input shape: (b, c, f, h, w)
59
+ videos_fake = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
60
+ videos_real = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
61
+
62
+ trepa_loss = TREPALoss(device="cuda")
63
+ loss = trepa_loss(videos_fake, videos_real)
64
+ print(loss)
latentsync/trepa/third_party/VideoMAEv2/__init__.py ADDED
File without changes
latentsync/trepa/third_party/VideoMAEv2/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import requests
4
+ from tqdm import tqdm
5
+ from torchvision import transforms
6
+ from .videomaev2_finetune import vit_giant_patch14_224
7
+
8
+ def to_normalized_float_tensor(vid):
9
+ return vid.permute(3, 0, 1, 2).to(torch.float32) / 255
10
+
11
+
12
+ # NOTE: for those functions, which generally expect mini-batches, we keep them
13
+ # as non-minibatch so that they are applied as if they were 4d (thus image).
14
+ # this way, we only apply the transformation in the spatial domain
15
+ def resize(vid, size, interpolation='bilinear'):
16
+ # NOTE: using bilinear interpolation because we don't work on minibatches
17
+ # at this level
18
+ scale = None
19
+ if isinstance(size, int):
20
+ scale = float(size) / min(vid.shape[-2:])
21
+ size = None
22
+ return torch.nn.functional.interpolate(
23
+ vid,
24
+ size=size,
25
+ scale_factor=scale,
26
+ mode=interpolation,
27
+ align_corners=False)
28
+
29
+
30
+ class ToFloatTensorInZeroOne(object):
31
+ def __call__(self, vid):
32
+ return to_normalized_float_tensor(vid)
33
+
34
+
35
+ class Resize(object):
36
+ def __init__(self, size):
37
+ self.size = size
38
+ def __call__(self, vid):
39
+ return resize(vid, self.size)
40
+
41
+ def preprocess_videomae(videos):
42
+ transform = transforms.Compose(
43
+ [ToFloatTensorInZeroOne(),
44
+ Resize((224, 224))])
45
+ return torch.stack([transform(f) for f in torch.from_numpy(videos)])
46
+
47
+
48
+ def load_videomae_model(device, ckpt_path=None):
49
+ if ckpt_path is None:
50
+ current_dir = os.path.dirname(os.path.abspath(__file__))
51
+ ckpt_path = os.path.join(current_dir, 'vit_g_hybrid_pt_1200e_ssv2_ft.pth')
52
+
53
+ if not os.path.exists(ckpt_path):
54
+ # download the ckpt to the path
55
+ ckpt_url = 'https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/videomaev2/vit_g_hybrid_pt_1200e_ssv2_ft.pth'
56
+ response = requests.get(ckpt_url, stream=True, allow_redirects=True)
57
+ total_size = int(response.headers.get("content-length", 0))
58
+ block_size = 1024
59
+
60
+ with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
61
+ with open(ckpt_path, "wb") as fw:
62
+ for data in response.iter_content(block_size):
63
+ progress_bar.update(len(data))
64
+ fw.write(data)
65
+
66
+ model = vit_giant_patch14_224(
67
+ img_size=224,
68
+ pretrained=False,
69
+ num_classes=174,
70
+ all_frames=16,
71
+ tubelet_size=2,
72
+ drop_path_rate=0.3,
73
+ use_mean_pooling=True)
74
+
75
+ ckpt = torch.load(ckpt_path, map_location='cpu')
76
+ for model_key in ['model', 'module']:
77
+ if model_key in ckpt:
78
+ ckpt = ckpt[model_key]
79
+ break
80
+ model.load_state_dict(ckpt)
81
+ return model.to(device)
latentsync/trepa/third_party/VideoMAEv2/videomaev2_finetune.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on BEiT, timm, DINO and DeiT code bases
3
+ # https://github.com/microsoft/unilm/tree/master/beit
4
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
5
+ # https://github.com/facebookresearch/deit
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ from functools import partial
9
+
10
+ import math
11
+ import warnings
12
+ import numpy as np
13
+ import collections.abc
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torch.utils.checkpoint as cp
18
+ from itertools import repeat
19
+
20
+
21
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
22
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
23
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
24
+ def norm_cdf(x):
25
+ # Computes standard normal cumulative distribution function
26
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
27
+
28
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
29
+ warnings.warn(
30
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
31
+ "The distribution of values may be incorrect.",
32
+ stacklevel=2,
33
+ )
34
+
35
+ with torch.no_grad():
36
+ # Values are generated by using a truncated uniform distribution and
37
+ # then using the inverse CDF for the normal distribution.
38
+ # Get upper and lower cdf values
39
+ l = norm_cdf((a - mean) / std)
40
+ u = norm_cdf((b - mean) / std)
41
+
42
+ # Uniformly fill tensor with values from [l, u], then translate to
43
+ # [2l-1, 2u-1].
44
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
45
+
46
+ # Use inverse cdf transform for normal distribution to get truncated
47
+ # standard normal
48
+ tensor.erfinv_()
49
+
50
+ # Transform to proper mean, std
51
+ tensor.mul_(std * math.sqrt(2.0))
52
+ tensor.add_(mean)
53
+
54
+ # Clamp to ensure it's in the proper range
55
+ tensor.clamp_(min=a, max=b)
56
+ return tensor
57
+
58
+
59
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
60
+ r"""Fills the input Tensor with values drawn from a truncated
61
+ normal distribution. The values are effectively drawn from the
62
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
63
+ with values outside :math:`[a, b]` redrawn until they are within
64
+ the bounds. The method used for generating the random values works
65
+ best when :math:`a \leq \text{mean} \leq b`.
66
+ Args:
67
+ tensor: an n-dimensional `torch.Tensor`
68
+ mean: the mean of the normal distribution
69
+ std: the standard deviation of the normal distribution
70
+ a: the minimum cutoff value
71
+ b: the maximum cutoff value
72
+ Examples:
73
+ >>> w = torch.empty(3, 5)
74
+ >>> nn.init.trunc_normal_(w)
75
+ """
76
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
77
+
78
+
79
+ def _ntuple(n):
80
+ def parse(x):
81
+ if isinstance(x, collections.abc.Iterable):
82
+ return x
83
+ return tuple(repeat(x, n))
84
+
85
+ return parse
86
+
87
+
88
+ to_2tuple = _ntuple(2)
89
+
90
+
91
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
92
+ """
93
+ Adapted from timm codebase
94
+ """
95
+ if drop_prob == 0.0 or not training:
96
+ return x
97
+ keep_prob = 1 - drop_prob
98
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
99
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
100
+ random_tensor.floor_() # binarize
101
+ output = x.div(keep_prob) * random_tensor
102
+ return output
103
+
104
+
105
+ def _cfg(url="", **kwargs):
106
+ return {
107
+ "url": url,
108
+ "num_classes": 400,
109
+ "input_size": (3, 224, 224),
110
+ "pool_size": None,
111
+ "crop_pct": 0.9,
112
+ "interpolation": "bicubic",
113
+ "mean": (0.5, 0.5, 0.5),
114
+ "std": (0.5, 0.5, 0.5),
115
+ **kwargs,
116
+ }
117
+
118
+
119
+ class DropPath(nn.Module):
120
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
121
+
122
+ def __init__(self, drop_prob=None):
123
+ super(DropPath, self).__init__()
124
+ self.drop_prob = drop_prob
125
+
126
+ def forward(self, x):
127
+ return drop_path(x, self.drop_prob, self.training)
128
+
129
+ def extra_repr(self) -> str:
130
+ return "p={}".format(self.drop_prob)
131
+
132
+
133
+ class Mlp(nn.Module):
134
+
135
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
136
+ super().__init__()
137
+ out_features = out_features or in_features
138
+ hidden_features = hidden_features or in_features
139
+ self.fc1 = nn.Linear(in_features, hidden_features)
140
+ self.act = act_layer()
141
+ self.fc2 = nn.Linear(hidden_features, out_features)
142
+ self.drop = nn.Dropout(drop)
143
+
144
+ def forward(self, x):
145
+ x = self.fc1(x)
146
+ x = self.act(x)
147
+ # x = self.drop(x)
148
+ # commit this for the orignal BERT implement
149
+ x = self.fc2(x)
150
+ x = self.drop(x)
151
+ return x
152
+
153
+
154
+ class CosAttention(nn.Module):
155
+
156
+ def __init__(
157
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_head_dim=None
158
+ ):
159
+ super().__init__()
160
+ self.num_heads = num_heads
161
+ head_dim = dim // num_heads
162
+ if attn_head_dim is not None:
163
+ head_dim = attn_head_dim
164
+ all_head_dim = head_dim * self.num_heads
165
+ # self.scale = qk_scale or head_dim**-0.5
166
+ # DO NOT RENAME [self.scale] (for no weight decay)
167
+ if qk_scale is None:
168
+ self.scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
169
+ else:
170
+ self.scale = qk_scale
171
+
172
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
173
+ if qkv_bias:
174
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
175
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
176
+ else:
177
+ self.q_bias = None
178
+ self.v_bias = None
179
+
180
+ self.attn_drop = nn.Dropout(attn_drop)
181
+ self.proj = nn.Linear(all_head_dim, dim)
182
+ self.proj_drop = nn.Dropout(proj_drop)
183
+
184
+ def forward(self, x):
185
+ B, N, C = x.shape
186
+ qkv_bias = None
187
+ if self.q_bias is not None:
188
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
189
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
190
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
191
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
192
+
193
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
194
+
195
+ # torch.log(torch.tensor(1. / 0.01)) = 4.6052
196
+ logit_scale = torch.clamp(self.scale, max=4.6052).exp()
197
+
198
+ attn = attn * logit_scale
199
+
200
+ attn = attn.softmax(dim=-1)
201
+ attn = self.attn_drop(attn)
202
+
203
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
204
+
205
+ x = self.proj(x)
206
+ x = self.proj_drop(x)
207
+ return x
208
+
209
+
210
+ class Attention(nn.Module):
211
+
212
+ def __init__(
213
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_head_dim=None
214
+ ):
215
+ super().__init__()
216
+ self.num_heads = num_heads
217
+ head_dim = dim // num_heads
218
+ if attn_head_dim is not None:
219
+ head_dim = attn_head_dim
220
+ all_head_dim = head_dim * self.num_heads
221
+ self.scale = qk_scale or head_dim**-0.5
222
+
223
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
224
+ if qkv_bias:
225
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
226
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
227
+ else:
228
+ self.q_bias = None
229
+ self.v_bias = None
230
+
231
+ self.attn_drop = nn.Dropout(attn_drop)
232
+ self.proj = nn.Linear(all_head_dim, dim)
233
+ self.proj_drop = nn.Dropout(proj_drop)
234
+
235
+ def forward(self, x):
236
+ B, N, C = x.shape
237
+ qkv_bias = None
238
+ if self.q_bias is not None:
239
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
240
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
241
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
242
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
243
+
244
+ q = q * self.scale
245
+ attn = q @ k.transpose(-2, -1)
246
+
247
+ attn = attn.softmax(dim=-1)
248
+ attn = self.attn_drop(attn)
249
+
250
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
251
+
252
+ x = self.proj(x)
253
+ x = self.proj_drop(x)
254
+ return x
255
+
256
+
257
+ class Block(nn.Module):
258
+
259
+ def __init__(
260
+ self,
261
+ dim,
262
+ num_heads,
263
+ mlp_ratio=4.0,
264
+ qkv_bias=False,
265
+ qk_scale=None,
266
+ drop=0.0,
267
+ attn_drop=0.0,
268
+ drop_path=0.0,
269
+ init_values=None,
270
+ act_layer=nn.GELU,
271
+ norm_layer=nn.LayerNorm,
272
+ attn_head_dim=None,
273
+ cos_attn=False,
274
+ ):
275
+ super().__init__()
276
+ self.norm1 = norm_layer(dim)
277
+ if cos_attn:
278
+ self.attn = CosAttention(
279
+ dim,
280
+ num_heads=num_heads,
281
+ qkv_bias=qkv_bias,
282
+ qk_scale=qk_scale,
283
+ attn_drop=attn_drop,
284
+ proj_drop=drop,
285
+ attn_head_dim=attn_head_dim,
286
+ )
287
+ else:
288
+ self.attn = Attention(
289
+ dim,
290
+ num_heads=num_heads,
291
+ qkv_bias=qkv_bias,
292
+ qk_scale=qk_scale,
293
+ attn_drop=attn_drop,
294
+ proj_drop=drop,
295
+ attn_head_dim=attn_head_dim,
296
+ )
297
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
298
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
299
+ self.norm2 = norm_layer(dim)
300
+ mlp_hidden_dim = int(dim * mlp_ratio)
301
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
302
+
303
+ if init_values > 0:
304
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
305
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
306
+ else:
307
+ self.gamma_1, self.gamma_2 = None, None
308
+
309
+ def forward(self, x):
310
+ if self.gamma_1 is None:
311
+ x = x + self.drop_path(self.attn(self.norm1(x)))
312
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
313
+ else:
314
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
315
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
316
+ return x
317
+
318
+
319
+ class PatchEmbed(nn.Module):
320
+ """Image to Patch Embedding"""
321
+
322
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
323
+ super().__init__()
324
+ img_size = to_2tuple(img_size)
325
+ patch_size = to_2tuple(patch_size)
326
+ num_spatial_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
327
+ num_patches = num_spatial_patches * (num_frames // tubelet_size)
328
+
329
+ self.img_size = img_size
330
+ self.tubelet_size = tubelet_size
331
+ self.patch_size = patch_size
332
+ self.num_patches = num_patches
333
+ self.proj = nn.Conv3d(
334
+ in_channels=in_chans,
335
+ out_channels=embed_dim,
336
+ kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
337
+ stride=(self.tubelet_size, patch_size[0], patch_size[1]),
338
+ )
339
+
340
+ def forward(self, x, **kwargs):
341
+ B, C, T, H, W = x.shape
342
+ assert (
343
+ H == self.img_size[0] and W == self.img_size[1]
344
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
345
+ # b, c, l -> b, l, c
346
+ # [1, 1408, 8, 16, 16] -> [1, 1408, 2048] -> [1, 2048, 1408]
347
+ x = self.proj(x).flatten(2).transpose(1, 2)
348
+ return x
349
+
350
+
351
+ # sin-cos position encoding
352
+ # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
353
+ def get_sinusoid_encoding_table(n_position, d_hid):
354
+ """Sinusoid position encoding table"""
355
+
356
+ # TODO: make it with torch instead of numpy
357
+ def get_position_angle_vec(position):
358
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
359
+
360
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
361
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
362
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
363
+
364
+ return torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
365
+
366
+
367
+ class VisionTransformer(nn.Module):
368
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
369
+
370
+ def __init__(
371
+ self,
372
+ img_size=224,
373
+ patch_size=16,
374
+ in_chans=3,
375
+ num_classes=1000,
376
+ embed_dim=768,
377
+ depth=12,
378
+ num_heads=12,
379
+ mlp_ratio=4.0,
380
+ qkv_bias=False,
381
+ qk_scale=None,
382
+ drop_rate=0.0,
383
+ attn_drop_rate=0.0,
384
+ drop_path_rate=0.0,
385
+ head_drop_rate=0.0,
386
+ norm_layer=nn.LayerNorm,
387
+ init_values=0.0,
388
+ use_learnable_pos_emb=False,
389
+ init_scale=0.0,
390
+ all_frames=16,
391
+ tubelet_size=2,
392
+ use_mean_pooling=True,
393
+ with_cp=False,
394
+ cos_attn=False,
395
+ ):
396
+ super().__init__()
397
+ self.num_classes = num_classes
398
+ # num_features for consistency with other models
399
+ self.num_features = self.embed_dim = embed_dim
400
+ self.tubelet_size = tubelet_size
401
+ self.patch_embed = PatchEmbed(
402
+ img_size=img_size,
403
+ patch_size=patch_size,
404
+ in_chans=in_chans,
405
+ embed_dim=embed_dim,
406
+ num_frames=all_frames,
407
+ tubelet_size=tubelet_size,
408
+ )
409
+ num_patches = self.patch_embed.num_patches
410
+ self.with_cp = with_cp
411
+
412
+ if use_learnable_pos_emb:
413
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
414
+ else:
415
+ # sine-cosine positional embeddings is on the way
416
+ self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
417
+
418
+ self.pos_drop = nn.Dropout(p=drop_rate)
419
+
420
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
421
+ self.blocks = nn.ModuleList(
422
+ [
423
+ Block(
424
+ dim=embed_dim,
425
+ num_heads=num_heads,
426
+ mlp_ratio=mlp_ratio,
427
+ qkv_bias=qkv_bias,
428
+ qk_scale=qk_scale,
429
+ drop=drop_rate,
430
+ attn_drop=attn_drop_rate,
431
+ drop_path=dpr[i],
432
+ norm_layer=norm_layer,
433
+ init_values=init_values,
434
+ cos_attn=cos_attn,
435
+ )
436
+ for i in range(depth)
437
+ ]
438
+ )
439
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
440
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
441
+ self.head_dropout = nn.Dropout(head_drop_rate)
442
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
443
+
444
+ if use_learnable_pos_emb:
445
+ trunc_normal_(self.pos_embed, std=0.02)
446
+
447
+ self.apply(self._init_weights)
448
+
449
+ self.head.weight.data.mul_(init_scale)
450
+ self.head.bias.data.mul_(init_scale)
451
+ self.num_frames = all_frames
452
+
453
+ def _init_weights(self, m):
454
+ if isinstance(m, nn.Linear):
455
+ trunc_normal_(m.weight, std=0.02)
456
+ if isinstance(m, nn.Linear) and m.bias is not None:
457
+ nn.init.constant_(m.bias, 0)
458
+ elif isinstance(m, nn.LayerNorm):
459
+ nn.init.constant_(m.bias, 0)
460
+ nn.init.constant_(m.weight, 1.0)
461
+
462
+ def get_num_layers(self):
463
+ return len(self.blocks)
464
+
465
+ @torch.jit.ignore
466
+ def no_weight_decay(self):
467
+ return {"pos_embed", "cls_token"}
468
+
469
+ def get_classifier(self):
470
+ return self.head
471
+
472
+ def reset_classifier(self, num_classes, global_pool=""):
473
+ self.num_classes = num_classes
474
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
475
+
476
+ def interpolate_pos_encoding(self, t):
477
+ T = 8
478
+ t0 = t // self.tubelet_size
479
+ if T == t0:
480
+ return self.pos_embed
481
+ dim = self.pos_embed.shape[-1]
482
+ patch_pos_embed = self.pos_embed.permute(0, 2, 1).reshape(1, dim, 8, 16, 16)
483
+ # we add a small number to avoid floating point error in the interpolation
484
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
485
+ t0 = t0 + 0.1
486
+ patch_pos_embed = nn.functional.interpolate(
487
+ patch_pos_embed,
488
+ scale_factor=(t0 / T, 1, 1),
489
+ mode="trilinear",
490
+ )
491
+ assert int(t0) == patch_pos_embed.shape[-3]
492
+ patch_pos_embed = patch_pos_embed.reshape(1, dim, -1).permute(0, 2, 1)
493
+ return patch_pos_embed
494
+
495
+ def forward_features(self, x):
496
+ # [1, 3, 16, 224, 224]
497
+ B = x.size(0)
498
+ T = x.size(2)
499
+
500
+ # [1, 2048, 1408]
501
+ x = self.patch_embed(x)
502
+
503
+ if self.pos_embed is not None:
504
+ x = x + self.interpolate_pos_encoding(T).expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
505
+ x = self.pos_drop(x)
506
+
507
+ for blk in self.blocks:
508
+ if self.with_cp:
509
+ x = cp.checkpoint(blk, x)
510
+ else:
511
+ x = blk(x)
512
+
513
+ # return self.fc_norm(x)
514
+
515
+ if self.fc_norm is not None:
516
+ return self.fc_norm(x.mean(1))
517
+ else:
518
+ return self.norm(x[:, 0])
519
+
520
+ def forward(self, x):
521
+ x = self.forward_features(x)
522
+ x = self.head_dropout(x)
523
+ x = self.head(x)
524
+ return x
525
+
526
+
527
+ def vit_giant_patch14_224(pretrained=False, **kwargs):
528
+ model = VisionTransformer(
529
+ patch_size=14,
530
+ embed_dim=1408,
531
+ depth=40,
532
+ num_heads=16,
533
+ mlp_ratio=48 / 11,
534
+ qkv_bias=True,
535
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
536
+ **kwargs,
537
+ )
538
+ model.default_cfg = _cfg()
539
+ return model
latentsync/trepa/third_party/VideoMAEv2/videomaev2_pretrain.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Based on BEiT, timm, DINO and DeiT code bases
3
+ # https://github.com/microsoft/unilm/tree/master/beit
4
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
5
+ # https://github.com/facebookresearch/deit
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ from functools import partial
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.utils.checkpoint as cp
13
+
14
+ from .videomaev2_finetune import (
15
+ Block,
16
+ PatchEmbed,
17
+ _cfg,
18
+ get_sinusoid_encoding_table,
19
+ )
20
+
21
+ from .videomaev2_finetune import trunc_normal_ as __call_trunc_normal_
22
+
23
+ def trunc_normal_(tensor, mean=0., std=1.):
24
+ __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
25
+
26
+
27
+ class PretrainVisionTransformerEncoder(nn.Module):
28
+ """ Vision Transformer with support for patch or hybrid CNN input stage
29
+ """
30
+
31
+ def __init__(self,
32
+ img_size=224,
33
+ patch_size=16,
34
+ in_chans=3,
35
+ num_classes=0,
36
+ embed_dim=768,
37
+ depth=12,
38
+ num_heads=12,
39
+ mlp_ratio=4.,
40
+ qkv_bias=False,
41
+ qk_scale=None,
42
+ drop_rate=0.,
43
+ attn_drop_rate=0.,
44
+ drop_path_rate=0.,
45
+ norm_layer=nn.LayerNorm,
46
+ init_values=None,
47
+ tubelet_size=2,
48
+ use_learnable_pos_emb=False,
49
+ with_cp=False,
50
+ all_frames=16,
51
+ cos_attn=False):
52
+ super().__init__()
53
+ self.num_classes = num_classes
54
+ # num_features for consistency with other models
55
+ self.num_features = self.embed_dim = embed_dim
56
+ self.patch_embed = PatchEmbed(
57
+ img_size=img_size,
58
+ patch_size=patch_size,
59
+ in_chans=in_chans,
60
+ embed_dim=embed_dim,
61
+ num_frames=all_frames,
62
+ tubelet_size=tubelet_size)
63
+ num_patches = self.patch_embed.num_patches
64
+ self.with_cp = with_cp
65
+
66
+ if use_learnable_pos_emb:
67
+ self.pos_embed = nn.Parameter(
68
+ torch.zeros(1, num_patches + 1, embed_dim))
69
+ else:
70
+ # sine-cosine positional embeddings
71
+ self.pos_embed = get_sinusoid_encoding_table(
72
+ num_patches, embed_dim)
73
+
74
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
75
+ ] # stochastic depth decay rule
76
+ self.blocks = nn.ModuleList([
77
+ Block(
78
+ dim=embed_dim,
79
+ num_heads=num_heads,
80
+ mlp_ratio=mlp_ratio,
81
+ qkv_bias=qkv_bias,
82
+ qk_scale=qk_scale,
83
+ drop=drop_rate,
84
+ attn_drop=attn_drop_rate,
85
+ drop_path=dpr[i],
86
+ norm_layer=norm_layer,
87
+ init_values=init_values,
88
+ cos_attn=cos_attn) for i in range(depth)
89
+ ])
90
+ self.norm = norm_layer(embed_dim)
91
+ self.head = nn.Linear(
92
+ embed_dim, num_classes) if num_classes > 0 else nn.Identity()
93
+
94
+ if use_learnable_pos_emb:
95
+ trunc_normal_(self.pos_embed, std=.02)
96
+
97
+ self.apply(self._init_weights)
98
+
99
+ def _init_weights(self, m):
100
+ if isinstance(m, nn.Linear):
101
+ nn.init.xavier_uniform_(m.weight)
102
+ if isinstance(m, nn.Linear) and m.bias is not None:
103
+ nn.init.constant_(m.bias, 0)
104
+ elif isinstance(m, nn.LayerNorm):
105
+ nn.init.constant_(m.bias, 0)
106
+ nn.init.constant_(m.weight, 1.0)
107
+
108
+ def get_num_layers(self):
109
+ return len(self.blocks)
110
+
111
+ @torch.jit.ignore
112
+ def no_weight_decay(self):
113
+ return {'pos_embed', 'cls_token'}
114
+
115
+ def get_classifier(self):
116
+ return self.head
117
+
118
+ def reset_classifier(self, num_classes, global_pool=''):
119
+ self.num_classes = num_classes
120
+ self.head = nn.Linear(
121
+ self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
122
+
123
+ def forward_features(self, x, mask):
124
+ x = self.patch_embed(x)
125
+
126
+ x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
127
+
128
+ B, _, C = x.shape
129
+ x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
130
+
131
+ for blk in self.blocks:
132
+ if self.with_cp:
133
+ x_vis = cp.checkpoint(blk, x_vis)
134
+ else:
135
+ x_vis = blk(x_vis)
136
+
137
+ x_vis = self.norm(x_vis)
138
+ return x_vis
139
+
140
+ def forward(self, x, mask):
141
+ x = self.forward_features(x, mask)
142
+ x = self.head(x)
143
+ return x
144
+
145
+
146
+ class PretrainVisionTransformerDecoder(nn.Module):
147
+ """ Vision Transformer with support for patch or hybrid CNN input stage
148
+ """
149
+
150
+ def __init__(self,
151
+ patch_size=16,
152
+ num_classes=768,
153
+ embed_dim=768,
154
+ depth=12,
155
+ num_heads=12,
156
+ mlp_ratio=4.,
157
+ qkv_bias=False,
158
+ qk_scale=None,
159
+ drop_rate=0.,
160
+ attn_drop_rate=0.,
161
+ drop_path_rate=0.,
162
+ norm_layer=nn.LayerNorm,
163
+ init_values=None,
164
+ num_patches=196,
165
+ tubelet_size=2,
166
+ with_cp=False,
167
+ cos_attn=False):
168
+ super().__init__()
169
+ self.num_classes = num_classes
170
+ assert num_classes == 3 * tubelet_size * patch_size**2
171
+ # num_features for consistency with other models
172
+ self.num_features = self.embed_dim = embed_dim
173
+ self.patch_size = patch_size
174
+ self.with_cp = with_cp
175
+
176
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
177
+ ] # stochastic depth decay rule
178
+ self.blocks = nn.ModuleList([
179
+ Block(
180
+ dim=embed_dim,
181
+ num_heads=num_heads,
182
+ mlp_ratio=mlp_ratio,
183
+ qkv_bias=qkv_bias,
184
+ qk_scale=qk_scale,
185
+ drop=drop_rate,
186
+ attn_drop=attn_drop_rate,
187
+ drop_path=dpr[i],
188
+ norm_layer=norm_layer,
189
+ init_values=init_values,
190
+ cos_attn=cos_attn) for i in range(depth)
191
+ ])
192
+ self.norm = norm_layer(embed_dim)
193
+ self.head = nn.Linear(
194
+ embed_dim, num_classes) if num_classes > 0 else nn.Identity()
195
+
196
+ self.apply(self._init_weights)
197
+
198
+ def _init_weights(self, m):
199
+ if isinstance(m, nn.Linear):
200
+ nn.init.xavier_uniform_(m.weight)
201
+ if isinstance(m, nn.Linear) and m.bias is not None:
202
+ nn.init.constant_(m.bias, 0)
203
+ elif isinstance(m, nn.LayerNorm):
204
+ nn.init.constant_(m.bias, 0)
205
+ nn.init.constant_(m.weight, 1.0)
206
+
207
+ def get_num_layers(self):
208
+ return len(self.blocks)
209
+
210
+ @torch.jit.ignore
211
+ def no_weight_decay(self):
212
+ return {'pos_embed', 'cls_token'}
213
+
214
+ def get_classifier(self):
215
+ return self.head
216
+
217
+ def reset_classifier(self, num_classes, global_pool=''):
218
+ self.num_classes = num_classes
219
+ self.head = nn.Linear(
220
+ self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
221
+
222
+ def forward(self, x, return_token_num):
223
+ for blk in self.blocks:
224
+ if self.with_cp:
225
+ x = cp.checkpoint(blk, x)
226
+ else:
227
+ x = blk(x)
228
+
229
+ if return_token_num > 0:
230
+ # only return the mask tokens predict pixels
231
+ x = self.head(self.norm(x[:, -return_token_num:]))
232
+ else:
233
+ # [B, N, 3*16^2]
234
+ x = self.head(self.norm(x))
235
+ return x
236
+
237
+
238
+ class PretrainVisionTransformer(nn.Module):
239
+ """ Vision Transformer with support for patch or hybrid CNN input stage
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ img_size=224,
245
+ patch_size=16,
246
+ encoder_in_chans=3,
247
+ encoder_num_classes=0,
248
+ encoder_embed_dim=768,
249
+ encoder_depth=12,
250
+ encoder_num_heads=12,
251
+ decoder_num_classes=1536, # decoder_num_classes=768
252
+ decoder_embed_dim=512,
253
+ decoder_depth=8,
254
+ decoder_num_heads=8,
255
+ mlp_ratio=4.,
256
+ qkv_bias=False,
257
+ qk_scale=None,
258
+ drop_rate=0.,
259
+ attn_drop_rate=0.,
260
+ drop_path_rate=0.,
261
+ norm_layer=nn.LayerNorm,
262
+ init_values=0.,
263
+ use_learnable_pos_emb=False,
264
+ tubelet_size=2,
265
+ num_classes=0, # avoid the error from create_fn in timm
266
+ in_chans=0, # avoid the error from create_fn in timm
267
+ with_cp=False,
268
+ all_frames=16,
269
+ cos_attn=False,
270
+ ):
271
+ super().__init__()
272
+ self.encoder = PretrainVisionTransformerEncoder(
273
+ img_size=img_size,
274
+ patch_size=patch_size,
275
+ in_chans=encoder_in_chans,
276
+ num_classes=encoder_num_classes,
277
+ embed_dim=encoder_embed_dim,
278
+ depth=encoder_depth,
279
+ num_heads=encoder_num_heads,
280
+ mlp_ratio=mlp_ratio,
281
+ qkv_bias=qkv_bias,
282
+ qk_scale=qk_scale,
283
+ drop_rate=drop_rate,
284
+ attn_drop_rate=attn_drop_rate,
285
+ drop_path_rate=drop_path_rate,
286
+ norm_layer=norm_layer,
287
+ init_values=init_values,
288
+ tubelet_size=tubelet_size,
289
+ use_learnable_pos_emb=use_learnable_pos_emb,
290
+ with_cp=with_cp,
291
+ all_frames=all_frames,
292
+ cos_attn=cos_attn)
293
+
294
+ self.decoder = PretrainVisionTransformerDecoder(
295
+ patch_size=patch_size,
296
+ num_patches=self.encoder.patch_embed.num_patches,
297
+ num_classes=decoder_num_classes,
298
+ embed_dim=decoder_embed_dim,
299
+ depth=decoder_depth,
300
+ num_heads=decoder_num_heads,
301
+ mlp_ratio=mlp_ratio,
302
+ qkv_bias=qkv_bias,
303
+ qk_scale=qk_scale,
304
+ drop_rate=drop_rate,
305
+ attn_drop_rate=attn_drop_rate,
306
+ drop_path_rate=drop_path_rate,
307
+ norm_layer=norm_layer,
308
+ init_values=init_values,
309
+ tubelet_size=tubelet_size,
310
+ with_cp=with_cp,
311
+ cos_attn=cos_attn)
312
+
313
+ self.encoder_to_decoder = nn.Linear(
314
+ encoder_embed_dim, decoder_embed_dim, bias=False)
315
+
316
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
317
+
318
+ self.pos_embed = get_sinusoid_encoding_table(
319
+ self.encoder.patch_embed.num_patches, decoder_embed_dim)
320
+
321
+ trunc_normal_(self.mask_token, std=.02)
322
+
323
+ def _init_weights(self, m):
324
+ if isinstance(m, nn.Linear):
325
+ nn.init.xavier_uniform_(m.weight)
326
+ if isinstance(m, nn.Linear) and m.bias is not None:
327
+ nn.init.constant_(m.bias, 0)
328
+ elif isinstance(m, nn.LayerNorm):
329
+ nn.init.constant_(m.bias, 0)
330
+ nn.init.constant_(m.weight, 1.0)
331
+
332
+ def get_num_layers(self):
333
+ return len(self.blocks)
334
+
335
+ @torch.jit.ignore
336
+ def no_weight_decay(self):
337
+ return {'pos_embed', 'cls_token', 'mask_token'}
338
+
339
+ def forward(self, x, mask, decode_mask=None):
340
+ decode_vis = mask if decode_mask is None else ~decode_mask
341
+
342
+ x_vis = self.encoder(x, mask) # [B, N_vis, C_e]
343
+ x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
344
+ B, N_vis, C = x_vis.shape
345
+
346
+ # we don't unshuffle the correct visible token order,
347
+ # but shuffle the pos embedding accorddingly.
348
+ expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(
349
+ x.device).clone().detach()
350
+ pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
351
+ pos_emd_mask = expand_pos_embed[decode_vis].reshape(B, -1, C)
352
+
353
+ # [B, N, C_d]
354
+ x_full = torch.cat(
355
+ [x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)
356
+ # NOTE: if N_mask==0, the shape of x is [B, N_mask, 3 * 16 * 16]
357
+ x = self.decoder(x_full, pos_emd_mask.shape[1])
358
+
359
+ return x
360
+
361
+
362
+ def pretrain_videomae_small_patch16_224(pretrained=False, **kwargs):
363
+ model = PretrainVisionTransformer(
364
+ img_size=224,
365
+ patch_size=16,
366
+ encoder_embed_dim=384,
367
+ encoder_depth=12,
368
+ encoder_num_heads=6,
369
+ encoder_num_classes=0,
370
+ decoder_num_classes=1536, # 16 * 16 * 3 * 2
371
+ decoder_embed_dim=192,
372
+ decoder_num_heads=3,
373
+ mlp_ratio=4,
374
+ qkv_bias=True,
375
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
376
+ **kwargs)
377
+ model.default_cfg = _cfg()
378
+ if pretrained:
379
+ checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
380
+ model.load_state_dict(checkpoint["model"])
381
+ return model
382
+
383
+
384
+ def pretrain_videomae_base_patch16_224(pretrained=False, **kwargs):
385
+ model = PretrainVisionTransformer(
386
+ img_size=224,
387
+ patch_size=16,
388
+ encoder_embed_dim=768,
389
+ encoder_depth=12,
390
+ encoder_num_heads=12,
391
+ encoder_num_classes=0,
392
+ decoder_num_classes=1536, # 16 * 16 * 3 * 2
393
+ decoder_embed_dim=384,
394
+ decoder_num_heads=6,
395
+ mlp_ratio=4,
396
+ qkv_bias=True,
397
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
398
+ **kwargs)
399
+ model.default_cfg = _cfg()
400
+ if pretrained:
401
+ checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
402
+ model.load_state_dict(checkpoint["model"])
403
+ return model
404
+
405
+
406
+ def pretrain_videomae_large_patch16_224(pretrained=False, **kwargs):
407
+ model = PretrainVisionTransformer(
408
+ img_size=224,
409
+ patch_size=16,
410
+ encoder_embed_dim=1024,
411
+ encoder_depth=24,
412
+ encoder_num_heads=16,
413
+ encoder_num_classes=0,
414
+ decoder_num_classes=1536, # 16 * 16 * 3 * 2
415
+ decoder_embed_dim=512,
416
+ decoder_num_heads=8,
417
+ mlp_ratio=4,
418
+ qkv_bias=True,
419
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
420
+ **kwargs)
421
+ model.default_cfg = _cfg()
422
+ if pretrained:
423
+ checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
424
+ model.load_state_dict(checkpoint["model"])
425
+ return model
426
+
427
+
428
+ def pretrain_videomae_huge_patch16_224(pretrained=False, **kwargs):
429
+ model = PretrainVisionTransformer(
430
+ img_size=224,
431
+ patch_size=16,
432
+ encoder_embed_dim=1280,
433
+ encoder_depth=32,
434
+ encoder_num_heads=16,
435
+ encoder_num_classes=0,
436
+ decoder_num_classes=1536, # 16 * 16 * 3 * 2
437
+ decoder_embed_dim=512,
438
+ decoder_num_heads=8,
439
+ mlp_ratio=4,
440
+ qkv_bias=True,
441
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
442
+ **kwargs)
443
+ model.default_cfg = _cfg()
444
+ if pretrained:
445
+ checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
446
+ model.load_state_dict(checkpoint["model"])
447
+ return model
448
+
449
+
450
+ def pretrain_videomae_giant_patch14_224(pretrained=False, **kwargs):
451
+ model = PretrainVisionTransformer(
452
+ img_size=224,
453
+ patch_size=14,
454
+ encoder_embed_dim=1408,
455
+ encoder_depth=40,
456
+ encoder_num_heads=16,
457
+ encoder_num_classes=0,
458
+ decoder_num_classes=1176, # 14 * 14 * 3 * 2,
459
+ decoder_embed_dim=512,
460
+ decoder_num_heads=8,
461
+ mlp_ratio=48 / 11,
462
+ qkv_bias=True,
463
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
464
+ **kwargs)
465
+ model.default_cfg = _cfg()
466
+ if pretrained:
467
+ checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
468
+ model.load_state_dict(checkpoint["model"])
469
+ return model
latentsync/trepa/third_party/__init__.py ADDED
File without changes
latentsync/trepa/utils/__init__.py ADDED
File without changes
latentsync/trepa/utils/data_utils.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import os.path as osp
4
+ import random
5
+ import pickle
6
+ import warnings
7
+
8
+ import glob
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ import torch
13
+ import torch.utils.data as data
14
+ import torch.nn.functional as F
15
+ import torch.distributed as dist
16
+ from torchvision.datasets.video_utils import VideoClips
17
+
18
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
19
+ VID_EXTENSIONS = ['.avi', '.mp4', '.webm', '.mov', '.mkv', '.m4v']
20
+
21
+
22
+ def get_dataloader(data_path, image_folder, resolution=128, sequence_length=16, sample_every_n_frames=1,
23
+ batch_size=16, num_workers=8):
24
+ data = VideoData(data_path, image_folder, resolution, sequence_length, sample_every_n_frames, batch_size, num_workers)
25
+ loader = data._dataloader()
26
+ return loader
27
+
28
+
29
+ def is_image_file(filename):
30
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
31
+
32
+
33
+ def get_parent_dir(path):
34
+ return osp.basename(osp.dirname(path))
35
+
36
+
37
+ def preprocess(video, resolution, sequence_length=None, in_channels=3, sample_every_n_frames=1):
38
+ # video: THWC, {0, ..., 255}
39
+ assert in_channels == 3
40
+ video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
41
+ t, c, h, w = video.shape
42
+
43
+ # temporal crop
44
+ if sequence_length is not None:
45
+ assert sequence_length <= t
46
+ video = video[:sequence_length]
47
+
48
+ # skip frames
49
+ if sample_every_n_frames > 1:
50
+ video = video[::sample_every_n_frames]
51
+
52
+ # scale shorter side to resolution
53
+ scale = resolution / min(h, w)
54
+ if h < w:
55
+ target_size = (resolution, math.ceil(w * scale))
56
+ else:
57
+ target_size = (math.ceil(h * scale), resolution)
58
+ video = F.interpolate(video, size=target_size, mode='bilinear',
59
+ align_corners=False, antialias=True)
60
+
61
+ # center crop
62
+ t, c, h, w = video.shape
63
+ w_start = (w - resolution) // 2
64
+ h_start = (h - resolution) // 2
65
+ video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
66
+ video = video.permute(1, 0, 2, 3).contiguous() # CTHW
67
+
68
+ return {'video': video}
69
+
70
+
71
+ def preprocess_image(image):
72
+ # [0, 1] => [-1, 1]
73
+ img = torch.from_numpy(image)
74
+ return img
75
+
76
+
77
+ class VideoData(data.Dataset):
78
+ """ Class to create dataloaders for video datasets
79
+
80
+ Args:
81
+ data_path: Path to the folder with video frames or videos.
82
+ image_folder: If True, the data is stored as images in folders.
83
+ resolution: Resolution of the returned videos.
84
+ sequence_length: Length of extracted video sequences.
85
+ sample_every_n_frames: Sample every n frames from the video.
86
+ batch_size: Batch size.
87
+ num_workers: Number of workers for the dataloader.
88
+ shuffle: If True, shuffle the data.
89
+ """
90
+
91
+ def __init__(self, data_path: str, image_folder: bool, resolution: int, sequence_length: int,
92
+ sample_every_n_frames: int, batch_size: int, num_workers: int, shuffle: bool = True):
93
+ super().__init__()
94
+ self.data_path = data_path
95
+ self.image_folder = image_folder
96
+ self.resolution = resolution
97
+ self.sequence_length = sequence_length
98
+ self.sample_every_n_frames = sample_every_n_frames
99
+ self.batch_size = batch_size
100
+ self.num_workers = num_workers
101
+ self.shuffle = shuffle
102
+
103
+ def _dataset(self):
104
+ '''
105
+ Initializes and return the dataset.
106
+ '''
107
+ if self.image_folder:
108
+ Dataset = FrameDataset
109
+ dataset = Dataset(self.data_path, self.sequence_length,
110
+ resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames)
111
+ else:
112
+ Dataset = VideoDataset
113
+ dataset = Dataset(self.data_path, self.sequence_length,
114
+ resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames)
115
+ return dataset
116
+
117
+ def _dataloader(self):
118
+ '''
119
+ Initializes and returns the dataloader.
120
+ '''
121
+ dataset = self._dataset()
122
+ if dist.is_initialized():
123
+ sampler = data.distributed.DistributedSampler(
124
+ dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()
125
+ )
126
+ else:
127
+ sampler = None
128
+ dataloader = data.DataLoader(
129
+ dataset,
130
+ batch_size=self.batch_size,
131
+ num_workers=self.num_workers,
132
+ pin_memory=True,
133
+ sampler=sampler,
134
+ shuffle=sampler is None and self.shuffle is True
135
+ )
136
+ return dataloader
137
+
138
+
139
+ class VideoDataset(data.Dataset):
140
+ """
141
+ Generic dataset for videos files stored in folders.
142
+ Videos of the same class are expected to be stored in a single folder. Multiple folders can exist in the provided directory.
143
+ The class depends on `torchvision.datasets.video_utils.VideoClips` to load the videos.
144
+ Returns BCTHW videos in the range [0, 1].
145
+
146
+ Args:
147
+ data_folder: Path to the folder with corresponding videos stored.
148
+ sequence_length: Length of extracted video sequences.
149
+ resolution: Resolution of the returned videos.
150
+ sample_every_n_frames: Sample every n frames from the video.
151
+ """
152
+
153
+ def __init__(self, data_folder: str, sequence_length: int = 16, resolution: int = 128, sample_every_n_frames: int = 1):
154
+ super().__init__()
155
+ self.sequence_length = sequence_length
156
+ self.resolution = resolution
157
+ self.sample_every_n_frames = sample_every_n_frames
158
+
159
+ folder = data_folder
160
+ files = sum([glob.glob(osp.join(folder, '**', f'*{ext}'), recursive=True)
161
+ for ext in VID_EXTENSIONS], [])
162
+
163
+ warnings.filterwarnings('ignore')
164
+ cache_file = osp.join(folder, f"metadata_{sequence_length}.pkl")
165
+ if not osp.exists(cache_file):
166
+ clips = VideoClips(files, sequence_length, num_workers=4)
167
+ try:
168
+ pickle.dump(clips.metadata, open(cache_file, 'wb'))
169
+ except:
170
+ print(f"Failed to save metadata to {cache_file}")
171
+ else:
172
+ metadata = pickle.load(open(cache_file, 'rb'))
173
+ clips = VideoClips(files, sequence_length,
174
+ _precomputed_metadata=metadata)
175
+
176
+ self._clips = clips
177
+ # instead of uniformly sampling from all possible clips, we sample uniformly from all possible videos
178
+ self._clips.get_clip_location = self.get_random_clip_from_video
179
+
180
+ def get_random_clip_from_video(self, idx: int) -> tuple:
181
+ '''
182
+ Sample a random clip starting index from the video.
183
+
184
+ Args:
185
+ idx: Index of the video.
186
+ '''
187
+ # Note that some videos may not contain enough frames, we skip those videos here.
188
+ while self._clips.clips[idx].shape[0] <= 0:
189
+ idx += 1
190
+ n_clip = self._clips.clips[idx].shape[0]
191
+ clip_id = random.randint(0, n_clip - 1)
192
+ return idx, clip_id
193
+
194
+ def __len__(self):
195
+ return self._clips.num_videos()
196
+
197
+ def __getitem__(self, idx):
198
+ resolution = self.resolution
199
+ while True:
200
+ try:
201
+ video, _, _, idx = self._clips.get_clip(idx)
202
+ except Exception as e:
203
+ print(idx, e)
204
+ idx = (idx + 1) % self._clips.num_clips()
205
+ continue
206
+ break
207
+
208
+ return dict(**preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames))
209
+
210
+
211
+ class FrameDataset(data.Dataset):
212
+ """
213
+ Generic dataset for videos stored as images. The loading will iterates over all the folders and subfolders
214
+ in the provided directory. Each leaf folder is assumed to contain frames from a single video.
215
+
216
+ Args:
217
+ data_folder: path to the folder with video frames. The folder
218
+ should contain folders with frames from each video.
219
+ sequence_length: length of extracted video sequences
220
+ resolution: resolution of the returned videos
221
+ sample_every_n_frames: sample every n frames from the video
222
+ """
223
+
224
+ def __init__(self, data_folder, sequence_length, resolution=64, sample_every_n_frames=1):
225
+ self.resolution = resolution
226
+ self.sequence_length = sequence_length
227
+ self.sample_every_n_frames = sample_every_n_frames
228
+ self.data_all = self.load_video_frames(data_folder)
229
+ self.video_num = len(self.data_all)
230
+
231
+ def __getitem__(self, index):
232
+ batch_data = self.getTensor(index)
233
+ return_list = {'video': batch_data}
234
+
235
+ return return_list
236
+
237
+ def load_video_frames(self, dataroot: str) -> list:
238
+ '''
239
+ Loads all the video frames under the dataroot and returns a list of all the video frames.
240
+
241
+ Args:
242
+ dataroot: The root directory containing the video frames.
243
+
244
+ Returns:
245
+ A list of all the video frames.
246
+
247
+ '''
248
+ data_all = []
249
+ frame_list = os.walk(dataroot)
250
+ for _, meta in enumerate(frame_list):
251
+ root = meta[0]
252
+ try:
253
+ frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
254
+ except:
255
+ print(meta[0], meta[2])
256
+ if len(frames) < max(0, self.sequence_length * self.sample_every_n_frames):
257
+ continue
258
+ frames = [
259
+ os.path.join(root, item) for item in frames
260
+ if is_image_file(item)
261
+ ]
262
+ if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames):
263
+ data_all.append(frames)
264
+
265
+ return data_all
266
+
267
+ def getTensor(self, index: int) -> torch.Tensor:
268
+ '''
269
+ Returns a tensor of the video frames at the given index.
270
+
271
+ Args:
272
+ index: The index of the video frames to return.
273
+
274
+ Returns:
275
+ A BCTHW tensor in the range `[0, 1]` of the video frames at the given index.
276
+
277
+ '''
278
+ video = self.data_all[index]
279
+ video_len = len(video)
280
+
281
+ # load the entire video when sequence_length = -1, whiel the sample_every_n_frames has to be 1
282
+ if self.sequence_length == -1:
283
+ assert self.sample_every_n_frames == 1
284
+ start_idx = 0
285
+ end_idx = video_len
286
+ else:
287
+ n_frames_interval = self.sequence_length * self.sample_every_n_frames
288
+ start_idx = random.randint(0, video_len - n_frames_interval)
289
+ end_idx = start_idx + n_frames_interval
290
+ img = Image.open(video[0])
291
+ h, w = img.height, img.width
292
+
293
+ if h > w:
294
+ half = (h - w) // 2
295
+ cropsize = (0, half, w, half + w) # left, upper, right, lower
296
+ elif w > h:
297
+ half = (w - h) // 2
298
+ cropsize = (half, 0, half + h, h)
299
+
300
+ images = []
301
+ for i in range(start_idx, end_idx,
302
+ self.sample_every_n_frames):
303
+ path = video[i]
304
+ img = Image.open(path)
305
+
306
+ if h != w:
307
+ img = img.crop(cropsize)
308
+
309
+ img = img.resize(
310
+ (self.resolution, self.resolution),
311
+ Image.ANTIALIAS)
312
+ img = np.asarray(img, dtype=np.float32)
313
+ img /= 255.
314
+ img_tensor = preprocess_image(img).unsqueeze(0)
315
+ images.append(img_tensor)
316
+
317
+ video_clip = torch.cat(images).permute(3, 0, 1, 2)
318
+ return video_clip
319
+
320
+ def __len__(self):
321
+ return self.video_num
latentsync/trepa/utils/metric_utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/universome/stylegan-v/blob/master/src/metrics/metric_utils.py
2
+ import os
3
+ import random
4
+ import torch
5
+ import pickle
6
+ import numpy as np
7
+
8
+ from typing import List, Tuple
9
+
10
+ def seed_everything(seed):
11
+ random.seed(seed)
12
+ os.environ['PYTHONHASHSEED'] = str(seed)
13
+ np.random.seed(seed)
14
+ torch.manual_seed(seed)
15
+ torch.cuda.manual_seed(seed)
16
+
17
+
18
+ class FeatureStats:
19
+ '''
20
+ Class to store statistics of features, including all features and mean/covariance.
21
+
22
+ Args:
23
+ capture_all: Whether to store all the features.
24
+ capture_mean_cov: Whether to store mean and covariance.
25
+ max_items: Maximum number of items to store.
26
+ '''
27
+ def __init__(self, capture_all: bool = False, capture_mean_cov: bool = False, max_items: int = None):
28
+ '''
29
+ '''
30
+ self.capture_all = capture_all
31
+ self.capture_mean_cov = capture_mean_cov
32
+ self.max_items = max_items
33
+ self.num_items = 0
34
+ self.num_features = None
35
+ self.all_features = None
36
+ self.raw_mean = None
37
+ self.raw_cov = None
38
+
39
+ def set_num_features(self, num_features: int):
40
+ '''
41
+ Set the number of features diminsions.
42
+
43
+ Args:
44
+ num_features: Number of features diminsions.
45
+ '''
46
+ if self.num_features is not None:
47
+ assert num_features == self.num_features
48
+ else:
49
+ self.num_features = num_features
50
+ self.all_features = []
51
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
52
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
53
+
54
+ def is_full(self) -> bool:
55
+ '''
56
+ Check if the maximum number of samples is reached.
57
+
58
+ Returns:
59
+ True if the storage is full, False otherwise.
60
+ '''
61
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
62
+
63
+ def append(self, x: np.ndarray):
64
+ '''
65
+ Add the newly computed features to the list. Update the mean and covariance.
66
+
67
+ Args:
68
+ x: New features to record.
69
+ '''
70
+ x = np.asarray(x, dtype=np.float32)
71
+ assert x.ndim == 2
72
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
73
+ if self.num_items >= self.max_items:
74
+ return
75
+ x = x[:self.max_items - self.num_items]
76
+
77
+ self.set_num_features(x.shape[1])
78
+ self.num_items += x.shape[0]
79
+ if self.capture_all:
80
+ self.all_features.append(x)
81
+ if self.capture_mean_cov:
82
+ x64 = x.astype(np.float64)
83
+ self.raw_mean += x64.sum(axis=0)
84
+ self.raw_cov += x64.T @ x64
85
+
86
+ def append_torch(self, x: torch.Tensor, rank: int, num_gpus: int):
87
+ '''
88
+ Add the newly computed PyTorch features to the list. Update the mean and covariance.
89
+
90
+ Args:
91
+ x: New features to record.
92
+ rank: Rank of the current GPU.
93
+ num_gpus: Total number of GPUs.
94
+ '''
95
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
96
+ assert 0 <= rank < num_gpus
97
+ if num_gpus > 1:
98
+ ys = []
99
+ for src in range(num_gpus):
100
+ y = x.clone()
101
+ torch.distributed.broadcast(y, src=src)
102
+ ys.append(y)
103
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
104
+ self.append(x.cpu().numpy())
105
+
106
+ def get_all(self) -> np.ndarray:
107
+ '''
108
+ Get all the stored features as NumPy Array.
109
+
110
+ Returns:
111
+ Concatenation of the stored features.
112
+ '''
113
+ assert self.capture_all
114
+ return np.concatenate(self.all_features, axis=0)
115
+
116
+ def get_all_torch(self) -> torch.Tensor:
117
+ '''
118
+ Get all the stored features as PyTorch Tensor.
119
+
120
+ Returns:
121
+ Concatenation of the stored features.
122
+ '''
123
+ return torch.from_numpy(self.get_all())
124
+
125
+ def get_mean_cov(self) -> Tuple[np.ndarray, np.ndarray]:
126
+ '''
127
+ Get the mean and covariance of the stored features.
128
+
129
+ Returns:
130
+ Mean and covariance of the stored features.
131
+ '''
132
+ assert self.capture_mean_cov
133
+ mean = self.raw_mean / self.num_items
134
+ cov = self.raw_cov / self.num_items
135
+ cov = cov - np.outer(mean, mean)
136
+ return mean, cov
137
+
138
+ def save(self, pkl_file: str):
139
+ '''
140
+ Save the features and statistics to a pickle file.
141
+
142
+ Args:
143
+ pkl_file: Path to the pickle file.
144
+ '''
145
+ with open(pkl_file, 'wb') as f:
146
+ pickle.dump(self.__dict__, f)
147
+
148
+ @staticmethod
149
+ def load(pkl_file: str) -> 'FeatureStats':
150
+ '''
151
+ Load the features and statistics from a pickle file.
152
+
153
+ Args:
154
+ pkl_file: Path to the pickle file.
155
+ '''
156
+ with open(pkl_file, 'rb') as f:
157
+ s = pickle.load(f)
158
+ obj = FeatureStats(capture_all=s['capture_all'], max_items=s['max_items'])
159
+ obj.__dict__.update(s)
160
+ print('Loaded %d features from %s' % (obj.num_items, pkl_file))
161
+ return obj
latentsync/utils/affine_transform.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guanjz20/StyleSync/blob/main/utils.py
2
+
3
+ import numpy as np
4
+ import cv2
5
+
6
+
7
+ def transformation_from_points(points1, points0, smooth=True, p_bias=None):
8
+ points2 = np.array(points0)
9
+ points2 = points2.astype(np.float64)
10
+ points1 = points1.astype(np.float64)
11
+ c1 = np.mean(points1, axis=0)
12
+ c2 = np.mean(points2, axis=0)
13
+ points1 -= c1
14
+ points2 -= c2
15
+ s1 = np.std(points1)
16
+ s2 = np.std(points2)
17
+ points1 /= s1
18
+ points2 /= s2
19
+ U, S, Vt = np.linalg.svd(np.matmul(points1.T, points2))
20
+ R = (np.matmul(U, Vt)).T
21
+ sR = (s2 / s1) * R
22
+ T = c2.reshape(2, 1) - (s2 / s1) * np.matmul(R, c1.reshape(2, 1))
23
+ M = np.concatenate((sR, T), axis=1)
24
+ if smooth:
25
+ bias = points2[2] - points1[2]
26
+ if p_bias is None:
27
+ p_bias = bias
28
+ else:
29
+ bias = p_bias * 0.2 + bias * 0.8
30
+ p_bias = bias
31
+ M[:, 2] = M[:, 2] + bias
32
+ return M, p_bias
33
+
34
+
35
+ class AlignRestore(object):
36
+ def __init__(self, align_points=3):
37
+ if align_points == 3:
38
+ self.upscale_factor = 1
39
+ self.crop_ratio = (2.8, 2.8)
40
+ self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]])
41
+ self.face_template = self.face_template * 2.8
42
+ # self.face_size = (int(100 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
43
+ self.face_size = (int(75 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
44
+ self.p_bias = None
45
+
46
+ def process(self, img, lmk_align=None, smooth=True, align_points=3):
47
+ aligned_face, affine_matrix = self.align_warp_face(img, lmk_align, smooth)
48
+ restored_img = self.restore_img(img, aligned_face, affine_matrix)
49
+ cv2.imwrite("restored.jpg", restored_img)
50
+ cv2.imwrite("aligned.jpg", aligned_face)
51
+ return aligned_face, restored_img
52
+
53
+ def align_warp_face(self, img, lmks3, smooth=True, border_mode="constant"):
54
+ affine_matrix, self.p_bias = transformation_from_points(lmks3, self.face_template, smooth, self.p_bias)
55
+ if border_mode == "constant":
56
+ border_mode = cv2.BORDER_CONSTANT
57
+ elif border_mode == "reflect101":
58
+ border_mode = cv2.BORDER_REFLECT101
59
+ elif border_mode == "reflect":
60
+ border_mode = cv2.BORDER_REFLECT
61
+ cropped_face = cv2.warpAffine(
62
+ img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=[127, 127, 127]
63
+ )
64
+ return cropped_face, affine_matrix
65
+
66
+ def align_warp_face2(self, img, landmark, border_mode="constant"):
67
+ affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template)[0]
68
+ if border_mode == "constant":
69
+ border_mode = cv2.BORDER_CONSTANT
70
+ elif border_mode == "reflect101":
71
+ border_mode = cv2.BORDER_REFLECT101
72
+ elif border_mode == "reflect":
73
+ border_mode = cv2.BORDER_REFLECT
74
+ cropped_face = cv2.warpAffine(
75
+ img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)
76
+ )
77
+ return cropped_face, affine_matrix
78
+
79
+ def restore_img(self, input_img, face, affine_matrix):
80
+ h, w, _ = input_img.shape
81
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
82
+ upsample_img = cv2.resize(input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
83
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
84
+ inverse_affine *= self.upscale_factor
85
+ if self.upscale_factor > 1:
86
+ extra_offset = 0.5 * self.upscale_factor
87
+ else:
88
+ extra_offset = 0
89
+ inverse_affine[:, 2] += extra_offset
90
+ inv_restored = cv2.warpAffine(face, inverse_affine, (w_up, h_up))
91
+ mask = np.ones((self.face_size[1], self.face_size[0]), dtype=np.float32)
92
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
93
+ inv_mask_erosion = cv2.erode(
94
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)
95
+ )
96
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
97
+ total_face_area = np.sum(inv_mask_erosion)
98
+ w_edge = int(total_face_area**0.5) // 20
99
+ erosion_radius = w_edge * 2
100
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
101
+ blur_size = w_edge * 2
102
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
103
+ inv_soft_mask = inv_soft_mask[:, :, None]
104
+ upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
105
+ if np.max(upsample_img) > 256:
106
+ upsample_img = upsample_img.astype(np.uint16)
107
+ else:
108
+ upsample_img = upsample_img.astype(np.uint8)
109
+ return upsample_img
110
+
111
+
112
+ class laplacianSmooth:
113
+ def __init__(self, smoothAlpha=0.3):
114
+ self.smoothAlpha = smoothAlpha
115
+ self.pts_last = None
116
+
117
+ def smooth(self, pts_cur):
118
+ if self.pts_last is None:
119
+ self.pts_last = pts_cur.copy()
120
+ return pts_cur.copy()
121
+ x1 = min(pts_cur[:, 0])
122
+ x2 = max(pts_cur[:, 0])
123
+ y1 = min(pts_cur[:, 1])
124
+ y2 = max(pts_cur[:, 1])
125
+ width = x2 - x1
126
+ pts_update = []
127
+ for i in range(len(pts_cur)):
128
+ x_new, y_new = pts_cur[i]
129
+ x_old, y_old = self.pts_last[i]
130
+ tmp = (x_new - x_old) ** 2 + (y_new - y_old) ** 2
131
+ w = np.exp(-tmp / (width * self.smoothAlpha))
132
+ x = x_old * w + x_new * (1 - w)
133
+ y = y_old * w + y_new * (1 - w)
134
+ pts_update.append([x, y])
135
+ pts_update = np.array(pts_update)
136
+ self.pts_last = pts_update.copy()
137
+
138
+ return pts_update
latentsync/utils/audio.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/Rudrabha/Wav2Lip/blob/master/audio.py
2
+
3
+ import librosa
4
+ import librosa.filters
5
+ import numpy as np
6
+ from scipy import signal
7
+ from scipy.io import wavfile
8
+ from omegaconf import OmegaConf
9
+ import torch
10
+
11
+ audio_config_path = "configs/audio.yaml"
12
+
13
+ config = OmegaConf.load(audio_config_path)
14
+
15
+
16
+ def load_wav(path, sr):
17
+ return librosa.core.load(path, sr=sr)[0]
18
+
19
+
20
+ def save_wav(wav, path, sr):
21
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
22
+ # proposed by @dsmiller
23
+ wavfile.write(path, sr, wav.astype(np.int16))
24
+
25
+
26
+ def save_wavenet_wav(wav, path, sr):
27
+ librosa.output.write_wav(path, wav, sr=sr)
28
+
29
+
30
+ def preemphasis(wav, k, preemphasize=True):
31
+ if preemphasize:
32
+ return signal.lfilter([1, -k], [1], wav)
33
+ return wav
34
+
35
+
36
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
37
+ if inv_preemphasize:
38
+ return signal.lfilter([1], [1, -k], wav)
39
+ return wav
40
+
41
+
42
+ def get_hop_size():
43
+ hop_size = config.audio.hop_size
44
+ if hop_size is None:
45
+ assert config.audio.frame_shift_ms is not None
46
+ hop_size = int(config.audio.frame_shift_ms / 1000 * config.audio.sample_rate)
47
+ return hop_size
48
+
49
+
50
+ def linearspectrogram(wav):
51
+ D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
52
+ S = _amp_to_db(np.abs(D)) - config.audio.ref_level_db
53
+
54
+ if config.audio.signal_normalization:
55
+ return _normalize(S)
56
+ return S
57
+
58
+
59
+ def melspectrogram(wav):
60
+ D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
61
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - config.audio.ref_level_db
62
+
63
+ if config.audio.signal_normalization:
64
+ return _normalize(S)
65
+ return S
66
+
67
+
68
+ def _lws_processor():
69
+ import lws
70
+
71
+ return lws.lws(config.audio.n_fft, get_hop_size(), fftsize=config.audio.win_size, mode="speech")
72
+
73
+
74
+ def _stft(y):
75
+ if config.audio.use_lws:
76
+ return _lws_processor(config.audio).stft(y).T
77
+ else:
78
+ return librosa.stft(y=y, n_fft=config.audio.n_fft, hop_length=get_hop_size(), win_length=config.audio.win_size)
79
+
80
+
81
+ ##########################################################
82
+ # Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
83
+ def num_frames(length, fsize, fshift):
84
+ """Compute number of time frames of spectrogram"""
85
+ pad = fsize - fshift
86
+ if length % fshift == 0:
87
+ M = (length + pad * 2 - fsize) // fshift + 1
88
+ else:
89
+ M = (length + pad * 2 - fsize) // fshift + 2
90
+ return M
91
+
92
+
93
+ def pad_lr(x, fsize, fshift):
94
+ """Compute left and right padding"""
95
+ M = num_frames(len(x), fsize, fshift)
96
+ pad = fsize - fshift
97
+ T = len(x) + 2 * pad
98
+ r = (M - 1) * fshift + fsize - T
99
+ return pad, pad + r
100
+
101
+
102
+ ##########################################################
103
+ # Librosa correct padding
104
+ def librosa_pad_lr(x, fsize, fshift):
105
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
106
+
107
+
108
+ # Conversions
109
+ _mel_basis = None
110
+
111
+
112
+ def _linear_to_mel(spectogram):
113
+ global _mel_basis
114
+ if _mel_basis is None:
115
+ _mel_basis = _build_mel_basis()
116
+ return np.dot(_mel_basis, spectogram)
117
+
118
+
119
+ def _build_mel_basis():
120
+ assert config.audio.fmax <= config.audio.sample_rate // 2
121
+ return librosa.filters.mel(
122
+ sr=config.audio.sample_rate,
123
+ n_fft=config.audio.n_fft,
124
+ n_mels=config.audio.num_mels,
125
+ fmin=config.audio.fmin,
126
+ fmax=config.audio.fmax,
127
+ )
128
+
129
+
130
+ def _amp_to_db(x):
131
+ min_level = np.exp(config.audio.min_level_db / 20 * np.log(10))
132
+ return 20 * np.log10(np.maximum(min_level, x))
133
+
134
+
135
+ def _db_to_amp(x):
136
+ return np.power(10.0, (x) * 0.05)
137
+
138
+
139
+ def _normalize(S):
140
+ if config.audio.allow_clipping_in_normalization:
141
+ if config.audio.symmetric_mels:
142
+ return np.clip(
143
+ (2 * config.audio.max_abs_value) * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
144
+ - config.audio.max_abs_value,
145
+ -config.audio.max_abs_value,
146
+ config.audio.max_abs_value,
147
+ )
148
+ else:
149
+ return np.clip(
150
+ config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db)),
151
+ 0,
152
+ config.audio.max_abs_value,
153
+ )
154
+
155
+ assert S.max() <= 0 and S.min() - config.audio.min_level_db >= 0
156
+ if config.audio.symmetric_mels:
157
+ return (2 * config.audio.max_abs_value) * (
158
+ (S - config.audio.min_level_db) / (-config.audio.min_level_db)
159
+ ) - config.audio.max_abs_value
160
+ else:
161
+ return config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
162
+
163
+
164
+ def _denormalize(D):
165
+ if config.audio.allow_clipping_in_normalization:
166
+ if config.audio.symmetric_mels:
167
+ return (
168
+ (np.clip(D, -config.audio.max_abs_value, config.audio.max_abs_value) + config.audio.max_abs_value)
169
+ * -config.audio.min_level_db
170
+ / (2 * config.audio.max_abs_value)
171
+ ) + config.audio.min_level_db
172
+ else:
173
+ return (
174
+ np.clip(D, 0, config.audio.max_abs_value) * -config.audio.min_level_db / config.audio.max_abs_value
175
+ ) + config.audio.min_level_db
176
+
177
+ if config.audio.symmetric_mels:
178
+ return (
179
+ (D + config.audio.max_abs_value) * -config.audio.min_level_db / (2 * config.audio.max_abs_value)
180
+ ) + config.audio.min_level_db
181
+ else:
182
+ return (D * -config.audio.min_level_db / config.audio.max_abs_value) + config.audio.min_level_db
183
+
184
+
185
+ def get_melspec_overlap(audio_samples, melspec_length=52):
186
+ mel_spec_overlap = melspectrogram(audio_samples.numpy())
187
+ mel_spec_overlap = torch.from_numpy(mel_spec_overlap)
188
+ i = 0
189
+ mel_spec_overlap_list = []
190
+ while i + melspec_length < mel_spec_overlap.shape[1] - 3:
191
+ mel_spec_overlap_list.append(mel_spec_overlap[:, i : i + melspec_length].unsqueeze(0))
192
+ i += 3
193
+ mel_spec_overlap = torch.stack(mel_spec_overlap_list)
194
+ return mel_spec_overlap
latentsync/utils/av_reader.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # We modified the original AVReader class of decord to solve the problem of memory leak.
2
+ # For more details, refer to: https://github.com/dmlc/decord/issues/208
3
+
4
+ import numpy as np
5
+ from decord.video_reader import VideoReader
6
+ from decord.audio_reader import AudioReader
7
+
8
+ from decord.ndarray import cpu
9
+ from decord import ndarray as _nd
10
+ from decord.bridge import bridge_out
11
+
12
+
13
+ class AVReader(object):
14
+ """Individual audio video reader with convenient indexing function.
15
+
16
+ Parameters
17
+ ----------
18
+ uri: str
19
+ Path of file.
20
+ ctx: decord.Context
21
+ The context to decode the file, can be decord.cpu() or decord.gpu().
22
+ sample_rate: int, default is -1
23
+ Desired output sample rate of the audio, unchanged if `-1` is specified.
24
+ mono: bool, default is True
25
+ Desired output channel layout of the audio. `True` is mono layout. `False` is unchanged.
26
+ width : int, default is -1
27
+ Desired output width of the video, unchanged if `-1` is specified.
28
+ height : int, default is -1
29
+ Desired output height of the video, unchanged if `-1` is specified.
30
+ num_threads : int, default is 0
31
+ Number of decoding thread, auto if `0` is specified.
32
+ fault_tol : int, default is -1
33
+ The threshold of corupted and recovered frames. This is to prevent silent fault
34
+ tolerance when for example 50% frames of a video cannot be decoded and duplicate
35
+ frames are returned. You may find the fault tolerant feature sweet in many cases,
36
+ but not for training models. Say `N = # recovered frames`
37
+ If `fault_tol` < 0, nothing will happen.
38
+ If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`, raise `DECORDLimitReachedError`.
39
+ If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
40
+ """
41
+
42
+ def __init__(
43
+ self, uri, ctx=cpu(0), sample_rate=44100, mono=True, width=-1, height=-1, num_threads=0, fault_tol=-1
44
+ ):
45
+ self.__audio_reader = AudioReader(uri, ctx, sample_rate, mono)
46
+ self.__audio_reader.add_padding()
47
+ if hasattr(uri, "read"):
48
+ uri.seek(0)
49
+ self.__video_reader = VideoReader(uri, ctx, width, height, num_threads, fault_tol)
50
+ self.__video_reader.seek(0)
51
+
52
+ def __len__(self):
53
+ """Get length of the video. Note that sometimes FFMPEG reports inaccurate number of frames,
54
+ we always follow what FFMPEG reports.
55
+ Returns
56
+ -------
57
+ int
58
+ The number of frames in the video file.
59
+ """
60
+ return len(self.__video_reader)
61
+
62
+ def __getitem__(self, idx):
63
+ """Get audio samples and video frame at `idx`.
64
+
65
+ Parameters
66
+ ----------
67
+ idx : int or slice
68
+ The frame index, can be negative which means it will index backwards,
69
+ or slice of frame indices.
70
+
71
+ Returns
72
+ -------
73
+ (ndarray/list of ndarray, ndarray)
74
+ First element is samples of shape CxS or a list of length N containing samples of shape CxS,
75
+ where N is the number of frames, C is the number of channels,
76
+ S is the number of samples of the corresponding frame.
77
+
78
+ Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
79
+ where N is the length of the slice.
80
+ """
81
+ assert self.__video_reader is not None and self.__audio_reader is not None
82
+ if isinstance(idx, slice):
83
+ return self.get_batch(range(*idx.indices(len(self.__video_reader))))
84
+ if idx < 0:
85
+ idx += len(self.__video_reader)
86
+ if idx >= len(self.__video_reader) or idx < 0:
87
+ raise IndexError("Index: {} out of bound: {}".format(idx, len(self.__video_reader)))
88
+ audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
89
+ audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
90
+ audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
91
+ results = (self.__audio_reader[audio_start_idx:audio_end_idx], self.__video_reader[idx])
92
+ self.__video_reader.seek(0)
93
+ return results
94
+
95
+ def get_batch(self, indices):
96
+ """Get entire batch of audio samples and video frames.
97
+
98
+ Parameters
99
+ ----------
100
+ indices : list of integers
101
+ A list of frame indices. If negative indices detected, the indices will be indexed from backward
102
+ Returns
103
+ -------
104
+ (list of ndarray, ndarray)
105
+ First element is a list of length N containing samples of shape CxS,
106
+ where N is the number of frames, C is the number of channels,
107
+ S is the number of samples of the corresponding frame.
108
+
109
+ Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
110
+ where N is the length of the slice.
111
+
112
+ """
113
+ assert self.__video_reader is not None and self.__audio_reader is not None
114
+ indices = self._validate_indices(indices)
115
+ audio_arr = []
116
+ prev_video_idx = None
117
+ prev_audio_end_idx = None
118
+ for idx in list(indices):
119
+ frame_start_time, frame_end_time = self.__video_reader.get_frame_timestamp(idx)
120
+ # timestamp and sample conversion could have some error that could cause non-continuous audio
121
+ # we detect if retrieving continuous frame and make the audio continuous
122
+ if prev_video_idx and idx == prev_video_idx + 1:
123
+ audio_start_idx = prev_audio_end_idx
124
+ else:
125
+ audio_start_idx = self.__audio_reader._time_to_sample(frame_start_time)
126
+ audio_end_idx = self.__audio_reader._time_to_sample(frame_end_time)
127
+ audio_arr.append(self.__audio_reader[audio_start_idx:audio_end_idx])
128
+ prev_video_idx = idx
129
+ prev_audio_end_idx = audio_end_idx
130
+ results = (audio_arr, self.__video_reader.get_batch(indices))
131
+ self.__video_reader.seek(0)
132
+ return results
133
+
134
+ def _get_slice(self, sl):
135
+ audio_arr = np.empty(shape=(self.__audio_reader.shape()[0], 0), dtype="float32")
136
+ for idx in list(sl):
137
+ audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
138
+ audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
139
+ audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
140
+ audio_arr = np.concatenate(
141
+ (audio_arr, self.__audio_reader[audio_start_idx:audio_end_idx].asnumpy()), axis=1
142
+ )
143
+ results = (bridge_out(_nd.array(audio_arr)), self.__video_reader.get_batch(sl))
144
+ self.__video_reader.seek(0)
145
+ return results
146
+
147
+ def _validate_indices(self, indices):
148
+ """Validate int64 integers and convert negative integers to positive by backward search"""
149
+ assert self.__video_reader is not None and self.__audio_reader is not None
150
+ indices = np.array(indices, dtype=np.int64)
151
+ # process negative indices
152
+ indices[indices < 0] += len(self.__video_reader)
153
+ if not (indices >= 0).all():
154
+ raise IndexError("Invalid negative indices: {}".format(indices[indices < 0] + len(self.__video_reader)))
155
+ if not (indices < len(self.__video_reader)).all():
156
+ raise IndexError("Out of bound indices: {}".format(indices[indices >= len(self.__video_reader)]))
157
+ return indices
latentsync/utils/image_processor.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from torchvision import transforms
16
+ import cv2
17
+ from einops import rearrange
18
+ import mediapipe as mp
19
+ import torch
20
+ import numpy as np
21
+ from typing import Union
22
+ from .affine_transform import AlignRestore, laplacianSmooth
23
+ import face_alignment
24
+
25
+ """
26
+ If you are enlarging the image, you should prefer to use INTER_LINEAR or INTER_CUBIC interpolation. If you are shrinking the image, you should prefer to use INTER_AREA interpolation.
27
+ https://stackoverflow.com/questions/23853632/which-kind-of-interpolation-best-for-resizing-image
28
+ """
29
+
30
+
31
+ def load_fixed_mask(resolution: int) -> torch.Tensor:
32
+ mask_image = cv2.imread("latentsync/utils/mask.png")
33
+ mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
34
+ mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0
35
+ mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
36
+ return mask_image
37
+
38
+
39
+ class ImageProcessor:
40
+ def __init__(self, resolution: int = 512, mask: str = "fix_mask", device: str = "cpu", mask_image=None):
41
+ self.resolution = resolution
42
+ self.resize = transforms.Resize(
43
+ (resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True
44
+ )
45
+ self.normalize = transforms.Normalize([0.5], [0.5], inplace=True)
46
+ self.mask = mask
47
+
48
+ if mask in ["mouth", "face", "eye"]:
49
+ self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
50
+ if mask == "fix_mask":
51
+ self.face_mesh = None
52
+ self.smoother = laplacianSmooth()
53
+ self.restorer = AlignRestore()
54
+
55
+ if mask_image is None:
56
+ self.mask_image = load_fixed_mask(resolution)
57
+ else:
58
+ self.mask_image = mask_image
59
+
60
+ if device != "cpu":
61
+ self.fa = face_alignment.FaceAlignment(
62
+ face_alignment.LandmarksType.TWO_D, flip_input=False, device=device
63
+ )
64
+ self.face_mesh = None
65
+ else:
66
+ # self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
67
+ self.face_mesh = None
68
+ self.fa = None
69
+
70
+ def detect_facial_landmarks(self, image: np.ndarray):
71
+ height, width, _ = image.shape
72
+ results = self.face_mesh.process(image)
73
+ if not results.multi_face_landmarks: # Face not detected
74
+ raise RuntimeError("Face not detected")
75
+ face_landmarks = results.multi_face_landmarks[0] # Only use the first face in the image
76
+ landmark_coordinates = [
77
+ (int(landmark.x * width), int(landmark.y * height)) for landmark in face_landmarks.landmark
78
+ ] # x means width, y means height
79
+ return landmark_coordinates
80
+
81
+ def preprocess_one_masked_image(self, image: torch.Tensor) -> np.ndarray:
82
+ image = self.resize(image)
83
+
84
+ if self.mask == "mouth" or self.mask == "face":
85
+ landmark_coordinates = self.detect_facial_landmarks(image)
86
+ if self.mask == "mouth":
87
+ surround_landmarks = mouth_surround_landmarks
88
+ else:
89
+ surround_landmarks = face_surround_landmarks
90
+
91
+ points = [landmark_coordinates[landmark] for landmark in surround_landmarks]
92
+ points = np.array(points)
93
+ mask = np.ones((self.resolution, self.resolution))
94
+ mask = cv2.fillPoly(mask, pts=[points], color=(0, 0, 0))
95
+ mask = torch.from_numpy(mask)
96
+ mask = mask.unsqueeze(0)
97
+ elif self.mask == "half":
98
+ mask = torch.ones((self.resolution, self.resolution))
99
+ height = mask.shape[0]
100
+ mask[height // 2 :, :] = 0
101
+ mask = mask.unsqueeze(0)
102
+ elif self.mask == "eye":
103
+ mask = torch.ones((self.resolution, self.resolution))
104
+ landmark_coordinates = self.detect_facial_landmarks(image)
105
+ y = landmark_coordinates[195][1]
106
+ mask[y:, :] = 0
107
+ mask = mask.unsqueeze(0)
108
+ else:
109
+ raise ValueError("Invalid mask type")
110
+
111
+ image = image.to(dtype=torch.float32)
112
+ pixel_values = self.normalize(image / 255.0)
113
+ masked_pixel_values = pixel_values * mask
114
+ mask = 1 - mask
115
+
116
+ return pixel_values, masked_pixel_values, mask
117
+
118
+ def affine_transform(self, image: torch.Tensor) -> np.ndarray:
119
+ # image = rearrange(image, "c h w-> h w c").numpy()
120
+ if self.fa is None:
121
+ landmark_coordinates = np.array(self.detect_facial_landmarks(image))
122
+ lm68 = mediapipe_lm478_to_face_alignment_lm68(landmark_coordinates)
123
+ else:
124
+ detected_faces = self.fa.get_landmarks(image)
125
+ if detected_faces is None:
126
+ raise RuntimeError("Face not detected")
127
+ lm68 = detected_faces[0]
128
+
129
+ points = self.smoother.smooth(lm68)
130
+ lmk3_ = np.zeros((3, 2))
131
+ lmk3_[0] = points[17:22].mean(0)
132
+ lmk3_[1] = points[22:27].mean(0)
133
+ lmk3_[2] = points[27:36].mean(0)
134
+ # print(lmk3_)
135
+ face, affine_matrix = self.restorer.align_warp_face(
136
+ image.copy(), lmks3=lmk3_, smooth=True, border_mode="constant"
137
+ )
138
+ box = [0, 0, face.shape[1], face.shape[0]] # x1, y1, x2, y2
139
+ face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_CUBIC)
140
+ face = rearrange(torch.from_numpy(face), "h w c -> c h w")
141
+ return face, box, affine_matrix
142
+
143
+ def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False):
144
+ if affine_transform:
145
+ image, _, _ = self.affine_transform(image)
146
+ else:
147
+ image = self.resize(image)
148
+ pixel_values = self.normalize(image / 255.0)
149
+ masked_pixel_values = pixel_values * self.mask_image
150
+ return pixel_values, masked_pixel_values, self.mask_image[0:1]
151
+
152
+ def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False):
153
+ if isinstance(images, np.ndarray):
154
+ images = torch.from_numpy(images)
155
+ if images.shape[3] == 3:
156
+ images = rearrange(images, "b h w c -> b c h w")
157
+ if self.mask == "fix_mask":
158
+ results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images]
159
+ else:
160
+ results = [self.preprocess_one_masked_image(image) for image in images]
161
+
162
+ pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results))
163
+ return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list)
164
+
165
+ def process_images(self, images: Union[torch.Tensor, np.ndarray]):
166
+ if isinstance(images, np.ndarray):
167
+ images = torch.from_numpy(images)
168
+ if images.shape[3] == 3:
169
+ images = rearrange(images, "b h w c -> b c h w")
170
+ images = self.resize(images)
171
+ pixel_values = self.normalize(images / 255.0)
172
+ return pixel_values
173
+
174
+ def close(self):
175
+ if self.face_mesh is not None:
176
+ self.face_mesh.close()
177
+
178
+
179
+ def mediapipe_lm478_to_face_alignment_lm68(lm478, return_2d=True):
180
+ """
181
+ lm478: [B, 478, 3] or [478,3]
182
+ """
183
+ # lm478[..., 0] *= W
184
+ # lm478[..., 1] *= H
185
+ landmarks_extracted = []
186
+ for index in landmark_points_68:
187
+ x = lm478[index][0]
188
+ y = lm478[index][1]
189
+ landmarks_extracted.append((x, y))
190
+ return np.array(landmarks_extracted)
191
+
192
+
193
+ landmark_points_68 = [
194
+ 162,
195
+ 234,
196
+ 93,
197
+ 58,
198
+ 172,
199
+ 136,
200
+ 149,
201
+ 148,
202
+ 152,
203
+ 377,
204
+ 378,
205
+ 365,
206
+ 397,
207
+ 288,
208
+ 323,
209
+ 454,
210
+ 389,
211
+ 71,
212
+ 63,
213
+ 105,
214
+ 66,
215
+ 107,
216
+ 336,
217
+ 296,
218
+ 334,
219
+ 293,
220
+ 301,
221
+ 168,
222
+ 197,
223
+ 5,
224
+ 4,
225
+ 75,
226
+ 97,
227
+ 2,
228
+ 326,
229
+ 305,
230
+ 33,
231
+ 160,
232
+ 158,
233
+ 133,
234
+ 153,
235
+ 144,
236
+ 362,
237
+ 385,
238
+ 387,
239
+ 263,
240
+ 373,
241
+ 380,
242
+ 61,
243
+ 39,
244
+ 37,
245
+ 0,
246
+ 267,
247
+ 269,
248
+ 291,
249
+ 405,
250
+ 314,
251
+ 17,
252
+ 84,
253
+ 181,
254
+ 78,
255
+ 82,
256
+ 13,
257
+ 312,
258
+ 308,
259
+ 317,
260
+ 14,
261
+ 87,
262
+ ]
263
+
264
+
265
+ # Refer to https://storage.googleapis.com/mediapipe-assets/documentation/mediapipe_face_landmark_fullsize.png
266
+ mouth_surround_landmarks = [
267
+ 164,
268
+ 165,
269
+ 167,
270
+ 92,
271
+ 186,
272
+ 57,
273
+ 43,
274
+ 106,
275
+ 182,
276
+ 83,
277
+ 18,
278
+ 313,
279
+ 406,
280
+ 335,
281
+ 273,
282
+ 287,
283
+ 410,
284
+ 322,
285
+ 391,
286
+ 393,
287
+ ]
288
+
289
+ face_surround_landmarks = [
290
+ 152,
291
+ 377,
292
+ 400,
293
+ 378,
294
+ 379,
295
+ 365,
296
+ 397,
297
+ 288,
298
+ 435,
299
+ 433,
300
+ 411,
301
+ 425,
302
+ 423,
303
+ 327,
304
+ 326,
305
+ 94,
306
+ 97,
307
+ 98,
308
+ 203,
309
+ 205,
310
+ 187,
311
+ 213,
312
+ 215,
313
+ 58,
314
+ 172,
315
+ 136,
316
+ 150,
317
+ 149,
318
+ 176,
319
+ 148,
320
+ ]
321
+
322
+ if __name__ == "__main__":
323
+ image_processor = ImageProcessor(512, mask="fix_mask")
324
+ video = cv2.VideoCapture("/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/original/val/RD_Radio57_000.mp4")
325
+ while True:
326
+ ret, frame = video.read()
327
+ # if not ret:
328
+ # break
329
+
330
+ # cv2.imwrite("image.jpg", frame)
331
+
332
+ frame = rearrange(torch.Tensor(frame).type(torch.uint8), "h w c -> c h w")
333
+ # face, masked_face, _ = image_processor.preprocess_fixed_mask_image(frame, affine_transform=True)
334
+ face, _, _ = image_processor.affine_transform(frame)
335
+
336
+ break
337
+
338
+ face = (rearrange(face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8)
339
+ cv2.imwrite("face.jpg", face)
340
+
341
+ # masked_face = (rearrange(masked_face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8)
342
+ # cv2.imwrite("masked_face.jpg", masked_face)
latentsync/utils/mask.png ADDED
latentsync/utils/util.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import imageio
17
+ import numpy as np
18
+ import json
19
+ from typing import Union
20
+ import matplotlib.pyplot as plt
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ import torchvision
26
+ import torch.distributed as dist
27
+ from torchvision import transforms
28
+
29
+ from tqdm import tqdm
30
+ from einops import rearrange
31
+ import cv2
32
+ from decord import AudioReader, VideoReader
33
+ import shutil
34
+ import subprocess
35
+
36
+
37
+ # Machine epsilon for a float32 (single precision)
38
+ eps = np.finfo(np.float32).eps
39
+
40
+
41
+ def read_json(filepath: str):
42
+ with open(filepath) as f:
43
+ json_dict = json.load(f)
44
+ return json_dict
45
+
46
+
47
+ def read_video(video_path: str, change_fps=True, use_decord=True):
48
+ if change_fps:
49
+ temp_dir = "temp"
50
+ if os.path.exists(temp_dir):
51
+ shutil.rmtree(temp_dir)
52
+ os.makedirs(temp_dir, exist_ok=True)
53
+ command = (
54
+ f"ffmpeg -loglevel error -y -nostdin -i {video_path} -r 25 -crf 18 {os.path.join(temp_dir, 'video.mp4')}"
55
+ )
56
+ subprocess.run(command, shell=True)
57
+ target_video_path = os.path.join(temp_dir, "video.mp4")
58
+ else:
59
+ target_video_path = video_path
60
+
61
+ if use_decord:
62
+ return read_video_decord(target_video_path)
63
+ else:
64
+ return read_video_cv2(target_video_path)
65
+
66
+
67
+ def read_video_decord(video_path: str):
68
+ vr = VideoReader(video_path)
69
+ video_frames = vr[:].asnumpy()
70
+ vr.seek(0)
71
+ return video_frames
72
+
73
+
74
+ def read_video_cv2(video_path: str):
75
+ # Open the video file
76
+ cap = cv2.VideoCapture(video_path)
77
+
78
+ # Check if the video was opened successfully
79
+ if not cap.isOpened():
80
+ print("Error: Could not open video.")
81
+ return np.array([])
82
+
83
+ frames = []
84
+
85
+ while True:
86
+ # Read a frame
87
+ ret, frame = cap.read()
88
+
89
+ # If frame is read correctly ret is True
90
+ if not ret:
91
+ break
92
+
93
+ # Convert BGR to RGB
94
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95
+
96
+ frames.append(frame_rgb)
97
+
98
+ # Release the video capture object
99
+ cap.release()
100
+
101
+ return np.array(frames)
102
+
103
+
104
+ def read_audio(audio_path: str, audio_sample_rate: int = 16000):
105
+ if audio_path is None:
106
+ raise ValueError("Audio path is required.")
107
+ ar = AudioReader(audio_path, sample_rate=audio_sample_rate, mono=True)
108
+
109
+ # To access the audio samples
110
+ audio_samples = torch.from_numpy(ar[:].asnumpy())
111
+ audio_samples = audio_samples.squeeze(0)
112
+
113
+ return audio_samples
114
+
115
+
116
+ def write_video(video_output_path: str, video_frames: np.ndarray, fps: int):
117
+ height, width = video_frames[0].shape[:2]
118
+ out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
119
+ # out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"vp09"), fps, (width, height))
120
+ for frame in video_frames:
121
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
122
+ out.write(frame)
123
+ out.release()
124
+
125
+
126
+ def init_dist(backend="nccl", **kwargs):
127
+ """Initializes distributed environment."""
128
+ rank = int(os.environ["RANK"])
129
+ num_gpus = torch.cuda.device_count()
130
+ if num_gpus == 0:
131
+ raise RuntimeError("No GPUs available for training.")
132
+ local_rank = rank % num_gpus
133
+ torch.cuda.set_device(local_rank)
134
+ dist.init_process_group(backend=backend, **kwargs)
135
+
136
+ return local_rank
137
+
138
+
139
+ def zero_rank_print(s):
140
+ if dist.is_initialized() and dist.get_rank() == 0:
141
+ print("### " + s)
142
+
143
+
144
+ def zero_rank_log(logger, message: str):
145
+ if dist.is_initialized() and dist.get_rank() == 0:
146
+ logger.info(message)
147
+
148
+
149
+ def make_audio_window(audio_embeddings: torch.Tensor, window_size: int):
150
+ audio_window = []
151
+ end_idx = audio_embeddings.shape[1] - window_size + 1
152
+ for i in range(end_idx):
153
+ audio_window.append(audio_embeddings[:, i : i + window_size, :])
154
+ audio_window = torch.stack(audio_window)
155
+ audio_window = rearrange(audio_window, "f b w d -> b f w d")
156
+ return audio_window
157
+
158
+
159
+ def check_video_fps(video_path: str):
160
+ cam = cv2.VideoCapture(video_path)
161
+ fps = cam.get(cv2.CAP_PROP_FPS)
162
+ if fps != 25:
163
+ raise ValueError(f"Video FPS is not 25, it is {fps}. Please convert the video to 25 FPS.")
164
+
165
+
166
+ def tailor_tensor_to_length(tensor: torch.Tensor, length: int):
167
+ if len(tensor) == length:
168
+ return tensor
169
+ elif len(tensor) > length:
170
+ return tensor[:length]
171
+ else:
172
+ return torch.cat([tensor, tensor[-1].repeat(length - len(tensor))])
173
+
174
+
175
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
176
+ videos = rearrange(videos, "b c f h w -> f b c h w")
177
+ outputs = []
178
+ for x in videos:
179
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
180
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
181
+ if rescale:
182
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
183
+ x = (x * 255).numpy().astype(np.uint8)
184
+ outputs.append(x)
185
+
186
+ os.makedirs(os.path.dirname(path), exist_ok=True)
187
+ imageio.mimsave(path, outputs, fps=fps)
188
+
189
+
190
+ def interpolate_features(features: torch.Tensor, output_len: int) -> torch.Tensor:
191
+ features = features.cpu().numpy()
192
+ input_len, num_features = features.shape
193
+
194
+ input_timesteps = np.linspace(0, 10, input_len)
195
+ output_timesteps = np.linspace(0, 10, output_len)
196
+ output_features = np.zeros((output_len, num_features))
197
+ for feat in range(num_features):
198
+ output_features[:, feat] = np.interp(output_timesteps, input_timesteps, features[:, feat])
199
+ return torch.from_numpy(output_features)
200
+
201
+
202
+ # DDIM Inversion
203
+ @torch.no_grad()
204
+ def init_prompt(prompt, pipeline):
205
+ uncond_input = pipeline.tokenizer(
206
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, return_tensors="pt"
207
+ )
208
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
209
+ text_input = pipeline.tokenizer(
210
+ [prompt],
211
+ padding="max_length",
212
+ max_length=pipeline.tokenizer.model_max_length,
213
+ truncation=True,
214
+ return_tensors="pt",
215
+ )
216
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
217
+ context = torch.cat([uncond_embeddings, text_embeddings])
218
+
219
+ return context
220
+
221
+
222
+ def reversed_forward(ddim_scheduler, pred_noise, timesteps, x_t):
223
+ # Compute alphas, betas
224
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timesteps]
225
+ beta_prod_t = 1 - alpha_prod_t
226
+
227
+ # 3. compute predicted original sample from predicted noise also called
228
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
229
+ if ddim_scheduler.config.prediction_type == "epsilon":
230
+ beta_prod_t = beta_prod_t[:, None, None, None, None]
231
+ alpha_prod_t = alpha_prod_t[:, None, None, None, None]
232
+ pred_original_sample = (x_t - beta_prod_t ** (0.5) * pred_noise) / alpha_prod_t ** (0.5)
233
+ else:
234
+ raise NotImplementedError("This prediction type is not implemented yet")
235
+
236
+ # Clip "predicted x_0"
237
+ if ddim_scheduler.config.clip_sample:
238
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
239
+ return pred_original_sample
240
+
241
+
242
+ def next_step(
243
+ model_output: Union[torch.FloatTensor, np.ndarray],
244
+ timestep: int,
245
+ sample: Union[torch.FloatTensor, np.ndarray],
246
+ ddim_scheduler,
247
+ ):
248
+ timestep, next_timestep = (
249
+ min(timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999),
250
+ timestep,
251
+ )
252
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
253
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
254
+ beta_prod_t = 1 - alpha_prod_t
255
+ next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
256
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
257
+ next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction
258
+ return next_sample
259
+
260
+
261
+ def get_noise_pred_single(latents, t, context, unet):
262
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
263
+ return noise_pred
264
+
265
+
266
+ @torch.no_grad()
267
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
268
+ context = init_prompt(prompt, pipeline)
269
+ uncond_embeddings, cond_embeddings = context.chunk(2)
270
+ all_latent = [latent]
271
+ latent = latent.clone().detach()
272
+ for i in tqdm(range(num_inv_steps)):
273
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
274
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
275
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
276
+ all_latent.append(latent)
277
+ return all_latent
278
+
279
+
280
+ @torch.no_grad()
281
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
282
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
283
+ return ddim_latents
284
+
285
+
286
+ def plot_loss_chart(save_path: str, *args):
287
+ # Creating the plot
288
+ plt.figure()
289
+ for loss_line in args:
290
+ plt.plot(loss_line[1], loss_line[2], label=loss_line[0])
291
+ plt.xlabel("Step")
292
+ plt.ylabel("Loss")
293
+ plt.legend()
294
+
295
+ # Save the figure to a file
296
+ plt.savefig(save_path)
297
+
298
+ # Close the figure to free memory
299
+ plt.close()
300
+
301
+
302
+ CRED = "\033[91m"
303
+ CEND = "\033[0m"
304
+
305
+
306
+ def red_text(text: str):
307
+ return f"{CRED}{text}{CEND}"
308
+
309
+
310
+ log_loss = nn.BCELoss(reduction="none")
311
+
312
+
313
+ def cosine_loss(vision_embeds, audio_embeds, y):
314
+ sims = nn.functional.cosine_similarity(vision_embeds, audio_embeds)
315
+ # sims[sims!=sims] = 0 # remove nan
316
+ # sims = sims.clamp(0, 1)
317
+ loss = log_loss(sims.unsqueeze(1), y).squeeze()
318
+ return loss
319
+
320
+
321
+ def save_image(image, save_path):
322
+ # input size (C, H, W)
323
+ image = (image / 2 + 0.5).clamp(0, 1)
324
+ image = (image * 255).to(torch.uint8)
325
+ image = transforms.ToPILImage()(image)
326
+ # Save the image copy
327
+ image.save(save_path)
328
+
329
+ # Close the image file
330
+ image.close()
331
+
332
+
333
+ def gather_loss(loss, device):
334
+ # Sum the local loss across all processes
335
+ local_loss = loss.item()
336
+ global_loss = torch.tensor(local_loss, dtype=torch.float32).to(device)
337
+ dist.all_reduce(global_loss, op=dist.ReduceOp.SUM)
338
+
339
+ # Calculate the average loss across all processes
340
+ global_average_loss = global_loss.item() / dist.get_world_size()
341
+ return global_average_loss
342
+
343
+
344
+ def gather_video_paths_recursively(input_dir):
345
+ print(f"Recursively gathering video paths of {input_dir} ...")
346
+ paths = []
347
+ gather_video_paths(input_dir, paths)
348
+ return paths
349
+
350
+
351
+ def gather_video_paths(input_dir, paths):
352
+ for file in sorted(os.listdir(input_dir)):
353
+ if file.endswith(".mp4"):
354
+ filepath = os.path.join(input_dir, file)
355
+ paths.append(filepath)
356
+ elif os.path.isdir(os.path.join(input_dir, file)):
357
+ gather_video_paths(os.path.join(input_dir, file), paths)
358
+
359
+
360
+ def count_video_time(video_path):
361
+ video = cv2.VideoCapture(video_path)
362
+
363
+ frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
364
+ fps = video.get(cv2.CAP_PROP_FPS)
365
+ return frame_count / fps
latentsync/whisper/audio2feature.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/TMElyralab/MuseTalk/blob/main/musetalk/whisper/audio2feature.py
2
+
3
+ from .whisper import load_model
4
+ import numpy as np
5
+ import torch
6
+ import os
7
+
8
+
9
+ class Audio2Feature:
10
+ def __init__(
11
+ self,
12
+ model_path="checkpoints/whisper/tiny.pt",
13
+ device=None,
14
+ audio_embeds_cache_dir=None,
15
+ num_frames=16,
16
+ ):
17
+ self.model = load_model(model_path, device)
18
+ self.audio_embeds_cache_dir = audio_embeds_cache_dir
19
+ self.num_frames = num_frames
20
+ self.embedding_dim = self.model.dims.n_audio_state
21
+
22
+ def get_sliced_feature(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25):
23
+ """
24
+ Get sliced features based on a given index
25
+ :param feature_array:
26
+ :param start_idx: the start index of the feature
27
+ :param audio_feat_length:
28
+ :return:
29
+ """
30
+ length = len(feature_array)
31
+ selected_feature = []
32
+ selected_idx = []
33
+
34
+ center_idx = int(vid_idx * 50 / fps)
35
+ left_idx = center_idx - audio_feat_length[0] * 2
36
+ right_idx = center_idx + (audio_feat_length[1] + 1) * 2
37
+
38
+ for idx in range(left_idx, right_idx):
39
+ idx = max(0, idx)
40
+ idx = min(length - 1, idx)
41
+ x = feature_array[idx]
42
+ selected_feature.append(x)
43
+ selected_idx.append(idx)
44
+
45
+ selected_feature = torch.cat(selected_feature, dim=0)
46
+ selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
47
+ return selected_feature, selected_idx
48
+
49
+ def get_sliced_feature_sparse(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25):
50
+ """
51
+ Get sliced features based on a given index
52
+ :param feature_array:
53
+ :param start_idx: the start index of the feature
54
+ :param audio_feat_length:
55
+ :return:
56
+ """
57
+ length = len(feature_array)
58
+ selected_feature = []
59
+ selected_idx = []
60
+
61
+ for dt in range(-audio_feat_length[0], audio_feat_length[1] + 1):
62
+ left_idx = int((vid_idx + dt) * 50 / fps)
63
+ if left_idx < 1 or left_idx > length - 1:
64
+ left_idx = max(0, left_idx)
65
+ left_idx = min(length - 1, left_idx)
66
+
67
+ x = feature_array[left_idx]
68
+ x = x[np.newaxis, :, :]
69
+ x = np.repeat(x, 2, axis=0)
70
+ selected_feature.append(x)
71
+ selected_idx.append(left_idx)
72
+ selected_idx.append(left_idx)
73
+ else:
74
+ x = feature_array[left_idx - 1 : left_idx + 1]
75
+ selected_feature.append(x)
76
+ selected_idx.append(left_idx - 1)
77
+ selected_idx.append(left_idx)
78
+ selected_feature = np.concatenate(selected_feature, axis=0)
79
+ selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
80
+ selected_feature = torch.from_numpy(selected_feature)
81
+ return selected_feature, selected_idx
82
+
83
+ def feature2chunks(self, feature_array, fps, audio_feat_length=[2, 2]):
84
+ whisper_chunks = []
85
+ whisper_idx_multiplier = 50.0 / fps
86
+ i = 0
87
+ print(f"video in {fps} FPS, audio idx in 50FPS")
88
+
89
+ while True:
90
+ start_idx = int(i * whisper_idx_multiplier)
91
+ selected_feature, selected_idx = self.get_sliced_feature(
92
+ feature_array=feature_array, vid_idx=i, audio_feat_length=audio_feat_length, fps=fps
93
+ )
94
+ # print(f"i:{i},selected_idx {selected_idx}")
95
+ whisper_chunks.append(selected_feature)
96
+ i += 1
97
+ if start_idx > len(feature_array):
98
+ break
99
+
100
+ return whisper_chunks
101
+
102
+ def _audio2feat(self, audio_path: str):
103
+ # get the sample rate of the audio
104
+ result = self.model.transcribe(audio_path)
105
+ embed_list = []
106
+ for emb in result["segments"]:
107
+ encoder_embeddings = emb["encoder_embeddings"]
108
+ encoder_embeddings = encoder_embeddings.transpose(0, 2, 1, 3)
109
+ encoder_embeddings = encoder_embeddings.squeeze(0)
110
+ start_idx = int(emb["start"])
111
+ end_idx = int(emb["end"])
112
+ emb_end_idx = int((end_idx - start_idx) / 2)
113
+ embed_list.append(encoder_embeddings[:emb_end_idx])
114
+ concatenated_array = torch.from_numpy(np.concatenate(embed_list, axis=0))
115
+ return concatenated_array
116
+
117
+ def audio2feat(self, audio_path):
118
+ if self.audio_embeds_cache_dir == "" or self.audio_embeds_cache_dir is None:
119
+ return self._audio2feat(audio_path)
120
+
121
+ audio_embeds_cache_path = os.path.join(self.audio_embeds_cache_dir, os.path.basename(audio_path) + ".pt")
122
+
123
+ if os.path.isfile(audio_embeds_cache_path):
124
+ try:
125
+ audio_feat = torch.load(audio_embeds_cache_path)
126
+ except Exception as e:
127
+ print(f"{type(e).__name__} - {e} - {audio_embeds_cache_path}")
128
+ os.remove(audio_embeds_cache_path)
129
+ audio_feat = self._audio2feat(audio_path)
130
+ torch.save(audio_feat, audio_embeds_cache_path)
131
+ else:
132
+ audio_feat = self._audio2feat(audio_path)
133
+ torch.save(audio_feat, audio_embeds_cache_path)
134
+
135
+ return audio_feat
136
+
137
+ def crop_overlap_audio_window(self, audio_feat, start_index):
138
+ selected_feature_list = []
139
+ for i in range(start_index, start_index + self.num_frames):
140
+ selected_feature, selected_idx = self.get_sliced_feature(
141
+ feature_array=audio_feat, vid_idx=i, audio_feat_length=[2, 2], fps=25
142
+ )
143
+ selected_feature_list.append(selected_feature)
144
+ mel_overlap = torch.stack(selected_feature_list)
145
+ return mel_overlap
146
+
147
+
148
+ if __name__ == "__main__":
149
+ audio_encoder = Audio2Feature(model_path="checkpoints/whisper/tiny.pt")
150
+ audio_path = "assets/demo1_audio.wav"
151
+ array = audio_encoder.audio2feat(audio_path)
152
+ print(array.shape)
153
+ fps = 25
154
+ whisper_idx_multiplier = 50.0 / fps
155
+
156
+ i = 0
157
+ print(f"video in {fps} FPS, audio idx in 50FPS")
158
+ while True:
159
+ start_idx = int(i * whisper_idx_multiplier)
160
+ selected_feature, selected_idx = audio_encoder.get_sliced_feature(
161
+ feature_array=array, vid_idx=i, audio_feat_length=[2, 2], fps=fps
162
+ )
163
+ print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
164
+ i += 1
165
+ if start_idx > len(array):
166
+ break
latentsync/whisper/whisper/__init__.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import io
3
+ import os
4
+ import urllib
5
+ import warnings
6
+ from typing import List, Optional, Union
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from .audio import load_audio, log_mel_spectrogram, pad_or_trim
12
+ from .decoding import DecodingOptions, DecodingResult, decode, detect_language
13
+ from .model import Whisper, ModelDimensions
14
+ from .transcribe import transcribe
15
+
16
+
17
+ _MODELS = {
18
+ "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
19
+ "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
20
+ "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
21
+ "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
22
+ "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
23
+ "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
24
+ "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
25
+ "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
26
+ "large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
27
+ "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
28
+ "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
29
+ "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
30
+ }
31
+
32
+
33
+ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
34
+ os.makedirs(root, exist_ok=True)
35
+
36
+ expected_sha256 = url.split("/")[-2]
37
+ download_target = os.path.join(root, os.path.basename(url))
38
+
39
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
40
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
41
+
42
+ if os.path.isfile(download_target):
43
+ model_bytes = open(download_target, "rb").read()
44
+ if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
45
+ return model_bytes if in_memory else download_target
46
+ else:
47
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
48
+
49
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
50
+ with tqdm(
51
+ total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024
52
+ ) as loop:
53
+ while True:
54
+ buffer = source.read(8192)
55
+ if not buffer:
56
+ break
57
+
58
+ output.write(buffer)
59
+ loop.update(len(buffer))
60
+
61
+ model_bytes = open(download_target, "rb").read()
62
+ if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
63
+ raise RuntimeError(
64
+ "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
65
+ )
66
+
67
+ return model_bytes if in_memory else download_target
68
+
69
+
70
+ def available_models() -> List[str]:
71
+ """Returns the names of available models"""
72
+ return list(_MODELS.keys())
73
+
74
+
75
+ def load_model(
76
+ name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False
77
+ ) -> Whisper:
78
+ """
79
+ Load a Whisper ASR model
80
+
81
+ Parameters
82
+ ----------
83
+ name : str
84
+ one of the official model names listed by `whisper.available_models()`, or
85
+ path to a model checkpoint containing the model dimensions and the model state_dict.
86
+ device : Union[str, torch.device]
87
+ the PyTorch device to put the model into
88
+ download_root: str
89
+ path to download the model files; by default, it uses "~/.cache/whisper"
90
+ in_memory: bool
91
+ whether to preload the model weights into host memory
92
+
93
+ Returns
94
+ -------
95
+ model : Whisper
96
+ The Whisper ASR model instance
97
+ """
98
+
99
+ if device is None:
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
101
+ if download_root is None:
102
+ download_root = os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
103
+
104
+ if name in _MODELS:
105
+ checkpoint_file = _download(_MODELS[name], download_root, in_memory)
106
+ elif os.path.isfile(name):
107
+ checkpoint_file = open(name, "rb").read() if in_memory else name
108
+ else:
109
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
110
+
111
+ with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
112
+ checkpoint = torch.load(fp, map_location=device)
113
+ del checkpoint_file
114
+
115
+ dims = ModelDimensions(**checkpoint["dims"])
116
+ model = Whisper(dims)
117
+ model.load_state_dict(checkpoint["model_state_dict"])
118
+
119
+ return model.to(device)
latentsync/whisper/whisper/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .transcribe import cli
2
+
3
+
4
+ cli()
latentsync/whisper/whisper/assets/gpt2/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
latentsync/whisper/whisper/assets/gpt2/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
latentsync/whisper/whisper/assets/gpt2/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
latentsync/whisper/whisper/assets/gpt2/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
latentsync/whisper/whisper/assets/mel_filters.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd2cc75e70e36fcbdd8ffbc2499062f30094093e6bf2cbafa9859f59972b420b
3
+ size 2048
latentsync/whisper/whisper/assets/multilingual/added_tokens.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"<|endoftext|>": 50257}
latentsync/whisper/whisper/assets/multilingual/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
latentsync/whisper/whisper/assets/multilingual/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
latentsync/whisper/whisper/assets/multilingual/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
latentsync/whisper/whisper/assets/multilingual/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
latentsync/whisper/whisper/audio.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import Union
4
+
5
+ import ffmpeg
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from .utils import exact_div
11
+
12
+ # hard-coded audio hyperparameters
13
+ SAMPLE_RATE = 16000
14
+ N_FFT = 400
15
+ N_MELS = 80
16
+ HOP_LENGTH = 160
17
+ CHUNK_LENGTH = 30
18
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
19
+ N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
20
+
21
+
22
+ def load_audio(file: str, sr: int = SAMPLE_RATE):
23
+ """
24
+ Open an audio file and read as mono waveform, resampling as necessary
25
+
26
+ Parameters
27
+ ----------
28
+ file: str
29
+ The audio file to open
30
+
31
+ sr: int
32
+ The sample rate to resample the audio if necessary
33
+
34
+ Returns
35
+ -------
36
+ A NumPy array containing the audio waveform, in float32 dtype.
37
+ """
38
+ try:
39
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
40
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
41
+ out, _ = (
42
+ ffmpeg.input(file, threads=0)
43
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
44
+ .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
45
+ )
46
+ except ffmpeg.Error as e:
47
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
48
+
49
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
50
+
51
+
52
+ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
53
+ """
54
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
55
+ """
56
+ if torch.is_tensor(array):
57
+ if array.shape[axis] > length:
58
+ array = array.index_select(dim=axis, index=torch.arange(length))
59
+
60
+ if array.shape[axis] < length:
61
+ pad_widths = [(0, 0)] * array.ndim
62
+ pad_widths[axis] = (0, length - array.shape[axis])
63
+ array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
64
+ else:
65
+ if array.shape[axis] > length:
66
+ array = array.take(indices=range(length), axis=axis)
67
+
68
+ if array.shape[axis] < length:
69
+ pad_widths = [(0, 0)] * array.ndim
70
+ pad_widths[axis] = (0, length - array.shape[axis])
71
+ array = np.pad(array, pad_widths)
72
+
73
+ return array
74
+
75
+
76
+ @lru_cache(maxsize=None)
77
+ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
78
+ """
79
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
80
+ Allows decoupling librosa dependency; saved using:
81
+
82
+ np.savez_compressed(
83
+ "mel_filters.npz",
84
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
85
+ )
86
+ """
87
+ assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
88
+ with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
89
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
90
+
91
+
92
+ def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
93
+ """
94
+ Compute the log-Mel spectrogram of
95
+
96
+ Parameters
97
+ ----------
98
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
99
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
100
+
101
+ n_mels: int
102
+ The number of Mel-frequency filters, only 80 is supported
103
+
104
+ Returns
105
+ -------
106
+ torch.Tensor, shape = (80, n_frames)
107
+ A Tensor that contains the Mel spectrogram
108
+ """
109
+ if not torch.is_tensor(audio):
110
+ if isinstance(audio, str):
111
+ audio = load_audio(audio)
112
+ audio = torch.from_numpy(audio)
113
+
114
+ window = torch.hann_window(N_FFT).to(audio.device)
115
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
116
+
117
+ magnitudes = stft[:, :-1].abs() ** 2
118
+
119
+ filters = mel_filters(audio.device, n_mels)
120
+ mel_spec = filters @ magnitudes
121
+
122
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
123
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
124
+ log_spec = (log_spec + 4.0) / 4.0
125
+ return log_spec
latentsync/whisper/whisper/decoding.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.distributions import Categorical
9
+
10
+ from .audio import CHUNK_LENGTH
11
+ from .tokenizer import Tokenizer, get_tokenizer
12
+ from .utils import compression_ratio
13
+
14
+ if TYPE_CHECKING:
15
+ from .model import Whisper
16
+
17
+
18
+ @torch.no_grad()
19
+ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
20
+ """
21
+ Detect the spoken language in the audio, and return them as list of strings, along with the ids
22
+ of the most probable language tokens and the probability distribution over all language tokens.
23
+ This is performed outside the main decode loop in order to not interfere with kv-caching.
24
+
25
+ Returns
26
+ -------
27
+ language_tokens : Tensor, shape = (n_audio,)
28
+ ids of the most probable language tokens, which appears after the startoftranscript token.
29
+ language_probs : List[Dict[str, float]], length = n_audio
30
+ list of dictionaries containing the probability distribution over all languages.
31
+ """
32
+ if tokenizer is None:
33
+ tokenizer = get_tokenizer(model.is_multilingual)
34
+ if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
35
+ raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
36
+
37
+ single = mel.ndim == 2
38
+ if single:
39
+ mel = mel.unsqueeze(0)
40
+
41
+ # skip encoder forward pass if already-encoded audio features were given
42
+ if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
43
+ mel = model.encoder(mel)
44
+
45
+ # forward pass using a single token, startoftranscript
46
+ n_audio = mel.shape[0]
47
+ x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
48
+ logits = model.logits(x, mel)[:, 0]
49
+
50
+ # collect detected languages; suppress all non-language tokens
51
+ mask = torch.ones(logits.shape[-1], dtype=torch.bool)
52
+ mask[list(tokenizer.all_language_tokens)] = False
53
+ logits[:, mask] = -np.inf
54
+ language_tokens = logits.argmax(dim=-1)
55
+ language_token_probs = logits.softmax(dim=-1).cpu()
56
+ language_probs = [
57
+ {
58
+ c: language_token_probs[i, j].item()
59
+ for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
60
+ }
61
+ for i in range(n_audio)
62
+ ]
63
+
64
+ if single:
65
+ language_tokens = language_tokens[0]
66
+ language_probs = language_probs[0]
67
+
68
+ return language_tokens, language_probs
69
+
70
+
71
+ @dataclass(frozen=True)
72
+ class DecodingOptions:
73
+ task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
74
+ language: Optional[str] = None # language that the audio is in; uses detected language if None
75
+
76
+ # sampling-related options
77
+ temperature: float = 0.0
78
+ sample_len: Optional[int] = None # maximum number of tokens to sample
79
+ best_of: Optional[int] = None # number of independent samples to collect, when t > 0
80
+ beam_size: Optional[int] = None # number of beams in beam search, when t == 0
81
+ patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
82
+
83
+ # options for ranking generations (either beams or best-of-N samples)
84
+ length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
85
+
86
+ # prompt, prefix, and token suppression
87
+ prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
88
+ prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
89
+ suppress_blank: bool = True # this will suppress blank outputs
90
+
91
+ # list of tokens ids (or comma-separated token ids) to suppress
92
+ # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
93
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
94
+
95
+ # timestamp sampling options
96
+ without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
97
+ max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
98
+
99
+ # implementation details
100
+ fp16: bool = True # use fp16 for most of the calculation
101
+
102
+
103
+ @dataclass(frozen=True)
104
+ class DecodingResult:
105
+ audio_features: Tensor
106
+ language: str
107
+ encoder_embeddings: np.ndarray
108
+ decoder_embeddings: np.ndarray
109
+ language_probs: Optional[Dict[str, float]] = None
110
+ tokens: List[int] = field(default_factory=list)
111
+ text: str = ""
112
+ avg_logprob: float = np.nan
113
+ no_speech_prob: float = np.nan
114
+ temperature: float = np.nan
115
+ compression_ratio: float = np.nan
116
+
117
+
118
+ class Inference:
119
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
120
+ """Perform a forward pass on the decoder and return per-token logits"""
121
+ raise NotImplementedError
122
+
123
+ def rearrange_kv_cache(self, source_indices) -> None:
124
+ """Update the key-value cache according to the updated beams"""
125
+ raise NotImplementedError
126
+
127
+ def cleanup_caching(self) -> None:
128
+ """Clean up any resources or hooks after decoding is finished"""
129
+ pass
130
+
131
+
132
+ class PyTorchInference(Inference):
133
+ def __init__(self, model: "Whisper", initial_token_length: int):
134
+ self.model: "Whisper" = model
135
+ self.initial_token_length = initial_token_length
136
+ self.kv_cache = {}
137
+ self.hooks = []
138
+
139
+ def logits(self, tokens: Tensor, audio_features: Tensor, include_embeddings=False) -> Tensor:
140
+ if not self.kv_cache:
141
+ self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
142
+
143
+ if tokens.shape[-1] > self.initial_token_length:
144
+ # only need to use the last token except in the first forward pass
145
+ tokens = tokens[:, -1:]
146
+
147
+ return_val = self.model.decoder(tokens, audio_features,
148
+ kv_cache=self.kv_cache, include_embeddings=include_embeddings)
149
+ return return_val
150
+
151
+ def cleanup_caching(self):
152
+ for hook in self.hooks:
153
+ hook.remove()
154
+
155
+ self.kv_cache = {}
156
+ self.hooks = []
157
+
158
+ def rearrange_kv_cache(self, source_indices):
159
+ for module, tensor in self.kv_cache.items():
160
+ # update the key/value cache to contain the selected sequences
161
+ self.kv_cache[module] = tensor[source_indices].detach()
162
+
163
+
164
+ class SequenceRanker:
165
+ def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
166
+ """
167
+ Given a list of groups of samples and their cumulative log probabilities,
168
+ return the indices of the samples in each group to select as the final result
169
+ """
170
+ raise NotImplementedError
171
+
172
+
173
+ class MaximumLikelihoodRanker(SequenceRanker):
174
+ """
175
+ Select the sample with the highest log probabilities, penalized using either
176
+ a simple length normalization or Google NMT paper's length penalty
177
+ """
178
+
179
+ def __init__(self, length_penalty: Optional[float]):
180
+ self.length_penalty = length_penalty
181
+
182
+ def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
183
+ def scores(logprobs, lengths):
184
+ result = []
185
+ for logprob, length in zip(logprobs, lengths):
186
+ if self.length_penalty is None:
187
+ penalty = length
188
+ else:
189
+ # from the Google NMT paper
190
+ penalty = ((5 + length) / 6) ** self.length_penalty
191
+ result.append(logprob / penalty)
192
+ return result
193
+
194
+ # get the sequence with the highest score
195
+ lengths = [[len(t) for t in s] for s in tokens]
196
+ return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
197
+
198
+
199
+ class TokenDecoder:
200
+ def reset(self):
201
+ """Initialize any stateful variables for decoding a new sequence"""
202
+
203
+ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
204
+ """Specify how to select the next token, based on the current trace and logits
205
+
206
+ Parameters
207
+ ----------
208
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
209
+ all tokens in the context so far, including the prefix and sot_sequence tokens
210
+
211
+ logits : Tensor, shape = (n_batch, vocab_size)
212
+ per-token logits of the probability distribution at the current step
213
+
214
+ sum_logprobs : Tensor, shape = (n_batch)
215
+ cumulative log probabilities for each sequence
216
+
217
+ Returns
218
+ -------
219
+ tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
220
+ the tokens, appended with the selected next token
221
+
222
+ completed : bool
223
+ True if all sequences has reached the end of text
224
+
225
+ """
226
+ raise NotImplementedError
227
+
228
+ def finalize(
229
+ self, tokens: Tensor, sum_logprobs: Tensor
230
+ ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
231
+ """Finalize search and return the final candidate sequences
232
+
233
+ Parameters
234
+ ----------
235
+ tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
236
+ all tokens in the context so far, including the prefix and sot_sequence
237
+
238
+ sum_logprobs : Tensor, shape = (n_audio, n_group)
239
+ cumulative log probabilities for each sequence
240
+
241
+ Returns
242
+ -------
243
+ tokens : Sequence[Sequence[Tensor]], length = n_audio
244
+ sequence of Tensors containing candidate token sequences, for each audio input
245
+
246
+ sum_logprobs : List[List[float]], length = n_audio
247
+ sequence of cumulative log probabilities corresponding to the above
248
+
249
+ """
250
+ raise NotImplementedError
251
+
252
+
253
+ class GreedyDecoder(TokenDecoder):
254
+ def __init__(self, temperature: float, eot: int):
255
+ self.temperature = temperature
256
+ self.eot = eot
257
+
258
+ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
259
+ temperature = self.temperature
260
+ if temperature == 0:
261
+ next_tokens = logits.argmax(dim=-1)
262
+ else:
263
+ next_tokens = Categorical(logits=logits / temperature).sample()
264
+
265
+ logprobs = F.log_softmax(logits.float(), dim=-1)
266
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
267
+ sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
268
+
269
+ next_tokens[tokens[:, -1] == self.eot] = self.eot
270
+ tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
271
+
272
+ completed = (tokens[:, -1] == self.eot).all()
273
+ return tokens, completed
274
+
275
+ def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
276
+ # make sure each sequence has at least one EOT token at the end
277
+ tokens = F.pad(tokens, (0, 1), value=self.eot)
278
+ return tokens, sum_logprobs.tolist()
279
+
280
+
281
+ class BeamSearchDecoder(TokenDecoder):
282
+ def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
283
+ self.beam_size = beam_size
284
+ self.eot = eot
285
+ self.inference = inference
286
+ self.patience = patience or 1.0
287
+ self.max_candidates: int = round(beam_size * self.patience)
288
+ self.finished_sequences = None
289
+
290
+ assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
291
+
292
+ def reset(self):
293
+ self.finished_sequences = None
294
+
295
+ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
296
+ if tokens.shape[0] % self.beam_size != 0:
297
+ raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
298
+
299
+ n_audio = tokens.shape[0] // self.beam_size
300
+ if self.finished_sequences is None: # for the first update
301
+ self.finished_sequences = [{} for _ in range(n_audio)]
302
+
303
+ logprobs = F.log_softmax(logits.float(), dim=-1)
304
+ next_tokens, source_indices, finished_sequences = [], [], []
305
+ for i in range(n_audio):
306
+ scores, sources, finished = {}, {}, {}
307
+
308
+ # STEP 1: calculate the cumulative log probabilities for possible candidates
309
+ for j in range(self.beam_size):
310
+ idx = i * self.beam_size + j
311
+ prefix = tokens[idx].tolist()
312
+ for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
313
+ new_logprob = (sum_logprobs[idx] + logprob).item()
314
+ sequence = tuple(prefix + [token.item()])
315
+ scores[sequence] = new_logprob
316
+ sources[sequence] = idx
317
+
318
+ # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
319
+ saved = 0
320
+ for sequence in sorted(scores, key=scores.get, reverse=True):
321
+ if sequence[-1] == self.eot:
322
+ finished[sequence] = scores[sequence]
323
+ else:
324
+ sum_logprobs[len(next_tokens)] = scores[sequence]
325
+ next_tokens.append(sequence)
326
+ source_indices.append(sources[sequence])
327
+
328
+ saved += 1
329
+ if saved == self.beam_size:
330
+ break
331
+
332
+ finished_sequences.append(finished)
333
+
334
+ tokens = torch.tensor(next_tokens, device=tokens.device)
335
+ self.inference.rearrange_kv_cache(source_indices)
336
+
337
+ # add newly finished sequences to self.finished_sequences
338
+ assert len(self.finished_sequences) == len(finished_sequences)
339
+ for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
340
+ for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
341
+ if len(previously_finished) >= self.max_candidates:
342
+ break # the candidate list is full
343
+ previously_finished[seq] = newly_finished[seq]
344
+
345
+ # mark as completed if all audio has enough number of samples
346
+ completed = all(
347
+ len(sequences) >= self.max_candidates for sequences in self.finished_sequences
348
+ )
349
+ return tokens, completed
350
+
351
+ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
352
+ # collect all finished sequences, including patience, and add unfinished ones if not enough
353
+ sum_logprobs = sum_logprobs.cpu()
354
+ for i, sequences in enumerate(self.finished_sequences):
355
+ if len(sequences) < self.beam_size: # when not enough sequences are finished
356
+ for j in list(np.argsort(sum_logprobs[i]))[::-1]:
357
+ sequence = preceding_tokens[i, j].tolist() + [self.eot]
358
+ sequences[tuple(sequence)] = sum_logprobs[i][j].item()
359
+ if len(sequences) >= self.beam_size:
360
+ break
361
+
362
+ tokens: List[List[Tensor]] = [
363
+ [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
364
+ ]
365
+ sum_logprobs: List[List[float]] = [
366
+ list(sequences.values()) for sequences in self.finished_sequences
367
+ ]
368
+ return tokens, sum_logprobs
369
+
370
+
371
+ class LogitFilter:
372
+ def apply(self, logits: Tensor, tokens: Tensor) -> None:
373
+ """Apply any filtering or masking to logits in-place
374
+
375
+ Parameters
376
+ ----------
377
+ logits : Tensor, shape = (n_batch, vocab_size)
378
+ per-token logits of the probability distribution at the current step
379
+
380
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
381
+ all tokens in the context so far, including the prefix and sot_sequence tokens
382
+
383
+ """
384
+ raise NotImplementedError
385
+
386
+
387
+ class SuppressBlank(LogitFilter):
388
+ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
389
+ self.tokenizer = tokenizer
390
+ self.sample_begin = sample_begin
391
+
392
+ def apply(self, logits: Tensor, tokens: Tensor):
393
+ if tokens.shape[1] == self.sample_begin:
394
+ logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
395
+
396
+
397
+ class SuppressTokens(LogitFilter):
398
+ def __init__(self, suppress_tokens: Sequence[int]):
399
+ self.suppress_tokens = list(suppress_tokens)
400
+
401
+ def apply(self, logits: Tensor, tokens: Tensor):
402
+ logits[:, self.suppress_tokens] = -np.inf
403
+
404
+
405
+ class ApplyTimestampRules(LogitFilter):
406
+ def __init__(
407
+ self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
408
+ ):
409
+ self.tokenizer = tokenizer
410
+ self.sample_begin = sample_begin
411
+ self.max_initial_timestamp_index = max_initial_timestamp_index
412
+
413
+ def apply(self, logits: Tensor, tokens: Tensor):
414
+ # suppress <|notimestamps|> which is handled by without_timestamps
415
+ if self.tokenizer.no_timestamps is not None:
416
+ logits[:, self.tokenizer.no_timestamps] = -np.inf
417
+
418
+ # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
419
+ for k in range(tokens.shape[0]):
420
+ seq = [t for t in tokens[k, self.sample_begin :].tolist()]
421
+ last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
422
+ penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
423
+
424
+ if last_was_timestamp:
425
+ if penultimate_was_timestamp: # has to be non-timestamp
426
+ logits[k, self.tokenizer.timestamp_begin :] = -np.inf
427
+ else: # cannot be normal text tokens
428
+ logits[k, : self.tokenizer.eot] = -np.inf
429
+
430
+ # apply the `max_initial_timestamp` option
431
+ if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
432
+ last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
433
+ logits[:, last_allowed + 1 :] = -np.inf
434
+
435
+ # if sum of probability over timestamps is above any other token, sample timestamp
436
+ logprobs = F.log_softmax(logits.float(), dim=-1)
437
+ for k in range(tokens.shape[0]):
438
+ timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
439
+ max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
440
+ if timestamp_logprob > max_text_token_logprob:
441
+ logits[k, : self.tokenizer.timestamp_begin] = -np.inf
442
+
443
+
444
+ class DecodingTask:
445
+ inference: Inference
446
+ sequence_ranker: SequenceRanker
447
+ decoder: TokenDecoder
448
+ logit_filters: List[LogitFilter]
449
+
450
+ def __init__(self, model: "Whisper", options: DecodingOptions):
451
+ self.model = model
452
+
453
+ language = options.language or "en"
454
+ tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
455
+ self.tokenizer: Tokenizer = tokenizer
456
+ self.options: DecodingOptions = self._verify_options(options)
457
+
458
+ self.n_group: int = options.beam_size or options.best_of or 1
459
+ self.n_ctx: int = model.dims.n_text_ctx
460
+ self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
461
+
462
+ self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
463
+ if self.options.without_timestamps:
464
+ self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
465
+
466
+ self.initial_tokens: Tuple[int] = self._get_initial_tokens()
467
+ self.sample_begin: int = len(self.initial_tokens)
468
+ self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
469
+
470
+ # inference: implements the forward pass through the decoder, including kv caching
471
+ self.inference = PyTorchInference(model, len(self.initial_tokens))
472
+
473
+ # sequence ranker: implements how to rank a group of sampled sequences
474
+ self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
475
+
476
+ # decoder: implements how to select the next tokens, given the autoregressive distribution
477
+ if options.beam_size is not None:
478
+ self.decoder = BeamSearchDecoder(
479
+ options.beam_size, tokenizer.eot, self.inference, options.patience
480
+ )
481
+ else:
482
+ self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
483
+
484
+ # logit filters: applies various rules to suppress or penalize certain tokens
485
+ self.logit_filters = []
486
+ if self.options.suppress_blank:
487
+ self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
488
+ if self.options.suppress_tokens:
489
+ self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
490
+ if not options.without_timestamps:
491
+ precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
492
+ max_initial_timestamp_index = None
493
+ if options.max_initial_timestamp:
494
+ max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
495
+ self.logit_filters.append(
496
+ ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
497
+ )
498
+
499
+ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
500
+ if options.beam_size is not None and options.best_of is not None:
501
+ raise ValueError("beam_size and best_of can't be given together")
502
+ if options.temperature == 0:
503
+ if options.best_of is not None:
504
+ raise ValueError("best_of with greedy sampling (T=0) is not compatible")
505
+ if options.patience is not None and options.beam_size is None:
506
+ raise ValueError("patience requires beam_size to be given")
507
+ if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
508
+ raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
509
+
510
+ return options
511
+
512
+ def _get_initial_tokens(self) -> Tuple[int]:
513
+ tokens = list(self.sot_sequence)
514
+ prefix = self.options.prefix
515
+ prompt = self.options.prompt
516
+
517
+ if prefix:
518
+ prefix_tokens = (
519
+ self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
520
+ )
521
+ if self.sample_len is not None:
522
+ max_prefix_len = self.n_ctx // 2 - self.sample_len
523
+ prefix_tokens = prefix_tokens[-max_prefix_len:]
524
+ tokens = tokens + prefix_tokens
525
+
526
+ if prompt:
527
+ prompt_tokens = (
528
+ self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
529
+ )
530
+ tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
531
+
532
+ return tuple(tokens)
533
+
534
+ def _get_suppress_tokens(self) -> Tuple[int]:
535
+ suppress_tokens = self.options.suppress_tokens
536
+
537
+ if isinstance(suppress_tokens, str):
538
+ suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
539
+
540
+ if -1 in suppress_tokens:
541
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
542
+ suppress_tokens.extend(self.tokenizer.non_speech_tokens)
543
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
544
+ suppress_tokens = [] # interpret empty string as an empty list
545
+ else:
546
+ assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
547
+
548
+ suppress_tokens.extend(
549
+ [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
550
+ )
551
+ if self.tokenizer.no_speech is not None:
552
+ # no-speech probability is collected separately
553
+ suppress_tokens.append(self.tokenizer.no_speech)
554
+
555
+ return tuple(sorted(set(suppress_tokens)))
556
+
557
+ def _get_audio_features(self, mel: Tensor, include_embeddings: bool = False):
558
+ if self.options.fp16:
559
+ mel = mel.half()
560
+
561
+ if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
562
+ # encoded audio features are given; skip audio encoding
563
+ audio_features = mel
564
+ else:
565
+ result = self.model.encoder(mel, include_embeddings)
566
+ if include_embeddings:
567
+ audio_features, embeddings = result
568
+ else:
569
+ audio_features = result
570
+
571
+ if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
572
+ return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
573
+
574
+ if include_embeddings:
575
+ return audio_features, embeddings
576
+ else:
577
+ return audio_features
578
+
579
+ def _detect_language(self, audio_features: Tensor, tokens: Tensor):
580
+ languages = [self.options.language] * audio_features.shape[0]
581
+ lang_probs = None
582
+
583
+ if self.options.language is None or self.options.task == "lang_id":
584
+ lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
585
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
586
+ if self.options.language is None:
587
+ tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
588
+
589
+ return languages, lang_probs
590
+
591
+ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
592
+ assert audio_features.shape[0] == tokens.shape[0]
593
+ n_batch = tokens.shape[0]
594
+ sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
595
+ no_speech_probs = [np.nan] * n_batch
596
+
597
+ try:
598
+ embeddings = []
599
+ for i in range(self.sample_len):
600
+ logits, token_embeddings = self.inference.logits(tokens, audio_features, include_embeddings=True)
601
+
602
+ if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
603
+ probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
604
+ no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
605
+
606
+ # now we need to consider the logits at the last token only
607
+ logits = logits[:, -1]
608
+ token_embeddings = token_embeddings[:, :, -1]
609
+
610
+ # Append embeddings together
611
+ embeddings.append(token_embeddings)
612
+
613
+ # apply the logit filters, e.g. for suppressing or applying penalty to
614
+ for logit_filter in self.logit_filters:
615
+ logit_filter.apply(logits, tokens)
616
+
617
+ # expand the tokens tensor with the selected next tokens
618
+ tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
619
+
620
+ if completed or tokens.shape[-1] > self.n_ctx:
621
+ break
622
+ finally:
623
+ if completed:
624
+ embeddings = embeddings[:-1]
625
+ embeddings = np.stack(embeddings, 2)
626
+ self.inference.cleanup_caching()
627
+
628
+ return tokens, sum_logprobs, no_speech_probs, embeddings
629
+
630
+ @torch.no_grad()
631
+ def run(self, mel: Tensor) -> List[DecodingResult]:
632
+ self.decoder.reset()
633
+ tokenizer: Tokenizer = self.tokenizer
634
+ n_audio: int = mel.shape[0]
635
+
636
+ # encoder forward pass
637
+ forward_pass: Tuple[Tensor, np.ndarray] = self._get_audio_features(mel, include_embeddings=True)
638
+ audio_features, encoder_embeddings = forward_pass
639
+ tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
640
+
641
+ # detect language if requested, overwriting the language token
642
+ languages, language_probs = self._detect_language(audio_features, tokens)
643
+ if self.options.task == "lang_id":
644
+ return [
645
+ DecodingResult(audio_features=features, language=language, language_probs=probs)
646
+ for features, language, probs in zip(audio_features, languages, language_probs)
647
+ ]
648
+
649
+ # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
650
+ audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
651
+ tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
652
+
653
+ # call the main sampling loop
654
+ tokens, sum_logprobs, no_speech_probs, decoder_embeddings = self._main_loop(audio_features, tokens)
655
+
656
+ # reshape the tensors to have (n_audio, n_group) as the first two dimensions
657
+ audio_features = audio_features[:: self.n_group]
658
+ no_speech_probs = no_speech_probs[:: self.n_group]
659
+ assert audio_features.shape[0] == len(no_speech_probs) == n_audio
660
+
661
+ tokens = tokens.reshape(n_audio, self.n_group, -1)
662
+ sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
663
+
664
+ # get the final candidates for each group, and slice between the first sampled token and EOT
665
+ tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
666
+ tokens: List[List[Tensor]] = [
667
+ [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
668
+ ]
669
+
670
+ # select the top-ranked sample in each group
671
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
672
+ tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
673
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
674
+
675
+ sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
676
+ avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
677
+
678
+ fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
679
+ if len(set(map(len, fields))) != 1:
680
+ raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
681
+
682
+ return [
683
+ DecodingResult(
684
+ audio_features=features,
685
+ language=language,
686
+ tokens=tokens,
687
+ text=text,
688
+ avg_logprob=avg_logprob,
689
+ no_speech_prob=no_speech_prob,
690
+ temperature=self.options.temperature,
691
+ compression_ratio=compression_ratio(text),
692
+ encoder_embeddings=encoder_embeddings,
693
+ decoder_embeddings=decoder_embeddings
694
+ )
695
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
696
+ ]
697
+
698
+
699
+ @torch.no_grad()
700
+ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
701
+ """
702
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
703
+
704
+ Parameters
705
+ ----------
706
+ model: Whisper
707
+ the Whisper model instance
708
+
709
+ mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
710
+ A tensor containing the Mel spectrogram(s)
711
+
712
+ options: DecodingOptions
713
+ A dataclass that contains all necessary options for decoding 30-second segments
714
+
715
+ Returns
716
+ -------
717
+ result: Union[DecodingResult, List[DecodingResult]]
718
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
719
+ """
720
+ single = mel.ndim == 2
721
+ if single:
722
+ mel = mel.unsqueeze(0)
723
+
724
+ result = DecodingTask(model, options).run(mel)
725
+
726
+ if single:
727
+ result = result[0]
728
+
729
+ return result
latentsync/whisper/whisper/model.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict
3
+ from typing import Iterable, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from torch import nn
10
+
11
+ from .transcribe import transcribe as transcribe_function
12
+ from .decoding import detect_language as detect_language_function, decode as decode_function
13
+
14
+
15
+ @dataclass
16
+ class ModelDimensions:
17
+ n_mels: int
18
+ n_audio_ctx: int
19
+ n_audio_state: int
20
+ n_audio_head: int
21
+ n_audio_layer: int
22
+ n_vocab: int
23
+ n_text_ctx: int
24
+ n_text_state: int
25
+ n_text_head: int
26
+ n_text_layer: int
27
+
28
+
29
+ class LayerNorm(nn.LayerNorm):
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ return super().forward(x.float()).type(x.dtype)
32
+
33
+
34
+ class Linear(nn.Linear):
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ return F.linear(
37
+ x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
38
+ )
39
+
40
+
41
+ class Conv1d(nn.Conv1d):
42
+ def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
43
+ return super()._conv_forward(
44
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
45
+ )
46
+
47
+
48
+ def sinusoids(length, channels, max_timescale=10000):
49
+ """Returns sinusoids for positional embedding"""
50
+ assert channels % 2 == 0
51
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
52
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
53
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
54
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
55
+
56
+
57
+ class MultiHeadAttention(nn.Module):
58
+ def __init__(self, n_state: int, n_head: int):
59
+ super().__init__()
60
+ self.n_head = n_head
61
+ self.query = Linear(n_state, n_state)
62
+ self.key = Linear(n_state, n_state, bias=False)
63
+ self.value = Linear(n_state, n_state)
64
+ self.out = Linear(n_state, n_state)
65
+
66
+ def forward(
67
+ self,
68
+ x: Tensor,
69
+ xa: Optional[Tensor] = None,
70
+ mask: Optional[Tensor] = None,
71
+ kv_cache: Optional[dict] = None,
72
+ ):
73
+ q = self.query(x)
74
+
75
+ if kv_cache is None or xa is None:
76
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
77
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
78
+ k = self.key(x if xa is None else xa)
79
+ v = self.value(x if xa is None else xa)
80
+ else:
81
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
82
+ k = kv_cache.get(self.key, self.key(xa))
83
+ v = kv_cache.get(self.value, self.value(xa))
84
+
85
+ wv = self.qkv_attention(q, k, v, mask)
86
+ return self.out(wv)
87
+
88
+ def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
89
+ n_batch, n_ctx, n_state = q.shape
90
+ scale = (n_state // self.n_head) ** -0.25
91
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
92
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
93
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
94
+
95
+ qk = q @ k
96
+ if mask is not None:
97
+ qk = qk + mask[:n_ctx, :n_ctx]
98
+
99
+ w = F.softmax(qk.float(), dim=-1).to(q.dtype)
100
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
101
+
102
+
103
+ class ResidualAttentionBlock(nn.Module):
104
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
105
+ super().__init__()
106
+
107
+ self.attn = MultiHeadAttention(n_state, n_head)
108
+ self.attn_ln = LayerNorm(n_state)
109
+
110
+ self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
111
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
112
+
113
+ n_mlp = n_state * 4
114
+ self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
115
+ self.mlp_ln = LayerNorm(n_state)
116
+
117
+ def forward(
118
+ self,
119
+ x: Tensor,
120
+ xa: Optional[Tensor] = None,
121
+ mask: Optional[Tensor] = None,
122
+ kv_cache: Optional[dict] = None,
123
+ ):
124
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
125
+ if self.cross_attn:
126
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
127
+ x = x + self.mlp(self.mlp_ln(x))
128
+ return x
129
+
130
+
131
+ class AudioEncoder(nn.Module):
132
+ def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
133
+ super().__init__()
134
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
135
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
136
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
137
+
138
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
139
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
140
+ )
141
+ self.ln_post = LayerNorm(n_state)
142
+
143
+ def forward(self, x: Tensor, include_embeddings: bool = False):
144
+ """
145
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
146
+ the mel spectrogram of the audio
147
+ include_embeddings: bool
148
+ whether to include intermediate steps in the output
149
+ """
150
+ x = F.gelu(self.conv1(x))
151
+ x = F.gelu(self.conv2(x))
152
+ x = x.permute(0, 2, 1)
153
+
154
+ assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
155
+ x = (x + self.positional_embedding).to(x.dtype)
156
+
157
+ if include_embeddings:
158
+ embeddings = [x.cpu().detach().numpy()]
159
+
160
+ for block in self.blocks:
161
+ x = block(x)
162
+ if include_embeddings:
163
+ embeddings.append(x.cpu().detach().numpy())
164
+
165
+ x = self.ln_post(x)
166
+
167
+ if include_embeddings:
168
+ embeddings = np.stack(embeddings, axis=1)
169
+ return x, embeddings
170
+ else:
171
+ return x
172
+
173
+
174
+ class TextDecoder(nn.Module):
175
+ def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
176
+ super().__init__()
177
+
178
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
179
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
180
+
181
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
182
+ [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
183
+ )
184
+ self.ln = LayerNorm(n_state)
185
+
186
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
187
+ self.register_buffer("mask", mask, persistent=False)
188
+
189
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None, include_embeddings: bool = False):
190
+ """
191
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
192
+ the text tokens
193
+ xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
194
+ the encoded audio features to be attended on
195
+ include_embeddings : bool
196
+ Whether to include intermediate values in the output to this function
197
+ """
198
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
199
+ x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
200
+ x = x.to(xa.dtype)
201
+
202
+ if include_embeddings:
203
+ embeddings = [x.cpu().detach().numpy()]
204
+
205
+ for block in self.blocks:
206
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
207
+ if include_embeddings:
208
+ embeddings.append(x.cpu().detach().numpy())
209
+
210
+ x = self.ln(x)
211
+ logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
212
+
213
+ if include_embeddings:
214
+ embeddings = np.stack(embeddings, axis=1)
215
+ return logits, embeddings
216
+ else:
217
+ return logits
218
+
219
+
220
+ class Whisper(nn.Module):
221
+ def __init__(self, dims: ModelDimensions):
222
+ super().__init__()
223
+ self.dims = dims
224
+ self.encoder = AudioEncoder(
225
+ self.dims.n_mels,
226
+ self.dims.n_audio_ctx,
227
+ self.dims.n_audio_state,
228
+ self.dims.n_audio_head,
229
+ self.dims.n_audio_layer,
230
+ )
231
+ self.decoder = TextDecoder(
232
+ self.dims.n_vocab,
233
+ self.dims.n_text_ctx,
234
+ self.dims.n_text_state,
235
+ self.dims.n_text_head,
236
+ self.dims.n_text_layer,
237
+ )
238
+
239
+ def embed_audio(self, mel: torch.Tensor):
240
+ return self.encoder.forward(mel)
241
+
242
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
243
+ return self.decoder.forward(tokens, audio_features)
244
+
245
+ def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
246
+ return self.decoder(tokens, self.encoder(mel))
247
+
248
+ @property
249
+ def device(self):
250
+ return next(self.parameters()).device
251
+
252
+ @property
253
+ def is_multilingual(self):
254
+ return self.dims.n_vocab == 51865
255
+
256
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
257
+ """
258
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
259
+ tensors calculated for the previous positions. This method returns a dictionary that stores
260
+ all caches, and the necessary hooks for the key and value projection modules that save the
261
+ intermediate tensors to be reused during later calculations.
262
+
263
+ Returns
264
+ -------
265
+ cache : Dict[nn.Module, torch.Tensor]
266
+ A dictionary object mapping the key/value projection modules to its cache
267
+ hooks : List[RemovableHandle]
268
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
269
+ """
270
+ cache = {**cache} if cache is not None else {}
271
+ hooks = []
272
+
273
+ def save_to_cache(module, _, output):
274
+ if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
275
+ cache[module] = output # save as-is, for the first token or cross attention
276
+ else:
277
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
278
+ return cache[module]
279
+
280
+ def install_hooks(layer: nn.Module):
281
+ if isinstance(layer, MultiHeadAttention):
282
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
283
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
284
+
285
+ self.decoder.apply(install_hooks)
286
+ return cache, hooks
287
+
288
+ detect_language = detect_language_function
289
+ transcribe = transcribe_function
290
+ decode = decode_function
latentsync/whisper/whisper/normalizers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .basic import BasicTextNormalizer
2
+ from .english import EnglishTextNormalizer
latentsync/whisper/whisper/normalizers/basic.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import unicodedata
3
+
4
+ import regex
5
+
6
+ # non-ASCII letters that are not separated by "NFKD" normalization
7
+ ADDITIONAL_DIACRITICS = {
8
+ "œ": "oe",
9
+ "Œ": "OE",
10
+ "ø": "o",
11
+ "Ø": "O",
12
+ "æ": "ae",
13
+ "Æ": "AE",
14
+ "ß": "ss",
15
+ "ẞ": "SS",
16
+ "đ": "d",
17
+ "Đ": "D",
18
+ "ð": "d",
19
+ "Ð": "D",
20
+ "þ": "th",
21
+ "Þ": "th",
22
+ "ł": "l",
23
+ "Ł": "L",
24
+ }
25
+
26
+
27
+ def remove_symbols_and_diacritics(s: str, keep=""):
28
+ """
29
+ Replace any other markers, symbols, and punctuations with a space,
30
+ and drop any diacritics (category 'Mn' and some manual mappings)
31
+ """
32
+ return "".join(
33
+ c
34
+ if c in keep
35
+ else ADDITIONAL_DIACRITICS[c]
36
+ if c in ADDITIONAL_DIACRITICS
37
+ else ""
38
+ if unicodedata.category(c) == "Mn"
39
+ else " "
40
+ if unicodedata.category(c)[0] in "MSP"
41
+ else c
42
+ for c in unicodedata.normalize("NFKD", s)
43
+ )
44
+
45
+
46
+ def remove_symbols(s: str):
47
+ """
48
+ Replace any other markers, symbols, punctuations with a space, keeping diacritics
49
+ """
50
+ return "".join(
51
+ " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
52
+ )
53
+
54
+
55
+ class BasicTextNormalizer:
56
+ def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
57
+ self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
58
+ self.split_letters = split_letters
59
+
60
+ def __call__(self, s: str):
61
+ s = s.lower()
62
+ s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
63
+ s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
64
+ s = self.clean(s).lower()
65
+
66
+ if self.split_letters:
67
+ s = " ".join(regex.findall(r"\X", s, regex.U))
68
+
69
+ s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
70
+
71
+ return s
latentsync/whisper/whisper/normalizers/english.json ADDED
@@ -0,0 +1,1742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "accessorise": "accessorize",
3
+ "accessorised": "accessorized",
4
+ "accessorises": "accessorizes",
5
+ "accessorising": "accessorizing",
6
+ "acclimatisation": "acclimatization",
7
+ "acclimatise": "acclimatize",
8
+ "acclimatised": "acclimatized",
9
+ "acclimatises": "acclimatizes",
10
+ "acclimatising": "acclimatizing",
11
+ "accoutrements": "accouterments",
12
+ "aeon": "eon",
13
+ "aeons": "eons",
14
+ "aerogramme": "aerogram",
15
+ "aerogrammes": "aerograms",
16
+ "aeroplane": "airplane",
17
+ "aeroplanes": "airplanes",
18
+ "aesthete": "esthete",
19
+ "aesthetes": "esthetes",
20
+ "aesthetic": "esthetic",
21
+ "aesthetically": "esthetically",
22
+ "aesthetics": "esthetics",
23
+ "aetiology": "etiology",
24
+ "ageing": "aging",
25
+ "aggrandisement": "aggrandizement",
26
+ "agonise": "agonize",
27
+ "agonised": "agonized",
28
+ "agonises": "agonizes",
29
+ "agonising": "agonizing",
30
+ "agonisingly": "agonizingly",
31
+ "almanack": "almanac",
32
+ "almanacks": "almanacs",
33
+ "aluminium": "aluminum",
34
+ "amortisable": "amortizable",
35
+ "amortisation": "amortization",
36
+ "amortisations": "amortizations",
37
+ "amortise": "amortize",
38
+ "amortised": "amortized",
39
+ "amortises": "amortizes",
40
+ "amortising": "amortizing",
41
+ "amphitheatre": "amphitheater",
42
+ "amphitheatres": "amphitheaters",
43
+ "anaemia": "anemia",
44
+ "anaemic": "anemic",
45
+ "anaesthesia": "anesthesia",
46
+ "anaesthetic": "anesthetic",
47
+ "anaesthetics": "anesthetics",
48
+ "anaesthetise": "anesthetize",
49
+ "anaesthetised": "anesthetized",
50
+ "anaesthetises": "anesthetizes",
51
+ "anaesthetising": "anesthetizing",
52
+ "anaesthetist": "anesthetist",
53
+ "anaesthetists": "anesthetists",
54
+ "anaesthetize": "anesthetize",
55
+ "anaesthetized": "anesthetized",
56
+ "anaesthetizes": "anesthetizes",
57
+ "anaesthetizing": "anesthetizing",
58
+ "analogue": "analog",
59
+ "analogues": "analogs",
60
+ "analyse": "analyze",
61
+ "analysed": "analyzed",
62
+ "analyses": "analyzes",
63
+ "analysing": "analyzing",
64
+ "anglicise": "anglicize",
65
+ "anglicised": "anglicized",
66
+ "anglicises": "anglicizes",
67
+ "anglicising": "anglicizing",
68
+ "annualised": "annualized",
69
+ "antagonise": "antagonize",
70
+ "antagonised": "antagonized",
71
+ "antagonises": "antagonizes",
72
+ "antagonising": "antagonizing",
73
+ "apologise": "apologize",
74
+ "apologised": "apologized",
75
+ "apologises": "apologizes",
76
+ "apologising": "apologizing",
77
+ "appal": "appall",
78
+ "appals": "appalls",
79
+ "appetiser": "appetizer",
80
+ "appetisers": "appetizers",
81
+ "appetising": "appetizing",
82
+ "appetisingly": "appetizingly",
83
+ "arbour": "arbor",
84
+ "arbours": "arbors",
85
+ "archeological": "archaeological",
86
+ "archaeologically": "archeologically",
87
+ "archaeologist": "archeologist",
88
+ "archaeologists": "archeologists",
89
+ "archaeology": "archeology</span>",
90
+ "ardour": "ardor",
91
+ "armour": "armor",
92
+ "armoured": "armored",
93
+ "armourer": "armorer",
94
+ "armourers": "armorers",
95
+ "armouries": "armories",
96
+ "armoury": "armory",
97
+ "artefact": "artifact",
98
+ "artefacts": "artifacts",
99
+ "authorise": "authorize",
100
+ "authorised": "authorized",
101
+ "authorises": "authorizes",
102
+ "authorising": "authorizing",
103
+ "axe": "ax",
104
+ "backpedalled": "backpedaled",
105
+ "backpedalling": "backpedaling",
106
+ "bannister": "banister",
107
+ "bannisters": "banisters",
108
+ "baptise": "baptize",
109
+ "baptised": "baptized",
110
+ "baptises": "baptizes",
111
+ "baptising": "baptizing",
112
+ "bastardise": "bastardize",
113
+ "bastardised": "bastardized",
114
+ "bastardises": "bastardizes",
115
+ "bastardising": "bastardizing",
116
+ "battleax": "battleaxe",
117
+ "baulk": "balk",
118
+ "baulked": "balked",
119
+ "baulking": "balking",
120
+ "baulks": "balks",
121
+ "bedevilled": "bedeviled",
122
+ "bedevilling": "bedeviling",
123
+ "behaviour": "behavior",
124
+ "behavioural": "behavioral",
125
+ "behaviourism": "behaviorism",
126
+ "behaviourist": "behaviorist",
127
+ "behaviourists": "behaviorists",
128
+ "behaviours": "behaviors",
129
+ "behove": "behoove",
130
+ "behoved": "behooved",
131
+ "behoves": "behooves",
132
+ "bejewelled": "bejeweled",
133
+ "belabour": "belabor",
134
+ "belaboured": "belabored",
135
+ "belabouring": "belaboring",
136
+ "belabours": "belabors",
137
+ "bevelled": "beveled",
138
+ "bevvies": "bevies",
139
+ "bevvy": "bevy",
140
+ "biassed": "biased",
141
+ "biassing": "biasing",
142
+ "bingeing": "binging",
143
+ "bougainvillaea": "bougainvillea",
144
+ "bougainvillaeas": "bougainvilleas",
145
+ "bowdlerise": "bowdlerize",
146
+ "bowdlerised": "bowdlerized",
147
+ "bowdlerises": "bowdlerizes",
148
+ "bowdlerising": "bowdlerizing",
149
+ "breathalyse": "breathalyze",
150
+ "breathalysed": "breathalyzed",
151
+ "breathalyser": "breathalyzer",
152
+ "breathalysers": "breathalyzers",
153
+ "breathalyses": "breathalyzes",
154
+ "breathalysing": "breathalyzing",
155
+ "brutalise": "brutalize",
156
+ "brutalised": "brutalized",
157
+ "brutalises": "brutalizes",
158
+ "brutalising": "brutalizing",
159
+ "busses": "buses",
160
+ "bussing": "busing",
161
+ "caesarean": "cesarean",
162
+ "caesareans": "cesareans",
163
+ "calibre": "caliber",
164
+ "calibres": "calibers",
165
+ "calliper": "caliper",
166
+ "callipers": "calipers",
167
+ "callisthenics": "calisthenics",
168
+ "canalise": "canalize",
169
+ "canalised": "canalized",
170
+ "canalises": "canalizes",
171
+ "canalising": "canalizing",
172
+ "cancelation": "cancellation",
173
+ "cancelations": "cancellations",
174
+ "cancelled": "canceled",
175
+ "cancelling": "canceling",
176
+ "candour": "candor",
177
+ "cannibalise": "cannibalize",
178
+ "cannibalised": "cannibalized",
179
+ "cannibalises": "cannibalizes",
180
+ "cannibalising": "cannibalizing",
181
+ "canonise": "canonize",
182
+ "canonised": "canonized",
183
+ "canonises": "canonizes",
184
+ "canonising": "canonizing",
185
+ "capitalise": "capitalize",
186
+ "capitalised": "capitalized",
187
+ "capitalises": "capitalizes",
188
+ "capitalising": "capitalizing",
189
+ "caramelise": "caramelize",
190
+ "caramelised": "caramelized",
191
+ "caramelises": "caramelizes",
192
+ "caramelising": "caramelizing",
193
+ "carbonise": "carbonize",
194
+ "carbonised": "carbonized",
195
+ "carbonises": "carbonizes",
196
+ "carbonising": "carbonizing",
197
+ "carolled": "caroled",
198
+ "carolling": "caroling",
199
+ "catalogue": "catalog",
200
+ "catalogued": "cataloged",
201
+ "catalogues": "catalogs",
202
+ "cataloguing": "cataloging",
203
+ "catalyse": "catalyze",
204
+ "catalysed": "catalyzed",
205
+ "catalyses": "catalyzes",
206
+ "catalysing": "catalyzing",
207
+ "categorise": "categorize",
208
+ "categorised": "categorized",
209
+ "categorises": "categorizes",
210
+ "categorising": "categorizing",
211
+ "cauterise": "cauterize",
212
+ "cauterised": "cauterized",
213
+ "cauterises": "cauterizes",
214
+ "cauterising": "cauterizing",
215
+ "cavilled": "caviled",
216
+ "cavilling": "caviling",
217
+ "centigramme": "centigram",
218
+ "centigrammes": "centigrams",
219
+ "centilitre": "centiliter",
220
+ "centilitres": "centiliters",
221
+ "centimetre": "centimeter",
222
+ "centimetres": "centimeters",
223
+ "centralise": "centralize",
224
+ "centralised": "centralized",
225
+ "centralises": "centralizes",
226
+ "centralising": "centralizing",
227
+ "centre": "center",
228
+ "centred": "centered",
229
+ "centrefold": "centerfold",
230
+ "centrefolds": "centerfolds",
231
+ "centrepiece": "centerpiece",
232
+ "centrepieces": "centerpieces",
233
+ "centres": "centers",
234
+ "channelled": "channeled",
235
+ "channelling": "channeling",
236
+ "characterise": "characterize",
237
+ "characterised": "characterized",
238
+ "characterises": "characterizes",
239
+ "characterising": "characterizing",
240
+ "cheque": "check",
241
+ "chequebook": "checkbook",
242
+ "chequebooks": "checkbooks",
243
+ "chequered": "checkered",
244
+ "cheques": "checks",
245
+ "chilli": "chili",
246
+ "chimaera": "chimera",
247
+ "chimaeras": "chimeras",
248
+ "chiselled": "chiseled",
249
+ "chiselling": "chiseling",
250
+ "circularise": "circularize",
251
+ "circularised": "circularized",
252
+ "circularises": "circularizes",
253
+ "circularising": "circularizing",
254
+ "civilise": "civilize",
255
+ "civilised": "civilized",
256
+ "civilises": "civilizes",
257
+ "civilising": "civilizing",
258
+ "clamour": "clamor",
259
+ "clamoured": "clamored",
260
+ "clamouring": "clamoring",
261
+ "clamours": "clamors",
262
+ "clangour": "clangor",
263
+ "clarinettist": "clarinetist",
264
+ "clarinettists": "clarinetists",
265
+ "collectivise": "collectivize",
266
+ "collectivised": "collectivized",
267
+ "collectivises": "collectivizes",
268
+ "collectivising": "collectivizing",
269
+ "colonisation": "colonization",
270
+ "colonise": "colonize",
271
+ "colonised": "colonized",
272
+ "coloniser": "colonizer",
273
+ "colonisers": "colonizers",
274
+ "colonises": "colonizes",
275
+ "colonising": "colonizing",
276
+ "colour": "color",
277
+ "colourant": "colorant",
278
+ "colourants": "colorants",
279
+ "coloured": "colored",
280
+ "coloureds": "coloreds",
281
+ "colourful": "colorful",
282
+ "colourfully": "colorfully",
283
+ "colouring": "coloring",
284
+ "colourize": "colorize",
285
+ "colourized": "colorized",
286
+ "colourizes": "colorizes",
287
+ "colourizing": "colorizing",
288
+ "colourless": "colorless",
289
+ "colours": "colors",
290
+ "commercialise": "commercialize",
291
+ "commercialised": "commercialized",
292
+ "commercialises": "commercializes",
293
+ "commercialising": "commercializing",
294
+ "compartmentalise": "compartmentalize",
295
+ "compartmentalised": "compartmentalized",
296
+ "compartmentalises": "compartmentalizes",
297
+ "compartmentalising": "compartmentalizing",
298
+ "computerise": "computerize",
299
+ "computerised": "computerized",
300
+ "computerises": "computerizes",
301
+ "computerising": "computerizing",
302
+ "conceptualise": "conceptualize",
303
+ "conceptualised": "conceptualized",
304
+ "conceptualises": "conceptualizes",
305
+ "conceptualising": "conceptualizing",
306
+ "connexion": "connection",
307
+ "connexions": "connections",
308
+ "contextualise": "contextualize",
309
+ "contextualised": "contextualized",
310
+ "contextualises": "contextualizes",
311
+ "contextualising": "contextualizing",
312
+ "cosier": "cozier",
313
+ "cosies": "cozies",
314
+ "cosiest": "coziest",
315
+ "cosily": "cozily",
316
+ "cosiness": "coziness",
317
+ "cosy": "cozy",
318
+ "councillor": "councilor",
319
+ "councillors": "councilors",
320
+ "counselled": "counseled",
321
+ "counselling": "counseling",
322
+ "counsellor": "counselor",
323
+ "counsellors": "counselors",
324
+ "crenelated": "crenellated",
325
+ "criminalise": "criminalize",
326
+ "criminalised": "criminalized",
327
+ "criminalises": "criminalizes",
328
+ "criminalising": "criminalizing",
329
+ "criticise": "criticize",
330
+ "criticised": "criticized",
331
+ "criticises": "criticizes",
332
+ "criticising": "criticizing",
333
+ "crueller": "crueler",
334
+ "cruellest": "cruelest",
335
+ "crystallisation": "crystallization",
336
+ "crystallise": "crystallize",
337
+ "crystallised": "crystallized",
338
+ "crystallises": "crystallizes",
339
+ "crystallising": "crystallizing",
340
+ "cudgelled": "cudgeled",
341
+ "cudgelling": "cudgeling",
342
+ "customise": "customize",
343
+ "customised": "customized",
344
+ "customises": "customizes",
345
+ "customising": "customizing",
346
+ "cypher": "cipher",
347
+ "cyphers": "ciphers",
348
+ "decentralisation": "decentralization",
349
+ "decentralise": "decentralize",
350
+ "decentralised": "decentralized",
351
+ "decentralises": "decentralizes",
352
+ "decentralising": "decentralizing",
353
+ "decriminalisation": "decriminalization",
354
+ "decriminalise": "decriminalize",
355
+ "decriminalised": "decriminalized",
356
+ "decriminalises": "decriminalizes",
357
+ "decriminalising": "decriminalizing",
358
+ "defence": "defense",
359
+ "defenceless": "defenseless",
360
+ "defences": "defenses",
361
+ "dehumanisation": "dehumanization",
362
+ "dehumanise": "dehumanize",
363
+ "dehumanised": "dehumanized",
364
+ "dehumanises": "dehumanizes",
365
+ "dehumanising": "dehumanizing",
366
+ "demeanour": "demeanor",
367
+ "demilitarisation": "demilitarization",
368
+ "demilitarise": "demilitarize",
369
+ "demilitarised": "demilitarized",
370
+ "demilitarises": "demilitarizes",
371
+ "demilitarising": "demilitarizing",
372
+ "demobilisation": "demobilization",
373
+ "demobilise": "demobilize",
374
+ "demobilised": "demobilized",
375
+ "demobilises": "demobilizes",
376
+ "demobilising": "demobilizing",
377
+ "democratisation": "democratization",
378
+ "democratise": "democratize",
379
+ "democratised": "democratized",
380
+ "democratises": "democratizes",
381
+ "democratising": "democratizing",
382
+ "demonise": "demonize",
383
+ "demonised": "demonized",
384
+ "demonises": "demonizes",
385
+ "demonising": "demonizing",
386
+ "demoralisation": "demoralization",
387
+ "demoralise": "demoralize",
388
+ "demoralised": "demoralized",
389
+ "demoralises": "demoralizes",
390
+ "demoralising": "demoralizing",
391
+ "denationalisation": "denationalization",
392
+ "denationalise": "denationalize",
393
+ "denationalised": "denationalized",
394
+ "denationalises": "denationalizes",
395
+ "denationalising": "denationalizing",
396
+ "deodorise": "deodorize",
397
+ "deodorised": "deodorized",
398
+ "deodorises": "deodorizes",
399
+ "deodorising": "deodorizing",
400
+ "depersonalise": "depersonalize",
401
+ "depersonalised": "depersonalized",
402
+ "depersonalises": "depersonalizes",
403
+ "depersonalising": "depersonalizing",
404
+ "deputise": "deputize",
405
+ "deputised": "deputized",
406
+ "deputises": "deputizes",
407
+ "deputising": "deputizing",
408
+ "desensitisation": "desensitization",
409
+ "desensitise": "desensitize",
410
+ "desensitised": "desensitized",
411
+ "desensitises": "desensitizes",
412
+ "desensitising": "desensitizing",
413
+ "destabilisation": "destabilization",
414
+ "destabilise": "destabilize",
415
+ "destabilised": "destabilized",
416
+ "destabilises": "destabilizes",
417
+ "destabilising": "destabilizing",
418
+ "dialled": "dialed",
419
+ "dialling": "dialing",
420
+ "dialogue": "dialog",
421
+ "dialogues": "dialogs",
422
+ "diarrhoea": "diarrhea",
423
+ "digitise": "digitize",
424
+ "digitised": "digitized",
425
+ "digitises": "digitizes",
426
+ "digitising": "digitizing",
427
+ "disc": "disk",
428
+ "discolour": "discolor",
429
+ "discoloured": "discolored",
430
+ "discolouring": "discoloring",
431
+ "discolours": "discolors",
432
+ "discs": "disks",
433
+ "disembowelled": "disemboweled",
434
+ "disembowelling": "disemboweling",
435
+ "disfavour": "disfavor",
436
+ "dishevelled": "disheveled",
437
+ "dishonour": "dishonor",
438
+ "dishonourable": "dishonorable",
439
+ "dishonourably": "dishonorably",
440
+ "dishonoured": "dishonored",
441
+ "dishonouring": "dishonoring",
442
+ "dishonours": "dishonors",
443
+ "disorganisation": "disorganization",
444
+ "disorganised": "disorganized",
445
+ "distil": "distill",
446
+ "distils": "distills",
447
+ "dramatisation": "dramatization",
448
+ "dramatisations": "dramatizations",
449
+ "dramatise": "dramatize",
450
+ "dramatised": "dramatized",
451
+ "dramatises": "dramatizes",
452
+ "dramatising": "dramatizing",
453
+ "draught": "draft",
454
+ "draughtboard": "draftboard",
455
+ "draughtboards": "draftboards",
456
+ "draughtier": "draftier",
457
+ "draughtiest": "draftiest",
458
+ "draughts": "drafts",
459
+ "draughtsman": "draftsman",
460
+ "draughtsmanship": "draftsmanship",
461
+ "draughtsmen": "draftsmen",
462
+ "draughtswoman": "draftswoman",
463
+ "draughtswomen": "draftswomen",
464
+ "draughty": "drafty",
465
+ "drivelled": "driveled",
466
+ "drivelling": "driveling",
467
+ "duelled": "dueled",
468
+ "duelling": "dueling",
469
+ "economise": "economize",
470
+ "economised": "economized",
471
+ "economises": "economizes",
472
+ "economising": "economizing",
473
+ "edoema": "edema",
474
+ "editorialise": "editorialize",
475
+ "editorialised": "editorialized",
476
+ "editorialises": "editorializes",
477
+ "editorialising": "editorializing",
478
+ "empathise": "empathize",
479
+ "empathised": "empathized",
480
+ "empathises": "empathizes",
481
+ "empathising": "empathizing",
482
+ "emphasise": "emphasize",
483
+ "emphasised": "emphasized",
484
+ "emphasises": "emphasizes",
485
+ "emphasising": "emphasizing",
486
+ "enamelled": "enameled",
487
+ "enamelling": "enameling",
488
+ "enamoured": "enamored",
489
+ "encyclopaedia": "encyclopedia",
490
+ "encyclopaedias": "encyclopedias",
491
+ "encyclopaedic": "encyclopedic",
492
+ "endeavour": "endeavor",
493
+ "endeavoured": "endeavored",
494
+ "endeavouring": "endeavoring",
495
+ "endeavours": "endeavors",
496
+ "energise": "energize",
497
+ "energised": "energized",
498
+ "energises": "energizes",
499
+ "energising": "energizing",
500
+ "enrol": "enroll",
501
+ "enrols": "enrolls",
502
+ "enthral": "enthrall",
503
+ "enthrals": "enthralls",
504
+ "epaulette": "epaulet",
505
+ "epaulettes": "epaulets",
506
+ "epicentre": "epicenter",
507
+ "epicentres": "epicenters",
508
+ "epilogue": "epilog",
509
+ "epilogues": "epilogs",
510
+ "epitomise": "epitomize",
511
+ "epitomised": "epitomized",
512
+ "epitomises": "epitomizes",
513
+ "epitomising": "epitomizing",
514
+ "equalisation": "equalization",
515
+ "equalise": "equalize",
516
+ "equalised": "equalized",
517
+ "equaliser": "equalizer",
518
+ "equalisers": "equalizers",
519
+ "equalises": "equalizes",
520
+ "equalising": "equalizing",
521
+ "eulogise": "eulogize",
522
+ "eulogised": "eulogized",
523
+ "eulogises": "eulogizes",
524
+ "eulogising": "eulogizing",
525
+ "evangelise": "evangelize",
526
+ "evangelised": "evangelized",
527
+ "evangelises": "evangelizes",
528
+ "evangelising": "evangelizing",
529
+ "exorcise": "exorcize",
530
+ "exorcised": "exorcized",
531
+ "exorcises": "exorcizes",
532
+ "exorcising": "exorcizing",
533
+ "extemporisation": "extemporization",
534
+ "extemporise": "extemporize",
535
+ "extemporised": "extemporized",
536
+ "extemporises": "extemporizes",
537
+ "extemporising": "extemporizing",
538
+ "externalisation": "externalization",
539
+ "externalisations": "externalizations",
540
+ "externalise": "externalize",
541
+ "externalised": "externalized",
542
+ "externalises": "externalizes",
543
+ "externalising": "externalizing",
544
+ "factorise": "factorize",
545
+ "factorised": "factorized",
546
+ "factorises": "factorizes",
547
+ "factorising": "factorizing",
548
+ "faecal": "fecal",
549
+ "faeces": "feces",
550
+ "familiarisation": "familiarization",
551
+ "familiarise": "familiarize",
552
+ "familiarised": "familiarized",
553
+ "familiarises": "familiarizes",
554
+ "familiarising": "familiarizing",
555
+ "fantasise": "fantasize",
556
+ "fantasised": "fantasized",
557
+ "fantasises": "fantasizes",
558
+ "fantasising": "fantasizing",
559
+ "favour": "favor",
560
+ "favourable": "favorable",
561
+ "favourably": "favorably",
562
+ "favoured": "favored",
563
+ "favouring": "favoring",
564
+ "favourite": "favorite",
565
+ "favourites": "favorites",
566
+ "favouritism": "favoritism",
567
+ "favours": "favors",
568
+ "feminise": "feminize",
569
+ "feminised": "feminized",
570
+ "feminises": "feminizes",
571
+ "feminising": "feminizing",
572
+ "fertilisation": "fertilization",
573
+ "fertilise": "fertilize",
574
+ "fertilised": "fertilized",
575
+ "fertiliser": "fertilizer",
576
+ "fertilisers": "fertilizers",
577
+ "fertilises": "fertilizes",
578
+ "fertilising": "fertilizing",
579
+ "fervour": "fervor",
580
+ "fibre": "fiber",
581
+ "fibreglass": "fiberglass",
582
+ "fibres": "fibers",
583
+ "fictionalisation": "fictionalization",
584
+ "fictionalisations": "fictionalizations",
585
+ "fictionalise": "fictionalize",
586
+ "fictionalised": "fictionalized",
587
+ "fictionalises": "fictionalizes",
588
+ "fictionalising": "fictionalizing",
589
+ "fillet": "filet",
590
+ "filleted": "fileted",
591
+ "filleting": "fileting",
592
+ "fillets": "filets",
593
+ "finalisation": "finalization",
594
+ "finalise": "finalize",
595
+ "finalised": "finalized",
596
+ "finalises": "finalizes",
597
+ "finalising": "finalizing",
598
+ "flautist": "flutist",
599
+ "flautists": "flutists",
600
+ "flavour": "flavor",
601
+ "flavoured": "flavored",
602
+ "flavouring": "flavoring",
603
+ "flavourings": "flavorings",
604
+ "flavourless": "flavorless",
605
+ "flavours": "flavors",
606
+ "flavoursome": "flavorsome",
607
+ "flyer / flier": "flier / flyer",
608
+ "foetal": "fetal",
609
+ "foetid": "fetid",
610
+ "foetus": "fetus",
611
+ "foetuses": "fetuses",
612
+ "formalisation": "formalization",
613
+ "formalise": "formalize",
614
+ "formalised": "formalized",
615
+ "formalises": "formalizes",
616
+ "formalising": "formalizing",
617
+ "fossilisation": "fossilization",
618
+ "fossilise": "fossilize",
619
+ "fossilised": "fossilized",
620
+ "fossilises": "fossilizes",
621
+ "fossilising": "fossilizing",
622
+ "fraternisation": "fraternization",
623
+ "fraternise": "fraternize",
624
+ "fraternised": "fraternized",
625
+ "fraternises": "fraternizes",
626
+ "fraternising": "fraternizing",
627
+ "fulfil": "fulfill",
628
+ "fulfilment": "fulfillment",
629
+ "fulfils": "fulfills",
630
+ "funnelled": "funneled",
631
+ "funnelling": "funneling",
632
+ "galvanise": "galvanize",
633
+ "galvanised": "galvanized",
634
+ "galvanises": "galvanizes",
635
+ "galvanising": "galvanizing",
636
+ "gambolled": "gamboled",
637
+ "gambolling": "gamboling",
638
+ "gaol": "jail",
639
+ "gaolbird": "jailbird",
640
+ "gaolbirds": "jailbirds",
641
+ "gaolbreak": "jailbreak",
642
+ "gaolbreaks": "jailbreaks",
643
+ "gaoled": "jailed",
644
+ "gaoler": "jailer",
645
+ "gaolers": "jailers",
646
+ "gaoling": "jailing",
647
+ "gaols": "jails",
648
+ "gasses": "gases",
649
+ "gage": "gauge",
650
+ "gaged": "gauged",
651
+ "gages": "gauges",
652
+ "gaging": "gauging",
653
+ "generalisation": "generalization",
654
+ "generalisations": "generalizations",
655
+ "generalise": "generalize",
656
+ "generalised": "generalized",
657
+ "generalises": "generalizes",
658
+ "generalising": "generalizing",
659
+ "ghettoise": "ghettoize",
660
+ "ghettoised": "ghettoized",
661
+ "ghettoises": "ghettoizes",
662
+ "ghettoising": "ghettoizing",
663
+ "gipsies": "gypsies",
664
+ "glamorise": "glamorize",
665
+ "glamorised": "glamorized",
666
+ "glamorises": "glamorizes",
667
+ "glamorising": "glamorizing",
668
+ "glamor": "glamour",
669
+ "globalisation": "globalization",
670
+ "globalise": "globalize",
671
+ "globalised": "globalized",
672
+ "globalises": "globalizes",
673
+ "globalising": "globalizing",
674
+ "glueing": "gluing",
675
+ "goitre": "goiter",
676
+ "goitres": "goiters",
677
+ "gonorrhoea": "gonorrhea",
678
+ "gramme": "gram",
679
+ "grammes": "grams",
680
+ "gravelled": "graveled",
681
+ "grey": "gray",
682
+ "greyed": "grayed",
683
+ "greying": "graying",
684
+ "greyish": "grayish",
685
+ "greyness": "grayness",
686
+ "greys": "grays",
687
+ "grovelled": "groveled",
688
+ "grovelling": "groveling",
689
+ "groyne": "groin",
690
+ "groynes": "groins",
691
+ "gruelling": "grueling",
692
+ "gruellingly": "gruelingly",
693
+ "gryphon": "griffin",
694
+ "gryphons": "griffins",
695
+ "gynaecological": "gynecological",
696
+ "gynaecologist": "gynecologist",
697
+ "gynaecologists": "gynecologists",
698
+ "gynaecology": "gynecology",
699
+ "haematological": "hematological",
700
+ "haematologist": "hematologist",
701
+ "haematologists": "hematologists",
702
+ "haematology": "hematology",
703
+ "haemoglobin": "hemoglobin",
704
+ "haemophilia": "hemophilia",
705
+ "haemophiliac": "hemophiliac",
706
+ "haemophiliacs": "hemophiliacs",
707
+ "haemorrhage": "hemorrhage",
708
+ "haemorrhaged": "hemorrhaged",
709
+ "haemorrhages": "hemorrhages",
710
+ "haemorrhaging": "hemorrhaging",
711
+ "haemorrhoids": "hemorrhoids",
712
+ "harbour": "harbor",
713
+ "harboured": "harbored",
714
+ "harbouring": "harboring",
715
+ "harbours": "harbors",
716
+ "harmonisation": "harmonization",
717
+ "harmonise": "harmonize",
718
+ "harmonised": "harmonized",
719
+ "harmonises": "harmonizes",
720
+ "harmonising": "harmonizing",
721
+ "homoeopath": "homeopath",
722
+ "homoeopathic": "homeopathic",
723
+ "homoeopaths": "homeopaths",
724
+ "homoeopathy": "homeopathy",
725
+ "homogenise": "homogenize",
726
+ "homogenised": "homogenized",
727
+ "homogenises": "homogenizes",
728
+ "homogenising": "homogenizing",
729
+ "honour": "honor",
730
+ "honourable": "honorable",
731
+ "honourably": "honorably",
732
+ "honoured": "honored",
733
+ "honouring": "honoring",
734
+ "honours": "honors",
735
+ "hospitalisation": "hospitalization",
736
+ "hospitalise": "hospitalize",
737
+ "hospitalised": "hospitalized",
738
+ "hospitalises": "hospitalizes",
739
+ "hospitalising": "hospitalizing",
740
+ "humanise": "humanize",
741
+ "humanised": "humanized",
742
+ "humanises": "humanizes",
743
+ "humanising": "humanizing",
744
+ "humour": "humor",
745
+ "humoured": "humored",
746
+ "humouring": "humoring",
747
+ "humourless": "humorless",
748
+ "humours": "humors",
749
+ "hybridise": "hybridize",
750
+ "hybridised": "hybridized",
751
+ "hybridises": "hybridizes",
752
+ "hybridising": "hybridizing",
753
+ "hypnotise": "hypnotize",
754
+ "hypnotised": "hypnotized",
755
+ "hypnotises": "hypnotizes",
756
+ "hypnotising": "hypnotizing",
757
+ "hypothesise": "hypothesize",
758
+ "hypothesised": "hypothesized",
759
+ "hypothesises": "hypothesizes",
760
+ "hypothesising": "hypothesizing",
761
+ "idealisation": "idealization",
762
+ "idealise": "idealize",
763
+ "idealised": "idealized",
764
+ "idealises": "idealizes",
765
+ "idealising": "idealizing",
766
+ "idolise": "idolize",
767
+ "idolised": "idolized",
768
+ "idolises": "idolizes",
769
+ "idolising": "idolizing",
770
+ "immobilisation": "immobilization",
771
+ "immobilise": "immobilize",
772
+ "immobilised": "immobilized",
773
+ "immobiliser": "immobilizer",
774
+ "immobilisers": "immobilizers",
775
+ "immobilises": "immobilizes",
776
+ "immobilising": "immobilizing",
777
+ "immortalise": "immortalize",
778
+ "immortalised": "immortalized",
779
+ "immortalises": "immortalizes",
780
+ "immortalising": "immortalizing",
781
+ "immunisation": "immunization",
782
+ "immunise": "immunize",
783
+ "immunised": "immunized",
784
+ "immunises": "immunizes",
785
+ "immunising": "immunizing",
786
+ "impanelled": "impaneled",
787
+ "impanelling": "impaneling",
788
+ "imperilled": "imperiled",
789
+ "imperilling": "imperiling",
790
+ "individualise": "individualize",
791
+ "individualised": "individualized",
792
+ "individualises": "individualizes",
793
+ "individualising": "individualizing",
794
+ "industrialise": "industrialize",
795
+ "industrialised": "industrialized",
796
+ "industrialises": "industrializes",
797
+ "industrialising": "industrializing",
798
+ "inflexion": "inflection",
799
+ "inflexions": "inflections",
800
+ "initialise": "initialize",
801
+ "initialised": "initialized",
802
+ "initialises": "initializes",
803
+ "initialising": "initializing",
804
+ "initialled": "initialed",
805
+ "initialling": "initialing",
806
+ "instal": "install",
807
+ "instalment": "installment",
808
+ "instalments": "installments",
809
+ "instals": "installs",
810
+ "instil": "instill",
811
+ "instils": "instills",
812
+ "institutionalisation": "institutionalization",
813
+ "institutionalise": "institutionalize",
814
+ "institutionalised": "institutionalized",
815
+ "institutionalises": "institutionalizes",
816
+ "institutionalising": "institutionalizing",
817
+ "intellectualise": "intellectualize",
818
+ "intellectualised": "intellectualized",
819
+ "intellectualises": "intellectualizes",
820
+ "intellectualising": "intellectualizing",
821
+ "internalisation": "internalization",
822
+ "internalise": "internalize",
823
+ "internalised": "internalized",
824
+ "internalises": "internalizes",
825
+ "internalising": "internalizing",
826
+ "internationalisation": "internationalization",
827
+ "internationalise": "internationalize",
828
+ "internationalised": "internationalized",
829
+ "internationalises": "internationalizes",
830
+ "internationalising": "internationalizing",
831
+ "ionisation": "ionization",
832
+ "ionise": "ionize",
833
+ "ionised": "ionized",
834
+ "ioniser": "ionizer",
835
+ "ionisers": "ionizers",
836
+ "ionises": "ionizes",
837
+ "ionising": "ionizing",
838
+ "italicise": "italicize",
839
+ "italicised": "italicized",
840
+ "italicises": "italicizes",
841
+ "italicising": "italicizing",
842
+ "itemise": "itemize",
843
+ "itemised": "itemized",
844
+ "itemises": "itemizes",
845
+ "itemising": "itemizing",
846
+ "jeopardise": "jeopardize",
847
+ "jeopardised": "jeopardized",
848
+ "jeopardises": "jeopardizes",
849
+ "jeopardising": "jeopardizing",
850
+ "jewelled": "jeweled",
851
+ "jeweller": "jeweler",
852
+ "jewellers": "jewelers",
853
+ "jewellery": "jewelry",
854
+ "judgement": "judgment",
855
+ "kilogramme": "kilogram",
856
+ "kilogrammes": "kilograms",
857
+ "kilometre": "kilometer",
858
+ "kilometres": "kilometers",
859
+ "labelled": "labeled",
860
+ "labelling": "labeling",
861
+ "labour": "labor",
862
+ "laboured": "labored",
863
+ "labourer": "laborer",
864
+ "labourers": "laborers",
865
+ "labouring": "laboring",
866
+ "labours": "labors",
867
+ "lacklustre": "lackluster",
868
+ "legalisation": "legalization",
869
+ "legalise": "legalize",
870
+ "legalised": "legalized",
871
+ "legalises": "legalizes",
872
+ "legalising": "legalizing",
873
+ "legitimise": "legitimize",
874
+ "legitimised": "legitimized",
875
+ "legitimises": "legitimizes",
876
+ "legitimising": "legitimizing",
877
+ "leukaemia": "leukemia",
878
+ "levelled": "leveled",
879
+ "leveller": "leveler",
880
+ "levellers": "levelers",
881
+ "levelling": "leveling",
882
+ "libelled": "libeled",
883
+ "libelling": "libeling",
884
+ "libellous": "libelous",
885
+ "liberalisation": "liberalization",
886
+ "liberalise": "liberalize",
887
+ "liberalised": "liberalized",
888
+ "liberalises": "liberalizes",
889
+ "liberalising": "liberalizing",
890
+ "licence": "license",
891
+ "licenced": "licensed",
892
+ "licences": "licenses",
893
+ "licencing": "licensing",
894
+ "likeable": "likable",
895
+ "lionisation": "lionization",
896
+ "lionise": "lionize",
897
+ "lionised": "lionized",
898
+ "lionises": "lionizes",
899
+ "lionising": "lionizing",
900
+ "liquidise": "liquidize",
901
+ "liquidised": "liquidized",
902
+ "liquidiser": "liquidizer",
903
+ "liquidisers": "liquidizers",
904
+ "liquidises": "liquidizes",
905
+ "liquidising": "liquidizing",
906
+ "litre": "liter",
907
+ "litres": "liters",
908
+ "localise": "localize",
909
+ "localised": "localized",
910
+ "localises": "localizes",
911
+ "localising": "localizing",
912
+ "louvre": "louver",
913
+ "louvred": "louvered",
914
+ "louvres": "louvers",
915
+ "lustre": "luster",
916
+ "magnetise": "magnetize",
917
+ "magnetised": "magnetized",
918
+ "magnetises": "magnetizes",
919
+ "magnetising": "magnetizing",
920
+ "manoeuvrability": "maneuverability",
921
+ "manoeuvrable": "maneuverable",
922
+ "manoeuvre": "maneuver",
923
+ "manoeuvred": "maneuvered",
924
+ "manoeuvres": "maneuvers",
925
+ "manoeuvring": "maneuvering",
926
+ "manoeuvrings": "maneuverings",
927
+ "marginalisation": "marginalization",
928
+ "marginalise": "marginalize",
929
+ "marginalised": "marginalized",
930
+ "marginalises": "marginalizes",
931
+ "marginalising": "marginalizing",
932
+ "marshalled": "marshaled",
933
+ "marshalling": "marshaling",
934
+ "marvelled": "marveled",
935
+ "marvelling": "marveling",
936
+ "marvellous": "marvelous",
937
+ "marvellously": "marvelously",
938
+ "materialisation": "materialization",
939
+ "materialise": "materialize",
940
+ "materialised": "materialized",
941
+ "materialises": "materializes",
942
+ "materialising": "materializing",
943
+ "maximisation": "maximization",
944
+ "maximise": "maximize",
945
+ "maximised": "maximized",
946
+ "maximises": "maximizes",
947
+ "maximising": "maximizing",
948
+ "meagre": "meager",
949
+ "mechanisation": "mechanization",
950
+ "mechanise": "mechanize",
951
+ "mechanised": "mechanized",
952
+ "mechanises": "mechanizes",
953
+ "mechanising": "mechanizing",
954
+ "mediaeval": "medieval",
955
+ "memorialise": "memorialize",
956
+ "memorialised": "memorialized",
957
+ "memorialises": "memorializes",
958
+ "memorialising": "memorializing",
959
+ "memorise": "memorize",
960
+ "memorised": "memorized",
961
+ "memorises": "memorizes",
962
+ "memorising": "memorizing",
963
+ "mesmerise": "mesmerize",
964
+ "mesmerised": "mesmerized",
965
+ "mesmerises": "mesmerizes",
966
+ "mesmerising": "mesmerizing",
967
+ "metabolise": "metabolize",
968
+ "metabolised": "metabolized",
969
+ "metabolises": "metabolizes",
970
+ "metabolising": "metabolizing",
971
+ "metre": "meter",
972
+ "metres": "meters",
973
+ "micrometre": "micrometer",
974
+ "micrometres": "micrometers",
975
+ "militarise": "militarize",
976
+ "militarised": "militarized",
977
+ "militarises": "militarizes",
978
+ "militarising": "militarizing",
979
+ "milligramme": "milligram",
980
+ "milligrammes": "milligrams",
981
+ "millilitre": "milliliter",
982
+ "millilitres": "milliliters",
983
+ "millimetre": "millimeter",
984
+ "millimetres": "millimeters",
985
+ "miniaturisation": "miniaturization",
986
+ "miniaturise": "miniaturize",
987
+ "miniaturised": "miniaturized",
988
+ "miniaturises": "miniaturizes",
989
+ "miniaturising": "miniaturizing",
990
+ "minibusses": "minibuses",
991
+ "minimise": "minimize",
992
+ "minimised": "minimized",
993
+ "minimises": "minimizes",
994
+ "minimising": "minimizing",
995
+ "misbehaviour": "misbehavior",
996
+ "misdemeanour": "misdemeanor",
997
+ "misdemeanours": "misdemeanors",
998
+ "misspelt": "misspelled",
999
+ "mitre": "miter",
1000
+ "mitres": "miters",
1001
+ "mobilisation": "mobilization",
1002
+ "mobilise": "mobilize",
1003
+ "mobilised": "mobilized",
1004
+ "mobilises": "mobilizes",
1005
+ "mobilising": "mobilizing",
1006
+ "modelled": "modeled",
1007
+ "modeller": "modeler",
1008
+ "modellers": "modelers",
1009
+ "modelling": "modeling",
1010
+ "modernise": "modernize",
1011
+ "modernised": "modernized",
1012
+ "modernises": "modernizes",
1013
+ "modernising": "modernizing",
1014
+ "moisturise": "moisturize",
1015
+ "moisturised": "moisturized",
1016
+ "moisturiser": "moisturizer",
1017
+ "moisturisers": "moisturizers",
1018
+ "moisturises": "moisturizes",
1019
+ "moisturising": "moisturizing",
1020
+ "monologue": "monolog",
1021
+ "monologues": "monologs",
1022
+ "monopolisation": "monopolization",
1023
+ "monopolise": "monopolize",
1024
+ "monopolised": "monopolized",
1025
+ "monopolises": "monopolizes",
1026
+ "monopolising": "monopolizing",
1027
+ "moralise": "moralize",
1028
+ "moralised": "moralized",
1029
+ "moralises": "moralizes",
1030
+ "moralising": "moralizing",
1031
+ "motorised": "motorized",
1032
+ "mould": "mold",
1033
+ "moulded": "molded",
1034
+ "moulder": "molder",
1035
+ "mouldered": "moldered",
1036
+ "mouldering": "moldering",
1037
+ "moulders": "molders",
1038
+ "mouldier": "moldier",
1039
+ "mouldiest": "moldiest",
1040
+ "moulding": "molding",
1041
+ "mouldings": "moldings",
1042
+ "moulds": "molds",
1043
+ "mouldy": "moldy",
1044
+ "moult": "molt",
1045
+ "moulted": "molted",
1046
+ "moulting": "molting",
1047
+ "moults": "molts",
1048
+ "moustache": "mustache",
1049
+ "moustached": "mustached",
1050
+ "moustaches": "mustaches",
1051
+ "moustachioed": "mustachioed",
1052
+ "multicoloured": "multicolored",
1053
+ "nationalisation": "nationalization",
1054
+ "nationalisations": "nationalizations",
1055
+ "nationalise": "nationalize",
1056
+ "nationalised": "nationalized",
1057
+ "nationalises": "nationalizes",
1058
+ "nationalising": "nationalizing",
1059
+ "naturalisation": "naturalization",
1060
+ "naturalise": "naturalize",
1061
+ "naturalised": "naturalized",
1062
+ "naturalises": "naturalizes",
1063
+ "naturalising": "naturalizing",
1064
+ "neighbour": "neighbor",
1065
+ "neighbourhood": "neighborhood",
1066
+ "neighbourhoods": "neighborhoods",
1067
+ "neighbouring": "neighboring",
1068
+ "neighbourliness": "neighborliness",
1069
+ "neighbourly": "neighborly",
1070
+ "neighbours": "neighbors",
1071
+ "neutralisation": "neutralization",
1072
+ "neutralise": "neutralize",
1073
+ "neutralised": "neutralized",
1074
+ "neutralises": "neutralizes",
1075
+ "neutralising": "neutralizing",
1076
+ "normalisation": "normalization",
1077
+ "normalise": "normalize",
1078
+ "normalised": "normalized",
1079
+ "normalises": "normalizes",
1080
+ "normalising": "normalizing",
1081
+ "odour": "odor",
1082
+ "odourless": "odorless",
1083
+ "odours": "odors",
1084
+ "oesophagus": "esophagus",
1085
+ "oesophaguses": "esophaguses",
1086
+ "oestrogen": "estrogen",
1087
+ "offence": "offense",
1088
+ "offences": "offenses",
1089
+ "omelette": "omelet",
1090
+ "omelettes": "omelets",
1091
+ "optimise": "optimize",
1092
+ "optimised": "optimized",
1093
+ "optimises": "optimizes",
1094
+ "optimising": "optimizing",
1095
+ "organisation": "organization",
1096
+ "organisational": "organizational",
1097
+ "organisations": "organizations",
1098
+ "organise": "organize",
1099
+ "organised": "organized",
1100
+ "organiser": "organizer",
1101
+ "organisers": "organizers",
1102
+ "organises": "organizes",
1103
+ "organising": "organizing",
1104
+ "orthopaedic": "orthopedic",
1105
+ "orthopaedics": "orthopedics",
1106
+ "ostracise": "ostracize",
1107
+ "ostracised": "ostracized",
1108
+ "ostracises": "ostracizes",
1109
+ "ostracising": "ostracizing",
1110
+ "outmanoeuvre": "outmaneuver",
1111
+ "outmanoeuvred": "outmaneuvered",
1112
+ "outmanoeuvres": "outmaneuvers",
1113
+ "outmanoeuvring": "outmaneuvering",
1114
+ "overemphasise": "overemphasize",
1115
+ "overemphasised": "overemphasized",
1116
+ "overemphasises": "overemphasizes",
1117
+ "overemphasising": "overemphasizing",
1118
+ "oxidisation": "oxidization",
1119
+ "oxidise": "oxidize",
1120
+ "oxidised": "oxidized",
1121
+ "oxidises": "oxidizes",
1122
+ "oxidising": "oxidizing",
1123
+ "paederast": "pederast",
1124
+ "paederasts": "pederasts",
1125
+ "paediatric": "pediatric",
1126
+ "paediatrician": "pediatrician",
1127
+ "paediatricians": "pediatricians",
1128
+ "paediatrics": "pediatrics",
1129
+ "paedophile": "pedophile",
1130
+ "paedophiles": "pedophiles",
1131
+ "paedophilia": "pedophilia",
1132
+ "palaeolithic": "paleolithic",
1133
+ "palaeontologist": "paleontologist",
1134
+ "palaeontologists": "paleontologists",
1135
+ "palaeontology": "paleontology",
1136
+ "panelled": "paneled",
1137
+ "panelling": "paneling",
1138
+ "panellist": "panelist",
1139
+ "panellists": "panelists",
1140
+ "paralyse": "paralyze",
1141
+ "paralysed": "paralyzed",
1142
+ "paralyses": "paralyzes",
1143
+ "paralysing": "paralyzing",
1144
+ "parcelled": "parceled",
1145
+ "parcelling": "parceling",
1146
+ "parlour": "parlor",
1147
+ "parlours": "parlors",
1148
+ "particularise": "particularize",
1149
+ "particularised": "particularized",
1150
+ "particularises": "particularizes",
1151
+ "particularising": "particularizing",
1152
+ "passivisation": "passivization",
1153
+ "passivise": "passivize",
1154
+ "passivised": "passivized",
1155
+ "passivises": "passivizes",
1156
+ "passivising": "passivizing",
1157
+ "pasteurisation": "pasteurization",
1158
+ "pasteurise": "pasteurize",
1159
+ "pasteurised": "pasteurized",
1160
+ "pasteurises": "pasteurizes",
1161
+ "pasteurising": "pasteurizing",
1162
+ "patronise": "patronize",
1163
+ "patronised": "patronized",
1164
+ "patronises": "patronizes",
1165
+ "patronising": "patronizing",
1166
+ "patronisingly": "patronizingly",
1167
+ "pedalled": "pedaled",
1168
+ "pedalling": "pedaling",
1169
+ "pedestrianisation": "pedestrianization",
1170
+ "pedestrianise": "pedestrianize",
1171
+ "pedestrianised": "pedestrianized",
1172
+ "pedestrianises": "pedestrianizes",
1173
+ "pedestrianising": "pedestrianizing",
1174
+ "penalise": "penalize",
1175
+ "penalised": "penalized",
1176
+ "penalises": "penalizes",
1177
+ "penalising": "penalizing",
1178
+ "pencilled": "penciled",
1179
+ "pencilling": "penciling",
1180
+ "personalise": "personalize",
1181
+ "personalised": "personalized",
1182
+ "personalises": "personalizes",
1183
+ "personalising": "personalizing",
1184
+ "pharmacopoeia": "pharmacopeia",
1185
+ "pharmacopoeias": "pharmacopeias",
1186
+ "philosophise": "philosophize",
1187
+ "philosophised": "philosophized",
1188
+ "philosophises": "philosophizes",
1189
+ "philosophising": "philosophizing",
1190
+ "philtre": "filter",
1191
+ "philtres": "filters",
1192
+ "phoney": "phony",
1193
+ "plagiarise": "plagiarize",
1194
+ "plagiarised": "plagiarized",
1195
+ "plagiarises": "plagiarizes",
1196
+ "plagiarising": "plagiarizing",
1197
+ "plough": "plow",
1198
+ "ploughed": "plowed",
1199
+ "ploughing": "plowing",
1200
+ "ploughman": "plowman",
1201
+ "ploughmen": "plowmen",
1202
+ "ploughs": "plows",
1203
+ "ploughshare": "plowshare",
1204
+ "ploughshares": "plowshares",
1205
+ "polarisation": "polarization",
1206
+ "polarise": "polarize",
1207
+ "polarised": "polarized",
1208
+ "polarises": "polarizes",
1209
+ "polarising": "polarizing",
1210
+ "politicisation": "politicization",
1211
+ "politicise": "politicize",
1212
+ "politicised": "politicized",
1213
+ "politicises": "politicizes",
1214
+ "politicising": "politicizing",
1215
+ "popularisation": "popularization",
1216
+ "popularise": "popularize",
1217
+ "popularised": "popularized",
1218
+ "popularises": "popularizes",
1219
+ "popularising": "popularizing",
1220
+ "pouffe": "pouf",
1221
+ "pouffes": "poufs",
1222
+ "practise": "practice",
1223
+ "practised": "practiced",
1224
+ "practises": "practices",
1225
+ "practising": "practicing",
1226
+ "praesidium": "presidium",
1227
+ "praesidiums": "presidiums",
1228
+ "pressurisation": "pressurization",
1229
+ "pressurise": "pressurize",
1230
+ "pressurised": "pressurized",
1231
+ "pressurises": "pressurizes",
1232
+ "pressurising": "pressurizing",
1233
+ "pretence": "pretense",
1234
+ "pretences": "pretenses",
1235
+ "primaeval": "primeval",
1236
+ "prioritisation": "prioritization",
1237
+ "prioritise": "prioritize",
1238
+ "prioritised": "prioritized",
1239
+ "prioritises": "prioritizes",
1240
+ "prioritising": "prioritizing",
1241
+ "privatisation": "privatization",
1242
+ "privatisations": "privatizations",
1243
+ "privatise": "privatize",
1244
+ "privatised": "privatized",
1245
+ "privatises": "privatizes",
1246
+ "privatising": "privatizing",
1247
+ "professionalisation": "professionalization",
1248
+ "professionalise": "professionalize",
1249
+ "professionalised": "professionalized",
1250
+ "professionalises": "professionalizes",
1251
+ "professionalising": "professionalizing",
1252
+ "programme": "program",
1253
+ "programmes": "programs",
1254
+ "prologue": "prolog",
1255
+ "prologues": "prologs",
1256
+ "propagandise": "propagandize",
1257
+ "propagandised": "propagandized",
1258
+ "propagandises": "propagandizes",
1259
+ "propagandising": "propagandizing",
1260
+ "proselytise": "proselytize",
1261
+ "proselytised": "proselytized",
1262
+ "proselytiser": "proselytizer",
1263
+ "proselytisers": "proselytizers",
1264
+ "proselytises": "proselytizes",
1265
+ "proselytising": "proselytizing",
1266
+ "psychoanalyse": "psychoanalyze",
1267
+ "psychoanalysed": "psychoanalyzed",
1268
+ "psychoanalyses": "psychoanalyzes",
1269
+ "psychoanalysing": "psychoanalyzing",
1270
+ "publicise": "publicize",
1271
+ "publicised": "publicized",
1272
+ "publicises": "publicizes",
1273
+ "publicising": "publicizing",
1274
+ "pulverisation": "pulverization",
1275
+ "pulverise": "pulverize",
1276
+ "pulverised": "pulverized",
1277
+ "pulverises": "pulverizes",
1278
+ "pulverising": "pulverizing",
1279
+ "pummelled": "pummel",
1280
+ "pummelling": "pummeled",
1281
+ "pyjama": "pajama",
1282
+ "pyjamas": "pajamas",
1283
+ "pzazz": "pizzazz",
1284
+ "quarrelled": "quarreled",
1285
+ "quarrelling": "quarreling",
1286
+ "radicalise": "radicalize",
1287
+ "radicalised": "radicalized",
1288
+ "radicalises": "radicalizes",
1289
+ "radicalising": "radicalizing",
1290
+ "rancour": "rancor",
1291
+ "randomise": "randomize",
1292
+ "randomised": "randomized",
1293
+ "randomises": "randomizes",
1294
+ "randomising": "randomizing",
1295
+ "rationalisation": "rationalization",
1296
+ "rationalisations": "rationalizations",
1297
+ "rationalise": "rationalize",
1298
+ "rationalised": "rationalized",
1299
+ "rationalises": "rationalizes",
1300
+ "rationalising": "rationalizing",
1301
+ "ravelled": "raveled",
1302
+ "ravelling": "raveling",
1303
+ "realisable": "realizable",
1304
+ "realisation": "realization",
1305
+ "realisations": "realizations",
1306
+ "realise": "realize",
1307
+ "realised": "realized",
1308
+ "realises": "realizes",
1309
+ "realising": "realizing",
1310
+ "recognisable": "recognizable",
1311
+ "recognisably": "recognizably",
1312
+ "recognisance": "recognizance",
1313
+ "recognise": "recognize",
1314
+ "recognised": "recognized",
1315
+ "recognises": "recognizes",
1316
+ "recognising": "recognizing",
1317
+ "reconnoitre": "reconnoiter",
1318
+ "reconnoitred": "reconnoitered",
1319
+ "reconnoitres": "reconnoiters",
1320
+ "reconnoitring": "reconnoitering",
1321
+ "refuelled": "refueled",
1322
+ "refuelling": "refueling",
1323
+ "regularisation": "regularization",
1324
+ "regularise": "regularize",
1325
+ "regularised": "regularized",
1326
+ "regularises": "regularizes",
1327
+ "regularising": "regularizing",
1328
+ "remodelled": "remodeled",
1329
+ "remodelling": "remodeling",
1330
+ "remould": "remold",
1331
+ "remoulded": "remolded",
1332
+ "remoulding": "remolding",
1333
+ "remoulds": "remolds",
1334
+ "reorganisation": "reorganization",
1335
+ "reorganisations": "reorganizations",
1336
+ "reorganise": "reorganize",
1337
+ "reorganised": "reorganized",
1338
+ "reorganises": "reorganizes",
1339
+ "reorganising": "reorganizing",
1340
+ "revelled": "reveled",
1341
+ "reveller": "reveler",
1342
+ "revellers": "revelers",
1343
+ "revelling": "reveling",
1344
+ "revitalise": "revitalize",
1345
+ "revitalised": "revitalized",
1346
+ "revitalises": "revitalizes",
1347
+ "revitalising": "revitalizing",
1348
+ "revolutionise": "revolutionize",
1349
+ "revolutionised": "revolutionized",
1350
+ "revolutionises": "revolutionizes",
1351
+ "revolutionising": "revolutionizing",
1352
+ "rhapsodise": "rhapsodize",
1353
+ "rhapsodised": "rhapsodized",
1354
+ "rhapsodises": "rhapsodizes",
1355
+ "rhapsodising": "rhapsodizing",
1356
+ "rigour": "rigor",
1357
+ "rigours": "rigors",
1358
+ "ritualised": "ritualized",
1359
+ "rivalled": "rivaled",
1360
+ "rivalling": "rivaling",
1361
+ "romanticise": "romanticize",
1362
+ "romanticised": "romanticized",
1363
+ "romanticises": "romanticizes",
1364
+ "romanticising": "romanticizing",
1365
+ "rumour": "rumor",
1366
+ "rumoured": "rumored",
1367
+ "rumours": "rumors",
1368
+ "sabre": "saber",
1369
+ "sabres": "sabers",
1370
+ "saltpetre": "saltpeter",
1371
+ "sanitise": "sanitize",
1372
+ "sanitised": "sanitized",
1373
+ "sanitises": "sanitizes",
1374
+ "sanitising": "sanitizing",
1375
+ "satirise": "satirize",
1376
+ "satirised": "satirized",
1377
+ "satirises": "satirizes",
1378
+ "satirising": "satirizing",
1379
+ "saviour": "savior",
1380
+ "saviours": "saviors",
1381
+ "savour": "savor",
1382
+ "savoured": "savored",
1383
+ "savouries": "savories",
1384
+ "savouring": "savoring",
1385
+ "savours": "savors",
1386
+ "savoury": "savory",
1387
+ "scandalise": "scandalize",
1388
+ "scandalised": "scandalized",
1389
+ "scandalises": "scandalizes",
1390
+ "scandalising": "scandalizing",
1391
+ "sceptic": "skeptic",
1392
+ "sceptical": "skeptical",
1393
+ "sceptically": "skeptically",
1394
+ "scepticism": "skepticism",
1395
+ "sceptics": "skeptics",
1396
+ "sceptre": "scepter",
1397
+ "sceptres": "scepters",
1398
+ "scrutinise": "scrutinize",
1399
+ "scrutinised": "scrutinized",
1400
+ "scrutinises": "scrutinizes",
1401
+ "scrutinising": "scrutinizing",
1402
+ "secularisation": "secularization",
1403
+ "secularise": "secularize",
1404
+ "secularised": "secularized",
1405
+ "secularises": "secularizes",
1406
+ "secularising": "secularizing",
1407
+ "sensationalise": "sensationalize",
1408
+ "sensationalised": "sensationalized",
1409
+ "sensationalises": "sensationalizes",
1410
+ "sensationalising": "sensationalizing",
1411
+ "sensitise": "sensitize",
1412
+ "sensitised": "sensitized",
1413
+ "sensitises": "sensitizes",
1414
+ "sensitising": "sensitizing",
1415
+ "sentimentalise": "sentimentalize",
1416
+ "sentimentalised": "sentimentalized",
1417
+ "sentimentalises": "sentimentalizes",
1418
+ "sentimentalising": "sentimentalizing",
1419
+ "sepulchre": "sepulcher",
1420
+ "sepulchres": "sepulchers",
1421
+ "serialisation": "serialization",
1422
+ "serialisations": "serializations",
1423
+ "serialise": "serialize",
1424
+ "serialised": "serialized",
1425
+ "serialises": "serializes",
1426
+ "serialising": "serializing",
1427
+ "sermonise": "sermonize",
1428
+ "sermonised": "sermonized",
1429
+ "sermonises": "sermonizes",
1430
+ "sermonising": "sermonizing",
1431
+ "sheikh": "sheik",
1432
+ "shovelled": "shoveled",
1433
+ "shovelling": "shoveling",
1434
+ "shrivelled": "shriveled",
1435
+ "shrivelling": "shriveling",
1436
+ "signalise": "signalize",
1437
+ "signalised": "signalized",
1438
+ "signalises": "signalizes",
1439
+ "signalising": "signalizing",
1440
+ "signalled": "signaled",
1441
+ "signalling": "signaling",
1442
+ "smoulder": "smolder",
1443
+ "smouldered": "smoldered",
1444
+ "smouldering": "smoldering",
1445
+ "smoulders": "smolders",
1446
+ "snivelled": "sniveled",
1447
+ "snivelling": "sniveling",
1448
+ "snorkelled": "snorkeled",
1449
+ "snorkelling": "snorkeling",
1450
+ "snowplough": "snowplow",
1451
+ "snowploughs": "snowplow",
1452
+ "socialisation": "socialization",
1453
+ "socialise": "socialize",
1454
+ "socialised": "socialized",
1455
+ "socialises": "socializes",
1456
+ "socialising": "socializing",
1457
+ "sodomise": "sodomize",
1458
+ "sodomised": "sodomized",
1459
+ "sodomises": "sodomizes",
1460
+ "sodomising": "sodomizing",
1461
+ "solemnise": "solemnize",
1462
+ "solemnised": "solemnized",
1463
+ "solemnises": "solemnizes",
1464
+ "solemnising": "solemnizing",
1465
+ "sombre": "somber",
1466
+ "specialisation": "specialization",
1467
+ "specialisations": "specializations",
1468
+ "specialise": "specialize",
1469
+ "specialised": "specialized",
1470
+ "specialises": "specializes",
1471
+ "specialising": "specializing",
1472
+ "spectre": "specter",
1473
+ "spectres": "specters",
1474
+ "spiralled": "spiraled",
1475
+ "spiralling": "spiraling",
1476
+ "splendour": "splendor",
1477
+ "splendours": "splendors",
1478
+ "squirrelled": "squirreled",
1479
+ "squirrelling": "squirreling",
1480
+ "stabilisation": "stabilization",
1481
+ "stabilise": "stabilize",
1482
+ "stabilised": "stabilized",
1483
+ "stabiliser": "stabilizer",
1484
+ "stabilisers": "stabilizers",
1485
+ "stabilises": "stabilizes",
1486
+ "stabilising": "stabilizing",
1487
+ "standardisation": "standardization",
1488
+ "standardise": "standardize",
1489
+ "standardised": "standardized",
1490
+ "standardises": "standardizes",
1491
+ "standardising": "standardizing",
1492
+ "stencilled": "stenciled",
1493
+ "stencilling": "stenciling",
1494
+ "sterilisation": "sterilization",
1495
+ "sterilisations": "sterilizations",
1496
+ "sterilise": "sterilize",
1497
+ "sterilised": "sterilized",
1498
+ "steriliser": "sterilizer",
1499
+ "sterilisers": "sterilizers",
1500
+ "sterilises": "sterilizes",
1501
+ "sterilising": "sterilizing",
1502
+ "stigmatisation": "stigmatization",
1503
+ "stigmatise": "stigmatize",
1504
+ "stigmatised": "stigmatized",
1505
+ "stigmatises": "stigmatizes",
1506
+ "stigmatising": "stigmatizing",
1507
+ "storey": "story",
1508
+ "storeys": "stories",
1509
+ "subsidisation": "subsidization",
1510
+ "subsidise": "subsidize",
1511
+ "subsidised": "subsidized",
1512
+ "subsidiser": "subsidizer",
1513
+ "subsidisers": "subsidizers",
1514
+ "subsidises": "subsidizes",
1515
+ "subsidising": "subsidizing",
1516
+ "succour": "succor",
1517
+ "succoured": "succored",
1518
+ "succouring": "succoring",
1519
+ "succours": "succors",
1520
+ "sulphate": "sulfate",
1521
+ "sulphates": "sulfates",
1522
+ "sulphide": "sulfide",
1523
+ "sulphides": "sulfides",
1524
+ "sulphur": "sulfur",
1525
+ "sulphurous": "sulfurous",
1526
+ "summarise": "summarize",
1527
+ "summarised": "summarized",
1528
+ "summarises": "summarizes",
1529
+ "summarising": "summarizing",
1530
+ "swivelled": "swiveled",
1531
+ "swivelling": "swiveling",
1532
+ "symbolise": "symbolize",
1533
+ "symbolised": "symbolized",
1534
+ "symbolises": "symbolizes",
1535
+ "symbolising": "symbolizing",
1536
+ "sympathise": "sympathize",
1537
+ "sympathised": "sympathized",
1538
+ "sympathiser": "sympathizer",
1539
+ "sympathisers": "sympathizers",
1540
+ "sympathises": "sympathizes",
1541
+ "sympathising": "sympathizing",
1542
+ "synchronisation": "synchronization",
1543
+ "synchronise": "synchronize",
1544
+ "synchronised": "synchronized",
1545
+ "synchronises": "synchronizes",
1546
+ "synchronising": "synchronizing",
1547
+ "synthesise": "synthesize",
1548
+ "synthesised": "synthesized",
1549
+ "synthesiser": "synthesizer",
1550
+ "synthesisers": "synthesizers",
1551
+ "synthesises": "synthesizes",
1552
+ "synthesising": "synthesizing",
1553
+ "syphon": "siphon",
1554
+ "syphoned": "siphoned",
1555
+ "syphoning": "siphoning",
1556
+ "syphons": "siphons",
1557
+ "systematisation": "systematization",
1558
+ "systematise": "systematize",
1559
+ "systematised": "systematized",
1560
+ "systematises": "systematizes",
1561
+ "systematising": "systematizing",
1562
+ "tantalise": "tantalize",
1563
+ "tantalised": "tantalized",
1564
+ "tantalises": "tantalizes",
1565
+ "tantalising": "tantalizing",
1566
+ "tantalisingly": "tantalizingly",
1567
+ "tasselled": "tasseled",
1568
+ "technicolour": "technicolor",
1569
+ "temporise": "temporize",
1570
+ "temporised": "temporized",
1571
+ "temporises": "temporizes",
1572
+ "temporising": "temporizing",
1573
+ "tenderise": "tenderize",
1574
+ "tenderised": "tenderized",
1575
+ "tenderises": "tenderizes",
1576
+ "tenderising": "tenderizing",
1577
+ "terrorise": "terrorize",
1578
+ "terrorised": "terrorized",
1579
+ "terrorises": "terrorizes",
1580
+ "terrorising": "terrorizing",
1581
+ "theatre": "theater",
1582
+ "theatregoer": "theatergoer",
1583
+ "theatregoers": "theatergoers",
1584
+ "theatres": "theaters",
1585
+ "theorise": "theorize",
1586
+ "theorised": "theorized",
1587
+ "theorises": "theorizes",
1588
+ "theorising": "theorizing",
1589
+ "tonne": "ton",
1590
+ "tonnes": "tons",
1591
+ "towelled": "toweled",
1592
+ "towelling": "toweling",
1593
+ "toxaemia": "toxemia",
1594
+ "tranquillise": "tranquilize",
1595
+ "tranquillised": "tranquilized",
1596
+ "tranquilliser": "tranquilizer",
1597
+ "tranquillisers": "tranquilizers",
1598
+ "tranquillises": "tranquilizes",
1599
+ "tranquillising": "tranquilizing",
1600
+ "tranquillity": "tranquility",
1601
+ "tranquillize": "tranquilize",
1602
+ "tranquillized": "tranquilized",
1603
+ "tranquillizer": "tranquilizer",
1604
+ "tranquillizers": "tranquilizers",
1605
+ "tranquillizes": "tranquilizes",
1606
+ "tranquillizing": "tranquilizing",
1607
+ "tranquilly": "tranquility",
1608
+ "transistorised": "transistorized",
1609
+ "traumatise": "traumatize",
1610
+ "traumatised": "traumatized",
1611
+ "traumatises": "traumatizes",
1612
+ "traumatising": "traumatizing",
1613
+ "travelled": "traveled",
1614
+ "traveller": "traveler",
1615
+ "travellers": "travelers",
1616
+ "travelling": "traveling",
1617
+ "travelog": "travelogue",
1618
+ "travelogs": "travelogues",
1619
+ "trialled": "trialed",
1620
+ "trialling": "trialing",
1621
+ "tricolour": "tricolor",
1622
+ "tricolours": "tricolors",
1623
+ "trivialise": "trivialize",
1624
+ "trivialised": "trivialized",
1625
+ "trivialises": "trivializes",
1626
+ "trivialising": "trivializing",
1627
+ "tumour": "tumor",
1628
+ "tumours": "tumors",
1629
+ "tunnelled": "tunneled",
1630
+ "tunnelling": "tunneling",
1631
+ "tyrannise": "tyrannize",
1632
+ "tyrannised": "tyrannized",
1633
+ "tyrannises": "tyrannizes",
1634
+ "tyrannising": "tyrannizing",
1635
+ "tyre": "tire",
1636
+ "tyres": "tires",
1637
+ "unauthorised": "unauthorized",
1638
+ "uncivilised": "uncivilized",
1639
+ "underutilised": "underutilized",
1640
+ "unequalled": "unequaled",
1641
+ "unfavourable": "unfavorable",
1642
+ "unfavourably": "unfavorably",
1643
+ "unionisation": "unionization",
1644
+ "unionise": "unionize",
1645
+ "unionised": "unionized",
1646
+ "unionises": "unionizes",
1647
+ "unionising": "unionizing",
1648
+ "unorganised": "unorganized",
1649
+ "unravelled": "unraveled",
1650
+ "unravelling": "unraveling",
1651
+ "unrecognisable": "unrecognizable",
1652
+ "unrecognised": "unrecognized",
1653
+ "unrivalled": "unrivaled",
1654
+ "unsavoury": "unsavory",
1655
+ "untrammelled": "untrammeled",
1656
+ "urbanisation": "urbanization",
1657
+ "urbanise": "urbanize",
1658
+ "urbanised": "urbanized",
1659
+ "urbanises": "urbanizes",
1660
+ "urbanising": "urbanizing",
1661
+ "utilisable": "utilizable",
1662
+ "utilisation": "utilization",
1663
+ "utilise": "utilize",
1664
+ "utilised": "utilized",
1665
+ "utilises": "utilizes",
1666
+ "utilising": "utilizing",
1667
+ "valour": "valor",
1668
+ "vandalise": "vandalize",
1669
+ "vandalised": "vandalized",
1670
+ "vandalises": "vandalizes",
1671
+ "vandalising": "vandalizing",
1672
+ "vaporisation": "vaporization",
1673
+ "vaporise": "vaporize",
1674
+ "vaporised": "vaporized",
1675
+ "vaporises": "vaporizes",
1676
+ "vaporising": "vaporizing",
1677
+ "vapour": "vapor",
1678
+ "vapours": "vapors",
1679
+ "verbalise": "verbalize",
1680
+ "verbalised": "verbalized",
1681
+ "verbalises": "verbalizes",
1682
+ "verbalising": "verbalizing",
1683
+ "victimisation": "victimization",
1684
+ "victimise": "victimize",
1685
+ "victimised": "victimized",
1686
+ "victimises": "victimizes",
1687
+ "victimising": "victimizing",
1688
+ "videodisc": "videodisk",
1689
+ "videodiscs": "videodisks",
1690
+ "vigour": "vigor",
1691
+ "visualisation": "visualization",
1692
+ "visualisations": "visualizations",
1693
+ "visualise": "visualize",
1694
+ "visualised": "visualized",
1695
+ "visualises": "visualizes",
1696
+ "visualising": "visualizing",
1697
+ "vocalisation": "vocalization",
1698
+ "vocalisations": "vocalizations",
1699
+ "vocalise": "vocalize",
1700
+ "vocalised": "vocalized",
1701
+ "vocalises": "vocalizes",
1702
+ "vocalising": "vocalizing",
1703
+ "vulcanised": "vulcanized",
1704
+ "vulgarisation": "vulgarization",
1705
+ "vulgarise": "vulgarize",
1706
+ "vulgarised": "vulgarized",
1707
+ "vulgarises": "vulgarizes",
1708
+ "vulgarising": "vulgarizing",
1709
+ "waggon": "wagon",
1710
+ "waggons": "wagons",
1711
+ "watercolour": "watercolor",
1712
+ "watercolours": "watercolors",
1713
+ "weaselled": "weaseled",
1714
+ "weaselling": "weaseling",
1715
+ "westernisation": "westernization",
1716
+ "westernise": "westernize",
1717
+ "westernised": "westernized",
1718
+ "westernises": "westernizes",
1719
+ "westernising": "westernizing",
1720
+ "womanise": "womanize",
1721
+ "womanised": "womanized",
1722
+ "womaniser": "womanizer",
1723
+ "womanisers": "womanizers",
1724
+ "womanises": "womanizes",
1725
+ "womanising": "womanizing",
1726
+ "woollen": "woolen",
1727
+ "woollens": "woolens",
1728
+ "woollies": "woolies",
1729
+ "woolly": "wooly",
1730
+ "worshipped": "worshiped",
1731
+ "worshipping": "worshiping",
1732
+ "worshipper": "worshiper",
1733
+ "yodelled": "yodeled",
1734
+ "yodelling": "yodeling",
1735
+ "yoghourt": "yogurt",
1736
+ "yoghourts": "yogurts",
1737
+ "yoghurt": "yogurt",
1738
+ "yoghurts": "yogurts",
1739
+ "mhm": "hmm",
1740
+ "mm": "hmm",
1741
+ "mmm": "hmm"
1742
+ }
latentsync/whisper/whisper/normalizers/english.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from fractions import Fraction
5
+ from typing import Iterator, List, Match, Optional, Union
6
+
7
+ from more_itertools import windowed
8
+
9
+ from .basic import remove_symbols_and_diacritics
10
+
11
+
12
+ class EnglishNumberNormalizer:
13
+ """
14
+ Convert any spelled-out numbers into arabic numbers, while handling:
15
+
16
+ - remove any commas
17
+ - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
18
+ - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
19
+ - spell out `one` and `ones`
20
+ - interpret successive single-digit numbers as nominal: `one oh one` -> `101`
21
+ """
22
+
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ self.zeros = {"o", "oh", "zero"}
27
+ self.ones = {
28
+ name: i
29
+ for i, name in enumerate(
30
+ [
31
+ "one",
32
+ "two",
33
+ "three",
34
+ "four",
35
+ "five",
36
+ "six",
37
+ "seven",
38
+ "eight",
39
+ "nine",
40
+ "ten",
41
+ "eleven",
42
+ "twelve",
43
+ "thirteen",
44
+ "fourteen",
45
+ "fifteen",
46
+ "sixteen",
47
+ "seventeen",
48
+ "eighteen",
49
+ "nineteen",
50
+ ],
51
+ start=1,
52
+ )
53
+ }
54
+ self.ones_plural = {
55
+ "sixes" if name == "six" else name + "s": (value, "s")
56
+ for name, value in self.ones.items()
57
+ }
58
+ self.ones_ordinal = {
59
+ "zeroth": (0, "th"),
60
+ "first": (1, "st"),
61
+ "second": (2, "nd"),
62
+ "third": (3, "rd"),
63
+ "fifth": (5, "th"),
64
+ "twelfth": (12, "th"),
65
+ **{
66
+ name + ("h" if name.endswith("t") else "th"): (value, "th")
67
+ for name, value in self.ones.items()
68
+ if value > 3 and value != 5 and value != 12
69
+ },
70
+ }
71
+ self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
72
+
73
+ self.tens = {
74
+ "twenty": 20,
75
+ "thirty": 30,
76
+ "forty": 40,
77
+ "fifty": 50,
78
+ "sixty": 60,
79
+ "seventy": 70,
80
+ "eighty": 80,
81
+ "ninety": 90,
82
+ }
83
+ self.tens_plural = {
84
+ name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
85
+ }
86
+ self.tens_ordinal = {
87
+ name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
88
+ }
89
+ self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
90
+
91
+ self.multipliers = {
92
+ "hundred": 100,
93
+ "thousand": 1_000,
94
+ "million": 1_000_000,
95
+ "billion": 1_000_000_000,
96
+ "trillion": 1_000_000_000_000,
97
+ "quadrillion": 1_000_000_000_000_000,
98
+ "quintillion": 1_000_000_000_000_000_000,
99
+ "sextillion": 1_000_000_000_000_000_000_000,
100
+ "septillion": 1_000_000_000_000_000_000_000_000,
101
+ "octillion": 1_000_000_000_000_000_000_000_000_000,
102
+ "nonillion": 1_000_000_000_000_000_000_000_000_000_000,
103
+ "decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
104
+ }
105
+ self.multipliers_plural = {
106
+ name + "s": (value, "s") for name, value in self.multipliers.items()
107
+ }
108
+ self.multipliers_ordinal = {
109
+ name + "th": (value, "th") for name, value in self.multipliers.items()
110
+ }
111
+ self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
112
+ self.decimals = {*self.ones, *self.tens, *self.zeros}
113
+
114
+ self.preceding_prefixers = {
115
+ "minus": "-",
116
+ "negative": "-",
117
+ "plus": "+",
118
+ "positive": "+",
119
+ }
120
+ self.following_prefixers = {
121
+ "pound": "£",
122
+ "pounds": "£",
123
+ "euro": "€",
124
+ "euros": "€",
125
+ "dollar": "$",
126
+ "dollars": "$",
127
+ "cent": "¢",
128
+ "cents": "¢",
129
+ }
130
+ self.prefixes = set(
131
+ list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
132
+ )
133
+ self.suffixers = {
134
+ "per": {"cent": "%"},
135
+ "percent": "%",
136
+ }
137
+ self.specials = {"and", "double", "triple", "point"}
138
+
139
+ self.words = set(
140
+ [
141
+ key
142
+ for mapping in [
143
+ self.zeros,
144
+ self.ones,
145
+ self.ones_suffixed,
146
+ self.tens,
147
+ self.tens_suffixed,
148
+ self.multipliers,
149
+ self.multipliers_suffixed,
150
+ self.preceding_prefixers,
151
+ self.following_prefixers,
152
+ self.suffixers,
153
+ self.specials,
154
+ ]
155
+ for key in mapping
156
+ ]
157
+ )
158
+ self.literal_words = {"one", "ones"}
159
+
160
+ def process_words(self, words: List[str]) -> Iterator[str]:
161
+ prefix: Optional[str] = None
162
+ value: Optional[Union[str, int]] = None
163
+ skip = False
164
+
165
+ def to_fraction(s: str):
166
+ try:
167
+ return Fraction(s)
168
+ except ValueError:
169
+ return None
170
+
171
+ def output(result: Union[str, int]):
172
+ nonlocal prefix, value
173
+ result = str(result)
174
+ if prefix is not None:
175
+ result = prefix + result
176
+ value = None
177
+ prefix = None
178
+ return result
179
+
180
+ if len(words) == 0:
181
+ return
182
+
183
+ for prev, current, next in windowed([None] + words + [None], 3):
184
+ if skip:
185
+ skip = False
186
+ continue
187
+
188
+ next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
189
+ has_prefix = current[0] in self.prefixes
190
+ current_without_prefix = current[1:] if has_prefix else current
191
+ if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
192
+ # arabic numbers (potentially with signs and fractions)
193
+ f = to_fraction(current_without_prefix)
194
+ assert f is not None
195
+ if value is not None:
196
+ if isinstance(value, str) and value.endswith("."):
197
+ # concatenate decimals / ip address components
198
+ value = str(value) + str(current)
199
+ continue
200
+ else:
201
+ yield output(value)
202
+
203
+ prefix = current[0] if has_prefix else prefix
204
+ if f.denominator == 1:
205
+ value = f.numerator # store integers as int
206
+ else:
207
+ value = current_without_prefix
208
+ elif current not in self.words:
209
+ # non-numeric words
210
+ if value is not None:
211
+ yield output(value)
212
+ yield output(current)
213
+ elif current in self.zeros:
214
+ value = str(value or "") + "0"
215
+ elif current in self.ones:
216
+ ones = self.ones[current]
217
+
218
+ if value is None:
219
+ value = ones
220
+ elif isinstance(value, str) or prev in self.ones:
221
+ if prev in self.tens and ones < 10: # replace the last zero with the digit
222
+ assert value[-1] == "0"
223
+ value = value[:-1] + str(ones)
224
+ else:
225
+ value = str(value) + str(ones)
226
+ elif ones < 10:
227
+ if value % 10 == 0:
228
+ value += ones
229
+ else:
230
+ value = str(value) + str(ones)
231
+ else: # eleven to nineteen
232
+ if value % 100 == 0:
233
+ value += ones
234
+ else:
235
+ value = str(value) + str(ones)
236
+ elif current in self.ones_suffixed:
237
+ # ordinal or cardinal; yield the number right away
238
+ ones, suffix = self.ones_suffixed[current]
239
+ if value is None:
240
+ yield output(str(ones) + suffix)
241
+ elif isinstance(value, str) or prev in self.ones:
242
+ if prev in self.tens and ones < 10:
243
+ assert value[-1] == "0"
244
+ yield output(value[:-1] + str(ones) + suffix)
245
+ else:
246
+ yield output(str(value) + str(ones) + suffix)
247
+ elif ones < 10:
248
+ if value % 10 == 0:
249
+ yield output(str(value + ones) + suffix)
250
+ else:
251
+ yield output(str(value) + str(ones) + suffix)
252
+ else: # eleven to nineteen
253
+ if value % 100 == 0:
254
+ yield output(str(value + ones) + suffix)
255
+ else:
256
+ yield output(str(value) + str(ones) + suffix)
257
+ value = None
258
+ elif current in self.tens:
259
+ tens = self.tens[current]
260
+ if value is None:
261
+ value = tens
262
+ elif isinstance(value, str):
263
+ value = str(value) + str(tens)
264
+ else:
265
+ if value % 100 == 0:
266
+ value += tens
267
+ else:
268
+ value = str(value) + str(tens)
269
+ elif current in self.tens_suffixed:
270
+ # ordinal or cardinal; yield the number right away
271
+ tens, suffix = self.tens_suffixed[current]
272
+ if value is None:
273
+ yield output(str(tens) + suffix)
274
+ elif isinstance(value, str):
275
+ yield output(str(value) + str(tens) + suffix)
276
+ else:
277
+ if value % 100 == 0:
278
+ yield output(str(value + tens) + suffix)
279
+ else:
280
+ yield output(str(value) + str(tens) + suffix)
281
+ elif current in self.multipliers:
282
+ multiplier = self.multipliers[current]
283
+ if value is None:
284
+ value = multiplier
285
+ elif isinstance(value, str) or value == 0:
286
+ f = to_fraction(value)
287
+ p = f * multiplier if f is not None else None
288
+ if f is not None and p.denominator == 1:
289
+ value = p.numerator
290
+ else:
291
+ yield output(value)
292
+ value = multiplier
293
+ else:
294
+ before = value // 1000 * 1000
295
+ residual = value % 1000
296
+ value = before + residual * multiplier
297
+ elif current in self.multipliers_suffixed:
298
+ multiplier, suffix = self.multipliers_suffixed[current]
299
+ if value is None:
300
+ yield output(str(multiplier) + suffix)
301
+ elif isinstance(value, str):
302
+ f = to_fraction(value)
303
+ p = f * multiplier if f is not None else None
304
+ if f is not None and p.denominator == 1:
305
+ yield output(str(p.numerator) + suffix)
306
+ else:
307
+ yield output(value)
308
+ yield output(str(multiplier) + suffix)
309
+ else: # int
310
+ before = value // 1000 * 1000
311
+ residual = value % 1000
312
+ value = before + residual * multiplier
313
+ yield output(str(value) + suffix)
314
+ value = None
315
+ elif current in self.preceding_prefixers:
316
+ # apply prefix (positive, minus, etc.) if it precedes a number
317
+ if value is not None:
318
+ yield output(value)
319
+
320
+ if next in self.words or next_is_numeric:
321
+ prefix = self.preceding_prefixers[current]
322
+ else:
323
+ yield output(current)
324
+ elif current in self.following_prefixers:
325
+ # apply prefix (dollars, cents, etc.) only after a number
326
+ if value is not None:
327
+ prefix = self.following_prefixers[current]
328
+ yield output(value)
329
+ else:
330
+ yield output(current)
331
+ elif current in self.suffixers:
332
+ # apply suffix symbols (percent -> '%')
333
+ if value is not None:
334
+ suffix = self.suffixers[current]
335
+ if isinstance(suffix, dict):
336
+ if next in suffix:
337
+ yield output(str(value) + suffix[next])
338
+ skip = True
339
+ else:
340
+ yield output(value)
341
+ yield output(current)
342
+ else:
343
+ yield output(str(value) + suffix)
344
+ else:
345
+ yield output(current)
346
+ elif current in self.specials:
347
+ if next not in self.words and not next_is_numeric:
348
+ # apply special handling only if the next word can be numeric
349
+ if value is not None:
350
+ yield output(value)
351
+ yield output(current)
352
+ elif current == "and":
353
+ # ignore "and" after hundreds, thousands, etc.
354
+ if prev not in self.multipliers:
355
+ if value is not None:
356
+ yield output(value)
357
+ yield output(current)
358
+ elif current == "double" or current == "triple":
359
+ if next in self.ones or next in self.zeros:
360
+ repeats = 2 if current == "double" else 3
361
+ ones = self.ones.get(next, 0)
362
+ value = str(value or "") + str(ones) * repeats
363
+ skip = True
364
+ else:
365
+ if value is not None:
366
+ yield output(value)
367
+ yield output(current)
368
+ elif current == "point":
369
+ if next in self.decimals or next_is_numeric:
370
+ value = str(value or "") + "."
371
+ else:
372
+ # should all have been covered at this point
373
+ raise ValueError(f"Unexpected token: {current}")
374
+ else:
375
+ # all should have been covered at this point
376
+ raise ValueError(f"Unexpected token: {current}")
377
+
378
+ if value is not None:
379
+ yield output(value)
380
+
381
+ def preprocess(self, s: str):
382
+ # replace "<number> and a half" with "<number> point five"
383
+ results = []
384
+
385
+ segments = re.split(r"\band\s+a\s+half\b", s)
386
+ for i, segment in enumerate(segments):
387
+ if len(segment.strip()) == 0:
388
+ continue
389
+ if i == len(segments) - 1:
390
+ results.append(segment)
391
+ else:
392
+ results.append(segment)
393
+ last_word = segment.rsplit(maxsplit=2)[-1]
394
+ if last_word in self.decimals or last_word in self.multipliers:
395
+ results.append("point five")
396
+ else:
397
+ results.append("and a half")
398
+
399
+ s = " ".join(results)
400
+
401
+ # put a space at number/letter boundary
402
+ s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
403
+ s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
404
+
405
+ # but remove spaces which could be a suffix
406
+ s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
407
+
408
+ return s
409
+
410
+ def postprocess(self, s: str):
411
+ def combine_cents(m: Match):
412
+ try:
413
+ currency = m.group(1)
414
+ integer = m.group(2)
415
+ cents = int(m.group(3))
416
+ return f"{currency}{integer}.{cents:02d}"
417
+ except ValueError:
418
+ return m.string
419
+
420
+ def extract_cents(m: Match):
421
+ try:
422
+ return f"¢{int(m.group(1))}"
423
+ except ValueError:
424
+ return m.string
425
+
426
+ # apply currency postprocessing; "$2 and ¢7" -> "$2.07"
427
+ s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
428
+ s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
429
+
430
+ # write "one(s)" instead of "1(s)", just for the readability
431
+ s = re.sub(r"\b1(s?)\b", r"one\1", s)
432
+
433
+ return s
434
+
435
+ def __call__(self, s: str):
436
+ s = self.preprocess(s)
437
+ s = " ".join(word for word in self.process_words(s.split()) if word is not None)
438
+ s = self.postprocess(s)
439
+
440
+ return s
441
+
442
+
443
+ class EnglishSpellingNormalizer:
444
+ """
445
+ Applies British-American spelling mappings as listed in [1].
446
+
447
+ [1] https://www.tysto.com/uk-us-spelling-list.html
448
+ """
449
+
450
+ def __init__(self):
451
+ mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
452
+ self.mapping = json.load(open(mapping_path))
453
+
454
+ def __call__(self, s: str):
455
+ return " ".join(self.mapping.get(word, word) for word in s.split())
456
+
457
+
458
+ class EnglishTextNormalizer:
459
+ def __init__(self):
460
+ self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
461
+ self.replacers = {
462
+ # common contractions
463
+ r"\bwon't\b": "will not",
464
+ r"\bcan't\b": "can not",
465
+ r"\blet's\b": "let us",
466
+ r"\bain't\b": "aint",
467
+ r"\by'all\b": "you all",
468
+ r"\bwanna\b": "want to",
469
+ r"\bgotta\b": "got to",
470
+ r"\bgonna\b": "going to",
471
+ r"\bi'ma\b": "i am going to",
472
+ r"\bimma\b": "i am going to",
473
+ r"\bwoulda\b": "would have",
474
+ r"\bcoulda\b": "could have",
475
+ r"\bshoulda\b": "should have",
476
+ r"\bma'am\b": "madam",
477
+ # contractions in titles/prefixes
478
+ r"\bmr\b": "mister ",
479
+ r"\bmrs\b": "missus ",
480
+ r"\bst\b": "saint ",
481
+ r"\bdr\b": "doctor ",
482
+ r"\bprof\b": "professor ",
483
+ r"\bcapt\b": "captain ",
484
+ r"\bgov\b": "governor ",
485
+ r"\bald\b": "alderman ",
486
+ r"\bgen\b": "general ",
487
+ r"\bsen\b": "senator ",
488
+ r"\brep\b": "representative ",
489
+ r"\bpres\b": "president ",
490
+ r"\brev\b": "reverend ",
491
+ r"\bhon\b": "honorable ",
492
+ r"\basst\b": "assistant ",
493
+ r"\bassoc\b": "associate ",
494
+ r"\blt\b": "lieutenant ",
495
+ r"\bcol\b": "colonel ",
496
+ r"\bjr\b": "junior ",
497
+ r"\bsr\b": "senior ",
498
+ r"\besq\b": "esquire ",
499
+ # prefect tenses, ideally it should be any past participles, but it's harder..
500
+ r"'d been\b": " had been",
501
+ r"'s been\b": " has been",
502
+ r"'d gone\b": " had gone",
503
+ r"'s gone\b": " has gone",
504
+ r"'d done\b": " had done", # "'s done" is ambiguous
505
+ r"'s got\b": " has got",
506
+ # general contractions
507
+ r"n't\b": " not",
508
+ r"'re\b": " are",
509
+ r"'s\b": " is",
510
+ r"'d\b": " would",
511
+ r"'ll\b": " will",
512
+ r"'t\b": " not",
513
+ r"'ve\b": " have",
514
+ r"'m\b": " am",
515
+ }
516
+ self.standardize_numbers = EnglishNumberNormalizer()
517
+ self.standardize_spellings = EnglishSpellingNormalizer()
518
+
519
+ def __call__(self, s: str):
520
+ s = s.lower()
521
+
522
+ s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
523
+ s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
524
+ s = re.sub(self.ignore_patterns, "", s)
525
+ s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
526
+
527
+ for pattern, replacement in self.replacers.items():
528
+ s = re.sub(pattern, replacement, s)
529
+
530
+ s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
531
+ s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
532
+ s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
533
+
534
+ s = self.standardize_numbers(s)
535
+ s = self.standardize_spellings(s)
536
+
537
+ # now remove prefix/suffix symbols that are not preceded/followed by numbers
538
+ s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
539
+ s = re.sub(r"([^0-9])%", r"\1 ", s)
540
+
541
+ s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
542
+
543
+ return s
latentsync/whisper/whisper/tokenizer.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from functools import lru_cache
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from transformers import GPT2TokenizerFast
9
+
10
+ LANGUAGES = {
11
+ "en": "english",
12
+ "zh": "chinese",
13
+ "de": "german",
14
+ "es": "spanish",
15
+ "ru": "russian",
16
+ "ko": "korean",
17
+ "fr": "french",
18
+ "ja": "japanese",
19
+ "pt": "portuguese",
20
+ "tr": "turkish",
21
+ "pl": "polish",
22
+ "ca": "catalan",
23
+ "nl": "dutch",
24
+ "ar": "arabic",
25
+ "sv": "swedish",
26
+ "it": "italian",
27
+ "id": "indonesian",
28
+ "hi": "hindi",
29
+ "fi": "finnish",
30
+ "vi": "vietnamese",
31
+ "iw": "hebrew",
32
+ "uk": "ukrainian",
33
+ "el": "greek",
34
+ "ms": "malay",
35
+ "cs": "czech",
36
+ "ro": "romanian",
37
+ "da": "danish",
38
+ "hu": "hungarian",
39
+ "ta": "tamil",
40
+ "no": "norwegian",
41
+ "th": "thai",
42
+ "ur": "urdu",
43
+ "hr": "croatian",
44
+ "bg": "bulgarian",
45
+ "lt": "lithuanian",
46
+ "la": "latin",
47
+ "mi": "maori",
48
+ "ml": "malayalam",
49
+ "cy": "welsh",
50
+ "sk": "slovak",
51
+ "te": "telugu",
52
+ "fa": "persian",
53
+ "lv": "latvian",
54
+ "bn": "bengali",
55
+ "sr": "serbian",
56
+ "az": "azerbaijani",
57
+ "sl": "slovenian",
58
+ "kn": "kannada",
59
+ "et": "estonian",
60
+ "mk": "macedonian",
61
+ "br": "breton",
62
+ "eu": "basque",
63
+ "is": "icelandic",
64
+ "hy": "armenian",
65
+ "ne": "nepali",
66
+ "mn": "mongolian",
67
+ "bs": "bosnian",
68
+ "kk": "kazakh",
69
+ "sq": "albanian",
70
+ "sw": "swahili",
71
+ "gl": "galician",
72
+ "mr": "marathi",
73
+ "pa": "punjabi",
74
+ "si": "sinhala",
75
+ "km": "khmer",
76
+ "sn": "shona",
77
+ "yo": "yoruba",
78
+ "so": "somali",
79
+ "af": "afrikaans",
80
+ "oc": "occitan",
81
+ "ka": "georgian",
82
+ "be": "belarusian",
83
+ "tg": "tajik",
84
+ "sd": "sindhi",
85
+ "gu": "gujarati",
86
+ "am": "amharic",
87
+ "yi": "yiddish",
88
+ "lo": "lao",
89
+ "uz": "uzbek",
90
+ "fo": "faroese",
91
+ "ht": "haitian creole",
92
+ "ps": "pashto",
93
+ "tk": "turkmen",
94
+ "nn": "nynorsk",
95
+ "mt": "maltese",
96
+ "sa": "sanskrit",
97
+ "lb": "luxembourgish",
98
+ "my": "myanmar",
99
+ "bo": "tibetan",
100
+ "tl": "tagalog",
101
+ "mg": "malagasy",
102
+ "as": "assamese",
103
+ "tt": "tatar",
104
+ "haw": "hawaiian",
105
+ "ln": "lingala",
106
+ "ha": "hausa",
107
+ "ba": "bashkir",
108
+ "jw": "javanese",
109
+ "su": "sundanese",
110
+ }
111
+
112
+ # language code lookup by name, with a few language aliases
113
+ TO_LANGUAGE_CODE = {
114
+ **{language: code for code, language in LANGUAGES.items()},
115
+ "burmese": "my",
116
+ "valencian": "ca",
117
+ "flemish": "nl",
118
+ "haitian": "ht",
119
+ "letzeburgesch": "lb",
120
+ "pushto": "ps",
121
+ "panjabi": "pa",
122
+ "moldavian": "ro",
123
+ "moldovan": "ro",
124
+ "sinhalese": "si",
125
+ "castilian": "es",
126
+ }
127
+
128
+
129
+ @dataclass(frozen=True)
130
+ class Tokenizer:
131
+ """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
132
+
133
+ tokenizer: "GPT2TokenizerFast"
134
+ language: Optional[str]
135
+ sot_sequence: Tuple[int]
136
+
137
+ def encode(self, text, **kwargs):
138
+ return self.tokenizer.encode(text, **kwargs)
139
+
140
+ def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
141
+ return self.tokenizer.decode(token_ids, **kwargs)
142
+
143
+ def decode_with_timestamps(self, tokens) -> str:
144
+ """
145
+ Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
146
+ This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
147
+ """
148
+ outputs = [[]]
149
+ for token in tokens:
150
+ if token >= self.timestamp_begin:
151
+ timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
152
+ outputs.append(timestamp)
153
+ outputs.append([])
154
+ else:
155
+ outputs[-1].append(token)
156
+ outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
157
+ return "".join(outputs)
158
+
159
+ @property
160
+ @lru_cache()
161
+ def eot(self) -> int:
162
+ return self.tokenizer.eos_token_id
163
+
164
+ @property
165
+ @lru_cache()
166
+ def sot(self) -> int:
167
+ return self._get_single_token_id("<|startoftranscript|>")
168
+
169
+ @property
170
+ @lru_cache()
171
+ def sot_lm(self) -> int:
172
+ return self._get_single_token_id("<|startoflm|>")
173
+
174
+ @property
175
+ @lru_cache()
176
+ def sot_prev(self) -> int:
177
+ return self._get_single_token_id("<|startofprev|>")
178
+
179
+ @property
180
+ @lru_cache()
181
+ def no_speech(self) -> int:
182
+ return self._get_single_token_id("<|nospeech|>")
183
+
184
+ @property
185
+ @lru_cache()
186
+ def no_timestamps(self) -> int:
187
+ return self._get_single_token_id("<|notimestamps|>")
188
+
189
+ @property
190
+ @lru_cache()
191
+ def timestamp_begin(self) -> int:
192
+ return self.tokenizer.all_special_ids[-1] + 1
193
+
194
+ @property
195
+ @lru_cache()
196
+ def language_token(self) -> int:
197
+ """Returns the token id corresponding to the value of the `language` field"""
198
+ if self.language is None:
199
+ raise ValueError(f"This tokenizer does not have language token configured")
200
+
201
+ additional_tokens = dict(
202
+ zip(
203
+ self.tokenizer.additional_special_tokens,
204
+ self.tokenizer.additional_special_tokens_ids,
205
+ )
206
+ )
207
+ candidate = f"<|{self.language}|>"
208
+ if candidate in additional_tokens:
209
+ return additional_tokens[candidate]
210
+
211
+ raise KeyError(f"Language {self.language} not found in tokenizer.")
212
+
213
+ @property
214
+ @lru_cache()
215
+ def all_language_tokens(self) -> Tuple[int]:
216
+ result = []
217
+ for token, token_id in zip(
218
+ self.tokenizer.additional_special_tokens,
219
+ self.tokenizer.additional_special_tokens_ids,
220
+ ):
221
+ if token.strip("<|>") in LANGUAGES:
222
+ result.append(token_id)
223
+ return tuple(result)
224
+
225
+ @property
226
+ @lru_cache()
227
+ def all_language_codes(self) -> Tuple[str]:
228
+ return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
229
+
230
+ @property
231
+ @lru_cache()
232
+ def sot_sequence_including_notimestamps(self) -> Tuple[int]:
233
+ return tuple(list(self.sot_sequence) + [self.no_timestamps])
234
+
235
+ @property
236
+ @lru_cache()
237
+ def non_speech_tokens(self) -> Tuple[int]:
238
+ """
239
+ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
240
+ annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
241
+
242
+ - ♪♪♪
243
+ - ( SPEAKING FOREIGN LANGUAGE )
244
+ - [DAVID] Hey there,
245
+
246
+ keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
247
+ """
248
+ symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
249
+ symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
250
+
251
+ # symbols that may be a single token or multiple tokens depending on the tokenizer.
252
+ # In case they're multiple tokens, suppress the first token, which is safe because:
253
+ # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
254
+ # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
255
+ miscellaneous = set("♩♪♫♬♭♮♯")
256
+ assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
257
+
258
+ # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
259
+ result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
260
+ for symbol in symbols + list(miscellaneous):
261
+ for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
262
+ if len(tokens) == 1 or symbol in miscellaneous:
263
+ result.add(tokens[0])
264
+
265
+ return tuple(sorted(result))
266
+
267
+ def _get_single_token_id(self, text) -> int:
268
+ tokens = self.tokenizer.encode(text)
269
+ assert len(tokens) == 1, f"{text} is not encoded as a single token"
270
+ return tokens[0]
271
+
272
+
273
+ @lru_cache(maxsize=None)
274
+ def build_tokenizer(name: str = "gpt2"):
275
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
276
+ path = os.path.join(os.path.dirname(__file__), "assets", name)
277
+ tokenizer = GPT2TokenizerFast.from_pretrained(path)
278
+
279
+ specials = [
280
+ "<|startoftranscript|>",
281
+ *[f"<|{lang}|>" for lang in LANGUAGES.keys()],
282
+ "<|translate|>",
283
+ "<|transcribe|>",
284
+ "<|startoflm|>",
285
+ "<|startofprev|>",
286
+ "<|nospeech|>",
287
+ "<|notimestamps|>",
288
+ ]
289
+
290
+ tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
291
+ return tokenizer
292
+
293
+
294
+ @lru_cache(maxsize=None)
295
+ def get_tokenizer(
296
+ multilingual: bool,
297
+ *,
298
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
299
+ language: Optional[str] = None,
300
+ ) -> Tokenizer:
301
+ if language is not None:
302
+ language = language.lower()
303
+ if language not in LANGUAGES:
304
+ if language in TO_LANGUAGE_CODE:
305
+ language = TO_LANGUAGE_CODE[language]
306
+ else:
307
+ raise ValueError(f"Unsupported language: {language}")
308
+
309
+ if multilingual:
310
+ tokenizer_name = "multilingual"
311
+ task = task or "transcribe"
312
+ language = language or "en"
313
+ else:
314
+ tokenizer_name = "gpt2"
315
+ task = None
316
+ language = None
317
+
318
+ tokenizer = build_tokenizer(name=tokenizer_name)
319
+ all_special_ids: List[int] = tokenizer.all_special_ids
320
+ sot: int = all_special_ids[1]
321
+ translate: int = all_special_ids[-6]
322
+ transcribe: int = all_special_ids[-5]
323
+
324
+ langs = tuple(LANGUAGES.keys())
325
+ sot_sequence = [sot]
326
+ if language is not None:
327
+ sot_sequence.append(sot + 1 + langs.index(language))
328
+ if task is not None:
329
+ sot_sequence.append(transcribe if task == "transcribe" else translate)
330
+
331
+ return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
latentsync/whisper/whisper/transcribe.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import warnings
4
+ from typing import List, Optional, Tuple, Union, TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+
10
+ from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
11
+ from .decoding import DecodingOptions, DecodingResult
12
+ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
13
+ from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
14
+
15
+ if TYPE_CHECKING:
16
+ from .model import Whisper
17
+
18
+
19
+ def transcribe(
20
+ model: "Whisper",
21
+ audio: Union[str, np.ndarray, torch.Tensor],
22
+ *,
23
+ verbose: Optional[bool] = None,
24
+ temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
25
+ compression_ratio_threshold: Optional[float] = 2.4,
26
+ logprob_threshold: Optional[float] = -1.0,
27
+ no_speech_threshold: Optional[float] = 0.6,
28
+ condition_on_previous_text: bool = True,
29
+ force_extraction: bool = False,
30
+ **decode_options,
31
+ ):
32
+ """
33
+ Transcribe an audio file using Whisper
34
+
35
+ Parameters
36
+ ----------
37
+ model: Whisper
38
+ The Whisper model instance
39
+
40
+ audio: Union[str, np.ndarray, torch.Tensor]
41
+ The path to the audio file to open, or the audio waveform
42
+
43
+ verbose: bool
44
+ Whether to display the text being decoded to the console. If True, displays all the details,
45
+ If False, displays minimal details. If None, does not display anything
46
+
47
+ temperature: Union[float, Tuple[float, ...]]
48
+ Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
49
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
50
+
51
+ compression_ratio_threshold: float
52
+ If the gzip compression ratio is above this value, treat as failed
53
+
54
+ logprob_threshold: float
55
+ If the average log probability over sampled tokens is below this value, treat as failed
56
+
57
+ no_speech_threshold: float
58
+ If the no_speech probability is higher than this value AND the average log probability
59
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
60
+
61
+ condition_on_previous_text: bool
62
+ if True, the previous output of the model is provided as a prompt for the next window;
63
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
64
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
65
+
66
+ decode_options: dict
67
+ Keyword arguments to construct `DecodingOptions` instances
68
+
69
+ Returns
70
+ -------
71
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
72
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
73
+ """
74
+ dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
75
+ if model.device == torch.device("cpu"):
76
+ if torch.cuda.is_available():
77
+ warnings.warn("Performing inference on CPU when CUDA is available")
78
+ if dtype == torch.float16:
79
+ warnings.warn("FP16 is not supported on CPU; using FP32 instead")
80
+ dtype = torch.float32
81
+
82
+ if dtype == torch.float32:
83
+ decode_options["fp16"] = False
84
+
85
+ mel = log_mel_spectrogram(audio)
86
+
87
+ all_segments = []
88
+ def add_segment(
89
+ *, start: float, end: float, encoder_embeddings
90
+ ):
91
+
92
+ all_segments.append(
93
+ {
94
+ "start": start,
95
+ "end": end,
96
+ "encoder_embeddings":encoder_embeddings,
97
+ }
98
+ )
99
+ # show the progress bar when verbose is False (otherwise the transcribed text will be printed)
100
+ num_frames = mel.shape[-1]
101
+ seek = 0
102
+ previous_seek_value = seek
103
+ sample_skip = 3000 #
104
+ with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
105
+ while seek < num_frames:
106
+ # seek是开始的帧数
107
+ end_seek = min(seek + sample_skip, num_frames)
108
+ segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype)
109
+
110
+ single = segment.ndim == 2
111
+ if single:
112
+ segment = segment.unsqueeze(0)
113
+ if dtype == torch.float16:
114
+ segment = segment.half()
115
+ audio_features, embeddings = model.encoder(segment, include_embeddings = True)
116
+
117
+ encoder_embeddings = embeddings
118
+ #print(f"encoder_embeddings shape {encoder_embeddings.shape}")
119
+ add_segment(
120
+ start=seek,
121
+ end=end_seek,
122
+ #text_tokens=tokens,
123
+ #result=result,
124
+ encoder_embeddings=encoder_embeddings,
125
+ )
126
+ seek+=sample_skip
127
+
128
+ return dict(segments=all_segments)
129
+
130
+
131
+ def cli():
132
+ from . import available_models
133
+
134
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
135
+ parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
136
+ parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
137
+ parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
138
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
139
+ parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
140
+ parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
141
+
142
+ parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
143
+ parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
144
+
145
+ parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
146
+ parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
147
+ parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
148
+ parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
149
+ parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
150
+
151
+ parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
152
+ parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
153
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
154
+ parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
155
+
156
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
157
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
158
+ parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
159
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
160
+ parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
161
+
162
+ args = parser.parse_args().__dict__
163
+ model_name: str = args.pop("model")
164
+ model_dir: str = args.pop("model_dir")
165
+ output_dir: str = args.pop("output_dir")
166
+ device: str = args.pop("device")
167
+ os.makedirs(output_dir, exist_ok=True)
168
+
169
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
170
+ if args["language"] is not None:
171
+ warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
172
+ args["language"] = "en"
173
+
174
+ temperature = args.pop("temperature")
175
+ temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
176
+ if temperature_increment_on_fallback is not None:
177
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
178
+ else:
179
+ temperature = [temperature]
180
+
181
+ threads = args.pop("threads")
182
+ if threads > 0:
183
+ torch.set_num_threads(threads)
184
+
185
+ from . import load_model
186
+ model = load_model(model_name, device=device, download_root=model_dir)
187
+
188
+ for audio_path in args.pop("audio"):
189
+ result = transcribe(model, audio_path, temperature=temperature, **args)
190
+
191
+ audio_basename = os.path.basename(audio_path)
192
+
193
+ # save TXT
194
+ with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
195
+ write_txt(result["segments"], file=txt)
196
+
197
+ # save VTT
198
+ with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
199
+ write_vtt(result["segments"], file=vtt)
200
+
201
+ # save SRT
202
+ with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
203
+ write_srt(result["segments"], file=srt)
204
+
205
+
206
+ if __name__ == '__main__':
207
+ cli()
latentsync/whisper/whisper/utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import zlib
2
+ from typing import Iterator, TextIO
3
+
4
+
5
+ def exact_div(x, y):
6
+ assert x % y == 0
7
+ return x // y
8
+
9
+
10
+ def str2bool(string):
11
+ str2val = {"True": True, "False": False}
12
+ if string in str2val:
13
+ return str2val[string]
14
+ else:
15
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
16
+
17
+
18
+ def optional_int(string):
19
+ return None if string == "None" else int(string)
20
+
21
+
22
+ def optional_float(string):
23
+ return None if string == "None" else float(string)
24
+
25
+
26
+ def compression_ratio(text) -> float:
27
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
28
+
29
+
30
+ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
31
+ assert seconds >= 0, "non-negative timestamp expected"
32
+ milliseconds = round(seconds * 1000.0)
33
+
34
+ hours = milliseconds // 3_600_000
35
+ milliseconds -= hours * 3_600_000
36
+
37
+ minutes = milliseconds // 60_000
38
+ milliseconds -= minutes * 60_000
39
+
40
+ seconds = milliseconds // 1_000
41
+ milliseconds -= seconds * 1_000
42
+
43
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
44
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
45
+
46
+
47
+ def write_txt(transcript: Iterator[dict], file: TextIO):
48
+ for segment in transcript:
49
+ print(segment['text'].strip(), file=file, flush=True)
50
+
51
+
52
+ def write_vtt(transcript: Iterator[dict], file: TextIO):
53
+ print("WEBVTT\n", file=file)
54
+ for segment in transcript:
55
+ print(
56
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
57
+ f"{segment['text'].strip().replace('-->', '->')}\n",
58
+ file=file,
59
+ flush=True,
60
+ )
61
+
62
+
63
+ def write_srt(transcript: Iterator[dict], file: TextIO):
64
+ """
65
+ Write a transcript to a file in SRT format.
66
+
67
+ Example usage:
68
+ from pathlib import Path
69
+ from whisper.utils import write_srt
70
+
71
+ result = transcribe(model, audio_path, temperature=temperature, **args)
72
+
73
+ # save SRT
74
+ audio_basename = Path(audio_path).stem
75
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
76
+ write_srt(result["segments"], file=srt)
77
+ """
78
+ for i, segment in enumerate(transcript, start=1):
79
+ # write srt lines
80
+ print(
81
+ f"{i}\n"
82
+ f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
83
+ f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
84
+ f"{segment['text'].strip().replace('-->', '->')}\n",
85
+ file=file,
86
+ flush=True,
87
+ )
preprocess/affine_transform.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from latentsync.utils.util import read_video, write_video
16
+ from latentsync.utils.image_processor import ImageProcessor
17
+ import torch
18
+ from einops import rearrange
19
+ import os
20
+ import tqdm
21
+ import subprocess
22
+ from multiprocessing import Process
23
+ import shutil
24
+
25
+ paths = []
26
+
27
+
28
+ def gather_video_paths(input_dir, output_dir):
29
+ for video in sorted(os.listdir(input_dir)):
30
+ if video.endswith(".mp4"):
31
+ video_input = os.path.join(input_dir, video)
32
+ video_output = os.path.join(output_dir, video)
33
+ if os.path.isfile(video_output):
34
+ continue
35
+ paths.append((video_input, video_output))
36
+ elif os.path.isdir(os.path.join(input_dir, video)):
37
+ gather_video_paths(os.path.join(input_dir, video), os.path.join(output_dir, video))
38
+
39
+
40
+ class FaceDetector:
41
+ def __init__(self, resolution: int = 512, device: str = "cpu"):
42
+ self.image_processor = ImageProcessor(resolution, "fix_mask", device)
43
+
44
+ def affine_transform_video(self, video_path):
45
+ video_frames = read_video(video_path, change_fps=False)
46
+ results = []
47
+ for frame in video_frames:
48
+ frame, _, _ = self.image_processor.affine_transform(frame)
49
+ results.append(frame)
50
+ results = torch.stack(results)
51
+
52
+ results = rearrange(results, "f c h w -> f h w c").numpy()
53
+ return results
54
+
55
+ def close(self):
56
+ self.image_processor.close()
57
+
58
+
59
+ def combine_video_audio(video_frames, video_input_path, video_output_path, process_temp_dir):
60
+ video_name = os.path.basename(video_input_path)[:-4]
61
+ audio_temp = os.path.join(process_temp_dir, f"{video_name}_temp.wav")
62
+ video_temp = os.path.join(process_temp_dir, f"{video_name}_temp.mp4")
63
+
64
+ write_video(video_temp, video_frames, fps=25)
65
+
66
+ command = f"ffmpeg -y -loglevel error -i {video_input_path} -q:a 0 -map a {audio_temp}"
67
+ subprocess.run(command, shell=True)
68
+
69
+ os.makedirs(os.path.dirname(video_output_path), exist_ok=True)
70
+ command = f"ffmpeg -y -loglevel error -i {video_temp} -i {audio_temp} -c:v libx264 -c:a aac -map 0:v -map 1:a -q:v 0 -q:a 0 {video_output_path}"
71
+ subprocess.run(command, shell=True)
72
+
73
+ os.remove(audio_temp)
74
+ os.remove(video_temp)
75
+
76
+
77
+ def func(paths, process_temp_dir, device_id, resolution):
78
+ os.makedirs(process_temp_dir, exist_ok=True)
79
+ face_detector = FaceDetector(resolution, f"cuda:{device_id}")
80
+
81
+ for video_input, video_output in paths:
82
+ if os.path.isfile(video_output):
83
+ continue
84
+ try:
85
+ video_frames = face_detector.affine_transform_video(video_input)
86
+ except Exception as e: # Handle the exception of face not detcted
87
+ print(f"Exception: {e} - {video_input}")
88
+ continue
89
+
90
+ os.makedirs(os.path.dirname(video_output), exist_ok=True)
91
+ combine_video_audio(video_frames, video_input, video_output, process_temp_dir)
92
+ print(f"Saved: {video_output}")
93
+
94
+ face_detector.close()
95
+
96
+
97
+ def split(a, n):
98
+ k, m = divmod(len(a), n)
99
+ return (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n))
100
+
101
+
102
+ def affine_transform_multi_gpus(input_dir, output_dir, temp_dir, resolution, num_workers):
103
+ print(f"Recursively gathering video paths of {input_dir} ...")
104
+ gather_video_paths(input_dir, output_dir)
105
+ num_devices = torch.cuda.device_count()
106
+ if num_devices == 0:
107
+ raise RuntimeError("No GPUs found")
108
+
109
+ if os.path.exists(temp_dir):
110
+ shutil.rmtree(temp_dir)
111
+ os.makedirs(temp_dir, exist_ok=True)
112
+
113
+ split_paths = list(split(paths, num_workers * num_devices))
114
+
115
+ processes = []
116
+
117
+ for i in range(num_devices):
118
+ for j in range(num_workers):
119
+ process_index = i * num_workers + j
120
+ process = Process(
121
+ target=func, args=(split_paths[process_index], os.path.join(temp_dir, f"process_{i}"), i, resolution)
122
+ )
123
+ process.start()
124
+ processes.append(process)
125
+
126
+ for process in processes:
127
+ process.join()
128
+
129
+
130
+ if __name__ == "__main__":
131
+ input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/avatars/resampled/train"
132
+ output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/avatars/affine_transformed/train"
133
+ temp_dir = "temp"
134
+ resolution = 256
135
+ num_workers = 10 # How many processes per device
136
+
137
+ affine_transform_multi_gpus(input_dir, output_dir, temp_dir, resolution, num_workers)