Spaces:
Running
Running
This view is limited to 50 files because it contains too many changes.
See raw diff
- latentsync/data/syncnet_dataset.py +153 -0
- latentsync/data/unet_dataset.py +164 -0
- latentsync/models/attention.py +492 -0
- latentsync/models/motion_module.py +332 -0
- latentsync/models/resnet.py +234 -0
- latentsync/models/syncnet.py +233 -0
- latentsync/models/syncnet_wav2lip.py +90 -0
- latentsync/models/unet.py +528 -0
- latentsync/models/unet_blocks.py +903 -0
- latentsync/models/utils.py +19 -0
- latentsync/pipelines/lipsync_pipeline.py +470 -0
- latentsync/trepa/__init__.py +64 -0
- latentsync/trepa/third_party/VideoMAEv2/__init__.py +0 -0
- latentsync/trepa/third_party/VideoMAEv2/utils.py +81 -0
- latentsync/trepa/third_party/VideoMAEv2/videomaev2_finetune.py +539 -0
- latentsync/trepa/third_party/VideoMAEv2/videomaev2_pretrain.py +469 -0
- latentsync/trepa/third_party/__init__.py +0 -0
- latentsync/trepa/utils/__init__.py +0 -0
- latentsync/trepa/utils/data_utils.py +321 -0
- latentsync/trepa/utils/metric_utils.py +161 -0
- latentsync/utils/affine_transform.py +138 -0
- latentsync/utils/audio.py +194 -0
- latentsync/utils/av_reader.py +157 -0
- latentsync/utils/image_processor.py +342 -0
- latentsync/utils/mask.png +0 -0
- latentsync/utils/util.py +365 -0
- latentsync/whisper/audio2feature.py +166 -0
- latentsync/whisper/whisper/__init__.py +119 -0
- latentsync/whisper/whisper/__main__.py +4 -0
- latentsync/whisper/whisper/assets/gpt2/merges.txt +0 -0
- latentsync/whisper/whisper/assets/gpt2/special_tokens_map.json +1 -0
- latentsync/whisper/whisper/assets/gpt2/tokenizer_config.json +1 -0
- latentsync/whisper/whisper/assets/gpt2/vocab.json +0 -0
- latentsync/whisper/whisper/assets/mel_filters.npz +3 -0
- latentsync/whisper/whisper/assets/multilingual/added_tokens.json +1 -0
- latentsync/whisper/whisper/assets/multilingual/merges.txt +0 -0
- latentsync/whisper/whisper/assets/multilingual/special_tokens_map.json +1 -0
- latentsync/whisper/whisper/assets/multilingual/tokenizer_config.json +1 -0
- latentsync/whisper/whisper/assets/multilingual/vocab.json +0 -0
- latentsync/whisper/whisper/audio.py +125 -0
- latentsync/whisper/whisper/decoding.py +729 -0
- latentsync/whisper/whisper/model.py +290 -0
- latentsync/whisper/whisper/normalizers/__init__.py +2 -0
- latentsync/whisper/whisper/normalizers/basic.py +71 -0
- latentsync/whisper/whisper/normalizers/english.json +1742 -0
- latentsync/whisper/whisper/normalizers/english.py +543 -0
- latentsync/whisper/whisper/tokenizer.py +331 -0
- latentsync/whisper/whisper/transcribe.py +207 -0
- latentsync/whisper/whisper/utils.py +87 -0
- preprocess/affine_transform.py +137 -0
latentsync/data/syncnet_dataset.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import numpy as np
|
17 |
+
from torch.utils.data import Dataset
|
18 |
+
import torch
|
19 |
+
import random
|
20 |
+
from ..utils.util import gather_video_paths_recursively
|
21 |
+
from ..utils.image_processor import ImageProcessor
|
22 |
+
from ..utils.audio import melspectrogram
|
23 |
+
import math
|
24 |
+
|
25 |
+
from decord import AudioReader, VideoReader, cpu
|
26 |
+
|
27 |
+
|
28 |
+
class SyncNetDataset(Dataset):
|
29 |
+
def __init__(self, data_dir: str, fileslist: str, config):
|
30 |
+
if fileslist != "":
|
31 |
+
with open(fileslist) as file:
|
32 |
+
self.video_paths = [line.rstrip() for line in file]
|
33 |
+
elif data_dir != "":
|
34 |
+
self.video_paths = gather_video_paths_recursively(data_dir)
|
35 |
+
else:
|
36 |
+
raise ValueError("data_dir and fileslist cannot be both empty")
|
37 |
+
|
38 |
+
self.resolution = config.data.resolution
|
39 |
+
self.num_frames = config.data.num_frames
|
40 |
+
|
41 |
+
self.mel_window_length = math.ceil(self.num_frames / 5 * 16)
|
42 |
+
|
43 |
+
self.audio_sample_rate = config.data.audio_sample_rate
|
44 |
+
self.video_fps = config.data.video_fps
|
45 |
+
self.audio_samples_length = int(
|
46 |
+
config.data.audio_sample_rate // config.data.video_fps * config.data.num_frames
|
47 |
+
)
|
48 |
+
self.image_processor = ImageProcessor(resolution=config.data.resolution, mask="half")
|
49 |
+
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
|
50 |
+
os.makedirs(self.audio_mel_cache_dir, exist_ok=True)
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self.video_paths)
|
54 |
+
|
55 |
+
def read_audio(self, video_path: str):
|
56 |
+
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
|
57 |
+
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
|
58 |
+
return torch.from_numpy(original_mel)
|
59 |
+
|
60 |
+
def crop_audio_window(self, original_mel, start_index):
|
61 |
+
start_idx = int(80.0 * (start_index / float(self.video_fps)))
|
62 |
+
end_idx = start_idx + self.mel_window_length
|
63 |
+
return original_mel[:, start_idx:end_idx].unsqueeze(0)
|
64 |
+
|
65 |
+
def get_frames(self, video_reader: VideoReader):
|
66 |
+
total_num_frames = len(video_reader)
|
67 |
+
|
68 |
+
start_idx = random.randint(0, total_num_frames - self.num_frames)
|
69 |
+
frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
|
70 |
+
|
71 |
+
while True:
|
72 |
+
wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
|
73 |
+
# wrong_start_idx = random.randint(
|
74 |
+
# max(0, start_idx - 25), min(total_num_frames - self.num_frames, start_idx + 25)
|
75 |
+
# )
|
76 |
+
if wrong_start_idx == start_idx:
|
77 |
+
continue
|
78 |
+
# if wrong_start_idx >= start_idx - self.num_frames and wrong_start_idx <= start_idx + self.num_frames:
|
79 |
+
# continue
|
80 |
+
wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
|
81 |
+
break
|
82 |
+
|
83 |
+
frames = video_reader.get_batch(frames_index).asnumpy()
|
84 |
+
wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
|
85 |
+
|
86 |
+
return frames, wrong_frames, start_idx
|
87 |
+
|
88 |
+
def worker_init_fn(self, worker_id):
|
89 |
+
# Initialize the face mesh object in each worker process,
|
90 |
+
# because the face mesh object cannot be called in subprocesses
|
91 |
+
self.worker_id = worker_id
|
92 |
+
# setattr(self, f"image_processor_{worker_id}", ImageProcessor(self.resolution, self.mask))
|
93 |
+
|
94 |
+
def __getitem__(self, idx):
|
95 |
+
# image_processor = getattr(self, f"image_processor_{self.worker_id}")
|
96 |
+
while True:
|
97 |
+
try:
|
98 |
+
idx = random.randint(0, len(self) - 1)
|
99 |
+
|
100 |
+
# Get video file path
|
101 |
+
video_path = self.video_paths[idx]
|
102 |
+
|
103 |
+
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
|
104 |
+
|
105 |
+
if len(vr) < 2 * self.num_frames:
|
106 |
+
continue
|
107 |
+
|
108 |
+
frames, wrong_frames, start_idx = self.get_frames(vr)
|
109 |
+
|
110 |
+
mel_cache_path = os.path.join(
|
111 |
+
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
|
112 |
+
)
|
113 |
+
|
114 |
+
if os.path.isfile(mel_cache_path):
|
115 |
+
try:
|
116 |
+
original_mel = torch.load(mel_cache_path)
|
117 |
+
except Exception as e:
|
118 |
+
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
|
119 |
+
os.remove(mel_cache_path)
|
120 |
+
original_mel = self.read_audio(video_path)
|
121 |
+
torch.save(original_mel, mel_cache_path)
|
122 |
+
else:
|
123 |
+
original_mel = self.read_audio(video_path)
|
124 |
+
torch.save(original_mel, mel_cache_path)
|
125 |
+
|
126 |
+
mel = self.crop_audio_window(original_mel, start_idx)
|
127 |
+
|
128 |
+
if mel.shape[-1] != self.mel_window_length:
|
129 |
+
continue
|
130 |
+
|
131 |
+
if random.choice([True, False]):
|
132 |
+
y = torch.ones(1).float()
|
133 |
+
chosen_frames = frames
|
134 |
+
else:
|
135 |
+
y = torch.zeros(1).float()
|
136 |
+
chosen_frames = wrong_frames
|
137 |
+
|
138 |
+
chosen_frames = self.image_processor.process_images(chosen_frames)
|
139 |
+
# chosen_frames, _, _ = image_processor.prepare_masks_and_masked_images(
|
140 |
+
# chosen_frames, affine_transform=True
|
141 |
+
# )
|
142 |
+
|
143 |
+
vr.seek(0) # avoid memory leak
|
144 |
+
break
|
145 |
+
|
146 |
+
except Exception as e: # Handle the exception of face not detcted
|
147 |
+
print(f"{type(e).__name__} - {e} - {video_path}")
|
148 |
+
if "vr" in locals():
|
149 |
+
vr.seek(0) # avoid memory leak
|
150 |
+
|
151 |
+
sample = dict(frames=chosen_frames, audio_samples=mel, y=y)
|
152 |
+
|
153 |
+
return sample
|
latentsync/data/unet_dataset.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import numpy as np
|
17 |
+
from torch.utils.data import Dataset
|
18 |
+
import torch
|
19 |
+
import random
|
20 |
+
import cv2
|
21 |
+
from ..utils.image_processor import ImageProcessor, load_fixed_mask
|
22 |
+
from ..utils.audio import melspectrogram
|
23 |
+
from decord import AudioReader, VideoReader, cpu
|
24 |
+
|
25 |
+
|
26 |
+
class UNetDataset(Dataset):
|
27 |
+
def __init__(self, train_data_dir: str, config):
|
28 |
+
if config.data.train_fileslist != "":
|
29 |
+
with open(config.data.train_fileslist) as file:
|
30 |
+
self.video_paths = [line.rstrip() for line in file]
|
31 |
+
elif train_data_dir != "":
|
32 |
+
self.video_paths = []
|
33 |
+
for file in os.listdir(train_data_dir):
|
34 |
+
if file.endswith(".mp4"):
|
35 |
+
self.video_paths.append(os.path.join(train_data_dir, file))
|
36 |
+
else:
|
37 |
+
raise ValueError("data_dir and fileslist cannot be both empty")
|
38 |
+
|
39 |
+
self.resolution = config.data.resolution
|
40 |
+
self.num_frames = config.data.num_frames
|
41 |
+
|
42 |
+
if self.num_frames == 16:
|
43 |
+
self.mel_window_length = 52
|
44 |
+
elif self.num_frames == 5:
|
45 |
+
self.mel_window_length = 16
|
46 |
+
else:
|
47 |
+
raise NotImplementedError("Only support 16 and 5 frames now")
|
48 |
+
|
49 |
+
self.audio_sample_rate = config.data.audio_sample_rate
|
50 |
+
self.video_fps = config.data.video_fps
|
51 |
+
self.mask = config.data.mask
|
52 |
+
self.mask_image = load_fixed_mask(self.resolution)
|
53 |
+
self.load_audio_data = config.model.add_audio_layer and config.run.use_syncnet
|
54 |
+
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
|
55 |
+
os.makedirs(self.audio_mel_cache_dir, exist_ok=True)
|
56 |
+
|
57 |
+
def __len__(self):
|
58 |
+
return len(self.video_paths)
|
59 |
+
|
60 |
+
def read_audio(self, video_path: str):
|
61 |
+
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
|
62 |
+
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
|
63 |
+
return torch.from_numpy(original_mel)
|
64 |
+
|
65 |
+
def crop_audio_window(self, original_mel, start_index):
|
66 |
+
start_idx = int(80.0 * (start_index / float(self.video_fps)))
|
67 |
+
end_idx = start_idx + self.mel_window_length
|
68 |
+
return original_mel[:, start_idx:end_idx].unsqueeze(0)
|
69 |
+
|
70 |
+
def get_frames(self, video_reader: VideoReader):
|
71 |
+
total_num_frames = len(video_reader)
|
72 |
+
|
73 |
+
start_idx = random.randint(self.num_frames // 2, total_num_frames - self.num_frames - self.num_frames // 2)
|
74 |
+
frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
|
75 |
+
|
76 |
+
while True:
|
77 |
+
wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
|
78 |
+
if wrong_start_idx > start_idx - self.num_frames and wrong_start_idx < start_idx + self.num_frames:
|
79 |
+
continue
|
80 |
+
wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
|
81 |
+
break
|
82 |
+
|
83 |
+
frames = video_reader.get_batch(frames_index).asnumpy()
|
84 |
+
wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
|
85 |
+
|
86 |
+
return frames, wrong_frames, start_idx
|
87 |
+
|
88 |
+
def worker_init_fn(self, worker_id):
|
89 |
+
# Initialize the face mesh object in each worker process,
|
90 |
+
# because the face mesh object cannot be called in subprocesses
|
91 |
+
self.worker_id = worker_id
|
92 |
+
setattr(
|
93 |
+
self,
|
94 |
+
f"image_processor_{worker_id}",
|
95 |
+
ImageProcessor(self.resolution, self.mask, mask_image=self.mask_image),
|
96 |
+
)
|
97 |
+
|
98 |
+
def __getitem__(self, idx):
|
99 |
+
image_processor = getattr(self, f"image_processor_{self.worker_id}")
|
100 |
+
while True:
|
101 |
+
try:
|
102 |
+
idx = random.randint(0, len(self) - 1)
|
103 |
+
|
104 |
+
# Get video file path
|
105 |
+
video_path = self.video_paths[idx]
|
106 |
+
|
107 |
+
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
|
108 |
+
|
109 |
+
if len(vr) < 3 * self.num_frames:
|
110 |
+
continue
|
111 |
+
|
112 |
+
continuous_frames, ref_frames, start_idx = self.get_frames(vr)
|
113 |
+
|
114 |
+
if self.load_audio_data:
|
115 |
+
mel_cache_path = os.path.join(
|
116 |
+
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
|
117 |
+
)
|
118 |
+
|
119 |
+
if os.path.isfile(mel_cache_path):
|
120 |
+
try:
|
121 |
+
original_mel = torch.load(mel_cache_path)
|
122 |
+
except Exception as e:
|
123 |
+
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
|
124 |
+
os.remove(mel_cache_path)
|
125 |
+
original_mel = self.read_audio(video_path)
|
126 |
+
torch.save(original_mel, mel_cache_path)
|
127 |
+
else:
|
128 |
+
original_mel = self.read_audio(video_path)
|
129 |
+
torch.save(original_mel, mel_cache_path)
|
130 |
+
|
131 |
+
mel = self.crop_audio_window(original_mel, start_idx)
|
132 |
+
|
133 |
+
if mel.shape[-1] != self.mel_window_length:
|
134 |
+
continue
|
135 |
+
else:
|
136 |
+
mel = []
|
137 |
+
|
138 |
+
gt, masked_gt, mask = image_processor.prepare_masks_and_masked_images(
|
139 |
+
continuous_frames, affine_transform=False
|
140 |
+
)
|
141 |
+
|
142 |
+
if self.mask == "fix_mask":
|
143 |
+
ref, _, _ = image_processor.prepare_masks_and_masked_images(ref_frames, affine_transform=False)
|
144 |
+
else:
|
145 |
+
ref = image_processor.process_images(ref_frames)
|
146 |
+
vr.seek(0) # avoid memory leak
|
147 |
+
break
|
148 |
+
|
149 |
+
except Exception as e: # Handle the exception of face not detcted
|
150 |
+
print(f"{type(e).__name__} - {e} - {video_path}")
|
151 |
+
if "vr" in locals():
|
152 |
+
vr.seek(0) # avoid memory leak
|
153 |
+
|
154 |
+
sample = dict(
|
155 |
+
gt=gt,
|
156 |
+
masked_gt=masked_gt,
|
157 |
+
ref=ref,
|
158 |
+
mel=mel,
|
159 |
+
mask=mask,
|
160 |
+
video_path=video_path,
|
161 |
+
start_idx=start_idx,
|
162 |
+
)
|
163 |
+
|
164 |
+
return sample
|
latentsync/models/attention.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from turtle import forward
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
12 |
+
from diffusers.modeling_utils import ModelMixin
|
13 |
+
from diffusers.utils import BaseOutput
|
14 |
+
from diffusers.utils.import_utils import is_xformers_available
|
15 |
+
from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
|
16 |
+
|
17 |
+
from einops import rearrange, repeat
|
18 |
+
from .utils import zero_module
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class Transformer3DModelOutput(BaseOutput):
|
23 |
+
sample: torch.FloatTensor
|
24 |
+
|
25 |
+
|
26 |
+
if is_xformers_available():
|
27 |
+
import xformers
|
28 |
+
import xformers.ops
|
29 |
+
else:
|
30 |
+
xformers = None
|
31 |
+
|
32 |
+
|
33 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
34 |
+
@register_to_config
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
num_attention_heads: int = 16,
|
38 |
+
attention_head_dim: int = 88,
|
39 |
+
in_channels: Optional[int] = None,
|
40 |
+
num_layers: int = 1,
|
41 |
+
dropout: float = 0.0,
|
42 |
+
norm_num_groups: int = 32,
|
43 |
+
cross_attention_dim: Optional[int] = None,
|
44 |
+
attention_bias: bool = False,
|
45 |
+
activation_fn: str = "geglu",
|
46 |
+
num_embeds_ada_norm: Optional[int] = None,
|
47 |
+
use_linear_projection: bool = False,
|
48 |
+
only_cross_attention: bool = False,
|
49 |
+
upcast_attention: bool = False,
|
50 |
+
use_motion_module: bool = False,
|
51 |
+
unet_use_cross_frame_attention=None,
|
52 |
+
unet_use_temporal_attention=None,
|
53 |
+
add_audio_layer=False,
|
54 |
+
audio_condition_method="cross_attn",
|
55 |
+
custom_audio_layer: bool = False,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.use_linear_projection = use_linear_projection
|
59 |
+
self.num_attention_heads = num_attention_heads
|
60 |
+
self.attention_head_dim = attention_head_dim
|
61 |
+
inner_dim = num_attention_heads * attention_head_dim
|
62 |
+
|
63 |
+
# Define input layers
|
64 |
+
self.in_channels = in_channels
|
65 |
+
|
66 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
67 |
+
if use_linear_projection:
|
68 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
69 |
+
else:
|
70 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
71 |
+
|
72 |
+
if not custom_audio_layer:
|
73 |
+
# Define transformers blocks
|
74 |
+
self.transformer_blocks = nn.ModuleList(
|
75 |
+
[
|
76 |
+
BasicTransformerBlock(
|
77 |
+
inner_dim,
|
78 |
+
num_attention_heads,
|
79 |
+
attention_head_dim,
|
80 |
+
dropout=dropout,
|
81 |
+
cross_attention_dim=cross_attention_dim,
|
82 |
+
activation_fn=activation_fn,
|
83 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
84 |
+
attention_bias=attention_bias,
|
85 |
+
only_cross_attention=only_cross_attention,
|
86 |
+
upcast_attention=upcast_attention,
|
87 |
+
use_motion_module=use_motion_module,
|
88 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
89 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
90 |
+
add_audio_layer=add_audio_layer,
|
91 |
+
custom_audio_layer=custom_audio_layer,
|
92 |
+
audio_condition_method=audio_condition_method,
|
93 |
+
)
|
94 |
+
for d in range(num_layers)
|
95 |
+
]
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
self.transformer_blocks = nn.ModuleList(
|
99 |
+
[
|
100 |
+
AudioTransformerBlock(
|
101 |
+
inner_dim,
|
102 |
+
num_attention_heads,
|
103 |
+
attention_head_dim,
|
104 |
+
dropout=dropout,
|
105 |
+
cross_attention_dim=cross_attention_dim,
|
106 |
+
activation_fn=activation_fn,
|
107 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
108 |
+
attention_bias=attention_bias,
|
109 |
+
only_cross_attention=only_cross_attention,
|
110 |
+
upcast_attention=upcast_attention,
|
111 |
+
use_motion_module=use_motion_module,
|
112 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
113 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
114 |
+
add_audio_layer=add_audio_layer,
|
115 |
+
)
|
116 |
+
for d in range(num_layers)
|
117 |
+
]
|
118 |
+
)
|
119 |
+
|
120 |
+
# 4. Define output layers
|
121 |
+
if use_linear_projection:
|
122 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
123 |
+
else:
|
124 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
125 |
+
|
126 |
+
if custom_audio_layer:
|
127 |
+
self.proj_out = zero_module(self.proj_out)
|
128 |
+
|
129 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
130 |
+
# Input
|
131 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
132 |
+
video_length = hidden_states.shape[2]
|
133 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
134 |
+
|
135 |
+
# No need to do this for audio input, because different audio samples are independent
|
136 |
+
# encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
|
137 |
+
|
138 |
+
batch, channel, height, weight = hidden_states.shape
|
139 |
+
residual = hidden_states
|
140 |
+
|
141 |
+
hidden_states = self.norm(hidden_states)
|
142 |
+
if not self.use_linear_projection:
|
143 |
+
hidden_states = self.proj_in(hidden_states)
|
144 |
+
inner_dim = hidden_states.shape[1]
|
145 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
146 |
+
else:
|
147 |
+
inner_dim = hidden_states.shape[1]
|
148 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
149 |
+
hidden_states = self.proj_in(hidden_states)
|
150 |
+
|
151 |
+
# Blocks
|
152 |
+
for block in self.transformer_blocks:
|
153 |
+
hidden_states = block(
|
154 |
+
hidden_states,
|
155 |
+
encoder_hidden_states=encoder_hidden_states,
|
156 |
+
timestep=timestep,
|
157 |
+
video_length=video_length,
|
158 |
+
)
|
159 |
+
|
160 |
+
# Output
|
161 |
+
if not self.use_linear_projection:
|
162 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
163 |
+
hidden_states = self.proj_out(hidden_states)
|
164 |
+
else:
|
165 |
+
hidden_states = self.proj_out(hidden_states)
|
166 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
167 |
+
|
168 |
+
output = hidden_states + residual
|
169 |
+
|
170 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
171 |
+
if not return_dict:
|
172 |
+
return (output,)
|
173 |
+
|
174 |
+
return Transformer3DModelOutput(sample=output)
|
175 |
+
|
176 |
+
|
177 |
+
class BasicTransformerBlock(nn.Module):
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
dim: int,
|
181 |
+
num_attention_heads: int,
|
182 |
+
attention_head_dim: int,
|
183 |
+
dropout=0.0,
|
184 |
+
cross_attention_dim: Optional[int] = None,
|
185 |
+
activation_fn: str = "geglu",
|
186 |
+
num_embeds_ada_norm: Optional[int] = None,
|
187 |
+
attention_bias: bool = False,
|
188 |
+
only_cross_attention: bool = False,
|
189 |
+
upcast_attention: bool = False,
|
190 |
+
use_motion_module: bool = False,
|
191 |
+
unet_use_cross_frame_attention=None,
|
192 |
+
unet_use_temporal_attention=None,
|
193 |
+
add_audio_layer=False,
|
194 |
+
custom_audio_layer=False,
|
195 |
+
audio_condition_method="cross_attn",
|
196 |
+
):
|
197 |
+
super().__init__()
|
198 |
+
self.only_cross_attention = only_cross_attention
|
199 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
200 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
201 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
202 |
+
self.use_motion_module = use_motion_module
|
203 |
+
self.add_audio_layer = add_audio_layer
|
204 |
+
|
205 |
+
# SC-Attn
|
206 |
+
assert unet_use_cross_frame_attention is not None
|
207 |
+
if unet_use_cross_frame_attention:
|
208 |
+
raise NotImplementedError("SparseCausalAttention2D not implemented yet.")
|
209 |
+
else:
|
210 |
+
self.attn1 = CrossAttention(
|
211 |
+
query_dim=dim,
|
212 |
+
heads=num_attention_heads,
|
213 |
+
dim_head=attention_head_dim,
|
214 |
+
dropout=dropout,
|
215 |
+
bias=attention_bias,
|
216 |
+
upcast_attention=upcast_attention,
|
217 |
+
)
|
218 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
219 |
+
|
220 |
+
# Cross-Attn
|
221 |
+
if add_audio_layer and audio_condition_method == "cross_attn" and not custom_audio_layer:
|
222 |
+
self.audio_cross_attn = AudioCrossAttn(
|
223 |
+
dim=dim,
|
224 |
+
cross_attention_dim=cross_attention_dim,
|
225 |
+
num_attention_heads=num_attention_heads,
|
226 |
+
attention_head_dim=attention_head_dim,
|
227 |
+
dropout=dropout,
|
228 |
+
attention_bias=attention_bias,
|
229 |
+
upcast_attention=upcast_attention,
|
230 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
231 |
+
use_ada_layer_norm=self.use_ada_layer_norm,
|
232 |
+
zero_proj_out=False,
|
233 |
+
)
|
234 |
+
else:
|
235 |
+
self.audio_cross_attn = None
|
236 |
+
|
237 |
+
# Feed-forward
|
238 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
239 |
+
self.norm3 = nn.LayerNorm(dim)
|
240 |
+
|
241 |
+
# Temp-Attn
|
242 |
+
assert unet_use_temporal_attention is not None
|
243 |
+
if unet_use_temporal_attention:
|
244 |
+
self.attn_temp = CrossAttention(
|
245 |
+
query_dim=dim,
|
246 |
+
heads=num_attention_heads,
|
247 |
+
dim_head=attention_head_dim,
|
248 |
+
dropout=dropout,
|
249 |
+
bias=attention_bias,
|
250 |
+
upcast_attention=upcast_attention,
|
251 |
+
)
|
252 |
+
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
253 |
+
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
254 |
+
|
255 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
256 |
+
if not is_xformers_available():
|
257 |
+
print("Here is how to install it")
|
258 |
+
raise ModuleNotFoundError(
|
259 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
260 |
+
" xformers",
|
261 |
+
name="xformers",
|
262 |
+
)
|
263 |
+
elif not torch.cuda.is_available():
|
264 |
+
raise ValueError(
|
265 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
266 |
+
" available for GPU "
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
try:
|
270 |
+
# Make sure we can run the memory efficient attention
|
271 |
+
_ = xformers.ops.memory_efficient_attention(
|
272 |
+
torch.randn((1, 2, 40), device="cuda"),
|
273 |
+
torch.randn((1, 2, 40), device="cuda"),
|
274 |
+
torch.randn((1, 2, 40), device="cuda"),
|
275 |
+
)
|
276 |
+
except Exception as e:
|
277 |
+
raise e
|
278 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
279 |
+
if self.audio_cross_attn is not None:
|
280 |
+
self.audio_cross_attn.attn._use_memory_efficient_attention_xformers = (
|
281 |
+
use_memory_efficient_attention_xformers
|
282 |
+
)
|
283 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
284 |
+
|
285 |
+
def forward(
|
286 |
+
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
|
287 |
+
):
|
288 |
+
# SparseCausal-Attention
|
289 |
+
norm_hidden_states = (
|
290 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
291 |
+
)
|
292 |
+
|
293 |
+
# if self.only_cross_attention:
|
294 |
+
# hidden_states = (
|
295 |
+
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
296 |
+
# )
|
297 |
+
# else:
|
298 |
+
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
299 |
+
|
300 |
+
# pdb.set_trace()
|
301 |
+
if self.unet_use_cross_frame_attention:
|
302 |
+
hidden_states = (
|
303 |
+
self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
|
304 |
+
+ hidden_states
|
305 |
+
)
|
306 |
+
else:
|
307 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
308 |
+
|
309 |
+
if self.audio_cross_attn is not None and encoder_hidden_states is not None:
|
310 |
+
hidden_states = self.audio_cross_attn(
|
311 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
312 |
+
)
|
313 |
+
|
314 |
+
# Feed-forward
|
315 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
316 |
+
|
317 |
+
# Temporal-Attention
|
318 |
+
if self.unet_use_temporal_attention:
|
319 |
+
d = hidden_states.shape[1]
|
320 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
321 |
+
norm_hidden_states = (
|
322 |
+
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
|
323 |
+
)
|
324 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
325 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
326 |
+
|
327 |
+
return hidden_states
|
328 |
+
|
329 |
+
|
330 |
+
class AudioTransformerBlock(nn.Module):
|
331 |
+
def __init__(
|
332 |
+
self,
|
333 |
+
dim: int,
|
334 |
+
num_attention_heads: int,
|
335 |
+
attention_head_dim: int,
|
336 |
+
dropout=0.0,
|
337 |
+
cross_attention_dim: Optional[int] = None,
|
338 |
+
activation_fn: str = "geglu",
|
339 |
+
num_embeds_ada_norm: Optional[int] = None,
|
340 |
+
attention_bias: bool = False,
|
341 |
+
only_cross_attention: bool = False,
|
342 |
+
upcast_attention: bool = False,
|
343 |
+
use_motion_module: bool = False,
|
344 |
+
unet_use_cross_frame_attention=None,
|
345 |
+
unet_use_temporal_attention=None,
|
346 |
+
add_audio_layer=False,
|
347 |
+
):
|
348 |
+
super().__init__()
|
349 |
+
self.only_cross_attention = only_cross_attention
|
350 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
351 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
352 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
353 |
+
self.use_motion_module = use_motion_module
|
354 |
+
self.add_audio_layer = add_audio_layer
|
355 |
+
|
356 |
+
# SC-Attn
|
357 |
+
assert unet_use_cross_frame_attention is not None
|
358 |
+
if unet_use_cross_frame_attention:
|
359 |
+
raise NotImplementedError("SparseCausalAttention2D not implemented yet.")
|
360 |
+
else:
|
361 |
+
self.attn1 = CrossAttention(
|
362 |
+
query_dim=dim,
|
363 |
+
heads=num_attention_heads,
|
364 |
+
dim_head=attention_head_dim,
|
365 |
+
dropout=dropout,
|
366 |
+
bias=attention_bias,
|
367 |
+
upcast_attention=upcast_attention,
|
368 |
+
)
|
369 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
370 |
+
|
371 |
+
self.audio_cross_attn = AudioCrossAttn(
|
372 |
+
dim=dim,
|
373 |
+
cross_attention_dim=cross_attention_dim,
|
374 |
+
num_attention_heads=num_attention_heads,
|
375 |
+
attention_head_dim=attention_head_dim,
|
376 |
+
dropout=dropout,
|
377 |
+
attention_bias=attention_bias,
|
378 |
+
upcast_attention=upcast_attention,
|
379 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
380 |
+
use_ada_layer_norm=self.use_ada_layer_norm,
|
381 |
+
zero_proj_out=False,
|
382 |
+
)
|
383 |
+
|
384 |
+
# Feed-forward
|
385 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
386 |
+
self.norm3 = nn.LayerNorm(dim)
|
387 |
+
|
388 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
389 |
+
if not is_xformers_available():
|
390 |
+
print("Here is how to install it")
|
391 |
+
raise ModuleNotFoundError(
|
392 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
393 |
+
" xformers",
|
394 |
+
name="xformers",
|
395 |
+
)
|
396 |
+
elif not torch.cuda.is_available():
|
397 |
+
raise ValueError(
|
398 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
399 |
+
" available for GPU "
|
400 |
+
)
|
401 |
+
else:
|
402 |
+
try:
|
403 |
+
# Make sure we can run the memory efficient attention
|
404 |
+
_ = xformers.ops.memory_efficient_attention(
|
405 |
+
torch.randn((1, 2, 40), device="cuda"),
|
406 |
+
torch.randn((1, 2, 40), device="cuda"),
|
407 |
+
torch.randn((1, 2, 40), device="cuda"),
|
408 |
+
)
|
409 |
+
except Exception as e:
|
410 |
+
raise e
|
411 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
412 |
+
if self.audio_cross_attn is not None:
|
413 |
+
self.audio_cross_attn.attn._use_memory_efficient_attention_xformers = (
|
414 |
+
use_memory_efficient_attention_xformers
|
415 |
+
)
|
416 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
417 |
+
|
418 |
+
def forward(
|
419 |
+
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
|
420 |
+
):
|
421 |
+
# SparseCausal-Attention
|
422 |
+
norm_hidden_states = (
|
423 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
424 |
+
)
|
425 |
+
|
426 |
+
# pdb.set_trace()
|
427 |
+
if self.unet_use_cross_frame_attention:
|
428 |
+
hidden_states = (
|
429 |
+
self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length)
|
430 |
+
+ hidden_states
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
434 |
+
|
435 |
+
if self.audio_cross_attn is not None and encoder_hidden_states is not None:
|
436 |
+
hidden_states = self.audio_cross_attn(
|
437 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
438 |
+
)
|
439 |
+
|
440 |
+
# Feed-forward
|
441 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
442 |
+
|
443 |
+
return hidden_states
|
444 |
+
|
445 |
+
|
446 |
+
class AudioCrossAttn(nn.Module):
|
447 |
+
def __init__(
|
448 |
+
self,
|
449 |
+
dim,
|
450 |
+
cross_attention_dim,
|
451 |
+
num_attention_heads,
|
452 |
+
attention_head_dim,
|
453 |
+
dropout,
|
454 |
+
attention_bias,
|
455 |
+
upcast_attention,
|
456 |
+
num_embeds_ada_norm,
|
457 |
+
use_ada_layer_norm,
|
458 |
+
zero_proj_out=False,
|
459 |
+
):
|
460 |
+
super().__init__()
|
461 |
+
|
462 |
+
self.norm = AdaLayerNorm(dim, num_embeds_ada_norm) if use_ada_layer_norm else nn.LayerNorm(dim)
|
463 |
+
self.attn = CrossAttention(
|
464 |
+
query_dim=dim,
|
465 |
+
cross_attention_dim=cross_attention_dim,
|
466 |
+
heads=num_attention_heads,
|
467 |
+
dim_head=attention_head_dim,
|
468 |
+
dropout=dropout,
|
469 |
+
bias=attention_bias,
|
470 |
+
upcast_attention=upcast_attention,
|
471 |
+
)
|
472 |
+
|
473 |
+
if zero_proj_out:
|
474 |
+
self.proj_out = zero_module(nn.Linear(dim, dim))
|
475 |
+
|
476 |
+
self.zero_proj_out = zero_proj_out
|
477 |
+
self.use_ada_layer_norm = use_ada_layer_norm
|
478 |
+
|
479 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
|
480 |
+
previous_hidden_states = hidden_states
|
481 |
+
hidden_states = self.norm(hidden_states, timestep) if self.use_ada_layer_norm else self.norm(hidden_states)
|
482 |
+
|
483 |
+
if encoder_hidden_states.dim() == 4:
|
484 |
+
encoder_hidden_states = rearrange(encoder_hidden_states, "b f n d -> (b f) n d")
|
485 |
+
|
486 |
+
hidden_states = self.attn(
|
487 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
488 |
+
)
|
489 |
+
|
490 |
+
if self.zero_proj_out:
|
491 |
+
hidden_states = self.proj_out(hidden_states)
|
492 |
+
return hidden_states + previous_hidden_states
|
latentsync/models/motion_module.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
|
2 |
+
|
3 |
+
# Actually we don't use the motion module in the final version of LatentSync
|
4 |
+
# When we started the project, we used the codebase of AnimateDiff and tried motion module
|
5 |
+
# But the results are poor, and we decied to leave the code here for possible future usage
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
14 |
+
from diffusers.modeling_utils import ModelMixin
|
15 |
+
from diffusers.utils import BaseOutput
|
16 |
+
from diffusers.utils.import_utils import is_xformers_available
|
17 |
+
from diffusers.models.attention import CrossAttention, FeedForward
|
18 |
+
|
19 |
+
from einops import rearrange, repeat
|
20 |
+
import math
|
21 |
+
from .utils import zero_module
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class TemporalTransformer3DModelOutput(BaseOutput):
|
26 |
+
sample: torch.FloatTensor
|
27 |
+
|
28 |
+
|
29 |
+
if is_xformers_available():
|
30 |
+
import xformers
|
31 |
+
import xformers.ops
|
32 |
+
else:
|
33 |
+
xformers = None
|
34 |
+
|
35 |
+
|
36 |
+
def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
|
37 |
+
if motion_module_type == "Vanilla":
|
38 |
+
return VanillaTemporalModule(
|
39 |
+
in_channels=in_channels,
|
40 |
+
**motion_module_kwargs,
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
raise ValueError
|
44 |
+
|
45 |
+
|
46 |
+
class VanillaTemporalModule(nn.Module):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
in_channels,
|
50 |
+
num_attention_heads=8,
|
51 |
+
num_transformer_block=2,
|
52 |
+
attention_block_types=("Temporal_Self", "Temporal_Self"),
|
53 |
+
cross_frame_attention_mode=None,
|
54 |
+
temporal_position_encoding=False,
|
55 |
+
temporal_position_encoding_max_len=24,
|
56 |
+
temporal_attention_dim_div=1,
|
57 |
+
zero_initialize=True,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
62 |
+
in_channels=in_channels,
|
63 |
+
num_attention_heads=num_attention_heads,
|
64 |
+
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
65 |
+
num_layers=num_transformer_block,
|
66 |
+
attention_block_types=attention_block_types,
|
67 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
68 |
+
temporal_position_encoding=temporal_position_encoding,
|
69 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
70 |
+
)
|
71 |
+
|
72 |
+
if zero_initialize:
|
73 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
74 |
+
|
75 |
+
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
|
76 |
+
hidden_states = input_tensor
|
77 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
78 |
+
|
79 |
+
output = hidden_states
|
80 |
+
return output
|
81 |
+
|
82 |
+
|
83 |
+
class TemporalTransformer3DModel(nn.Module):
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
in_channels,
|
87 |
+
num_attention_heads,
|
88 |
+
attention_head_dim,
|
89 |
+
num_layers,
|
90 |
+
attention_block_types=(
|
91 |
+
"Temporal_Self",
|
92 |
+
"Temporal_Self",
|
93 |
+
),
|
94 |
+
dropout=0.0,
|
95 |
+
norm_num_groups=32,
|
96 |
+
cross_attention_dim=768,
|
97 |
+
activation_fn="geglu",
|
98 |
+
attention_bias=False,
|
99 |
+
upcast_attention=False,
|
100 |
+
cross_frame_attention_mode=None,
|
101 |
+
temporal_position_encoding=False,
|
102 |
+
temporal_position_encoding_max_len=24,
|
103 |
+
):
|
104 |
+
super().__init__()
|
105 |
+
|
106 |
+
inner_dim = num_attention_heads * attention_head_dim
|
107 |
+
|
108 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
109 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
110 |
+
|
111 |
+
self.transformer_blocks = nn.ModuleList(
|
112 |
+
[
|
113 |
+
TemporalTransformerBlock(
|
114 |
+
dim=inner_dim,
|
115 |
+
num_attention_heads=num_attention_heads,
|
116 |
+
attention_head_dim=attention_head_dim,
|
117 |
+
attention_block_types=attention_block_types,
|
118 |
+
dropout=dropout,
|
119 |
+
norm_num_groups=norm_num_groups,
|
120 |
+
cross_attention_dim=cross_attention_dim,
|
121 |
+
activation_fn=activation_fn,
|
122 |
+
attention_bias=attention_bias,
|
123 |
+
upcast_attention=upcast_attention,
|
124 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
125 |
+
temporal_position_encoding=temporal_position_encoding,
|
126 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
127 |
+
)
|
128 |
+
for d in range(num_layers)
|
129 |
+
]
|
130 |
+
)
|
131 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
132 |
+
|
133 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
134 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
135 |
+
video_length = hidden_states.shape[2]
|
136 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
137 |
+
|
138 |
+
batch, channel, height, weight = hidden_states.shape
|
139 |
+
residual = hidden_states
|
140 |
+
|
141 |
+
hidden_states = self.norm(hidden_states)
|
142 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
|
143 |
+
hidden_states = self.proj_in(hidden_states)
|
144 |
+
|
145 |
+
# Transformer Blocks
|
146 |
+
for block in self.transformer_blocks:
|
147 |
+
hidden_states = block(
|
148 |
+
hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length
|
149 |
+
)
|
150 |
+
|
151 |
+
# output
|
152 |
+
hidden_states = self.proj_out(hidden_states)
|
153 |
+
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2).contiguous()
|
154 |
+
|
155 |
+
output = hidden_states + residual
|
156 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
157 |
+
|
158 |
+
return output
|
159 |
+
|
160 |
+
|
161 |
+
class TemporalTransformerBlock(nn.Module):
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
dim,
|
165 |
+
num_attention_heads,
|
166 |
+
attention_head_dim,
|
167 |
+
attention_block_types=(
|
168 |
+
"Temporal_Self",
|
169 |
+
"Temporal_Self",
|
170 |
+
),
|
171 |
+
dropout=0.0,
|
172 |
+
norm_num_groups=32,
|
173 |
+
cross_attention_dim=768,
|
174 |
+
activation_fn="geglu",
|
175 |
+
attention_bias=False,
|
176 |
+
upcast_attention=False,
|
177 |
+
cross_frame_attention_mode=None,
|
178 |
+
temporal_position_encoding=False,
|
179 |
+
temporal_position_encoding_max_len=24,
|
180 |
+
):
|
181 |
+
super().__init__()
|
182 |
+
|
183 |
+
attention_blocks = []
|
184 |
+
norms = []
|
185 |
+
|
186 |
+
for block_name in attention_block_types:
|
187 |
+
attention_blocks.append(
|
188 |
+
VersatileAttention(
|
189 |
+
attention_mode=block_name.split("_")[0],
|
190 |
+
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
|
191 |
+
query_dim=dim,
|
192 |
+
heads=num_attention_heads,
|
193 |
+
dim_head=attention_head_dim,
|
194 |
+
dropout=dropout,
|
195 |
+
bias=attention_bias,
|
196 |
+
upcast_attention=upcast_attention,
|
197 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
198 |
+
temporal_position_encoding=temporal_position_encoding,
|
199 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
200 |
+
)
|
201 |
+
)
|
202 |
+
norms.append(nn.LayerNorm(dim))
|
203 |
+
|
204 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
205 |
+
self.norms = nn.ModuleList(norms)
|
206 |
+
|
207 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
208 |
+
self.ff_norm = nn.LayerNorm(dim)
|
209 |
+
|
210 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
211 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
212 |
+
norm_hidden_states = norm(hidden_states)
|
213 |
+
hidden_states = (
|
214 |
+
attention_block(
|
215 |
+
norm_hidden_states,
|
216 |
+
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
|
217 |
+
video_length=video_length,
|
218 |
+
)
|
219 |
+
+ hidden_states
|
220 |
+
)
|
221 |
+
|
222 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
223 |
+
|
224 |
+
output = hidden_states
|
225 |
+
return output
|
226 |
+
|
227 |
+
|
228 |
+
class PositionalEncoding(nn.Module):
|
229 |
+
def __init__(self, d_model, dropout=0.0, max_len=24):
|
230 |
+
super().__init__()
|
231 |
+
self.dropout = nn.Dropout(p=dropout)
|
232 |
+
position = torch.arange(max_len).unsqueeze(1)
|
233 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
234 |
+
pe = torch.zeros(1, max_len, d_model)
|
235 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
236 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
237 |
+
self.register_buffer("pe", pe)
|
238 |
+
|
239 |
+
def forward(self, x):
|
240 |
+
x = x + self.pe[:, : x.size(1)]
|
241 |
+
return self.dropout(x)
|
242 |
+
|
243 |
+
|
244 |
+
class VersatileAttention(CrossAttention):
|
245 |
+
def __init__(
|
246 |
+
self,
|
247 |
+
attention_mode=None,
|
248 |
+
cross_frame_attention_mode=None,
|
249 |
+
temporal_position_encoding=False,
|
250 |
+
temporal_position_encoding_max_len=24,
|
251 |
+
*args,
|
252 |
+
**kwargs,
|
253 |
+
):
|
254 |
+
super().__init__(*args, **kwargs)
|
255 |
+
assert attention_mode == "Temporal"
|
256 |
+
|
257 |
+
self.attention_mode = attention_mode
|
258 |
+
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
259 |
+
|
260 |
+
self.pos_encoder = (
|
261 |
+
PositionalEncoding(kwargs["query_dim"], dropout=0.0, max_len=temporal_position_encoding_max_len)
|
262 |
+
if (temporal_position_encoding and attention_mode == "Temporal")
|
263 |
+
else None
|
264 |
+
)
|
265 |
+
|
266 |
+
def extra_repr(self):
|
267 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
268 |
+
|
269 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
270 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
271 |
+
|
272 |
+
if self.attention_mode == "Temporal":
|
273 |
+
d = hidden_states.shape[1]
|
274 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
275 |
+
|
276 |
+
if self.pos_encoder is not None:
|
277 |
+
hidden_states = self.pos_encoder(hidden_states)
|
278 |
+
|
279 |
+
encoder_hidden_states = (
|
280 |
+
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
|
281 |
+
if encoder_hidden_states is not None
|
282 |
+
else encoder_hidden_states
|
283 |
+
)
|
284 |
+
else:
|
285 |
+
raise NotImplementedError
|
286 |
+
|
287 |
+
# encoder_hidden_states = encoder_hidden_states
|
288 |
+
|
289 |
+
if self.group_norm is not None:
|
290 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
291 |
+
|
292 |
+
query = self.to_q(hidden_states)
|
293 |
+
dim = query.shape[-1]
|
294 |
+
query = self.reshape_heads_to_batch_dim(query)
|
295 |
+
|
296 |
+
if self.added_kv_proj_dim is not None:
|
297 |
+
raise NotImplementedError
|
298 |
+
|
299 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
300 |
+
key = self.to_k(encoder_hidden_states)
|
301 |
+
value = self.to_v(encoder_hidden_states)
|
302 |
+
|
303 |
+
key = self.reshape_heads_to_batch_dim(key)
|
304 |
+
value = self.reshape_heads_to_batch_dim(value)
|
305 |
+
|
306 |
+
if attention_mask is not None:
|
307 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
308 |
+
target_length = query.shape[1]
|
309 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
310 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
311 |
+
|
312 |
+
# attention, what we cannot get enough of
|
313 |
+
if self._use_memory_efficient_attention_xformers:
|
314 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
315 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
316 |
+
hidden_states = hidden_states.to(query.dtype)
|
317 |
+
else:
|
318 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
319 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
320 |
+
else:
|
321 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
322 |
+
|
323 |
+
# linear proj
|
324 |
+
hidden_states = self.to_out[0](hidden_states)
|
325 |
+
|
326 |
+
# dropout
|
327 |
+
hidden_states = self.to_out[1](hidden_states)
|
328 |
+
|
329 |
+
if self.attention_mode == "Temporal":
|
330 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
331 |
+
|
332 |
+
return hidden_states
|
latentsync/models/resnet.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
|
10 |
+
class InflatedConv3d(nn.Conv2d):
|
11 |
+
def forward(self, x):
|
12 |
+
video_length = x.shape[2]
|
13 |
+
|
14 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
15 |
+
x = super().forward(x)
|
16 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
17 |
+
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class InflatedGroupNorm(nn.GroupNorm):
|
22 |
+
def forward(self, x):
|
23 |
+
video_length = x.shape[2]
|
24 |
+
|
25 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
26 |
+
x = super().forward(x)
|
27 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
28 |
+
|
29 |
+
return x
|
30 |
+
|
31 |
+
|
32 |
+
class Upsample3D(nn.Module):
|
33 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
34 |
+
super().__init__()
|
35 |
+
self.channels = channels
|
36 |
+
self.out_channels = out_channels or channels
|
37 |
+
self.use_conv = use_conv
|
38 |
+
self.use_conv_transpose = use_conv_transpose
|
39 |
+
self.name = name
|
40 |
+
|
41 |
+
conv = None
|
42 |
+
if use_conv_transpose:
|
43 |
+
raise NotImplementedError
|
44 |
+
elif use_conv:
|
45 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
|
46 |
+
|
47 |
+
def forward(self, hidden_states, output_size=None):
|
48 |
+
assert hidden_states.shape[1] == self.channels
|
49 |
+
|
50 |
+
if self.use_conv_transpose:
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
54 |
+
dtype = hidden_states.dtype
|
55 |
+
if dtype == torch.bfloat16:
|
56 |
+
hidden_states = hidden_states.to(torch.float32)
|
57 |
+
|
58 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
59 |
+
if hidden_states.shape[0] >= 64:
|
60 |
+
hidden_states = hidden_states.contiguous()
|
61 |
+
|
62 |
+
# if `output_size` is passed we force the interpolation output
|
63 |
+
# size and do not make use of `scale_factor=2`
|
64 |
+
if output_size is None:
|
65 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
|
66 |
+
else:
|
67 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
68 |
+
|
69 |
+
# If the input is bfloat16, we cast back to bfloat16
|
70 |
+
if dtype == torch.bfloat16:
|
71 |
+
hidden_states = hidden_states.to(dtype)
|
72 |
+
|
73 |
+
# if self.use_conv:
|
74 |
+
# if self.name == "conv":
|
75 |
+
# hidden_states = self.conv(hidden_states)
|
76 |
+
# else:
|
77 |
+
# hidden_states = self.Conv2d_0(hidden_states)
|
78 |
+
hidden_states = self.conv(hidden_states)
|
79 |
+
|
80 |
+
return hidden_states
|
81 |
+
|
82 |
+
|
83 |
+
class Downsample3D(nn.Module):
|
84 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
85 |
+
super().__init__()
|
86 |
+
self.channels = channels
|
87 |
+
self.out_channels = out_channels or channels
|
88 |
+
self.use_conv = use_conv
|
89 |
+
self.padding = padding
|
90 |
+
stride = 2
|
91 |
+
self.name = name
|
92 |
+
|
93 |
+
if use_conv:
|
94 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
def forward(self, hidden_states):
|
99 |
+
assert hidden_states.shape[1] == self.channels
|
100 |
+
if self.use_conv and self.padding == 0:
|
101 |
+
raise NotImplementedError
|
102 |
+
|
103 |
+
assert hidden_states.shape[1] == self.channels
|
104 |
+
hidden_states = self.conv(hidden_states)
|
105 |
+
|
106 |
+
return hidden_states
|
107 |
+
|
108 |
+
|
109 |
+
class ResnetBlock3D(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
*,
|
113 |
+
in_channels,
|
114 |
+
out_channels=None,
|
115 |
+
conv_shortcut=False,
|
116 |
+
dropout=0.0,
|
117 |
+
temb_channels=512,
|
118 |
+
groups=32,
|
119 |
+
groups_out=None,
|
120 |
+
pre_norm=True,
|
121 |
+
eps=1e-6,
|
122 |
+
non_linearity="swish",
|
123 |
+
time_embedding_norm="default",
|
124 |
+
output_scale_factor=1.0,
|
125 |
+
use_in_shortcut=None,
|
126 |
+
use_inflated_groupnorm=False,
|
127 |
+
):
|
128 |
+
super().__init__()
|
129 |
+
self.pre_norm = pre_norm
|
130 |
+
self.pre_norm = True
|
131 |
+
self.in_channels = in_channels
|
132 |
+
out_channels = in_channels if out_channels is None else out_channels
|
133 |
+
self.out_channels = out_channels
|
134 |
+
self.use_conv_shortcut = conv_shortcut
|
135 |
+
self.time_embedding_norm = time_embedding_norm
|
136 |
+
self.output_scale_factor = output_scale_factor
|
137 |
+
|
138 |
+
if groups_out is None:
|
139 |
+
groups_out = groups
|
140 |
+
|
141 |
+
assert use_inflated_groupnorm != None
|
142 |
+
if use_inflated_groupnorm:
|
143 |
+
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
144 |
+
else:
|
145 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
146 |
+
|
147 |
+
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
148 |
+
|
149 |
+
if temb_channels is not None:
|
150 |
+
time_emb_proj_out_channels = out_channels
|
151 |
+
# if self.time_embedding_norm == "default":
|
152 |
+
# time_emb_proj_out_channels = out_channels
|
153 |
+
# elif self.time_embedding_norm == "scale_shift":
|
154 |
+
# time_emb_proj_out_channels = out_channels * 2
|
155 |
+
# else:
|
156 |
+
# raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
157 |
+
|
158 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
159 |
+
else:
|
160 |
+
self.time_emb_proj = None
|
161 |
+
|
162 |
+
if self.time_embedding_norm == "scale_shift":
|
163 |
+
self.double_len_linear = torch.nn.Linear(time_emb_proj_out_channels, 2 * time_emb_proj_out_channels)
|
164 |
+
else:
|
165 |
+
self.double_len_linear = None
|
166 |
+
|
167 |
+
if use_inflated_groupnorm:
|
168 |
+
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
169 |
+
else:
|
170 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
171 |
+
|
172 |
+
self.dropout = torch.nn.Dropout(dropout)
|
173 |
+
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
174 |
+
|
175 |
+
if non_linearity == "swish":
|
176 |
+
self.nonlinearity = lambda x: F.silu(x)
|
177 |
+
elif non_linearity == "mish":
|
178 |
+
self.nonlinearity = Mish()
|
179 |
+
elif non_linearity == "silu":
|
180 |
+
self.nonlinearity = nn.SiLU()
|
181 |
+
|
182 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
183 |
+
|
184 |
+
self.conv_shortcut = None
|
185 |
+
if self.use_in_shortcut:
|
186 |
+
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
187 |
+
|
188 |
+
def forward(self, input_tensor, temb):
|
189 |
+
hidden_states = input_tensor
|
190 |
+
|
191 |
+
hidden_states = self.norm1(hidden_states)
|
192 |
+
hidden_states = self.nonlinearity(hidden_states)
|
193 |
+
|
194 |
+
hidden_states = self.conv1(hidden_states)
|
195 |
+
|
196 |
+
if temb is not None:
|
197 |
+
if temb.dim() == 2:
|
198 |
+
# input (1, 1280)
|
199 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))
|
200 |
+
temb = temb[:, :, None, None, None] # unsqueeze
|
201 |
+
else:
|
202 |
+
# input (1, 1280, 16)
|
203 |
+
temb = temb.permute(0, 2, 1)
|
204 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))
|
205 |
+
if self.double_len_linear is not None:
|
206 |
+
temb = self.double_len_linear(self.nonlinearity(temb))
|
207 |
+
temb = temb.permute(0, 2, 1)
|
208 |
+
temb = temb[:, :, :, None, None]
|
209 |
+
|
210 |
+
if temb is not None and self.time_embedding_norm == "default":
|
211 |
+
hidden_states = hidden_states + temb
|
212 |
+
|
213 |
+
hidden_states = self.norm2(hidden_states)
|
214 |
+
|
215 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
216 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
217 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
218 |
+
|
219 |
+
hidden_states = self.nonlinearity(hidden_states)
|
220 |
+
|
221 |
+
hidden_states = self.dropout(hidden_states)
|
222 |
+
hidden_states = self.conv2(hidden_states)
|
223 |
+
|
224 |
+
if self.conv_shortcut is not None:
|
225 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
226 |
+
|
227 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
228 |
+
|
229 |
+
return output_tensor
|
230 |
+
|
231 |
+
|
232 |
+
class Mish(torch.nn.Module):
|
233 |
+
def forward(self, hidden_states):
|
234 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
latentsync/models/syncnet.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn
|
17 |
+
from einops import rearrange
|
18 |
+
from torch.nn import functional as F
|
19 |
+
from ..utils.util import cosine_loss
|
20 |
+
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from diffusers.models.attention import CrossAttention, FeedForward
|
25 |
+
from diffusers.utils.import_utils import is_xformers_available
|
26 |
+
from einops import rearrange
|
27 |
+
|
28 |
+
|
29 |
+
class SyncNet(nn.Module):
|
30 |
+
def __init__(self, config):
|
31 |
+
super().__init__()
|
32 |
+
self.audio_encoder = DownEncoder2D(
|
33 |
+
in_channels=config["audio_encoder"]["in_channels"],
|
34 |
+
block_out_channels=config["audio_encoder"]["block_out_channels"],
|
35 |
+
downsample_factors=config["audio_encoder"]["downsample_factors"],
|
36 |
+
dropout=config["audio_encoder"]["dropout"],
|
37 |
+
attn_blocks=config["audio_encoder"]["attn_blocks"],
|
38 |
+
)
|
39 |
+
|
40 |
+
self.visual_encoder = DownEncoder2D(
|
41 |
+
in_channels=config["visual_encoder"]["in_channels"],
|
42 |
+
block_out_channels=config["visual_encoder"]["block_out_channels"],
|
43 |
+
downsample_factors=config["visual_encoder"]["downsample_factors"],
|
44 |
+
dropout=config["visual_encoder"]["dropout"],
|
45 |
+
attn_blocks=config["visual_encoder"]["attn_blocks"],
|
46 |
+
)
|
47 |
+
|
48 |
+
self.eval()
|
49 |
+
|
50 |
+
def forward(self, image_sequences, audio_sequences):
|
51 |
+
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
52 |
+
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
53 |
+
|
54 |
+
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
55 |
+
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
56 |
+
|
57 |
+
# Make them unit vectors
|
58 |
+
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
59 |
+
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
60 |
+
|
61 |
+
return vision_embeds, audio_embeds
|
62 |
+
|
63 |
+
|
64 |
+
class ResnetBlock2D(nn.Module):
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
in_channels: int,
|
68 |
+
out_channels: int,
|
69 |
+
dropout: float = 0.0,
|
70 |
+
norm_num_groups: int = 32,
|
71 |
+
eps: float = 1e-6,
|
72 |
+
act_fn: str = "silu",
|
73 |
+
downsample_factor=2,
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
78 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
79 |
+
|
80 |
+
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
|
81 |
+
self.dropout = nn.Dropout(dropout)
|
82 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
83 |
+
|
84 |
+
if act_fn == "relu":
|
85 |
+
self.act_fn = nn.ReLU()
|
86 |
+
elif act_fn == "silu":
|
87 |
+
self.act_fn = nn.SiLU()
|
88 |
+
|
89 |
+
if in_channels != out_channels:
|
90 |
+
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
91 |
+
else:
|
92 |
+
self.conv_shortcut = None
|
93 |
+
|
94 |
+
if isinstance(downsample_factor, list):
|
95 |
+
downsample_factor = tuple(downsample_factor)
|
96 |
+
|
97 |
+
if downsample_factor == 1:
|
98 |
+
self.downsample_conv = None
|
99 |
+
else:
|
100 |
+
self.downsample_conv = nn.Conv2d(
|
101 |
+
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
|
102 |
+
)
|
103 |
+
self.pad = (0, 1, 0, 1)
|
104 |
+
if isinstance(downsample_factor, tuple):
|
105 |
+
if downsample_factor[0] == 1:
|
106 |
+
self.pad = (0, 1, 1, 1) # The padding order is from back to front
|
107 |
+
elif downsample_factor[1] == 1:
|
108 |
+
self.pad = (1, 1, 0, 1)
|
109 |
+
|
110 |
+
def forward(self, input_tensor):
|
111 |
+
hidden_states = input_tensor
|
112 |
+
|
113 |
+
hidden_states = self.norm1(hidden_states)
|
114 |
+
hidden_states = self.act_fn(hidden_states)
|
115 |
+
|
116 |
+
hidden_states = self.conv1(hidden_states)
|
117 |
+
hidden_states = self.norm2(hidden_states)
|
118 |
+
hidden_states = self.act_fn(hidden_states)
|
119 |
+
|
120 |
+
hidden_states = self.dropout(hidden_states)
|
121 |
+
hidden_states = self.conv2(hidden_states)
|
122 |
+
|
123 |
+
if self.conv_shortcut is not None:
|
124 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
125 |
+
|
126 |
+
hidden_states += input_tensor
|
127 |
+
|
128 |
+
if self.downsample_conv is not None:
|
129 |
+
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
|
130 |
+
hidden_states = self.downsample_conv(hidden_states)
|
131 |
+
|
132 |
+
return hidden_states
|
133 |
+
|
134 |
+
|
135 |
+
class AttentionBlock2D(nn.Module):
|
136 |
+
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
|
137 |
+
super().__init__()
|
138 |
+
if not is_xformers_available():
|
139 |
+
raise ModuleNotFoundError(
|
140 |
+
"You have to install xformers to enable memory efficient attetion", name="xformers"
|
141 |
+
)
|
142 |
+
# inner_dim = dim_head * heads
|
143 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
|
144 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
145 |
+
self.norm3 = nn.LayerNorm(query_dim)
|
146 |
+
|
147 |
+
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
|
148 |
+
|
149 |
+
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
150 |
+
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
151 |
+
|
152 |
+
self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
|
153 |
+
self.attn._use_memory_efficient_attention_xformers = True
|
154 |
+
|
155 |
+
def forward(self, hidden_states):
|
156 |
+
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
|
157 |
+
|
158 |
+
batch, channel, height, width = hidden_states.shape
|
159 |
+
residual = hidden_states
|
160 |
+
|
161 |
+
hidden_states = self.norm1(hidden_states)
|
162 |
+
hidden_states = self.conv_in(hidden_states)
|
163 |
+
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
|
164 |
+
|
165 |
+
norm_hidden_states = self.norm2(hidden_states)
|
166 |
+
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
|
167 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
168 |
+
|
169 |
+
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
|
170 |
+
hidden_states = self.conv_out(hidden_states)
|
171 |
+
|
172 |
+
hidden_states = hidden_states + residual
|
173 |
+
return hidden_states
|
174 |
+
|
175 |
+
|
176 |
+
class DownEncoder2D(nn.Module):
|
177 |
+
def __init__(
|
178 |
+
self,
|
179 |
+
in_channels=4 * 16,
|
180 |
+
block_out_channels=[64, 128, 256, 256],
|
181 |
+
downsample_factors=[2, 2, 2, 2],
|
182 |
+
layers_per_block=2,
|
183 |
+
norm_num_groups=32,
|
184 |
+
attn_blocks=[1, 1, 1, 1],
|
185 |
+
dropout: float = 0.0,
|
186 |
+
act_fn="silu",
|
187 |
+
):
|
188 |
+
super().__init__()
|
189 |
+
self.layers_per_block = layers_per_block
|
190 |
+
|
191 |
+
# in
|
192 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
193 |
+
|
194 |
+
# down
|
195 |
+
self.down_blocks = nn.ModuleList([])
|
196 |
+
|
197 |
+
output_channels = block_out_channels[0]
|
198 |
+
for i, block_out_channel in enumerate(block_out_channels):
|
199 |
+
input_channels = output_channels
|
200 |
+
output_channels = block_out_channel
|
201 |
+
# is_final_block = i == len(block_out_channels) - 1
|
202 |
+
|
203 |
+
down_block = ResnetBlock2D(
|
204 |
+
in_channels=input_channels,
|
205 |
+
out_channels=output_channels,
|
206 |
+
downsample_factor=downsample_factors[i],
|
207 |
+
norm_num_groups=norm_num_groups,
|
208 |
+
dropout=dropout,
|
209 |
+
act_fn=act_fn,
|
210 |
+
)
|
211 |
+
|
212 |
+
self.down_blocks.append(down_block)
|
213 |
+
|
214 |
+
if attn_blocks[i] == 1:
|
215 |
+
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
|
216 |
+
self.down_blocks.append(attention_block)
|
217 |
+
|
218 |
+
# out
|
219 |
+
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
220 |
+
self.act_fn_out = nn.ReLU()
|
221 |
+
|
222 |
+
def forward(self, hidden_states):
|
223 |
+
hidden_states = self.conv_in(hidden_states)
|
224 |
+
|
225 |
+
# down
|
226 |
+
for down_block in self.down_blocks:
|
227 |
+
hidden_states = down_block(hidden_states)
|
228 |
+
|
229 |
+
# post-process
|
230 |
+
hidden_states = self.norm_out(hidden_states)
|
231 |
+
hidden_states = self.act_fn_out(hidden_states)
|
232 |
+
|
233 |
+
return hidden_states
|
latentsync/models/syncnet_wav2lip.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/primepake/wav2lip_288x288/blob/master/models/syncnetv2.py
|
2 |
+
# The code here is for ablation study.
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class SyncNetWav2Lip(nn.Module):
|
9 |
+
def __init__(self, act_fn="leaky"):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
# input image sequences: (15, 128, 256)
|
13 |
+
self.visual_encoder = nn.Sequential(
|
14 |
+
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3, act_fn=act_fn), # (128, 256)
|
15 |
+
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1, act_fn=act_fn), # (126, 127)
|
16 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
17 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
18 |
+
Conv2d(64, 128, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (63, 64)
|
19 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
20 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
21 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
22 |
+
Conv2d(128, 256, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (21, 22)
|
23 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
24 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
25 |
+
Conv2d(256, 512, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (11, 11)
|
26 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
27 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
28 |
+
Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (6, 6)
|
29 |
+
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
30 |
+
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
31 |
+
Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1, act_fn="relu"), # (3, 3)
|
32 |
+
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
|
33 |
+
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
|
34 |
+
)
|
35 |
+
|
36 |
+
# input audio sequences: (1, 80, 16)
|
37 |
+
self.audio_encoder = nn.Sequential(
|
38 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
|
39 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
40 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
41 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1, act_fn=act_fn), # (27, 16)
|
42 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
43 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
44 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (9, 6)
|
45 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
46 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
47 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1, act_fn=act_fn), # (3, 3)
|
48 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
49 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
50 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
|
51 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
52 |
+
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
53 |
+
Conv2d(512, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
|
54 |
+
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
|
55 |
+
)
|
56 |
+
|
57 |
+
def forward(self, image_sequences, audio_sequences):
|
58 |
+
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
59 |
+
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
60 |
+
|
61 |
+
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
62 |
+
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
63 |
+
|
64 |
+
# Make them unit vectors
|
65 |
+
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
66 |
+
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
67 |
+
|
68 |
+
return vision_embeds, audio_embeds
|
69 |
+
|
70 |
+
|
71 |
+
class Conv2d(nn.Module):
|
72 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, act_fn="relu", *args, **kwargs):
|
73 |
+
super().__init__(*args, **kwargs)
|
74 |
+
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
|
75 |
+
if act_fn == "relu":
|
76 |
+
self.act_fn = nn.ReLU()
|
77 |
+
elif act_fn == "tanh":
|
78 |
+
self.act_fn = nn.Tanh()
|
79 |
+
elif act_fn == "silu":
|
80 |
+
self.act_fn = nn.SiLU()
|
81 |
+
elif act_fn == "leaky":
|
82 |
+
self.act_fn = nn.LeakyReLU(0.2, inplace=True)
|
83 |
+
|
84 |
+
self.residual = residual
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
out = self.conv_block(x)
|
88 |
+
if self.residual:
|
89 |
+
out += x
|
90 |
+
return self.act_fn(out)
|
latentsync/models/unet.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
import copy
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.utils.checkpoint
|
10 |
+
|
11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
12 |
+
from diffusers.modeling_utils import ModelMixin
|
13 |
+
from diffusers import UNet2DConditionModel
|
14 |
+
from diffusers.utils import BaseOutput, logging
|
15 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
16 |
+
from .unet_blocks import (
|
17 |
+
CrossAttnDownBlock3D,
|
18 |
+
CrossAttnUpBlock3D,
|
19 |
+
DownBlock3D,
|
20 |
+
UNetMidBlock3DCrossAttn,
|
21 |
+
UpBlock3D,
|
22 |
+
get_down_block,
|
23 |
+
get_up_block,
|
24 |
+
)
|
25 |
+
from .resnet import InflatedConv3d, InflatedGroupNorm
|
26 |
+
|
27 |
+
from ..utils.util import zero_rank_log
|
28 |
+
from einops import rearrange
|
29 |
+
from .utils import zero_module
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class UNet3DConditionOutput(BaseOutput):
|
37 |
+
sample: torch.FloatTensor
|
38 |
+
|
39 |
+
|
40 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
41 |
+
_supports_gradient_checkpointing = True
|
42 |
+
|
43 |
+
@register_to_config
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
sample_size: Optional[int] = None,
|
47 |
+
in_channels: int = 4,
|
48 |
+
out_channels: int = 4,
|
49 |
+
center_input_sample: bool = False,
|
50 |
+
flip_sin_to_cos: bool = True,
|
51 |
+
freq_shift: int = 0,
|
52 |
+
down_block_types: Tuple[str] = (
|
53 |
+
"CrossAttnDownBlock3D",
|
54 |
+
"CrossAttnDownBlock3D",
|
55 |
+
"CrossAttnDownBlock3D",
|
56 |
+
"DownBlock3D",
|
57 |
+
),
|
58 |
+
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
59 |
+
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
60 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
61 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
62 |
+
layers_per_block: int = 2,
|
63 |
+
downsample_padding: int = 1,
|
64 |
+
mid_block_scale_factor: float = 1,
|
65 |
+
act_fn: str = "silu",
|
66 |
+
norm_num_groups: int = 32,
|
67 |
+
norm_eps: float = 1e-5,
|
68 |
+
cross_attention_dim: int = 1280,
|
69 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
70 |
+
dual_cross_attention: bool = False,
|
71 |
+
use_linear_projection: bool = False,
|
72 |
+
class_embed_type: Optional[str] = None,
|
73 |
+
num_class_embeds: Optional[int] = None,
|
74 |
+
upcast_attention: bool = False,
|
75 |
+
resnet_time_scale_shift: str = "default",
|
76 |
+
use_inflated_groupnorm=False,
|
77 |
+
# Additional
|
78 |
+
use_motion_module=False,
|
79 |
+
motion_module_resolutions=(1, 2, 4, 8),
|
80 |
+
motion_module_mid_block=False,
|
81 |
+
motion_module_decoder_only=False,
|
82 |
+
motion_module_type=None,
|
83 |
+
motion_module_kwargs={},
|
84 |
+
unet_use_cross_frame_attention=False,
|
85 |
+
unet_use_temporal_attention=False,
|
86 |
+
add_audio_layer=False,
|
87 |
+
audio_condition_method: str = "cross_attn",
|
88 |
+
custom_audio_layer=False,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
self.sample_size = sample_size
|
93 |
+
time_embed_dim = block_out_channels[0] * 4
|
94 |
+
self.use_motion_module = use_motion_module
|
95 |
+
self.add_audio_layer = add_audio_layer
|
96 |
+
|
97 |
+
self.conv_in = zero_module(InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)))
|
98 |
+
|
99 |
+
# time
|
100 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
101 |
+
timestep_input_dim = block_out_channels[0]
|
102 |
+
|
103 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
104 |
+
|
105 |
+
# class embedding
|
106 |
+
if class_embed_type is None and num_class_embeds is not None:
|
107 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
108 |
+
elif class_embed_type == "timestep":
|
109 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
110 |
+
elif class_embed_type == "identity":
|
111 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
112 |
+
else:
|
113 |
+
self.class_embedding = None
|
114 |
+
|
115 |
+
self.down_blocks = nn.ModuleList([])
|
116 |
+
self.mid_block = None
|
117 |
+
self.up_blocks = nn.ModuleList([])
|
118 |
+
|
119 |
+
if isinstance(only_cross_attention, bool):
|
120 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
121 |
+
|
122 |
+
if isinstance(attention_head_dim, int):
|
123 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
124 |
+
|
125 |
+
# down
|
126 |
+
output_channel = block_out_channels[0]
|
127 |
+
for i, down_block_type in enumerate(down_block_types):
|
128 |
+
res = 2**i
|
129 |
+
input_channel = output_channel
|
130 |
+
output_channel = block_out_channels[i]
|
131 |
+
is_final_block = i == len(block_out_channels) - 1
|
132 |
+
|
133 |
+
down_block = get_down_block(
|
134 |
+
down_block_type,
|
135 |
+
num_layers=layers_per_block,
|
136 |
+
in_channels=input_channel,
|
137 |
+
out_channels=output_channel,
|
138 |
+
temb_channels=time_embed_dim,
|
139 |
+
add_downsample=not is_final_block,
|
140 |
+
resnet_eps=norm_eps,
|
141 |
+
resnet_act_fn=act_fn,
|
142 |
+
resnet_groups=norm_num_groups,
|
143 |
+
cross_attention_dim=cross_attention_dim,
|
144 |
+
attn_num_head_channels=attention_head_dim[i],
|
145 |
+
downsample_padding=downsample_padding,
|
146 |
+
dual_cross_attention=dual_cross_attention,
|
147 |
+
use_linear_projection=use_linear_projection,
|
148 |
+
only_cross_attention=only_cross_attention[i],
|
149 |
+
upcast_attention=upcast_attention,
|
150 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
151 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
152 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
153 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
154 |
+
use_motion_module=use_motion_module
|
155 |
+
and (res in motion_module_resolutions)
|
156 |
+
and (not motion_module_decoder_only),
|
157 |
+
motion_module_type=motion_module_type,
|
158 |
+
motion_module_kwargs=motion_module_kwargs,
|
159 |
+
add_audio_layer=add_audio_layer,
|
160 |
+
audio_condition_method=audio_condition_method,
|
161 |
+
custom_audio_layer=custom_audio_layer,
|
162 |
+
)
|
163 |
+
self.down_blocks.append(down_block)
|
164 |
+
|
165 |
+
# mid
|
166 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
167 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
168 |
+
in_channels=block_out_channels[-1],
|
169 |
+
temb_channels=time_embed_dim,
|
170 |
+
resnet_eps=norm_eps,
|
171 |
+
resnet_act_fn=act_fn,
|
172 |
+
output_scale_factor=mid_block_scale_factor,
|
173 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
174 |
+
cross_attention_dim=cross_attention_dim,
|
175 |
+
attn_num_head_channels=attention_head_dim[-1],
|
176 |
+
resnet_groups=norm_num_groups,
|
177 |
+
dual_cross_attention=dual_cross_attention,
|
178 |
+
use_linear_projection=use_linear_projection,
|
179 |
+
upcast_attention=upcast_attention,
|
180 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
181 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
182 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
183 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
184 |
+
motion_module_type=motion_module_type,
|
185 |
+
motion_module_kwargs=motion_module_kwargs,
|
186 |
+
add_audio_layer=add_audio_layer,
|
187 |
+
audio_condition_method=audio_condition_method,
|
188 |
+
custom_audio_layer=custom_audio_layer,
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
192 |
+
|
193 |
+
# count how many layers upsample the videos
|
194 |
+
self.num_upsamplers = 0
|
195 |
+
|
196 |
+
# up
|
197 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
198 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
199 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
200 |
+
output_channel = reversed_block_out_channels[0]
|
201 |
+
for i, up_block_type in enumerate(up_block_types):
|
202 |
+
res = 2 ** (3 - i)
|
203 |
+
is_final_block = i == len(block_out_channels) - 1
|
204 |
+
|
205 |
+
prev_output_channel = output_channel
|
206 |
+
output_channel = reversed_block_out_channels[i]
|
207 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
208 |
+
|
209 |
+
# add upsample block for all BUT final layer
|
210 |
+
if not is_final_block:
|
211 |
+
add_upsample = True
|
212 |
+
self.num_upsamplers += 1
|
213 |
+
else:
|
214 |
+
add_upsample = False
|
215 |
+
|
216 |
+
up_block = get_up_block(
|
217 |
+
up_block_type,
|
218 |
+
num_layers=layers_per_block + 1,
|
219 |
+
in_channels=input_channel,
|
220 |
+
out_channels=output_channel,
|
221 |
+
prev_output_channel=prev_output_channel,
|
222 |
+
temb_channels=time_embed_dim,
|
223 |
+
add_upsample=add_upsample,
|
224 |
+
resnet_eps=norm_eps,
|
225 |
+
resnet_act_fn=act_fn,
|
226 |
+
resnet_groups=norm_num_groups,
|
227 |
+
cross_attention_dim=cross_attention_dim,
|
228 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
229 |
+
dual_cross_attention=dual_cross_attention,
|
230 |
+
use_linear_projection=use_linear_projection,
|
231 |
+
only_cross_attention=only_cross_attention[i],
|
232 |
+
upcast_attention=upcast_attention,
|
233 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
234 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
235 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
236 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
237 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
238 |
+
motion_module_type=motion_module_type,
|
239 |
+
motion_module_kwargs=motion_module_kwargs,
|
240 |
+
add_audio_layer=add_audio_layer,
|
241 |
+
audio_condition_method=audio_condition_method,
|
242 |
+
custom_audio_layer=custom_audio_layer,
|
243 |
+
)
|
244 |
+
self.up_blocks.append(up_block)
|
245 |
+
prev_output_channel = output_channel
|
246 |
+
|
247 |
+
# out
|
248 |
+
if use_inflated_groupnorm:
|
249 |
+
self.conv_norm_out = InflatedGroupNorm(
|
250 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
251 |
+
)
|
252 |
+
else:
|
253 |
+
self.conv_norm_out = nn.GroupNorm(
|
254 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
255 |
+
)
|
256 |
+
self.conv_act = nn.SiLU()
|
257 |
+
|
258 |
+
self.conv_out = zero_module(InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1))
|
259 |
+
|
260 |
+
def set_attention_slice(self, slice_size):
|
261 |
+
r"""
|
262 |
+
Enable sliced attention computation.
|
263 |
+
|
264 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
265 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
269 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
270 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
271 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
272 |
+
must be a multiple of `slice_size`.
|
273 |
+
"""
|
274 |
+
sliceable_head_dims = []
|
275 |
+
|
276 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
277 |
+
if hasattr(module, "set_attention_slice"):
|
278 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
279 |
+
|
280 |
+
for child in module.children():
|
281 |
+
fn_recursive_retrieve_slicable_dims(child)
|
282 |
+
|
283 |
+
# retrieve number of attention layers
|
284 |
+
for module in self.children():
|
285 |
+
fn_recursive_retrieve_slicable_dims(module)
|
286 |
+
|
287 |
+
num_slicable_layers = len(sliceable_head_dims)
|
288 |
+
|
289 |
+
if slice_size == "auto":
|
290 |
+
# half the attention head size is usually a good trade-off between
|
291 |
+
# speed and memory
|
292 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
293 |
+
elif slice_size == "max":
|
294 |
+
# make smallest slice possible
|
295 |
+
slice_size = num_slicable_layers * [1]
|
296 |
+
|
297 |
+
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
298 |
+
|
299 |
+
if len(slice_size) != len(sliceable_head_dims):
|
300 |
+
raise ValueError(
|
301 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
302 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
303 |
+
)
|
304 |
+
|
305 |
+
for i in range(len(slice_size)):
|
306 |
+
size = slice_size[i]
|
307 |
+
dim = sliceable_head_dims[i]
|
308 |
+
if size is not None and size > dim:
|
309 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
310 |
+
|
311 |
+
# Recursively walk through all the children.
|
312 |
+
# Any children which exposes the set_attention_slice method
|
313 |
+
# gets the message
|
314 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
315 |
+
if hasattr(module, "set_attention_slice"):
|
316 |
+
module.set_attention_slice(slice_size.pop())
|
317 |
+
|
318 |
+
for child in module.children():
|
319 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
320 |
+
|
321 |
+
reversed_slice_size = list(reversed(slice_size))
|
322 |
+
for module in self.children():
|
323 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
324 |
+
|
325 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
326 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
327 |
+
module.gradient_checkpointing = value
|
328 |
+
|
329 |
+
def forward(
|
330 |
+
self,
|
331 |
+
sample: torch.FloatTensor,
|
332 |
+
timestep: Union[torch.Tensor, float, int],
|
333 |
+
encoder_hidden_states: torch.Tensor,
|
334 |
+
class_labels: Optional[torch.Tensor] = None,
|
335 |
+
attention_mask: Optional[torch.Tensor] = None,
|
336 |
+
# support controlnet
|
337 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
338 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
339 |
+
return_dict: bool = True,
|
340 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
341 |
+
r"""
|
342 |
+
Args:
|
343 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
344 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
345 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
346 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
347 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
351 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
352 |
+
returning a tuple, the first element is the sample tensor.
|
353 |
+
"""
|
354 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
355 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
356 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
357 |
+
# on the fly if necessary.
|
358 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
359 |
+
|
360 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
361 |
+
forward_upsample_size = False
|
362 |
+
upsample_size = None
|
363 |
+
|
364 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
365 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
366 |
+
forward_upsample_size = True
|
367 |
+
|
368 |
+
# prepare attention_mask
|
369 |
+
if attention_mask is not None:
|
370 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
371 |
+
attention_mask = attention_mask.unsqueeze(1)
|
372 |
+
|
373 |
+
# center input if necessary
|
374 |
+
if self.config.center_input_sample:
|
375 |
+
sample = 2 * sample - 1.0
|
376 |
+
|
377 |
+
# time
|
378 |
+
timesteps = timestep
|
379 |
+
if not torch.is_tensor(timesteps):
|
380 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
381 |
+
is_mps = sample.device.type == "mps"
|
382 |
+
if isinstance(timestep, float):
|
383 |
+
dtype = torch.float32 if is_mps else torch.float64
|
384 |
+
else:
|
385 |
+
dtype = torch.int32 if is_mps else torch.int64
|
386 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
387 |
+
elif len(timesteps.shape) == 0:
|
388 |
+
timesteps = timesteps[None].to(sample.device)
|
389 |
+
|
390 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
391 |
+
timesteps = timesteps.expand(sample.shape[0])
|
392 |
+
|
393 |
+
t_emb = self.time_proj(timesteps)
|
394 |
+
|
395 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
396 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
397 |
+
# there might be better ways to encapsulate this.
|
398 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
399 |
+
emb = self.time_embedding(t_emb)
|
400 |
+
|
401 |
+
if self.class_embedding is not None:
|
402 |
+
if class_labels is None:
|
403 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
404 |
+
|
405 |
+
if self.config.class_embed_type == "timestep":
|
406 |
+
class_labels = self.time_proj(class_labels)
|
407 |
+
|
408 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
409 |
+
emb = emb + class_emb
|
410 |
+
|
411 |
+
# pre-process
|
412 |
+
sample = self.conv_in(sample)
|
413 |
+
|
414 |
+
# down
|
415 |
+
down_block_res_samples = (sample,)
|
416 |
+
for downsample_block in self.down_blocks:
|
417 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
418 |
+
sample, res_samples = downsample_block(
|
419 |
+
hidden_states=sample,
|
420 |
+
temb=emb,
|
421 |
+
encoder_hidden_states=encoder_hidden_states,
|
422 |
+
attention_mask=attention_mask,
|
423 |
+
)
|
424 |
+
else:
|
425 |
+
sample, res_samples = downsample_block(
|
426 |
+
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
427 |
+
)
|
428 |
+
|
429 |
+
down_block_res_samples += res_samples
|
430 |
+
|
431 |
+
# support controlnet
|
432 |
+
down_block_res_samples = list(down_block_res_samples)
|
433 |
+
if down_block_additional_residuals is not None:
|
434 |
+
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
|
435 |
+
if down_block_additional_residual.dim() == 4: # boardcast
|
436 |
+
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
|
437 |
+
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
|
438 |
+
|
439 |
+
# mid
|
440 |
+
sample = self.mid_block(
|
441 |
+
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
442 |
+
)
|
443 |
+
|
444 |
+
# support controlnet
|
445 |
+
if mid_block_additional_residual is not None:
|
446 |
+
if mid_block_additional_residual.dim() == 4: # boardcast
|
447 |
+
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
|
448 |
+
sample = sample + mid_block_additional_residual
|
449 |
+
|
450 |
+
# up
|
451 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
452 |
+
is_final_block = i == len(self.up_blocks) - 1
|
453 |
+
|
454 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
455 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
456 |
+
|
457 |
+
# if we have not reached the final block and need to forward the
|
458 |
+
# upsample size, we do it here
|
459 |
+
if not is_final_block and forward_upsample_size:
|
460 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
461 |
+
|
462 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
463 |
+
sample = upsample_block(
|
464 |
+
hidden_states=sample,
|
465 |
+
temb=emb,
|
466 |
+
res_hidden_states_tuple=res_samples,
|
467 |
+
encoder_hidden_states=encoder_hidden_states,
|
468 |
+
upsample_size=upsample_size,
|
469 |
+
attention_mask=attention_mask,
|
470 |
+
)
|
471 |
+
else:
|
472 |
+
sample = upsample_block(
|
473 |
+
hidden_states=sample,
|
474 |
+
temb=emb,
|
475 |
+
res_hidden_states_tuple=res_samples,
|
476 |
+
upsample_size=upsample_size,
|
477 |
+
encoder_hidden_states=encoder_hidden_states,
|
478 |
+
)
|
479 |
+
|
480 |
+
# post-process
|
481 |
+
sample = self.conv_norm_out(sample)
|
482 |
+
sample = self.conv_act(sample)
|
483 |
+
sample = self.conv_out(sample)
|
484 |
+
|
485 |
+
if not return_dict:
|
486 |
+
return (sample,)
|
487 |
+
|
488 |
+
return UNet3DConditionOutput(sample=sample)
|
489 |
+
|
490 |
+
def load_state_dict(self, state_dict, strict=True):
|
491 |
+
# If the loaded checkpoint's in_channels or out_channels are different from config
|
492 |
+
temp_state_dict = copy.deepcopy(state_dict)
|
493 |
+
if temp_state_dict["conv_in.weight"].shape[1] != self.config.in_channels:
|
494 |
+
del temp_state_dict["conv_in.weight"]
|
495 |
+
del temp_state_dict["conv_in.bias"]
|
496 |
+
if temp_state_dict["conv_out.weight"].shape[0] != self.config.out_channels:
|
497 |
+
del temp_state_dict["conv_out.weight"]
|
498 |
+
del temp_state_dict["conv_out.bias"]
|
499 |
+
|
500 |
+
# If the loaded checkpoint's cross_attention_dim is different from config
|
501 |
+
keys_to_remove = []
|
502 |
+
for key in temp_state_dict:
|
503 |
+
if "audio_cross_attn.attn.to_k." in key or "audio_cross_attn.attn.to_v." in key:
|
504 |
+
if temp_state_dict[key].shape[1] != self.config.cross_attention_dim:
|
505 |
+
keys_to_remove.append(key)
|
506 |
+
|
507 |
+
for key in keys_to_remove:
|
508 |
+
del temp_state_dict[key]
|
509 |
+
|
510 |
+
return super().load_state_dict(state_dict=temp_state_dict, strict=strict)
|
511 |
+
|
512 |
+
@classmethod
|
513 |
+
def from_pretrained(cls, model_config: dict, ckpt_path: str, device="cpu"):
|
514 |
+
unet = cls.from_config(model_config).to(device)
|
515 |
+
if ckpt_path != "":
|
516 |
+
zero_rank_log(logger, f"Load from checkpoint: {ckpt_path}")
|
517 |
+
ckpt = torch.load(ckpt_path, map_location=device)
|
518 |
+
if "global_step" in ckpt:
|
519 |
+
zero_rank_log(logger, f"resume from global_step: {ckpt['global_step']}")
|
520 |
+
resume_global_step = ckpt["global_step"]
|
521 |
+
else:
|
522 |
+
resume_global_step = 0
|
523 |
+
state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
|
524 |
+
unet.load_state_dict(state_dict, strict=False)
|
525 |
+
else:
|
526 |
+
resume_global_step = 0
|
527 |
+
|
528 |
+
return unet, resume_global_step
|
latentsync/models/unet_blocks.py
ADDED
@@ -0,0 +1,903 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from .attention import Transformer3DModel
|
7 |
+
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
8 |
+
from .motion_module import get_motion_module
|
9 |
+
|
10 |
+
|
11 |
+
def get_down_block(
|
12 |
+
down_block_type,
|
13 |
+
num_layers,
|
14 |
+
in_channels,
|
15 |
+
out_channels,
|
16 |
+
temb_channels,
|
17 |
+
add_downsample,
|
18 |
+
resnet_eps,
|
19 |
+
resnet_act_fn,
|
20 |
+
attn_num_head_channels,
|
21 |
+
resnet_groups=None,
|
22 |
+
cross_attention_dim=None,
|
23 |
+
downsample_padding=None,
|
24 |
+
dual_cross_attention=False,
|
25 |
+
use_linear_projection=False,
|
26 |
+
only_cross_attention=False,
|
27 |
+
upcast_attention=False,
|
28 |
+
resnet_time_scale_shift="default",
|
29 |
+
unet_use_cross_frame_attention=False,
|
30 |
+
unet_use_temporal_attention=False,
|
31 |
+
use_inflated_groupnorm=False,
|
32 |
+
use_motion_module=None,
|
33 |
+
motion_module_type=None,
|
34 |
+
motion_module_kwargs=None,
|
35 |
+
add_audio_layer=False,
|
36 |
+
audio_condition_method="cross_attn",
|
37 |
+
custom_audio_layer=False,
|
38 |
+
):
|
39 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
40 |
+
if down_block_type == "DownBlock3D":
|
41 |
+
return DownBlock3D(
|
42 |
+
num_layers=num_layers,
|
43 |
+
in_channels=in_channels,
|
44 |
+
out_channels=out_channels,
|
45 |
+
temb_channels=temb_channels,
|
46 |
+
add_downsample=add_downsample,
|
47 |
+
resnet_eps=resnet_eps,
|
48 |
+
resnet_act_fn=resnet_act_fn,
|
49 |
+
resnet_groups=resnet_groups,
|
50 |
+
downsample_padding=downsample_padding,
|
51 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
52 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
53 |
+
use_motion_module=use_motion_module,
|
54 |
+
motion_module_type=motion_module_type,
|
55 |
+
motion_module_kwargs=motion_module_kwargs,
|
56 |
+
)
|
57 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
58 |
+
if cross_attention_dim is None:
|
59 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
60 |
+
return CrossAttnDownBlock3D(
|
61 |
+
num_layers=num_layers,
|
62 |
+
in_channels=in_channels,
|
63 |
+
out_channels=out_channels,
|
64 |
+
temb_channels=temb_channels,
|
65 |
+
add_downsample=add_downsample,
|
66 |
+
resnet_eps=resnet_eps,
|
67 |
+
resnet_act_fn=resnet_act_fn,
|
68 |
+
resnet_groups=resnet_groups,
|
69 |
+
downsample_padding=downsample_padding,
|
70 |
+
cross_attention_dim=cross_attention_dim,
|
71 |
+
attn_num_head_channels=attn_num_head_channels,
|
72 |
+
dual_cross_attention=dual_cross_attention,
|
73 |
+
use_linear_projection=use_linear_projection,
|
74 |
+
only_cross_attention=only_cross_attention,
|
75 |
+
upcast_attention=upcast_attention,
|
76 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
77 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
78 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
79 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
80 |
+
use_motion_module=use_motion_module,
|
81 |
+
motion_module_type=motion_module_type,
|
82 |
+
motion_module_kwargs=motion_module_kwargs,
|
83 |
+
add_audio_layer=add_audio_layer,
|
84 |
+
audio_condition_method=audio_condition_method,
|
85 |
+
custom_audio_layer=custom_audio_layer,
|
86 |
+
)
|
87 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
88 |
+
|
89 |
+
|
90 |
+
def get_up_block(
|
91 |
+
up_block_type,
|
92 |
+
num_layers,
|
93 |
+
in_channels,
|
94 |
+
out_channels,
|
95 |
+
prev_output_channel,
|
96 |
+
temb_channels,
|
97 |
+
add_upsample,
|
98 |
+
resnet_eps,
|
99 |
+
resnet_act_fn,
|
100 |
+
attn_num_head_channels,
|
101 |
+
resnet_groups=None,
|
102 |
+
cross_attention_dim=None,
|
103 |
+
dual_cross_attention=False,
|
104 |
+
use_linear_projection=False,
|
105 |
+
only_cross_attention=False,
|
106 |
+
upcast_attention=False,
|
107 |
+
resnet_time_scale_shift="default",
|
108 |
+
unet_use_cross_frame_attention=False,
|
109 |
+
unet_use_temporal_attention=False,
|
110 |
+
use_inflated_groupnorm=False,
|
111 |
+
use_motion_module=None,
|
112 |
+
motion_module_type=None,
|
113 |
+
motion_module_kwargs=None,
|
114 |
+
add_audio_layer=False,
|
115 |
+
audio_condition_method="cross_attn",
|
116 |
+
custom_audio_layer=False,
|
117 |
+
):
|
118 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
119 |
+
if up_block_type == "UpBlock3D":
|
120 |
+
return UpBlock3D(
|
121 |
+
num_layers=num_layers,
|
122 |
+
in_channels=in_channels,
|
123 |
+
out_channels=out_channels,
|
124 |
+
prev_output_channel=prev_output_channel,
|
125 |
+
temb_channels=temb_channels,
|
126 |
+
add_upsample=add_upsample,
|
127 |
+
resnet_eps=resnet_eps,
|
128 |
+
resnet_act_fn=resnet_act_fn,
|
129 |
+
resnet_groups=resnet_groups,
|
130 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
131 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
132 |
+
use_motion_module=use_motion_module,
|
133 |
+
motion_module_type=motion_module_type,
|
134 |
+
motion_module_kwargs=motion_module_kwargs,
|
135 |
+
)
|
136 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
137 |
+
if cross_attention_dim is None:
|
138 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
139 |
+
return CrossAttnUpBlock3D(
|
140 |
+
num_layers=num_layers,
|
141 |
+
in_channels=in_channels,
|
142 |
+
out_channels=out_channels,
|
143 |
+
prev_output_channel=prev_output_channel,
|
144 |
+
temb_channels=temb_channels,
|
145 |
+
add_upsample=add_upsample,
|
146 |
+
resnet_eps=resnet_eps,
|
147 |
+
resnet_act_fn=resnet_act_fn,
|
148 |
+
resnet_groups=resnet_groups,
|
149 |
+
cross_attention_dim=cross_attention_dim,
|
150 |
+
attn_num_head_channels=attn_num_head_channels,
|
151 |
+
dual_cross_attention=dual_cross_attention,
|
152 |
+
use_linear_projection=use_linear_projection,
|
153 |
+
only_cross_attention=only_cross_attention,
|
154 |
+
upcast_attention=upcast_attention,
|
155 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
156 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
157 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
158 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
159 |
+
use_motion_module=use_motion_module,
|
160 |
+
motion_module_type=motion_module_type,
|
161 |
+
motion_module_kwargs=motion_module_kwargs,
|
162 |
+
add_audio_layer=add_audio_layer,
|
163 |
+
audio_condition_method=audio_condition_method,
|
164 |
+
custom_audio_layer=custom_audio_layer,
|
165 |
+
)
|
166 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
167 |
+
|
168 |
+
|
169 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
170 |
+
def __init__(
|
171 |
+
self,
|
172 |
+
in_channels: int,
|
173 |
+
temb_channels: int,
|
174 |
+
dropout: float = 0.0,
|
175 |
+
num_layers: int = 1,
|
176 |
+
resnet_eps: float = 1e-6,
|
177 |
+
resnet_time_scale_shift: str = "default",
|
178 |
+
resnet_act_fn: str = "swish",
|
179 |
+
resnet_groups: int = 32,
|
180 |
+
resnet_pre_norm: bool = True,
|
181 |
+
attn_num_head_channels=1,
|
182 |
+
output_scale_factor=1.0,
|
183 |
+
cross_attention_dim=1280,
|
184 |
+
dual_cross_attention=False,
|
185 |
+
use_linear_projection=False,
|
186 |
+
upcast_attention=False,
|
187 |
+
unet_use_cross_frame_attention=False,
|
188 |
+
unet_use_temporal_attention=False,
|
189 |
+
use_inflated_groupnorm=False,
|
190 |
+
use_motion_module=None,
|
191 |
+
motion_module_type=None,
|
192 |
+
motion_module_kwargs=None,
|
193 |
+
add_audio_layer=False,
|
194 |
+
audio_condition_method="cross_attn",
|
195 |
+
custom_audio_layer: bool = False,
|
196 |
+
):
|
197 |
+
super().__init__()
|
198 |
+
|
199 |
+
self.has_cross_attention = True
|
200 |
+
self.attn_num_head_channels = attn_num_head_channels
|
201 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
202 |
+
|
203 |
+
# there is always at least one resnet
|
204 |
+
resnets = [
|
205 |
+
ResnetBlock3D(
|
206 |
+
in_channels=in_channels,
|
207 |
+
out_channels=in_channels,
|
208 |
+
temb_channels=temb_channels,
|
209 |
+
eps=resnet_eps,
|
210 |
+
groups=resnet_groups,
|
211 |
+
dropout=dropout,
|
212 |
+
time_embedding_norm=resnet_time_scale_shift,
|
213 |
+
non_linearity=resnet_act_fn,
|
214 |
+
output_scale_factor=output_scale_factor,
|
215 |
+
pre_norm=resnet_pre_norm,
|
216 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
217 |
+
)
|
218 |
+
]
|
219 |
+
attentions = []
|
220 |
+
audio_attentions = []
|
221 |
+
motion_modules = []
|
222 |
+
|
223 |
+
for _ in range(num_layers):
|
224 |
+
if dual_cross_attention:
|
225 |
+
raise NotImplementedError
|
226 |
+
attentions.append(
|
227 |
+
Transformer3DModel(
|
228 |
+
attn_num_head_channels,
|
229 |
+
in_channels // attn_num_head_channels,
|
230 |
+
in_channels=in_channels,
|
231 |
+
num_layers=1,
|
232 |
+
cross_attention_dim=cross_attention_dim,
|
233 |
+
norm_num_groups=resnet_groups,
|
234 |
+
use_linear_projection=use_linear_projection,
|
235 |
+
upcast_attention=upcast_attention,
|
236 |
+
use_motion_module=use_motion_module,
|
237 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
238 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
239 |
+
add_audio_layer=add_audio_layer,
|
240 |
+
audio_condition_method=audio_condition_method,
|
241 |
+
)
|
242 |
+
)
|
243 |
+
audio_attentions.append(
|
244 |
+
Transformer3DModel(
|
245 |
+
attn_num_head_channels,
|
246 |
+
in_channels // attn_num_head_channels,
|
247 |
+
in_channels=in_channels,
|
248 |
+
num_layers=1,
|
249 |
+
cross_attention_dim=cross_attention_dim,
|
250 |
+
norm_num_groups=resnet_groups,
|
251 |
+
use_linear_projection=use_linear_projection,
|
252 |
+
upcast_attention=upcast_attention,
|
253 |
+
use_motion_module=use_motion_module,
|
254 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
255 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
256 |
+
add_audio_layer=add_audio_layer,
|
257 |
+
audio_condition_method=audio_condition_method,
|
258 |
+
custom_audio_layer=True,
|
259 |
+
)
|
260 |
+
if custom_audio_layer
|
261 |
+
else None
|
262 |
+
)
|
263 |
+
motion_modules.append(
|
264 |
+
get_motion_module(
|
265 |
+
in_channels=in_channels,
|
266 |
+
motion_module_type=motion_module_type,
|
267 |
+
motion_module_kwargs=motion_module_kwargs,
|
268 |
+
)
|
269 |
+
if use_motion_module
|
270 |
+
else None
|
271 |
+
)
|
272 |
+
resnets.append(
|
273 |
+
ResnetBlock3D(
|
274 |
+
in_channels=in_channels,
|
275 |
+
out_channels=in_channels,
|
276 |
+
temb_channels=temb_channels,
|
277 |
+
eps=resnet_eps,
|
278 |
+
groups=resnet_groups,
|
279 |
+
dropout=dropout,
|
280 |
+
time_embedding_norm=resnet_time_scale_shift,
|
281 |
+
non_linearity=resnet_act_fn,
|
282 |
+
output_scale_factor=output_scale_factor,
|
283 |
+
pre_norm=resnet_pre_norm,
|
284 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
285 |
+
)
|
286 |
+
)
|
287 |
+
|
288 |
+
self.attentions = nn.ModuleList(attentions)
|
289 |
+
self.audio_attentions = nn.ModuleList(audio_attentions)
|
290 |
+
self.resnets = nn.ModuleList(resnets)
|
291 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
292 |
+
|
293 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
294 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
295 |
+
for attn, audio_attn, resnet, motion_module in zip(
|
296 |
+
self.attentions, self.audio_attentions, self.resnets[1:], self.motion_modules
|
297 |
+
):
|
298 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
299 |
+
hidden_states = (
|
300 |
+
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
301 |
+
if audio_attn is not None
|
302 |
+
else hidden_states
|
303 |
+
)
|
304 |
+
hidden_states = (
|
305 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
306 |
+
if motion_module is not None
|
307 |
+
else hidden_states
|
308 |
+
)
|
309 |
+
hidden_states = resnet(hidden_states, temb)
|
310 |
+
|
311 |
+
return hidden_states
|
312 |
+
|
313 |
+
|
314 |
+
class CrossAttnDownBlock3D(nn.Module):
|
315 |
+
def __init__(
|
316 |
+
self,
|
317 |
+
in_channels: int,
|
318 |
+
out_channels: int,
|
319 |
+
temb_channels: int,
|
320 |
+
dropout: float = 0.0,
|
321 |
+
num_layers: int = 1,
|
322 |
+
resnet_eps: float = 1e-6,
|
323 |
+
resnet_time_scale_shift: str = "default",
|
324 |
+
resnet_act_fn: str = "swish",
|
325 |
+
resnet_groups: int = 32,
|
326 |
+
resnet_pre_norm: bool = True,
|
327 |
+
attn_num_head_channels=1,
|
328 |
+
cross_attention_dim=1280,
|
329 |
+
output_scale_factor=1.0,
|
330 |
+
downsample_padding=1,
|
331 |
+
add_downsample=True,
|
332 |
+
dual_cross_attention=False,
|
333 |
+
use_linear_projection=False,
|
334 |
+
only_cross_attention=False,
|
335 |
+
upcast_attention=False,
|
336 |
+
unet_use_cross_frame_attention=False,
|
337 |
+
unet_use_temporal_attention=False,
|
338 |
+
use_inflated_groupnorm=False,
|
339 |
+
use_motion_module=None,
|
340 |
+
motion_module_type=None,
|
341 |
+
motion_module_kwargs=None,
|
342 |
+
add_audio_layer=False,
|
343 |
+
audio_condition_method="cross_attn",
|
344 |
+
custom_audio_layer: bool = False,
|
345 |
+
):
|
346 |
+
super().__init__()
|
347 |
+
resnets = []
|
348 |
+
attentions = []
|
349 |
+
audio_attentions = []
|
350 |
+
motion_modules = []
|
351 |
+
|
352 |
+
self.has_cross_attention = True
|
353 |
+
self.attn_num_head_channels = attn_num_head_channels
|
354 |
+
|
355 |
+
for i in range(num_layers):
|
356 |
+
in_channels = in_channels if i == 0 else out_channels
|
357 |
+
resnets.append(
|
358 |
+
ResnetBlock3D(
|
359 |
+
in_channels=in_channels,
|
360 |
+
out_channels=out_channels,
|
361 |
+
temb_channels=temb_channels,
|
362 |
+
eps=resnet_eps,
|
363 |
+
groups=resnet_groups,
|
364 |
+
dropout=dropout,
|
365 |
+
time_embedding_norm=resnet_time_scale_shift,
|
366 |
+
non_linearity=resnet_act_fn,
|
367 |
+
output_scale_factor=output_scale_factor,
|
368 |
+
pre_norm=resnet_pre_norm,
|
369 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
370 |
+
)
|
371 |
+
)
|
372 |
+
if dual_cross_attention:
|
373 |
+
raise NotImplementedError
|
374 |
+
attentions.append(
|
375 |
+
Transformer3DModel(
|
376 |
+
attn_num_head_channels,
|
377 |
+
out_channels // attn_num_head_channels,
|
378 |
+
in_channels=out_channels,
|
379 |
+
num_layers=1,
|
380 |
+
cross_attention_dim=cross_attention_dim,
|
381 |
+
norm_num_groups=resnet_groups,
|
382 |
+
use_linear_projection=use_linear_projection,
|
383 |
+
only_cross_attention=only_cross_attention,
|
384 |
+
upcast_attention=upcast_attention,
|
385 |
+
use_motion_module=use_motion_module,
|
386 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
387 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
388 |
+
add_audio_layer=add_audio_layer,
|
389 |
+
audio_condition_method=audio_condition_method,
|
390 |
+
)
|
391 |
+
)
|
392 |
+
audio_attentions.append(
|
393 |
+
Transformer3DModel(
|
394 |
+
attn_num_head_channels,
|
395 |
+
out_channels // attn_num_head_channels,
|
396 |
+
in_channels=out_channels,
|
397 |
+
num_layers=1,
|
398 |
+
cross_attention_dim=cross_attention_dim,
|
399 |
+
norm_num_groups=resnet_groups,
|
400 |
+
use_linear_projection=use_linear_projection,
|
401 |
+
only_cross_attention=only_cross_attention,
|
402 |
+
upcast_attention=upcast_attention,
|
403 |
+
use_motion_module=use_motion_module,
|
404 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
405 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
406 |
+
add_audio_layer=add_audio_layer,
|
407 |
+
audio_condition_method=audio_condition_method,
|
408 |
+
custom_audio_layer=True,
|
409 |
+
)
|
410 |
+
if custom_audio_layer
|
411 |
+
else None
|
412 |
+
)
|
413 |
+
motion_modules.append(
|
414 |
+
get_motion_module(
|
415 |
+
in_channels=out_channels,
|
416 |
+
motion_module_type=motion_module_type,
|
417 |
+
motion_module_kwargs=motion_module_kwargs,
|
418 |
+
)
|
419 |
+
if use_motion_module
|
420 |
+
else None
|
421 |
+
)
|
422 |
+
|
423 |
+
self.attentions = nn.ModuleList(attentions)
|
424 |
+
self.audio_attentions = nn.ModuleList(audio_attentions)
|
425 |
+
self.resnets = nn.ModuleList(resnets)
|
426 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
427 |
+
|
428 |
+
if add_downsample:
|
429 |
+
self.downsamplers = nn.ModuleList(
|
430 |
+
[
|
431 |
+
Downsample3D(
|
432 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
433 |
+
)
|
434 |
+
]
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
self.downsamplers = None
|
438 |
+
|
439 |
+
self.gradient_checkpointing = False
|
440 |
+
|
441 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
442 |
+
output_states = ()
|
443 |
+
|
444 |
+
for resnet, attn, audio_attn, motion_module in zip(
|
445 |
+
self.resnets, self.attentions, self.audio_attentions, self.motion_modules
|
446 |
+
):
|
447 |
+
if self.training and self.gradient_checkpointing:
|
448 |
+
|
449 |
+
def create_custom_forward(module, return_dict=None):
|
450 |
+
def custom_forward(*inputs):
|
451 |
+
if return_dict is not None:
|
452 |
+
return module(*inputs, return_dict=return_dict)
|
453 |
+
else:
|
454 |
+
return module(*inputs)
|
455 |
+
|
456 |
+
return custom_forward
|
457 |
+
|
458 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
459 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
460 |
+
create_custom_forward(attn, return_dict=False),
|
461 |
+
hidden_states,
|
462 |
+
encoder_hidden_states,
|
463 |
+
)[0]
|
464 |
+
if motion_module is not None:
|
465 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
466 |
+
create_custom_forward(motion_module),
|
467 |
+
hidden_states.requires_grad_(),
|
468 |
+
temb,
|
469 |
+
encoder_hidden_states,
|
470 |
+
)
|
471 |
+
|
472 |
+
else:
|
473 |
+
hidden_states = resnet(hidden_states, temb)
|
474 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
475 |
+
|
476 |
+
hidden_states = (
|
477 |
+
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
478 |
+
if audio_attn is not None
|
479 |
+
else hidden_states
|
480 |
+
)
|
481 |
+
|
482 |
+
# add motion module
|
483 |
+
hidden_states = (
|
484 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
485 |
+
if motion_module is not None
|
486 |
+
else hidden_states
|
487 |
+
)
|
488 |
+
|
489 |
+
output_states += (hidden_states,)
|
490 |
+
|
491 |
+
if self.downsamplers is not None:
|
492 |
+
for downsampler in self.downsamplers:
|
493 |
+
hidden_states = downsampler(hidden_states)
|
494 |
+
|
495 |
+
output_states += (hidden_states,)
|
496 |
+
|
497 |
+
return hidden_states, output_states
|
498 |
+
|
499 |
+
|
500 |
+
class DownBlock3D(nn.Module):
|
501 |
+
def __init__(
|
502 |
+
self,
|
503 |
+
in_channels: int,
|
504 |
+
out_channels: int,
|
505 |
+
temb_channels: int,
|
506 |
+
dropout: float = 0.0,
|
507 |
+
num_layers: int = 1,
|
508 |
+
resnet_eps: float = 1e-6,
|
509 |
+
resnet_time_scale_shift: str = "default",
|
510 |
+
resnet_act_fn: str = "swish",
|
511 |
+
resnet_groups: int = 32,
|
512 |
+
resnet_pre_norm: bool = True,
|
513 |
+
output_scale_factor=1.0,
|
514 |
+
add_downsample=True,
|
515 |
+
downsample_padding=1,
|
516 |
+
use_inflated_groupnorm=False,
|
517 |
+
use_motion_module=None,
|
518 |
+
motion_module_type=None,
|
519 |
+
motion_module_kwargs=None,
|
520 |
+
):
|
521 |
+
super().__init__()
|
522 |
+
resnets = []
|
523 |
+
motion_modules = []
|
524 |
+
|
525 |
+
for i in range(num_layers):
|
526 |
+
in_channels = in_channels if i == 0 else out_channels
|
527 |
+
resnets.append(
|
528 |
+
ResnetBlock3D(
|
529 |
+
in_channels=in_channels,
|
530 |
+
out_channels=out_channels,
|
531 |
+
temb_channels=temb_channels,
|
532 |
+
eps=resnet_eps,
|
533 |
+
groups=resnet_groups,
|
534 |
+
dropout=dropout,
|
535 |
+
time_embedding_norm=resnet_time_scale_shift,
|
536 |
+
non_linearity=resnet_act_fn,
|
537 |
+
output_scale_factor=output_scale_factor,
|
538 |
+
pre_norm=resnet_pre_norm,
|
539 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
540 |
+
)
|
541 |
+
)
|
542 |
+
motion_modules.append(
|
543 |
+
get_motion_module(
|
544 |
+
in_channels=out_channels,
|
545 |
+
motion_module_type=motion_module_type,
|
546 |
+
motion_module_kwargs=motion_module_kwargs,
|
547 |
+
)
|
548 |
+
if use_motion_module
|
549 |
+
else None
|
550 |
+
)
|
551 |
+
|
552 |
+
self.resnets = nn.ModuleList(resnets)
|
553 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
554 |
+
|
555 |
+
if add_downsample:
|
556 |
+
self.downsamplers = nn.ModuleList(
|
557 |
+
[
|
558 |
+
Downsample3D(
|
559 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
560 |
+
)
|
561 |
+
]
|
562 |
+
)
|
563 |
+
else:
|
564 |
+
self.downsamplers = None
|
565 |
+
|
566 |
+
self.gradient_checkpointing = False
|
567 |
+
|
568 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
569 |
+
output_states = ()
|
570 |
+
|
571 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
572 |
+
if self.training and self.gradient_checkpointing:
|
573 |
+
|
574 |
+
def create_custom_forward(module):
|
575 |
+
def custom_forward(*inputs):
|
576 |
+
return module(*inputs)
|
577 |
+
|
578 |
+
return custom_forward
|
579 |
+
|
580 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
581 |
+
if motion_module is not None:
|
582 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
583 |
+
create_custom_forward(motion_module),
|
584 |
+
hidden_states.requires_grad_(),
|
585 |
+
temb,
|
586 |
+
encoder_hidden_states,
|
587 |
+
)
|
588 |
+
else:
|
589 |
+
hidden_states = resnet(hidden_states, temb)
|
590 |
+
|
591 |
+
# add motion module
|
592 |
+
hidden_states = (
|
593 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
594 |
+
if motion_module is not None
|
595 |
+
else hidden_states
|
596 |
+
)
|
597 |
+
|
598 |
+
output_states += (hidden_states,)
|
599 |
+
|
600 |
+
if self.downsamplers is not None:
|
601 |
+
for downsampler in self.downsamplers:
|
602 |
+
hidden_states = downsampler(hidden_states)
|
603 |
+
|
604 |
+
output_states += (hidden_states,)
|
605 |
+
|
606 |
+
return hidden_states, output_states
|
607 |
+
|
608 |
+
|
609 |
+
class CrossAttnUpBlock3D(nn.Module):
|
610 |
+
def __init__(
|
611 |
+
self,
|
612 |
+
in_channels: int,
|
613 |
+
out_channels: int,
|
614 |
+
prev_output_channel: int,
|
615 |
+
temb_channels: int,
|
616 |
+
dropout: float = 0.0,
|
617 |
+
num_layers: int = 1,
|
618 |
+
resnet_eps: float = 1e-6,
|
619 |
+
resnet_time_scale_shift: str = "default",
|
620 |
+
resnet_act_fn: str = "swish",
|
621 |
+
resnet_groups: int = 32,
|
622 |
+
resnet_pre_norm: bool = True,
|
623 |
+
attn_num_head_channels=1,
|
624 |
+
cross_attention_dim=1280,
|
625 |
+
output_scale_factor=1.0,
|
626 |
+
add_upsample=True,
|
627 |
+
dual_cross_attention=False,
|
628 |
+
use_linear_projection=False,
|
629 |
+
only_cross_attention=False,
|
630 |
+
upcast_attention=False,
|
631 |
+
unet_use_cross_frame_attention=False,
|
632 |
+
unet_use_temporal_attention=False,
|
633 |
+
use_inflated_groupnorm=False,
|
634 |
+
use_motion_module=None,
|
635 |
+
motion_module_type=None,
|
636 |
+
motion_module_kwargs=None,
|
637 |
+
add_audio_layer=False,
|
638 |
+
audio_condition_method="cross_attn",
|
639 |
+
custom_audio_layer=False,
|
640 |
+
):
|
641 |
+
super().__init__()
|
642 |
+
resnets = []
|
643 |
+
attentions = []
|
644 |
+
audio_attentions = []
|
645 |
+
motion_modules = []
|
646 |
+
|
647 |
+
self.has_cross_attention = True
|
648 |
+
self.attn_num_head_channels = attn_num_head_channels
|
649 |
+
|
650 |
+
for i in range(num_layers):
|
651 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
652 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
653 |
+
|
654 |
+
resnets.append(
|
655 |
+
ResnetBlock3D(
|
656 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
657 |
+
out_channels=out_channels,
|
658 |
+
temb_channels=temb_channels,
|
659 |
+
eps=resnet_eps,
|
660 |
+
groups=resnet_groups,
|
661 |
+
dropout=dropout,
|
662 |
+
time_embedding_norm=resnet_time_scale_shift,
|
663 |
+
non_linearity=resnet_act_fn,
|
664 |
+
output_scale_factor=output_scale_factor,
|
665 |
+
pre_norm=resnet_pre_norm,
|
666 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
667 |
+
)
|
668 |
+
)
|
669 |
+
if dual_cross_attention:
|
670 |
+
raise NotImplementedError
|
671 |
+
attentions.append(
|
672 |
+
Transformer3DModel(
|
673 |
+
attn_num_head_channels,
|
674 |
+
out_channels // attn_num_head_channels,
|
675 |
+
in_channels=out_channels,
|
676 |
+
num_layers=1,
|
677 |
+
cross_attention_dim=cross_attention_dim,
|
678 |
+
norm_num_groups=resnet_groups,
|
679 |
+
use_linear_projection=use_linear_projection,
|
680 |
+
only_cross_attention=only_cross_attention,
|
681 |
+
upcast_attention=upcast_attention,
|
682 |
+
use_motion_module=use_motion_module,
|
683 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
684 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
685 |
+
add_audio_layer=add_audio_layer,
|
686 |
+
audio_condition_method=audio_condition_method,
|
687 |
+
)
|
688 |
+
)
|
689 |
+
audio_attentions.append(
|
690 |
+
Transformer3DModel(
|
691 |
+
attn_num_head_channels,
|
692 |
+
out_channels // attn_num_head_channels,
|
693 |
+
in_channels=out_channels,
|
694 |
+
num_layers=1,
|
695 |
+
cross_attention_dim=cross_attention_dim,
|
696 |
+
norm_num_groups=resnet_groups,
|
697 |
+
use_linear_projection=use_linear_projection,
|
698 |
+
only_cross_attention=only_cross_attention,
|
699 |
+
upcast_attention=upcast_attention,
|
700 |
+
use_motion_module=use_motion_module,
|
701 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
702 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
703 |
+
add_audio_layer=add_audio_layer,
|
704 |
+
audio_condition_method=audio_condition_method,
|
705 |
+
custom_audio_layer=True,
|
706 |
+
)
|
707 |
+
if custom_audio_layer
|
708 |
+
else None
|
709 |
+
)
|
710 |
+
motion_modules.append(
|
711 |
+
get_motion_module(
|
712 |
+
in_channels=out_channels,
|
713 |
+
motion_module_type=motion_module_type,
|
714 |
+
motion_module_kwargs=motion_module_kwargs,
|
715 |
+
)
|
716 |
+
if use_motion_module
|
717 |
+
else None
|
718 |
+
)
|
719 |
+
|
720 |
+
self.attentions = nn.ModuleList(attentions)
|
721 |
+
self.audio_attentions = nn.ModuleList(audio_attentions)
|
722 |
+
self.resnets = nn.ModuleList(resnets)
|
723 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
724 |
+
|
725 |
+
if add_upsample:
|
726 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
727 |
+
else:
|
728 |
+
self.upsamplers = None
|
729 |
+
|
730 |
+
self.gradient_checkpointing = False
|
731 |
+
|
732 |
+
def forward(
|
733 |
+
self,
|
734 |
+
hidden_states,
|
735 |
+
res_hidden_states_tuple,
|
736 |
+
temb=None,
|
737 |
+
encoder_hidden_states=None,
|
738 |
+
upsample_size=None,
|
739 |
+
attention_mask=None,
|
740 |
+
):
|
741 |
+
for resnet, attn, audio_attn, motion_module in zip(
|
742 |
+
self.resnets, self.attentions, self.audio_attentions, self.motion_modules
|
743 |
+
):
|
744 |
+
# pop res hidden states
|
745 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
746 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
747 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
748 |
+
|
749 |
+
if self.training and self.gradient_checkpointing:
|
750 |
+
|
751 |
+
def create_custom_forward(module, return_dict=None):
|
752 |
+
def custom_forward(*inputs):
|
753 |
+
if return_dict is not None:
|
754 |
+
return module(*inputs, return_dict=return_dict)
|
755 |
+
else:
|
756 |
+
return module(*inputs)
|
757 |
+
|
758 |
+
return custom_forward
|
759 |
+
|
760 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
761 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
762 |
+
create_custom_forward(attn, return_dict=False),
|
763 |
+
hidden_states,
|
764 |
+
encoder_hidden_states,
|
765 |
+
)[0]
|
766 |
+
if motion_module is not None:
|
767 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
768 |
+
create_custom_forward(motion_module),
|
769 |
+
hidden_states.requires_grad_(),
|
770 |
+
temb,
|
771 |
+
encoder_hidden_states,
|
772 |
+
)
|
773 |
+
|
774 |
+
else:
|
775 |
+
hidden_states = resnet(hidden_states, temb)
|
776 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
777 |
+
hidden_states = (
|
778 |
+
audio_attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
779 |
+
if audio_attn is not None
|
780 |
+
else hidden_states
|
781 |
+
)
|
782 |
+
|
783 |
+
# add motion module
|
784 |
+
hidden_states = (
|
785 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
786 |
+
if motion_module is not None
|
787 |
+
else hidden_states
|
788 |
+
)
|
789 |
+
|
790 |
+
if self.upsamplers is not None:
|
791 |
+
for upsampler in self.upsamplers:
|
792 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
793 |
+
|
794 |
+
return hidden_states
|
795 |
+
|
796 |
+
|
797 |
+
class UpBlock3D(nn.Module):
|
798 |
+
def __init__(
|
799 |
+
self,
|
800 |
+
in_channels: int,
|
801 |
+
prev_output_channel: int,
|
802 |
+
out_channels: int,
|
803 |
+
temb_channels: int,
|
804 |
+
dropout: float = 0.0,
|
805 |
+
num_layers: int = 1,
|
806 |
+
resnet_eps: float = 1e-6,
|
807 |
+
resnet_time_scale_shift: str = "default",
|
808 |
+
resnet_act_fn: str = "swish",
|
809 |
+
resnet_groups: int = 32,
|
810 |
+
resnet_pre_norm: bool = True,
|
811 |
+
output_scale_factor=1.0,
|
812 |
+
add_upsample=True,
|
813 |
+
use_inflated_groupnorm=False,
|
814 |
+
use_motion_module=None,
|
815 |
+
motion_module_type=None,
|
816 |
+
motion_module_kwargs=None,
|
817 |
+
):
|
818 |
+
super().__init__()
|
819 |
+
resnets = []
|
820 |
+
motion_modules = []
|
821 |
+
|
822 |
+
for i in range(num_layers):
|
823 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
824 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
825 |
+
|
826 |
+
resnets.append(
|
827 |
+
ResnetBlock3D(
|
828 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
829 |
+
out_channels=out_channels,
|
830 |
+
temb_channels=temb_channels,
|
831 |
+
eps=resnet_eps,
|
832 |
+
groups=resnet_groups,
|
833 |
+
dropout=dropout,
|
834 |
+
time_embedding_norm=resnet_time_scale_shift,
|
835 |
+
non_linearity=resnet_act_fn,
|
836 |
+
output_scale_factor=output_scale_factor,
|
837 |
+
pre_norm=resnet_pre_norm,
|
838 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
839 |
+
)
|
840 |
+
)
|
841 |
+
motion_modules.append(
|
842 |
+
get_motion_module(
|
843 |
+
in_channels=out_channels,
|
844 |
+
motion_module_type=motion_module_type,
|
845 |
+
motion_module_kwargs=motion_module_kwargs,
|
846 |
+
)
|
847 |
+
if use_motion_module
|
848 |
+
else None
|
849 |
+
)
|
850 |
+
|
851 |
+
self.resnets = nn.ModuleList(resnets)
|
852 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
853 |
+
|
854 |
+
if add_upsample:
|
855 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
856 |
+
else:
|
857 |
+
self.upsamplers = None
|
858 |
+
|
859 |
+
self.gradient_checkpointing = False
|
860 |
+
|
861 |
+
def forward(
|
862 |
+
self,
|
863 |
+
hidden_states,
|
864 |
+
res_hidden_states_tuple,
|
865 |
+
temb=None,
|
866 |
+
upsample_size=None,
|
867 |
+
encoder_hidden_states=None,
|
868 |
+
):
|
869 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
870 |
+
# pop res hidden states
|
871 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
872 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
873 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
874 |
+
|
875 |
+
if self.training and self.gradient_checkpointing:
|
876 |
+
|
877 |
+
def create_custom_forward(module):
|
878 |
+
def custom_forward(*inputs):
|
879 |
+
return module(*inputs)
|
880 |
+
|
881 |
+
return custom_forward
|
882 |
+
|
883 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
884 |
+
if motion_module is not None:
|
885 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
886 |
+
create_custom_forward(motion_module),
|
887 |
+
hidden_states.requires_grad_(),
|
888 |
+
temb,
|
889 |
+
encoder_hidden_states,
|
890 |
+
)
|
891 |
+
else:
|
892 |
+
hidden_states = resnet(hidden_states, temb)
|
893 |
+
hidden_states = (
|
894 |
+
motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
895 |
+
if motion_module is not None
|
896 |
+
else hidden_states
|
897 |
+
)
|
898 |
+
|
899 |
+
if self.upsamplers is not None:
|
900 |
+
for upsampler in self.upsamplers:
|
901 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
902 |
+
|
903 |
+
return hidden_states
|
latentsync/models/utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
def zero_module(module):
|
16 |
+
# Zero out the parameters of a module and return it.
|
17 |
+
for p in module.parameters():
|
18 |
+
p.detach().zero_()
|
19 |
+
return module
|
latentsync/pipelines/lipsync_pipeline.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/pipelines/pipeline_animation.py
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
from typing import Callable, List, Optional, Union
|
7 |
+
import subprocess
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torchvision
|
12 |
+
|
13 |
+
from diffusers.utils import is_accelerate_available
|
14 |
+
from packaging import version
|
15 |
+
|
16 |
+
from diffusers.configuration_utils import FrozenDict
|
17 |
+
from diffusers.models import AutoencoderKL
|
18 |
+
from diffusers.pipeline_utils import DiffusionPipeline
|
19 |
+
from diffusers.schedulers import (
|
20 |
+
DDIMScheduler,
|
21 |
+
DPMSolverMultistepScheduler,
|
22 |
+
EulerAncestralDiscreteScheduler,
|
23 |
+
EulerDiscreteScheduler,
|
24 |
+
LMSDiscreteScheduler,
|
25 |
+
PNDMScheduler,
|
26 |
+
)
|
27 |
+
from diffusers.utils import deprecate, logging
|
28 |
+
|
29 |
+
from einops import rearrange
|
30 |
+
|
31 |
+
from ..models.unet import UNet3DConditionModel
|
32 |
+
from ..utils.image_processor import ImageProcessor
|
33 |
+
from ..utils.util import read_video, read_audio, write_video
|
34 |
+
from ..whisper.audio2feature import Audio2Feature
|
35 |
+
import tqdm
|
36 |
+
import soundfile as sf
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
|
41 |
+
class LipsyncPipeline(DiffusionPipeline):
|
42 |
+
_optional_components = []
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
vae: AutoencoderKL,
|
47 |
+
audio_encoder: Audio2Feature,
|
48 |
+
unet: UNet3DConditionModel,
|
49 |
+
scheduler: Union[
|
50 |
+
DDIMScheduler,
|
51 |
+
PNDMScheduler,
|
52 |
+
LMSDiscreteScheduler,
|
53 |
+
EulerDiscreteScheduler,
|
54 |
+
EulerAncestralDiscreteScheduler,
|
55 |
+
DPMSolverMultistepScheduler,
|
56 |
+
],
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
61 |
+
deprecation_message = (
|
62 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
63 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
64 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
65 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
66 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
67 |
+
" file"
|
68 |
+
)
|
69 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
70 |
+
new_config = dict(scheduler.config)
|
71 |
+
new_config["steps_offset"] = 1
|
72 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
73 |
+
|
74 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
75 |
+
deprecation_message = (
|
76 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
77 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
78 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
79 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
80 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
81 |
+
)
|
82 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
83 |
+
new_config = dict(scheduler.config)
|
84 |
+
new_config["clip_sample"] = False
|
85 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
86 |
+
|
87 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
88 |
+
version.parse(unet.config._diffusers_version).base_version
|
89 |
+
) < version.parse("0.9.0.dev0")
|
90 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
91 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
92 |
+
deprecation_message = (
|
93 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
94 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
95 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
96 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
97 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
98 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
99 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
100 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
101 |
+
" the `unet/config.json` file"
|
102 |
+
)
|
103 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
104 |
+
new_config = dict(unet.config)
|
105 |
+
new_config["sample_size"] = 64
|
106 |
+
unet._internal_dict = FrozenDict(new_config)
|
107 |
+
|
108 |
+
self.register_modules(
|
109 |
+
vae=vae,
|
110 |
+
audio_encoder=audio_encoder,
|
111 |
+
unet=unet,
|
112 |
+
scheduler=scheduler,
|
113 |
+
)
|
114 |
+
|
115 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
116 |
+
|
117 |
+
self.set_progress_bar_config(desc="Steps")
|
118 |
+
|
119 |
+
def enable_vae_slicing(self):
|
120 |
+
self.vae.enable_slicing()
|
121 |
+
|
122 |
+
def disable_vae_slicing(self):
|
123 |
+
self.vae.disable_slicing()
|
124 |
+
|
125 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
126 |
+
if is_accelerate_available():
|
127 |
+
from accelerate import cpu_offload
|
128 |
+
else:
|
129 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
130 |
+
|
131 |
+
device = torch.device(f"cuda:{gpu_id}")
|
132 |
+
|
133 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
134 |
+
if cpu_offloaded_model is not None:
|
135 |
+
cpu_offload(cpu_offloaded_model, device)
|
136 |
+
|
137 |
+
@property
|
138 |
+
def _execution_device(self):
|
139 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
140 |
+
return self.device
|
141 |
+
for module in self.unet.modules():
|
142 |
+
if (
|
143 |
+
hasattr(module, "_hf_hook")
|
144 |
+
and hasattr(module._hf_hook, "execution_device")
|
145 |
+
and module._hf_hook.execution_device is not None
|
146 |
+
):
|
147 |
+
return torch.device(module._hf_hook.execution_device)
|
148 |
+
return self.device
|
149 |
+
|
150 |
+
def decode_latents(self, latents):
|
151 |
+
latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
|
152 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
153 |
+
decoded_latents = self.vae.decode(latents).sample
|
154 |
+
return decoded_latents
|
155 |
+
|
156 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
157 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
158 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
159 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
160 |
+
# and should be between [0, 1]
|
161 |
+
|
162 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
163 |
+
extra_step_kwargs = {}
|
164 |
+
if accepts_eta:
|
165 |
+
extra_step_kwargs["eta"] = eta
|
166 |
+
|
167 |
+
# check if the scheduler accepts generator
|
168 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
169 |
+
if accepts_generator:
|
170 |
+
extra_step_kwargs["generator"] = generator
|
171 |
+
return extra_step_kwargs
|
172 |
+
|
173 |
+
def check_inputs(self, height, width, callback_steps):
|
174 |
+
assert height == width, "Height and width must be equal"
|
175 |
+
|
176 |
+
if height % 8 != 0 or width % 8 != 0:
|
177 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
178 |
+
|
179 |
+
if (callback_steps is None) or (
|
180 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
181 |
+
):
|
182 |
+
raise ValueError(
|
183 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
184 |
+
f" {type(callback_steps)}."
|
185 |
+
)
|
186 |
+
|
187 |
+
def prepare_latents(self, batch_size, num_frames, num_channels_latents, height, width, dtype, device, generator):
|
188 |
+
shape = (
|
189 |
+
batch_size,
|
190 |
+
num_channels_latents,
|
191 |
+
1,
|
192 |
+
height // self.vae_scale_factor,
|
193 |
+
width // self.vae_scale_factor,
|
194 |
+
)
|
195 |
+
rand_device = "cpu" if device.type == "mps" else device
|
196 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
197 |
+
latents = latents.repeat(1, 1, num_frames, 1, 1)
|
198 |
+
|
199 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
200 |
+
latents = latents * self.scheduler.init_noise_sigma
|
201 |
+
return latents
|
202 |
+
|
203 |
+
def prepare_mask_latents(
|
204 |
+
self, mask, masked_image, height, width, dtype, device, generator, do_classifier_free_guidance
|
205 |
+
):
|
206 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
207 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
208 |
+
# and half precision
|
209 |
+
mask = torch.nn.functional.interpolate(
|
210 |
+
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
211 |
+
)
|
212 |
+
masked_image = masked_image.to(device=device, dtype=dtype)
|
213 |
+
|
214 |
+
# encode the mask image into latents space so we can concatenate it to the latents
|
215 |
+
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
216 |
+
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
217 |
+
|
218 |
+
# aligning device to prevent device errors when concating it with the latent model input
|
219 |
+
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
220 |
+
mask = mask.to(device=device, dtype=dtype)
|
221 |
+
|
222 |
+
# assume batch size = 1
|
223 |
+
mask = rearrange(mask, "f c h w -> 1 c f h w")
|
224 |
+
masked_image_latents = rearrange(masked_image_latents, "f c h w -> 1 c f h w")
|
225 |
+
|
226 |
+
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
227 |
+
masked_image_latents = (
|
228 |
+
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
229 |
+
)
|
230 |
+
return mask, masked_image_latents
|
231 |
+
|
232 |
+
def prepare_image_latents(self, images, device, dtype, generator, do_classifier_free_guidance):
|
233 |
+
images = images.to(device=device, dtype=dtype)
|
234 |
+
image_latents = self.vae.encode(images).latent_dist.sample(generator=generator)
|
235 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
236 |
+
image_latents = rearrange(image_latents, "f c h w -> 1 c f h w")
|
237 |
+
image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
|
238 |
+
|
239 |
+
return image_latents
|
240 |
+
|
241 |
+
def set_progress_bar_config(self, **kwargs):
|
242 |
+
if not hasattr(self, "_progress_bar_config"):
|
243 |
+
self._progress_bar_config = {}
|
244 |
+
self._progress_bar_config.update(kwargs)
|
245 |
+
|
246 |
+
@staticmethod
|
247 |
+
def paste_surrounding_pixels_back(decoded_latents, pixel_values, masks, device, weight_dtype):
|
248 |
+
# Paste the surrounding pixels back, because we only want to change the mouth region
|
249 |
+
pixel_values = pixel_values.to(device=device, dtype=weight_dtype)
|
250 |
+
masks = masks.to(device=device, dtype=weight_dtype)
|
251 |
+
combined_pixel_values = decoded_latents * masks + pixel_values * (1 - masks)
|
252 |
+
return combined_pixel_values
|
253 |
+
|
254 |
+
@staticmethod
|
255 |
+
def pixel_values_to_images(pixel_values: torch.Tensor):
|
256 |
+
pixel_values = rearrange(pixel_values, "f c h w -> f h w c")
|
257 |
+
pixel_values = (pixel_values / 2 + 0.5).clamp(0, 1)
|
258 |
+
images = (pixel_values * 255).to(torch.uint8)
|
259 |
+
images = images.cpu().numpy()
|
260 |
+
return images
|
261 |
+
|
262 |
+
def affine_transform_video(self, video_path):
|
263 |
+
video_frames = read_video(video_path, use_decord=False)
|
264 |
+
faces = []
|
265 |
+
boxes = []
|
266 |
+
affine_matrices = []
|
267 |
+
print(f"Affine transforming {len(video_frames)} faces...")
|
268 |
+
for frame in tqdm.tqdm(video_frames):
|
269 |
+
face, box, affine_matrix = self.image_processor.affine_transform(frame)
|
270 |
+
faces.append(face)
|
271 |
+
boxes.append(box)
|
272 |
+
affine_matrices.append(affine_matrix)
|
273 |
+
|
274 |
+
faces = torch.stack(faces)
|
275 |
+
return faces, video_frames, boxes, affine_matrices
|
276 |
+
|
277 |
+
def restore_video(self, faces, video_frames, boxes, affine_matrices):
|
278 |
+
video_frames = video_frames[: faces.shape[0]]
|
279 |
+
out_frames = []
|
280 |
+
for index, face in enumerate(faces):
|
281 |
+
x1, y1, x2, y2 = boxes[index]
|
282 |
+
height = int(y2 - y1)
|
283 |
+
width = int(x2 - x1)
|
284 |
+
face = torchvision.transforms.functional.resize(face, size=(height, width), antialias=True)
|
285 |
+
face = rearrange(face, "c h w -> h w c")
|
286 |
+
face = (face / 2 + 0.5).clamp(0, 1)
|
287 |
+
face = (face * 255).to(torch.uint8).cpu().numpy()
|
288 |
+
out_frame = self.image_processor.restorer.restore_img(video_frames[index], face, affine_matrices[index])
|
289 |
+
out_frames.append(out_frame)
|
290 |
+
return np.stack(out_frames, axis=0)
|
291 |
+
|
292 |
+
@torch.no_grad()
|
293 |
+
def __call__(
|
294 |
+
self,
|
295 |
+
video_path: str,
|
296 |
+
audio_path: str,
|
297 |
+
video_out_path: str,
|
298 |
+
video_mask_path: str = None,
|
299 |
+
num_frames: int = 16,
|
300 |
+
video_fps: int = 25,
|
301 |
+
audio_sample_rate: int = 16000,
|
302 |
+
height: Optional[int] = None,
|
303 |
+
width: Optional[int] = None,
|
304 |
+
num_inference_steps: int = 20,
|
305 |
+
guidance_scale: float = 1.5,
|
306 |
+
weight_dtype: Optional[torch.dtype] = torch.float16,
|
307 |
+
eta: float = 0.0,
|
308 |
+
mask: str = "fix_mask",
|
309 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
310 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
311 |
+
callback_steps: Optional[int] = 1,
|
312 |
+
**kwargs,
|
313 |
+
):
|
314 |
+
is_train = self.unet.training
|
315 |
+
self.unet.eval()
|
316 |
+
|
317 |
+
# 0. Define call parameters
|
318 |
+
batch_size = 1
|
319 |
+
device = self._execution_device
|
320 |
+
self.image_processor = ImageProcessor(height, mask=mask, device="cuda")
|
321 |
+
self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
|
322 |
+
|
323 |
+
video_frames, original_video_frames, boxes, affine_matrices = self.affine_transform_video(video_path)
|
324 |
+
audio_samples = read_audio(audio_path)
|
325 |
+
|
326 |
+
# 1. Default height and width to unet
|
327 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
328 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
329 |
+
|
330 |
+
# 2. Check inputs
|
331 |
+
self.check_inputs(height, width, callback_steps)
|
332 |
+
|
333 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
334 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
335 |
+
# corresponds to doing no classifier free guidance.
|
336 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
337 |
+
|
338 |
+
# 3. set timesteps
|
339 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
340 |
+
timesteps = self.scheduler.timesteps
|
341 |
+
|
342 |
+
# 4. Prepare extra step kwargs.
|
343 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
344 |
+
|
345 |
+
self.video_fps = video_fps
|
346 |
+
|
347 |
+
if self.unet.add_audio_layer:
|
348 |
+
whisper_feature = self.audio_encoder.audio2feat(audio_path)
|
349 |
+
whisper_chunks = self.audio_encoder.feature2chunks(feature_array=whisper_feature, fps=video_fps)
|
350 |
+
|
351 |
+
num_inferences = min(len(video_frames), len(whisper_chunks)) // num_frames
|
352 |
+
else:
|
353 |
+
num_inferences = len(video_frames) // num_frames
|
354 |
+
|
355 |
+
synced_video_frames = []
|
356 |
+
masked_video_frames = []
|
357 |
+
|
358 |
+
num_channels_latents = self.vae.config.latent_channels
|
359 |
+
|
360 |
+
# Prepare latent variables
|
361 |
+
all_latents = self.prepare_latents(
|
362 |
+
batch_size,
|
363 |
+
num_frames * num_inferences,
|
364 |
+
num_channels_latents,
|
365 |
+
height,
|
366 |
+
width,
|
367 |
+
weight_dtype,
|
368 |
+
device,
|
369 |
+
generator,
|
370 |
+
)
|
371 |
+
|
372 |
+
for i in tqdm.tqdm(range(num_inferences), desc="Doing inference..."):
|
373 |
+
if self.unet.add_audio_layer:
|
374 |
+
audio_embeds = torch.stack(whisper_chunks[i * num_frames : (i + 1) * num_frames])
|
375 |
+
audio_embeds = audio_embeds.to(device, dtype=weight_dtype)
|
376 |
+
if do_classifier_free_guidance:
|
377 |
+
empty_audio_embeds = torch.zeros_like(audio_embeds)
|
378 |
+
audio_embeds = torch.cat([empty_audio_embeds, audio_embeds])
|
379 |
+
else:
|
380 |
+
audio_embeds = None
|
381 |
+
inference_video_frames = video_frames[i * num_frames : (i + 1) * num_frames]
|
382 |
+
latents = all_latents[:, :, i * num_frames : (i + 1) * num_frames]
|
383 |
+
pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
|
384 |
+
inference_video_frames, affine_transform=False
|
385 |
+
)
|
386 |
+
|
387 |
+
# 7. Prepare mask latent variables
|
388 |
+
mask_latents, masked_image_latents = self.prepare_mask_latents(
|
389 |
+
masks,
|
390 |
+
masked_pixel_values,
|
391 |
+
height,
|
392 |
+
width,
|
393 |
+
weight_dtype,
|
394 |
+
device,
|
395 |
+
generator,
|
396 |
+
do_classifier_free_guidance,
|
397 |
+
)
|
398 |
+
|
399 |
+
# 8. Prepare image latents
|
400 |
+
image_latents = self.prepare_image_latents(
|
401 |
+
pixel_values,
|
402 |
+
device,
|
403 |
+
weight_dtype,
|
404 |
+
generator,
|
405 |
+
do_classifier_free_guidance,
|
406 |
+
)
|
407 |
+
|
408 |
+
# 9. Denoising loop
|
409 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
410 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
411 |
+
for j, t in enumerate(timesteps):
|
412 |
+
# expand the latents if we are doing classifier free guidance
|
413 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
414 |
+
|
415 |
+
# concat latents, mask, masked_image_latents in the channel dimension
|
416 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
417 |
+
latent_model_input = torch.cat(
|
418 |
+
[latent_model_input, mask_latents, masked_image_latents, image_latents], dim=1
|
419 |
+
)
|
420 |
+
|
421 |
+
# predict the noise residual
|
422 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=audio_embeds).sample
|
423 |
+
|
424 |
+
# perform guidance
|
425 |
+
if do_classifier_free_guidance:
|
426 |
+
noise_pred_uncond, noise_pred_audio = noise_pred.chunk(2)
|
427 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_audio - noise_pred_uncond)
|
428 |
+
|
429 |
+
# compute the previous noisy sample x_t -> x_t-1
|
430 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
431 |
+
|
432 |
+
# call the callback, if provided
|
433 |
+
if j == len(timesteps) - 1 or ((j + 1) > num_warmup_steps and (j + 1) % self.scheduler.order == 0):
|
434 |
+
progress_bar.update()
|
435 |
+
if callback is not None and j % callback_steps == 0:
|
436 |
+
callback(j, t, latents)
|
437 |
+
|
438 |
+
# Recover the pixel values
|
439 |
+
decoded_latents = self.decode_latents(latents)
|
440 |
+
decoded_latents = self.paste_surrounding_pixels_back(
|
441 |
+
decoded_latents, pixel_values, 1 - masks, device, weight_dtype
|
442 |
+
)
|
443 |
+
synced_video_frames.append(decoded_latents)
|
444 |
+
masked_video_frames.append(masked_pixel_values)
|
445 |
+
|
446 |
+
synced_video_frames = self.restore_video(
|
447 |
+
torch.cat(synced_video_frames), original_video_frames, boxes, affine_matrices
|
448 |
+
)
|
449 |
+
masked_video_frames = self.restore_video(
|
450 |
+
torch.cat(masked_video_frames), original_video_frames, boxes, affine_matrices
|
451 |
+
)
|
452 |
+
|
453 |
+
audio_samples_remain_length = int(synced_video_frames.shape[0] / video_fps * audio_sample_rate)
|
454 |
+
audio_samples = audio_samples[:audio_samples_remain_length].cpu().numpy()
|
455 |
+
|
456 |
+
if is_train:
|
457 |
+
self.unet.train()
|
458 |
+
|
459 |
+
temp_dir = "temp"
|
460 |
+
if os.path.exists(temp_dir):
|
461 |
+
shutil.rmtree(temp_dir)
|
462 |
+
os.makedirs(temp_dir, exist_ok=True)
|
463 |
+
|
464 |
+
write_video(os.path.join(temp_dir, "video.mp4"), synced_video_frames, fps=25)
|
465 |
+
# write_video(video_mask_path, masked_video_frames, fps=25)
|
466 |
+
|
467 |
+
sf.write(os.path.join(temp_dir, "audio.wav"), audio_samples, audio_sample_rate)
|
468 |
+
|
469 |
+
command = f"ffmpeg -y -loglevel error -nostdin -i {os.path.join(temp_dir, 'video.mp4')} -i {os.path.join(temp_dir, 'audio.wav')} -c:v libx264 -c:a aac -q:v 0 -q:a 0 {video_out_path}"
|
470 |
+
subprocess.run(command, shell=True)
|
latentsync/trepa/__init__.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torch.nn as nn
|
18 |
+
from einops import rearrange
|
19 |
+
from .third_party.VideoMAEv2.utils import load_videomae_model
|
20 |
+
|
21 |
+
|
22 |
+
class TREPALoss:
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
device="cuda",
|
26 |
+
ckpt_path="/mnt/bn/maliva-gen-ai-v2/chunyu.li/checkpoints/vit_g_hybrid_pt_1200e_ssv2_ft.pth",
|
27 |
+
):
|
28 |
+
self.model = load_videomae_model(device, ckpt_path).eval().to(dtype=torch.float16)
|
29 |
+
self.model.requires_grad_(False)
|
30 |
+
self.bce_loss = nn.BCELoss()
|
31 |
+
|
32 |
+
def __call__(self, videos_fake, videos_real, loss_type="mse"):
|
33 |
+
batch_size = videos_fake.shape[0]
|
34 |
+
num_frames = videos_fake.shape[2]
|
35 |
+
videos_fake = rearrange(videos_fake.clone(), "b c f h w -> (b f) c h w")
|
36 |
+
videos_real = rearrange(videos_real.clone(), "b c f h w -> (b f) c h w")
|
37 |
+
|
38 |
+
videos_fake = F.interpolate(videos_fake, size=(224, 224), mode="bilinear")
|
39 |
+
videos_real = F.interpolate(videos_real, size=(224, 224), mode="bilinear")
|
40 |
+
|
41 |
+
videos_fake = rearrange(videos_fake, "(b f) c h w -> b c f h w", f=num_frames)
|
42 |
+
videos_real = rearrange(videos_real, "(b f) c h w -> b c f h w", f=num_frames)
|
43 |
+
|
44 |
+
# Because input pixel range is [-1, 1], and model expects pixel range to be [0, 1]
|
45 |
+
videos_fake = (videos_fake / 2 + 0.5).clamp(0, 1)
|
46 |
+
videos_real = (videos_real / 2 + 0.5).clamp(0, 1)
|
47 |
+
|
48 |
+
feats_fake = self.model.forward_features(videos_fake)
|
49 |
+
feats_real = self.model.forward_features(videos_real)
|
50 |
+
|
51 |
+
feats_fake = F.normalize(feats_fake, p=2, dim=1)
|
52 |
+
feats_real = F.normalize(feats_real, p=2, dim=1)
|
53 |
+
|
54 |
+
return F.mse_loss(feats_fake, feats_real)
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
# input shape: (b, c, f, h, w)
|
59 |
+
videos_fake = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
|
60 |
+
videos_real = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
|
61 |
+
|
62 |
+
trepa_loss = TREPALoss(device="cuda")
|
63 |
+
loss = trepa_loss(videos_fake, videos_real)
|
64 |
+
print(loss)
|
latentsync/trepa/third_party/VideoMAEv2/__init__.py
ADDED
File without changes
|
latentsync/trepa/third_party/VideoMAEv2/utils.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import requests
|
4 |
+
from tqdm import tqdm
|
5 |
+
from torchvision import transforms
|
6 |
+
from .videomaev2_finetune import vit_giant_patch14_224
|
7 |
+
|
8 |
+
def to_normalized_float_tensor(vid):
|
9 |
+
return vid.permute(3, 0, 1, 2).to(torch.float32) / 255
|
10 |
+
|
11 |
+
|
12 |
+
# NOTE: for those functions, which generally expect mini-batches, we keep them
|
13 |
+
# as non-minibatch so that they are applied as if they were 4d (thus image).
|
14 |
+
# this way, we only apply the transformation in the spatial domain
|
15 |
+
def resize(vid, size, interpolation='bilinear'):
|
16 |
+
# NOTE: using bilinear interpolation because we don't work on minibatches
|
17 |
+
# at this level
|
18 |
+
scale = None
|
19 |
+
if isinstance(size, int):
|
20 |
+
scale = float(size) / min(vid.shape[-2:])
|
21 |
+
size = None
|
22 |
+
return torch.nn.functional.interpolate(
|
23 |
+
vid,
|
24 |
+
size=size,
|
25 |
+
scale_factor=scale,
|
26 |
+
mode=interpolation,
|
27 |
+
align_corners=False)
|
28 |
+
|
29 |
+
|
30 |
+
class ToFloatTensorInZeroOne(object):
|
31 |
+
def __call__(self, vid):
|
32 |
+
return to_normalized_float_tensor(vid)
|
33 |
+
|
34 |
+
|
35 |
+
class Resize(object):
|
36 |
+
def __init__(self, size):
|
37 |
+
self.size = size
|
38 |
+
def __call__(self, vid):
|
39 |
+
return resize(vid, self.size)
|
40 |
+
|
41 |
+
def preprocess_videomae(videos):
|
42 |
+
transform = transforms.Compose(
|
43 |
+
[ToFloatTensorInZeroOne(),
|
44 |
+
Resize((224, 224))])
|
45 |
+
return torch.stack([transform(f) for f in torch.from_numpy(videos)])
|
46 |
+
|
47 |
+
|
48 |
+
def load_videomae_model(device, ckpt_path=None):
|
49 |
+
if ckpt_path is None:
|
50 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
51 |
+
ckpt_path = os.path.join(current_dir, 'vit_g_hybrid_pt_1200e_ssv2_ft.pth')
|
52 |
+
|
53 |
+
if not os.path.exists(ckpt_path):
|
54 |
+
# download the ckpt to the path
|
55 |
+
ckpt_url = 'https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/videomaev2/vit_g_hybrid_pt_1200e_ssv2_ft.pth'
|
56 |
+
response = requests.get(ckpt_url, stream=True, allow_redirects=True)
|
57 |
+
total_size = int(response.headers.get("content-length", 0))
|
58 |
+
block_size = 1024
|
59 |
+
|
60 |
+
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
|
61 |
+
with open(ckpt_path, "wb") as fw:
|
62 |
+
for data in response.iter_content(block_size):
|
63 |
+
progress_bar.update(len(data))
|
64 |
+
fw.write(data)
|
65 |
+
|
66 |
+
model = vit_giant_patch14_224(
|
67 |
+
img_size=224,
|
68 |
+
pretrained=False,
|
69 |
+
num_classes=174,
|
70 |
+
all_frames=16,
|
71 |
+
tubelet_size=2,
|
72 |
+
drop_path_rate=0.3,
|
73 |
+
use_mean_pooling=True)
|
74 |
+
|
75 |
+
ckpt = torch.load(ckpt_path, map_location='cpu')
|
76 |
+
for model_key in ['model', 'module']:
|
77 |
+
if model_key in ckpt:
|
78 |
+
ckpt = ckpt[model_key]
|
79 |
+
break
|
80 |
+
model.load_state_dict(ckpt)
|
81 |
+
return model.to(device)
|
latentsync/trepa/third_party/VideoMAEv2/videomaev2_finetune.py
ADDED
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on BEiT, timm, DINO and DeiT code bases
|
3 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
4 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
5 |
+
# https://github.com/facebookresearch/deit
|
6 |
+
# https://github.com/facebookresearch/dino
|
7 |
+
# --------------------------------------------------------'
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
import math
|
11 |
+
import warnings
|
12 |
+
import numpy as np
|
13 |
+
import collections.abc
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torch.utils.checkpoint as cp
|
18 |
+
from itertools import repeat
|
19 |
+
|
20 |
+
|
21 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
22 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
23 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
24 |
+
def norm_cdf(x):
|
25 |
+
# Computes standard normal cumulative distribution function
|
26 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
27 |
+
|
28 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
29 |
+
warnings.warn(
|
30 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
31 |
+
"The distribution of values may be incorrect.",
|
32 |
+
stacklevel=2,
|
33 |
+
)
|
34 |
+
|
35 |
+
with torch.no_grad():
|
36 |
+
# Values are generated by using a truncated uniform distribution and
|
37 |
+
# then using the inverse CDF for the normal distribution.
|
38 |
+
# Get upper and lower cdf values
|
39 |
+
l = norm_cdf((a - mean) / std)
|
40 |
+
u = norm_cdf((b - mean) / std)
|
41 |
+
|
42 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
43 |
+
# [2l-1, 2u-1].
|
44 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
45 |
+
|
46 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
47 |
+
# standard normal
|
48 |
+
tensor.erfinv_()
|
49 |
+
|
50 |
+
# Transform to proper mean, std
|
51 |
+
tensor.mul_(std * math.sqrt(2.0))
|
52 |
+
tensor.add_(mean)
|
53 |
+
|
54 |
+
# Clamp to ensure it's in the proper range
|
55 |
+
tensor.clamp_(min=a, max=b)
|
56 |
+
return tensor
|
57 |
+
|
58 |
+
|
59 |
+
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
60 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
61 |
+
normal distribution. The values are effectively drawn from the
|
62 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
63 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
64 |
+
the bounds. The method used for generating the random values works
|
65 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
66 |
+
Args:
|
67 |
+
tensor: an n-dimensional `torch.Tensor`
|
68 |
+
mean: the mean of the normal distribution
|
69 |
+
std: the standard deviation of the normal distribution
|
70 |
+
a: the minimum cutoff value
|
71 |
+
b: the maximum cutoff value
|
72 |
+
Examples:
|
73 |
+
>>> w = torch.empty(3, 5)
|
74 |
+
>>> nn.init.trunc_normal_(w)
|
75 |
+
"""
|
76 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
77 |
+
|
78 |
+
|
79 |
+
def _ntuple(n):
|
80 |
+
def parse(x):
|
81 |
+
if isinstance(x, collections.abc.Iterable):
|
82 |
+
return x
|
83 |
+
return tuple(repeat(x, n))
|
84 |
+
|
85 |
+
return parse
|
86 |
+
|
87 |
+
|
88 |
+
to_2tuple = _ntuple(2)
|
89 |
+
|
90 |
+
|
91 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
92 |
+
"""
|
93 |
+
Adapted from timm codebase
|
94 |
+
"""
|
95 |
+
if drop_prob == 0.0 or not training:
|
96 |
+
return x
|
97 |
+
keep_prob = 1 - drop_prob
|
98 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
99 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
100 |
+
random_tensor.floor_() # binarize
|
101 |
+
output = x.div(keep_prob) * random_tensor
|
102 |
+
return output
|
103 |
+
|
104 |
+
|
105 |
+
def _cfg(url="", **kwargs):
|
106 |
+
return {
|
107 |
+
"url": url,
|
108 |
+
"num_classes": 400,
|
109 |
+
"input_size": (3, 224, 224),
|
110 |
+
"pool_size": None,
|
111 |
+
"crop_pct": 0.9,
|
112 |
+
"interpolation": "bicubic",
|
113 |
+
"mean": (0.5, 0.5, 0.5),
|
114 |
+
"std": (0.5, 0.5, 0.5),
|
115 |
+
**kwargs,
|
116 |
+
}
|
117 |
+
|
118 |
+
|
119 |
+
class DropPath(nn.Module):
|
120 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
121 |
+
|
122 |
+
def __init__(self, drop_prob=None):
|
123 |
+
super(DropPath, self).__init__()
|
124 |
+
self.drop_prob = drop_prob
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
return drop_path(x, self.drop_prob, self.training)
|
128 |
+
|
129 |
+
def extra_repr(self) -> str:
|
130 |
+
return "p={}".format(self.drop_prob)
|
131 |
+
|
132 |
+
|
133 |
+
class Mlp(nn.Module):
|
134 |
+
|
135 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
136 |
+
super().__init__()
|
137 |
+
out_features = out_features or in_features
|
138 |
+
hidden_features = hidden_features or in_features
|
139 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
140 |
+
self.act = act_layer()
|
141 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
142 |
+
self.drop = nn.Dropout(drop)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
x = self.fc1(x)
|
146 |
+
x = self.act(x)
|
147 |
+
# x = self.drop(x)
|
148 |
+
# commit this for the orignal BERT implement
|
149 |
+
x = self.fc2(x)
|
150 |
+
x = self.drop(x)
|
151 |
+
return x
|
152 |
+
|
153 |
+
|
154 |
+
class CosAttention(nn.Module):
|
155 |
+
|
156 |
+
def __init__(
|
157 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_head_dim=None
|
158 |
+
):
|
159 |
+
super().__init__()
|
160 |
+
self.num_heads = num_heads
|
161 |
+
head_dim = dim // num_heads
|
162 |
+
if attn_head_dim is not None:
|
163 |
+
head_dim = attn_head_dim
|
164 |
+
all_head_dim = head_dim * self.num_heads
|
165 |
+
# self.scale = qk_scale or head_dim**-0.5
|
166 |
+
# DO NOT RENAME [self.scale] (for no weight decay)
|
167 |
+
if qk_scale is None:
|
168 |
+
self.scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
|
169 |
+
else:
|
170 |
+
self.scale = qk_scale
|
171 |
+
|
172 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
173 |
+
if qkv_bias:
|
174 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
175 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
176 |
+
else:
|
177 |
+
self.q_bias = None
|
178 |
+
self.v_bias = None
|
179 |
+
|
180 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
181 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
182 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
183 |
+
|
184 |
+
def forward(self, x):
|
185 |
+
B, N, C = x.shape
|
186 |
+
qkv_bias = None
|
187 |
+
if self.q_bias is not None:
|
188 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
189 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
190 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
191 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
192 |
+
|
193 |
+
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
|
194 |
+
|
195 |
+
# torch.log(torch.tensor(1. / 0.01)) = 4.6052
|
196 |
+
logit_scale = torch.clamp(self.scale, max=4.6052).exp()
|
197 |
+
|
198 |
+
attn = attn * logit_scale
|
199 |
+
|
200 |
+
attn = attn.softmax(dim=-1)
|
201 |
+
attn = self.attn_drop(attn)
|
202 |
+
|
203 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
204 |
+
|
205 |
+
x = self.proj(x)
|
206 |
+
x = self.proj_drop(x)
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class Attention(nn.Module):
|
211 |
+
|
212 |
+
def __init__(
|
213 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_head_dim=None
|
214 |
+
):
|
215 |
+
super().__init__()
|
216 |
+
self.num_heads = num_heads
|
217 |
+
head_dim = dim // num_heads
|
218 |
+
if attn_head_dim is not None:
|
219 |
+
head_dim = attn_head_dim
|
220 |
+
all_head_dim = head_dim * self.num_heads
|
221 |
+
self.scale = qk_scale or head_dim**-0.5
|
222 |
+
|
223 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
224 |
+
if qkv_bias:
|
225 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
226 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
227 |
+
else:
|
228 |
+
self.q_bias = None
|
229 |
+
self.v_bias = None
|
230 |
+
|
231 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
232 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
233 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
234 |
+
|
235 |
+
def forward(self, x):
|
236 |
+
B, N, C = x.shape
|
237 |
+
qkv_bias = None
|
238 |
+
if self.q_bias is not None:
|
239 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
240 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
241 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
242 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
243 |
+
|
244 |
+
q = q * self.scale
|
245 |
+
attn = q @ k.transpose(-2, -1)
|
246 |
+
|
247 |
+
attn = attn.softmax(dim=-1)
|
248 |
+
attn = self.attn_drop(attn)
|
249 |
+
|
250 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
251 |
+
|
252 |
+
x = self.proj(x)
|
253 |
+
x = self.proj_drop(x)
|
254 |
+
return x
|
255 |
+
|
256 |
+
|
257 |
+
class Block(nn.Module):
|
258 |
+
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
dim,
|
262 |
+
num_heads,
|
263 |
+
mlp_ratio=4.0,
|
264 |
+
qkv_bias=False,
|
265 |
+
qk_scale=None,
|
266 |
+
drop=0.0,
|
267 |
+
attn_drop=0.0,
|
268 |
+
drop_path=0.0,
|
269 |
+
init_values=None,
|
270 |
+
act_layer=nn.GELU,
|
271 |
+
norm_layer=nn.LayerNorm,
|
272 |
+
attn_head_dim=None,
|
273 |
+
cos_attn=False,
|
274 |
+
):
|
275 |
+
super().__init__()
|
276 |
+
self.norm1 = norm_layer(dim)
|
277 |
+
if cos_attn:
|
278 |
+
self.attn = CosAttention(
|
279 |
+
dim,
|
280 |
+
num_heads=num_heads,
|
281 |
+
qkv_bias=qkv_bias,
|
282 |
+
qk_scale=qk_scale,
|
283 |
+
attn_drop=attn_drop,
|
284 |
+
proj_drop=drop,
|
285 |
+
attn_head_dim=attn_head_dim,
|
286 |
+
)
|
287 |
+
else:
|
288 |
+
self.attn = Attention(
|
289 |
+
dim,
|
290 |
+
num_heads=num_heads,
|
291 |
+
qkv_bias=qkv_bias,
|
292 |
+
qk_scale=qk_scale,
|
293 |
+
attn_drop=attn_drop,
|
294 |
+
proj_drop=drop,
|
295 |
+
attn_head_dim=attn_head_dim,
|
296 |
+
)
|
297 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
298 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
299 |
+
self.norm2 = norm_layer(dim)
|
300 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
301 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
302 |
+
|
303 |
+
if init_values > 0:
|
304 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
305 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
306 |
+
else:
|
307 |
+
self.gamma_1, self.gamma_2 = None, None
|
308 |
+
|
309 |
+
def forward(self, x):
|
310 |
+
if self.gamma_1 is None:
|
311 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
312 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
313 |
+
else:
|
314 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
315 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
316 |
+
return x
|
317 |
+
|
318 |
+
|
319 |
+
class PatchEmbed(nn.Module):
|
320 |
+
"""Image to Patch Embedding"""
|
321 |
+
|
322 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
|
323 |
+
super().__init__()
|
324 |
+
img_size = to_2tuple(img_size)
|
325 |
+
patch_size = to_2tuple(patch_size)
|
326 |
+
num_spatial_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
|
327 |
+
num_patches = num_spatial_patches * (num_frames // tubelet_size)
|
328 |
+
|
329 |
+
self.img_size = img_size
|
330 |
+
self.tubelet_size = tubelet_size
|
331 |
+
self.patch_size = patch_size
|
332 |
+
self.num_patches = num_patches
|
333 |
+
self.proj = nn.Conv3d(
|
334 |
+
in_channels=in_chans,
|
335 |
+
out_channels=embed_dim,
|
336 |
+
kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
|
337 |
+
stride=(self.tubelet_size, patch_size[0], patch_size[1]),
|
338 |
+
)
|
339 |
+
|
340 |
+
def forward(self, x, **kwargs):
|
341 |
+
B, C, T, H, W = x.shape
|
342 |
+
assert (
|
343 |
+
H == self.img_size[0] and W == self.img_size[1]
|
344 |
+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
345 |
+
# b, c, l -> b, l, c
|
346 |
+
# [1, 1408, 8, 16, 16] -> [1, 1408, 2048] -> [1, 2048, 1408]
|
347 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
348 |
+
return x
|
349 |
+
|
350 |
+
|
351 |
+
# sin-cos position encoding
|
352 |
+
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
|
353 |
+
def get_sinusoid_encoding_table(n_position, d_hid):
|
354 |
+
"""Sinusoid position encoding table"""
|
355 |
+
|
356 |
+
# TODO: make it with torch instead of numpy
|
357 |
+
def get_position_angle_vec(position):
|
358 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
359 |
+
|
360 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
361 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
362 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
363 |
+
|
364 |
+
return torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
|
365 |
+
|
366 |
+
|
367 |
+
class VisionTransformer(nn.Module):
|
368 |
+
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
369 |
+
|
370 |
+
def __init__(
|
371 |
+
self,
|
372 |
+
img_size=224,
|
373 |
+
patch_size=16,
|
374 |
+
in_chans=3,
|
375 |
+
num_classes=1000,
|
376 |
+
embed_dim=768,
|
377 |
+
depth=12,
|
378 |
+
num_heads=12,
|
379 |
+
mlp_ratio=4.0,
|
380 |
+
qkv_bias=False,
|
381 |
+
qk_scale=None,
|
382 |
+
drop_rate=0.0,
|
383 |
+
attn_drop_rate=0.0,
|
384 |
+
drop_path_rate=0.0,
|
385 |
+
head_drop_rate=0.0,
|
386 |
+
norm_layer=nn.LayerNorm,
|
387 |
+
init_values=0.0,
|
388 |
+
use_learnable_pos_emb=False,
|
389 |
+
init_scale=0.0,
|
390 |
+
all_frames=16,
|
391 |
+
tubelet_size=2,
|
392 |
+
use_mean_pooling=True,
|
393 |
+
with_cp=False,
|
394 |
+
cos_attn=False,
|
395 |
+
):
|
396 |
+
super().__init__()
|
397 |
+
self.num_classes = num_classes
|
398 |
+
# num_features for consistency with other models
|
399 |
+
self.num_features = self.embed_dim = embed_dim
|
400 |
+
self.tubelet_size = tubelet_size
|
401 |
+
self.patch_embed = PatchEmbed(
|
402 |
+
img_size=img_size,
|
403 |
+
patch_size=patch_size,
|
404 |
+
in_chans=in_chans,
|
405 |
+
embed_dim=embed_dim,
|
406 |
+
num_frames=all_frames,
|
407 |
+
tubelet_size=tubelet_size,
|
408 |
+
)
|
409 |
+
num_patches = self.patch_embed.num_patches
|
410 |
+
self.with_cp = with_cp
|
411 |
+
|
412 |
+
if use_learnable_pos_emb:
|
413 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
414 |
+
else:
|
415 |
+
# sine-cosine positional embeddings is on the way
|
416 |
+
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
|
417 |
+
|
418 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
419 |
+
|
420 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
421 |
+
self.blocks = nn.ModuleList(
|
422 |
+
[
|
423 |
+
Block(
|
424 |
+
dim=embed_dim,
|
425 |
+
num_heads=num_heads,
|
426 |
+
mlp_ratio=mlp_ratio,
|
427 |
+
qkv_bias=qkv_bias,
|
428 |
+
qk_scale=qk_scale,
|
429 |
+
drop=drop_rate,
|
430 |
+
attn_drop=attn_drop_rate,
|
431 |
+
drop_path=dpr[i],
|
432 |
+
norm_layer=norm_layer,
|
433 |
+
init_values=init_values,
|
434 |
+
cos_attn=cos_attn,
|
435 |
+
)
|
436 |
+
for i in range(depth)
|
437 |
+
]
|
438 |
+
)
|
439 |
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
440 |
+
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
441 |
+
self.head_dropout = nn.Dropout(head_drop_rate)
|
442 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
443 |
+
|
444 |
+
if use_learnable_pos_emb:
|
445 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
446 |
+
|
447 |
+
self.apply(self._init_weights)
|
448 |
+
|
449 |
+
self.head.weight.data.mul_(init_scale)
|
450 |
+
self.head.bias.data.mul_(init_scale)
|
451 |
+
self.num_frames = all_frames
|
452 |
+
|
453 |
+
def _init_weights(self, m):
|
454 |
+
if isinstance(m, nn.Linear):
|
455 |
+
trunc_normal_(m.weight, std=0.02)
|
456 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
457 |
+
nn.init.constant_(m.bias, 0)
|
458 |
+
elif isinstance(m, nn.LayerNorm):
|
459 |
+
nn.init.constant_(m.bias, 0)
|
460 |
+
nn.init.constant_(m.weight, 1.0)
|
461 |
+
|
462 |
+
def get_num_layers(self):
|
463 |
+
return len(self.blocks)
|
464 |
+
|
465 |
+
@torch.jit.ignore
|
466 |
+
def no_weight_decay(self):
|
467 |
+
return {"pos_embed", "cls_token"}
|
468 |
+
|
469 |
+
def get_classifier(self):
|
470 |
+
return self.head
|
471 |
+
|
472 |
+
def reset_classifier(self, num_classes, global_pool=""):
|
473 |
+
self.num_classes = num_classes
|
474 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
475 |
+
|
476 |
+
def interpolate_pos_encoding(self, t):
|
477 |
+
T = 8
|
478 |
+
t0 = t // self.tubelet_size
|
479 |
+
if T == t0:
|
480 |
+
return self.pos_embed
|
481 |
+
dim = self.pos_embed.shape[-1]
|
482 |
+
patch_pos_embed = self.pos_embed.permute(0, 2, 1).reshape(1, dim, 8, 16, 16)
|
483 |
+
# we add a small number to avoid floating point error in the interpolation
|
484 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
485 |
+
t0 = t0 + 0.1
|
486 |
+
patch_pos_embed = nn.functional.interpolate(
|
487 |
+
patch_pos_embed,
|
488 |
+
scale_factor=(t0 / T, 1, 1),
|
489 |
+
mode="trilinear",
|
490 |
+
)
|
491 |
+
assert int(t0) == patch_pos_embed.shape[-3]
|
492 |
+
patch_pos_embed = patch_pos_embed.reshape(1, dim, -1).permute(0, 2, 1)
|
493 |
+
return patch_pos_embed
|
494 |
+
|
495 |
+
def forward_features(self, x):
|
496 |
+
# [1, 3, 16, 224, 224]
|
497 |
+
B = x.size(0)
|
498 |
+
T = x.size(2)
|
499 |
+
|
500 |
+
# [1, 2048, 1408]
|
501 |
+
x = self.patch_embed(x)
|
502 |
+
|
503 |
+
if self.pos_embed is not None:
|
504 |
+
x = x + self.interpolate_pos_encoding(T).expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
|
505 |
+
x = self.pos_drop(x)
|
506 |
+
|
507 |
+
for blk in self.blocks:
|
508 |
+
if self.with_cp:
|
509 |
+
x = cp.checkpoint(blk, x)
|
510 |
+
else:
|
511 |
+
x = blk(x)
|
512 |
+
|
513 |
+
# return self.fc_norm(x)
|
514 |
+
|
515 |
+
if self.fc_norm is not None:
|
516 |
+
return self.fc_norm(x.mean(1))
|
517 |
+
else:
|
518 |
+
return self.norm(x[:, 0])
|
519 |
+
|
520 |
+
def forward(self, x):
|
521 |
+
x = self.forward_features(x)
|
522 |
+
x = self.head_dropout(x)
|
523 |
+
x = self.head(x)
|
524 |
+
return x
|
525 |
+
|
526 |
+
|
527 |
+
def vit_giant_patch14_224(pretrained=False, **kwargs):
|
528 |
+
model = VisionTransformer(
|
529 |
+
patch_size=14,
|
530 |
+
embed_dim=1408,
|
531 |
+
depth=40,
|
532 |
+
num_heads=16,
|
533 |
+
mlp_ratio=48 / 11,
|
534 |
+
qkv_bias=True,
|
535 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
536 |
+
**kwargs,
|
537 |
+
)
|
538 |
+
model.default_cfg = _cfg()
|
539 |
+
return model
|
latentsync/trepa/third_party/VideoMAEv2/videomaev2_pretrain.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Based on BEiT, timm, DINO and DeiT code bases
|
3 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
4 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
5 |
+
# https://github.com/facebookresearch/deit
|
6 |
+
# https://github.com/facebookresearch/dino
|
7 |
+
# --------------------------------------------------------'
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.utils.checkpoint as cp
|
13 |
+
|
14 |
+
from .videomaev2_finetune import (
|
15 |
+
Block,
|
16 |
+
PatchEmbed,
|
17 |
+
_cfg,
|
18 |
+
get_sinusoid_encoding_table,
|
19 |
+
)
|
20 |
+
|
21 |
+
from .videomaev2_finetune import trunc_normal_ as __call_trunc_normal_
|
22 |
+
|
23 |
+
def trunc_normal_(tensor, mean=0., std=1.):
|
24 |
+
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
|
25 |
+
|
26 |
+
|
27 |
+
class PretrainVisionTransformerEncoder(nn.Module):
|
28 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
img_size=224,
|
33 |
+
patch_size=16,
|
34 |
+
in_chans=3,
|
35 |
+
num_classes=0,
|
36 |
+
embed_dim=768,
|
37 |
+
depth=12,
|
38 |
+
num_heads=12,
|
39 |
+
mlp_ratio=4.,
|
40 |
+
qkv_bias=False,
|
41 |
+
qk_scale=None,
|
42 |
+
drop_rate=0.,
|
43 |
+
attn_drop_rate=0.,
|
44 |
+
drop_path_rate=0.,
|
45 |
+
norm_layer=nn.LayerNorm,
|
46 |
+
init_values=None,
|
47 |
+
tubelet_size=2,
|
48 |
+
use_learnable_pos_emb=False,
|
49 |
+
with_cp=False,
|
50 |
+
all_frames=16,
|
51 |
+
cos_attn=False):
|
52 |
+
super().__init__()
|
53 |
+
self.num_classes = num_classes
|
54 |
+
# num_features for consistency with other models
|
55 |
+
self.num_features = self.embed_dim = embed_dim
|
56 |
+
self.patch_embed = PatchEmbed(
|
57 |
+
img_size=img_size,
|
58 |
+
patch_size=patch_size,
|
59 |
+
in_chans=in_chans,
|
60 |
+
embed_dim=embed_dim,
|
61 |
+
num_frames=all_frames,
|
62 |
+
tubelet_size=tubelet_size)
|
63 |
+
num_patches = self.patch_embed.num_patches
|
64 |
+
self.with_cp = with_cp
|
65 |
+
|
66 |
+
if use_learnable_pos_emb:
|
67 |
+
self.pos_embed = nn.Parameter(
|
68 |
+
torch.zeros(1, num_patches + 1, embed_dim))
|
69 |
+
else:
|
70 |
+
# sine-cosine positional embeddings
|
71 |
+
self.pos_embed = get_sinusoid_encoding_table(
|
72 |
+
num_patches, embed_dim)
|
73 |
+
|
74 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
75 |
+
] # stochastic depth decay rule
|
76 |
+
self.blocks = nn.ModuleList([
|
77 |
+
Block(
|
78 |
+
dim=embed_dim,
|
79 |
+
num_heads=num_heads,
|
80 |
+
mlp_ratio=mlp_ratio,
|
81 |
+
qkv_bias=qkv_bias,
|
82 |
+
qk_scale=qk_scale,
|
83 |
+
drop=drop_rate,
|
84 |
+
attn_drop=attn_drop_rate,
|
85 |
+
drop_path=dpr[i],
|
86 |
+
norm_layer=norm_layer,
|
87 |
+
init_values=init_values,
|
88 |
+
cos_attn=cos_attn) for i in range(depth)
|
89 |
+
])
|
90 |
+
self.norm = norm_layer(embed_dim)
|
91 |
+
self.head = nn.Linear(
|
92 |
+
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
93 |
+
|
94 |
+
if use_learnable_pos_emb:
|
95 |
+
trunc_normal_(self.pos_embed, std=.02)
|
96 |
+
|
97 |
+
self.apply(self._init_weights)
|
98 |
+
|
99 |
+
def _init_weights(self, m):
|
100 |
+
if isinstance(m, nn.Linear):
|
101 |
+
nn.init.xavier_uniform_(m.weight)
|
102 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
103 |
+
nn.init.constant_(m.bias, 0)
|
104 |
+
elif isinstance(m, nn.LayerNorm):
|
105 |
+
nn.init.constant_(m.bias, 0)
|
106 |
+
nn.init.constant_(m.weight, 1.0)
|
107 |
+
|
108 |
+
def get_num_layers(self):
|
109 |
+
return len(self.blocks)
|
110 |
+
|
111 |
+
@torch.jit.ignore
|
112 |
+
def no_weight_decay(self):
|
113 |
+
return {'pos_embed', 'cls_token'}
|
114 |
+
|
115 |
+
def get_classifier(self):
|
116 |
+
return self.head
|
117 |
+
|
118 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
119 |
+
self.num_classes = num_classes
|
120 |
+
self.head = nn.Linear(
|
121 |
+
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
122 |
+
|
123 |
+
def forward_features(self, x, mask):
|
124 |
+
x = self.patch_embed(x)
|
125 |
+
|
126 |
+
x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
|
127 |
+
|
128 |
+
B, _, C = x.shape
|
129 |
+
x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
|
130 |
+
|
131 |
+
for blk in self.blocks:
|
132 |
+
if self.with_cp:
|
133 |
+
x_vis = cp.checkpoint(blk, x_vis)
|
134 |
+
else:
|
135 |
+
x_vis = blk(x_vis)
|
136 |
+
|
137 |
+
x_vis = self.norm(x_vis)
|
138 |
+
return x_vis
|
139 |
+
|
140 |
+
def forward(self, x, mask):
|
141 |
+
x = self.forward_features(x, mask)
|
142 |
+
x = self.head(x)
|
143 |
+
return x
|
144 |
+
|
145 |
+
|
146 |
+
class PretrainVisionTransformerDecoder(nn.Module):
|
147 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(self,
|
151 |
+
patch_size=16,
|
152 |
+
num_classes=768,
|
153 |
+
embed_dim=768,
|
154 |
+
depth=12,
|
155 |
+
num_heads=12,
|
156 |
+
mlp_ratio=4.,
|
157 |
+
qkv_bias=False,
|
158 |
+
qk_scale=None,
|
159 |
+
drop_rate=0.,
|
160 |
+
attn_drop_rate=0.,
|
161 |
+
drop_path_rate=0.,
|
162 |
+
norm_layer=nn.LayerNorm,
|
163 |
+
init_values=None,
|
164 |
+
num_patches=196,
|
165 |
+
tubelet_size=2,
|
166 |
+
with_cp=False,
|
167 |
+
cos_attn=False):
|
168 |
+
super().__init__()
|
169 |
+
self.num_classes = num_classes
|
170 |
+
assert num_classes == 3 * tubelet_size * patch_size**2
|
171 |
+
# num_features for consistency with other models
|
172 |
+
self.num_features = self.embed_dim = embed_dim
|
173 |
+
self.patch_size = patch_size
|
174 |
+
self.with_cp = with_cp
|
175 |
+
|
176 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
177 |
+
] # stochastic depth decay rule
|
178 |
+
self.blocks = nn.ModuleList([
|
179 |
+
Block(
|
180 |
+
dim=embed_dim,
|
181 |
+
num_heads=num_heads,
|
182 |
+
mlp_ratio=mlp_ratio,
|
183 |
+
qkv_bias=qkv_bias,
|
184 |
+
qk_scale=qk_scale,
|
185 |
+
drop=drop_rate,
|
186 |
+
attn_drop=attn_drop_rate,
|
187 |
+
drop_path=dpr[i],
|
188 |
+
norm_layer=norm_layer,
|
189 |
+
init_values=init_values,
|
190 |
+
cos_attn=cos_attn) for i in range(depth)
|
191 |
+
])
|
192 |
+
self.norm = norm_layer(embed_dim)
|
193 |
+
self.head = nn.Linear(
|
194 |
+
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
195 |
+
|
196 |
+
self.apply(self._init_weights)
|
197 |
+
|
198 |
+
def _init_weights(self, m):
|
199 |
+
if isinstance(m, nn.Linear):
|
200 |
+
nn.init.xavier_uniform_(m.weight)
|
201 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
202 |
+
nn.init.constant_(m.bias, 0)
|
203 |
+
elif isinstance(m, nn.LayerNorm):
|
204 |
+
nn.init.constant_(m.bias, 0)
|
205 |
+
nn.init.constant_(m.weight, 1.0)
|
206 |
+
|
207 |
+
def get_num_layers(self):
|
208 |
+
return len(self.blocks)
|
209 |
+
|
210 |
+
@torch.jit.ignore
|
211 |
+
def no_weight_decay(self):
|
212 |
+
return {'pos_embed', 'cls_token'}
|
213 |
+
|
214 |
+
def get_classifier(self):
|
215 |
+
return self.head
|
216 |
+
|
217 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
218 |
+
self.num_classes = num_classes
|
219 |
+
self.head = nn.Linear(
|
220 |
+
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
221 |
+
|
222 |
+
def forward(self, x, return_token_num):
|
223 |
+
for blk in self.blocks:
|
224 |
+
if self.with_cp:
|
225 |
+
x = cp.checkpoint(blk, x)
|
226 |
+
else:
|
227 |
+
x = blk(x)
|
228 |
+
|
229 |
+
if return_token_num > 0:
|
230 |
+
# only return the mask tokens predict pixels
|
231 |
+
x = self.head(self.norm(x[:, -return_token_num:]))
|
232 |
+
else:
|
233 |
+
# [B, N, 3*16^2]
|
234 |
+
x = self.head(self.norm(x))
|
235 |
+
return x
|
236 |
+
|
237 |
+
|
238 |
+
class PretrainVisionTransformer(nn.Module):
|
239 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
240 |
+
"""
|
241 |
+
|
242 |
+
def __init__(
|
243 |
+
self,
|
244 |
+
img_size=224,
|
245 |
+
patch_size=16,
|
246 |
+
encoder_in_chans=3,
|
247 |
+
encoder_num_classes=0,
|
248 |
+
encoder_embed_dim=768,
|
249 |
+
encoder_depth=12,
|
250 |
+
encoder_num_heads=12,
|
251 |
+
decoder_num_classes=1536, # decoder_num_classes=768
|
252 |
+
decoder_embed_dim=512,
|
253 |
+
decoder_depth=8,
|
254 |
+
decoder_num_heads=8,
|
255 |
+
mlp_ratio=4.,
|
256 |
+
qkv_bias=False,
|
257 |
+
qk_scale=None,
|
258 |
+
drop_rate=0.,
|
259 |
+
attn_drop_rate=0.,
|
260 |
+
drop_path_rate=0.,
|
261 |
+
norm_layer=nn.LayerNorm,
|
262 |
+
init_values=0.,
|
263 |
+
use_learnable_pos_emb=False,
|
264 |
+
tubelet_size=2,
|
265 |
+
num_classes=0, # avoid the error from create_fn in timm
|
266 |
+
in_chans=0, # avoid the error from create_fn in timm
|
267 |
+
with_cp=False,
|
268 |
+
all_frames=16,
|
269 |
+
cos_attn=False,
|
270 |
+
):
|
271 |
+
super().__init__()
|
272 |
+
self.encoder = PretrainVisionTransformerEncoder(
|
273 |
+
img_size=img_size,
|
274 |
+
patch_size=patch_size,
|
275 |
+
in_chans=encoder_in_chans,
|
276 |
+
num_classes=encoder_num_classes,
|
277 |
+
embed_dim=encoder_embed_dim,
|
278 |
+
depth=encoder_depth,
|
279 |
+
num_heads=encoder_num_heads,
|
280 |
+
mlp_ratio=mlp_ratio,
|
281 |
+
qkv_bias=qkv_bias,
|
282 |
+
qk_scale=qk_scale,
|
283 |
+
drop_rate=drop_rate,
|
284 |
+
attn_drop_rate=attn_drop_rate,
|
285 |
+
drop_path_rate=drop_path_rate,
|
286 |
+
norm_layer=norm_layer,
|
287 |
+
init_values=init_values,
|
288 |
+
tubelet_size=tubelet_size,
|
289 |
+
use_learnable_pos_emb=use_learnable_pos_emb,
|
290 |
+
with_cp=with_cp,
|
291 |
+
all_frames=all_frames,
|
292 |
+
cos_attn=cos_attn)
|
293 |
+
|
294 |
+
self.decoder = PretrainVisionTransformerDecoder(
|
295 |
+
patch_size=patch_size,
|
296 |
+
num_patches=self.encoder.patch_embed.num_patches,
|
297 |
+
num_classes=decoder_num_classes,
|
298 |
+
embed_dim=decoder_embed_dim,
|
299 |
+
depth=decoder_depth,
|
300 |
+
num_heads=decoder_num_heads,
|
301 |
+
mlp_ratio=mlp_ratio,
|
302 |
+
qkv_bias=qkv_bias,
|
303 |
+
qk_scale=qk_scale,
|
304 |
+
drop_rate=drop_rate,
|
305 |
+
attn_drop_rate=attn_drop_rate,
|
306 |
+
drop_path_rate=drop_path_rate,
|
307 |
+
norm_layer=norm_layer,
|
308 |
+
init_values=init_values,
|
309 |
+
tubelet_size=tubelet_size,
|
310 |
+
with_cp=with_cp,
|
311 |
+
cos_attn=cos_attn)
|
312 |
+
|
313 |
+
self.encoder_to_decoder = nn.Linear(
|
314 |
+
encoder_embed_dim, decoder_embed_dim, bias=False)
|
315 |
+
|
316 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
317 |
+
|
318 |
+
self.pos_embed = get_sinusoid_encoding_table(
|
319 |
+
self.encoder.patch_embed.num_patches, decoder_embed_dim)
|
320 |
+
|
321 |
+
trunc_normal_(self.mask_token, std=.02)
|
322 |
+
|
323 |
+
def _init_weights(self, m):
|
324 |
+
if isinstance(m, nn.Linear):
|
325 |
+
nn.init.xavier_uniform_(m.weight)
|
326 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
327 |
+
nn.init.constant_(m.bias, 0)
|
328 |
+
elif isinstance(m, nn.LayerNorm):
|
329 |
+
nn.init.constant_(m.bias, 0)
|
330 |
+
nn.init.constant_(m.weight, 1.0)
|
331 |
+
|
332 |
+
def get_num_layers(self):
|
333 |
+
return len(self.blocks)
|
334 |
+
|
335 |
+
@torch.jit.ignore
|
336 |
+
def no_weight_decay(self):
|
337 |
+
return {'pos_embed', 'cls_token', 'mask_token'}
|
338 |
+
|
339 |
+
def forward(self, x, mask, decode_mask=None):
|
340 |
+
decode_vis = mask if decode_mask is None else ~decode_mask
|
341 |
+
|
342 |
+
x_vis = self.encoder(x, mask) # [B, N_vis, C_e]
|
343 |
+
x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
|
344 |
+
B, N_vis, C = x_vis.shape
|
345 |
+
|
346 |
+
# we don't unshuffle the correct visible token order,
|
347 |
+
# but shuffle the pos embedding accorddingly.
|
348 |
+
expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(
|
349 |
+
x.device).clone().detach()
|
350 |
+
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
|
351 |
+
pos_emd_mask = expand_pos_embed[decode_vis].reshape(B, -1, C)
|
352 |
+
|
353 |
+
# [B, N, C_d]
|
354 |
+
x_full = torch.cat(
|
355 |
+
[x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)
|
356 |
+
# NOTE: if N_mask==0, the shape of x is [B, N_mask, 3 * 16 * 16]
|
357 |
+
x = self.decoder(x_full, pos_emd_mask.shape[1])
|
358 |
+
|
359 |
+
return x
|
360 |
+
|
361 |
+
|
362 |
+
def pretrain_videomae_small_patch16_224(pretrained=False, **kwargs):
|
363 |
+
model = PretrainVisionTransformer(
|
364 |
+
img_size=224,
|
365 |
+
patch_size=16,
|
366 |
+
encoder_embed_dim=384,
|
367 |
+
encoder_depth=12,
|
368 |
+
encoder_num_heads=6,
|
369 |
+
encoder_num_classes=0,
|
370 |
+
decoder_num_classes=1536, # 16 * 16 * 3 * 2
|
371 |
+
decoder_embed_dim=192,
|
372 |
+
decoder_num_heads=3,
|
373 |
+
mlp_ratio=4,
|
374 |
+
qkv_bias=True,
|
375 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
376 |
+
**kwargs)
|
377 |
+
model.default_cfg = _cfg()
|
378 |
+
if pretrained:
|
379 |
+
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
|
380 |
+
model.load_state_dict(checkpoint["model"])
|
381 |
+
return model
|
382 |
+
|
383 |
+
|
384 |
+
def pretrain_videomae_base_patch16_224(pretrained=False, **kwargs):
|
385 |
+
model = PretrainVisionTransformer(
|
386 |
+
img_size=224,
|
387 |
+
patch_size=16,
|
388 |
+
encoder_embed_dim=768,
|
389 |
+
encoder_depth=12,
|
390 |
+
encoder_num_heads=12,
|
391 |
+
encoder_num_classes=0,
|
392 |
+
decoder_num_classes=1536, # 16 * 16 * 3 * 2
|
393 |
+
decoder_embed_dim=384,
|
394 |
+
decoder_num_heads=6,
|
395 |
+
mlp_ratio=4,
|
396 |
+
qkv_bias=True,
|
397 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
398 |
+
**kwargs)
|
399 |
+
model.default_cfg = _cfg()
|
400 |
+
if pretrained:
|
401 |
+
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
|
402 |
+
model.load_state_dict(checkpoint["model"])
|
403 |
+
return model
|
404 |
+
|
405 |
+
|
406 |
+
def pretrain_videomae_large_patch16_224(pretrained=False, **kwargs):
|
407 |
+
model = PretrainVisionTransformer(
|
408 |
+
img_size=224,
|
409 |
+
patch_size=16,
|
410 |
+
encoder_embed_dim=1024,
|
411 |
+
encoder_depth=24,
|
412 |
+
encoder_num_heads=16,
|
413 |
+
encoder_num_classes=0,
|
414 |
+
decoder_num_classes=1536, # 16 * 16 * 3 * 2
|
415 |
+
decoder_embed_dim=512,
|
416 |
+
decoder_num_heads=8,
|
417 |
+
mlp_ratio=4,
|
418 |
+
qkv_bias=True,
|
419 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
420 |
+
**kwargs)
|
421 |
+
model.default_cfg = _cfg()
|
422 |
+
if pretrained:
|
423 |
+
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
|
424 |
+
model.load_state_dict(checkpoint["model"])
|
425 |
+
return model
|
426 |
+
|
427 |
+
|
428 |
+
def pretrain_videomae_huge_patch16_224(pretrained=False, **kwargs):
|
429 |
+
model = PretrainVisionTransformer(
|
430 |
+
img_size=224,
|
431 |
+
patch_size=16,
|
432 |
+
encoder_embed_dim=1280,
|
433 |
+
encoder_depth=32,
|
434 |
+
encoder_num_heads=16,
|
435 |
+
encoder_num_classes=0,
|
436 |
+
decoder_num_classes=1536, # 16 * 16 * 3 * 2
|
437 |
+
decoder_embed_dim=512,
|
438 |
+
decoder_num_heads=8,
|
439 |
+
mlp_ratio=4,
|
440 |
+
qkv_bias=True,
|
441 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
442 |
+
**kwargs)
|
443 |
+
model.default_cfg = _cfg()
|
444 |
+
if pretrained:
|
445 |
+
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
|
446 |
+
model.load_state_dict(checkpoint["model"])
|
447 |
+
return model
|
448 |
+
|
449 |
+
|
450 |
+
def pretrain_videomae_giant_patch14_224(pretrained=False, **kwargs):
|
451 |
+
model = PretrainVisionTransformer(
|
452 |
+
img_size=224,
|
453 |
+
patch_size=14,
|
454 |
+
encoder_embed_dim=1408,
|
455 |
+
encoder_depth=40,
|
456 |
+
encoder_num_heads=16,
|
457 |
+
encoder_num_classes=0,
|
458 |
+
decoder_num_classes=1176, # 14 * 14 * 3 * 2,
|
459 |
+
decoder_embed_dim=512,
|
460 |
+
decoder_num_heads=8,
|
461 |
+
mlp_ratio=48 / 11,
|
462 |
+
qkv_bias=True,
|
463 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
464 |
+
**kwargs)
|
465 |
+
model.default_cfg = _cfg()
|
466 |
+
if pretrained:
|
467 |
+
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
|
468 |
+
model.load_state_dict(checkpoint["model"])
|
469 |
+
return model
|
latentsync/trepa/third_party/__init__.py
ADDED
File without changes
|
latentsync/trepa/utils/__init__.py
ADDED
File without changes
|
latentsync/trepa/utils/data_utils.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import os.path as osp
|
4 |
+
import random
|
5 |
+
import pickle
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import glob
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.utils.data as data
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torch.distributed as dist
|
16 |
+
from torchvision.datasets.video_utils import VideoClips
|
17 |
+
|
18 |
+
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
|
19 |
+
VID_EXTENSIONS = ['.avi', '.mp4', '.webm', '.mov', '.mkv', '.m4v']
|
20 |
+
|
21 |
+
|
22 |
+
def get_dataloader(data_path, image_folder, resolution=128, sequence_length=16, sample_every_n_frames=1,
|
23 |
+
batch_size=16, num_workers=8):
|
24 |
+
data = VideoData(data_path, image_folder, resolution, sequence_length, sample_every_n_frames, batch_size, num_workers)
|
25 |
+
loader = data._dataloader()
|
26 |
+
return loader
|
27 |
+
|
28 |
+
|
29 |
+
def is_image_file(filename):
|
30 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
31 |
+
|
32 |
+
|
33 |
+
def get_parent_dir(path):
|
34 |
+
return osp.basename(osp.dirname(path))
|
35 |
+
|
36 |
+
|
37 |
+
def preprocess(video, resolution, sequence_length=None, in_channels=3, sample_every_n_frames=1):
|
38 |
+
# video: THWC, {0, ..., 255}
|
39 |
+
assert in_channels == 3
|
40 |
+
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
|
41 |
+
t, c, h, w = video.shape
|
42 |
+
|
43 |
+
# temporal crop
|
44 |
+
if sequence_length is not None:
|
45 |
+
assert sequence_length <= t
|
46 |
+
video = video[:sequence_length]
|
47 |
+
|
48 |
+
# skip frames
|
49 |
+
if sample_every_n_frames > 1:
|
50 |
+
video = video[::sample_every_n_frames]
|
51 |
+
|
52 |
+
# scale shorter side to resolution
|
53 |
+
scale = resolution / min(h, w)
|
54 |
+
if h < w:
|
55 |
+
target_size = (resolution, math.ceil(w * scale))
|
56 |
+
else:
|
57 |
+
target_size = (math.ceil(h * scale), resolution)
|
58 |
+
video = F.interpolate(video, size=target_size, mode='bilinear',
|
59 |
+
align_corners=False, antialias=True)
|
60 |
+
|
61 |
+
# center crop
|
62 |
+
t, c, h, w = video.shape
|
63 |
+
w_start = (w - resolution) // 2
|
64 |
+
h_start = (h - resolution) // 2
|
65 |
+
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
|
66 |
+
video = video.permute(1, 0, 2, 3).contiguous() # CTHW
|
67 |
+
|
68 |
+
return {'video': video}
|
69 |
+
|
70 |
+
|
71 |
+
def preprocess_image(image):
|
72 |
+
# [0, 1] => [-1, 1]
|
73 |
+
img = torch.from_numpy(image)
|
74 |
+
return img
|
75 |
+
|
76 |
+
|
77 |
+
class VideoData(data.Dataset):
|
78 |
+
""" Class to create dataloaders for video datasets
|
79 |
+
|
80 |
+
Args:
|
81 |
+
data_path: Path to the folder with video frames or videos.
|
82 |
+
image_folder: If True, the data is stored as images in folders.
|
83 |
+
resolution: Resolution of the returned videos.
|
84 |
+
sequence_length: Length of extracted video sequences.
|
85 |
+
sample_every_n_frames: Sample every n frames from the video.
|
86 |
+
batch_size: Batch size.
|
87 |
+
num_workers: Number of workers for the dataloader.
|
88 |
+
shuffle: If True, shuffle the data.
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(self, data_path: str, image_folder: bool, resolution: int, sequence_length: int,
|
92 |
+
sample_every_n_frames: int, batch_size: int, num_workers: int, shuffle: bool = True):
|
93 |
+
super().__init__()
|
94 |
+
self.data_path = data_path
|
95 |
+
self.image_folder = image_folder
|
96 |
+
self.resolution = resolution
|
97 |
+
self.sequence_length = sequence_length
|
98 |
+
self.sample_every_n_frames = sample_every_n_frames
|
99 |
+
self.batch_size = batch_size
|
100 |
+
self.num_workers = num_workers
|
101 |
+
self.shuffle = shuffle
|
102 |
+
|
103 |
+
def _dataset(self):
|
104 |
+
'''
|
105 |
+
Initializes and return the dataset.
|
106 |
+
'''
|
107 |
+
if self.image_folder:
|
108 |
+
Dataset = FrameDataset
|
109 |
+
dataset = Dataset(self.data_path, self.sequence_length,
|
110 |
+
resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames)
|
111 |
+
else:
|
112 |
+
Dataset = VideoDataset
|
113 |
+
dataset = Dataset(self.data_path, self.sequence_length,
|
114 |
+
resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames)
|
115 |
+
return dataset
|
116 |
+
|
117 |
+
def _dataloader(self):
|
118 |
+
'''
|
119 |
+
Initializes and returns the dataloader.
|
120 |
+
'''
|
121 |
+
dataset = self._dataset()
|
122 |
+
if dist.is_initialized():
|
123 |
+
sampler = data.distributed.DistributedSampler(
|
124 |
+
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
sampler = None
|
128 |
+
dataloader = data.DataLoader(
|
129 |
+
dataset,
|
130 |
+
batch_size=self.batch_size,
|
131 |
+
num_workers=self.num_workers,
|
132 |
+
pin_memory=True,
|
133 |
+
sampler=sampler,
|
134 |
+
shuffle=sampler is None and self.shuffle is True
|
135 |
+
)
|
136 |
+
return dataloader
|
137 |
+
|
138 |
+
|
139 |
+
class VideoDataset(data.Dataset):
|
140 |
+
"""
|
141 |
+
Generic dataset for videos files stored in folders.
|
142 |
+
Videos of the same class are expected to be stored in a single folder. Multiple folders can exist in the provided directory.
|
143 |
+
The class depends on `torchvision.datasets.video_utils.VideoClips` to load the videos.
|
144 |
+
Returns BCTHW videos in the range [0, 1].
|
145 |
+
|
146 |
+
Args:
|
147 |
+
data_folder: Path to the folder with corresponding videos stored.
|
148 |
+
sequence_length: Length of extracted video sequences.
|
149 |
+
resolution: Resolution of the returned videos.
|
150 |
+
sample_every_n_frames: Sample every n frames from the video.
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(self, data_folder: str, sequence_length: int = 16, resolution: int = 128, sample_every_n_frames: int = 1):
|
154 |
+
super().__init__()
|
155 |
+
self.sequence_length = sequence_length
|
156 |
+
self.resolution = resolution
|
157 |
+
self.sample_every_n_frames = sample_every_n_frames
|
158 |
+
|
159 |
+
folder = data_folder
|
160 |
+
files = sum([glob.glob(osp.join(folder, '**', f'*{ext}'), recursive=True)
|
161 |
+
for ext in VID_EXTENSIONS], [])
|
162 |
+
|
163 |
+
warnings.filterwarnings('ignore')
|
164 |
+
cache_file = osp.join(folder, f"metadata_{sequence_length}.pkl")
|
165 |
+
if not osp.exists(cache_file):
|
166 |
+
clips = VideoClips(files, sequence_length, num_workers=4)
|
167 |
+
try:
|
168 |
+
pickle.dump(clips.metadata, open(cache_file, 'wb'))
|
169 |
+
except:
|
170 |
+
print(f"Failed to save metadata to {cache_file}")
|
171 |
+
else:
|
172 |
+
metadata = pickle.load(open(cache_file, 'rb'))
|
173 |
+
clips = VideoClips(files, sequence_length,
|
174 |
+
_precomputed_metadata=metadata)
|
175 |
+
|
176 |
+
self._clips = clips
|
177 |
+
# instead of uniformly sampling from all possible clips, we sample uniformly from all possible videos
|
178 |
+
self._clips.get_clip_location = self.get_random_clip_from_video
|
179 |
+
|
180 |
+
def get_random_clip_from_video(self, idx: int) -> tuple:
|
181 |
+
'''
|
182 |
+
Sample a random clip starting index from the video.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
idx: Index of the video.
|
186 |
+
'''
|
187 |
+
# Note that some videos may not contain enough frames, we skip those videos here.
|
188 |
+
while self._clips.clips[idx].shape[0] <= 0:
|
189 |
+
idx += 1
|
190 |
+
n_clip = self._clips.clips[idx].shape[0]
|
191 |
+
clip_id = random.randint(0, n_clip - 1)
|
192 |
+
return idx, clip_id
|
193 |
+
|
194 |
+
def __len__(self):
|
195 |
+
return self._clips.num_videos()
|
196 |
+
|
197 |
+
def __getitem__(self, idx):
|
198 |
+
resolution = self.resolution
|
199 |
+
while True:
|
200 |
+
try:
|
201 |
+
video, _, _, idx = self._clips.get_clip(idx)
|
202 |
+
except Exception as e:
|
203 |
+
print(idx, e)
|
204 |
+
idx = (idx + 1) % self._clips.num_clips()
|
205 |
+
continue
|
206 |
+
break
|
207 |
+
|
208 |
+
return dict(**preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames))
|
209 |
+
|
210 |
+
|
211 |
+
class FrameDataset(data.Dataset):
|
212 |
+
"""
|
213 |
+
Generic dataset for videos stored as images. The loading will iterates over all the folders and subfolders
|
214 |
+
in the provided directory. Each leaf folder is assumed to contain frames from a single video.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
data_folder: path to the folder with video frames. The folder
|
218 |
+
should contain folders with frames from each video.
|
219 |
+
sequence_length: length of extracted video sequences
|
220 |
+
resolution: resolution of the returned videos
|
221 |
+
sample_every_n_frames: sample every n frames from the video
|
222 |
+
"""
|
223 |
+
|
224 |
+
def __init__(self, data_folder, sequence_length, resolution=64, sample_every_n_frames=1):
|
225 |
+
self.resolution = resolution
|
226 |
+
self.sequence_length = sequence_length
|
227 |
+
self.sample_every_n_frames = sample_every_n_frames
|
228 |
+
self.data_all = self.load_video_frames(data_folder)
|
229 |
+
self.video_num = len(self.data_all)
|
230 |
+
|
231 |
+
def __getitem__(self, index):
|
232 |
+
batch_data = self.getTensor(index)
|
233 |
+
return_list = {'video': batch_data}
|
234 |
+
|
235 |
+
return return_list
|
236 |
+
|
237 |
+
def load_video_frames(self, dataroot: str) -> list:
|
238 |
+
'''
|
239 |
+
Loads all the video frames under the dataroot and returns a list of all the video frames.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
dataroot: The root directory containing the video frames.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
A list of all the video frames.
|
246 |
+
|
247 |
+
'''
|
248 |
+
data_all = []
|
249 |
+
frame_list = os.walk(dataroot)
|
250 |
+
for _, meta in enumerate(frame_list):
|
251 |
+
root = meta[0]
|
252 |
+
try:
|
253 |
+
frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
|
254 |
+
except:
|
255 |
+
print(meta[0], meta[2])
|
256 |
+
if len(frames) < max(0, self.sequence_length * self.sample_every_n_frames):
|
257 |
+
continue
|
258 |
+
frames = [
|
259 |
+
os.path.join(root, item) for item in frames
|
260 |
+
if is_image_file(item)
|
261 |
+
]
|
262 |
+
if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames):
|
263 |
+
data_all.append(frames)
|
264 |
+
|
265 |
+
return data_all
|
266 |
+
|
267 |
+
def getTensor(self, index: int) -> torch.Tensor:
|
268 |
+
'''
|
269 |
+
Returns a tensor of the video frames at the given index.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
index: The index of the video frames to return.
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
A BCTHW tensor in the range `[0, 1]` of the video frames at the given index.
|
276 |
+
|
277 |
+
'''
|
278 |
+
video = self.data_all[index]
|
279 |
+
video_len = len(video)
|
280 |
+
|
281 |
+
# load the entire video when sequence_length = -1, whiel the sample_every_n_frames has to be 1
|
282 |
+
if self.sequence_length == -1:
|
283 |
+
assert self.sample_every_n_frames == 1
|
284 |
+
start_idx = 0
|
285 |
+
end_idx = video_len
|
286 |
+
else:
|
287 |
+
n_frames_interval = self.sequence_length * self.sample_every_n_frames
|
288 |
+
start_idx = random.randint(0, video_len - n_frames_interval)
|
289 |
+
end_idx = start_idx + n_frames_interval
|
290 |
+
img = Image.open(video[0])
|
291 |
+
h, w = img.height, img.width
|
292 |
+
|
293 |
+
if h > w:
|
294 |
+
half = (h - w) // 2
|
295 |
+
cropsize = (0, half, w, half + w) # left, upper, right, lower
|
296 |
+
elif w > h:
|
297 |
+
half = (w - h) // 2
|
298 |
+
cropsize = (half, 0, half + h, h)
|
299 |
+
|
300 |
+
images = []
|
301 |
+
for i in range(start_idx, end_idx,
|
302 |
+
self.sample_every_n_frames):
|
303 |
+
path = video[i]
|
304 |
+
img = Image.open(path)
|
305 |
+
|
306 |
+
if h != w:
|
307 |
+
img = img.crop(cropsize)
|
308 |
+
|
309 |
+
img = img.resize(
|
310 |
+
(self.resolution, self.resolution),
|
311 |
+
Image.ANTIALIAS)
|
312 |
+
img = np.asarray(img, dtype=np.float32)
|
313 |
+
img /= 255.
|
314 |
+
img_tensor = preprocess_image(img).unsqueeze(0)
|
315 |
+
images.append(img_tensor)
|
316 |
+
|
317 |
+
video_clip = torch.cat(images).permute(3, 0, 1, 2)
|
318 |
+
return video_clip
|
319 |
+
|
320 |
+
def __len__(self):
|
321 |
+
return self.video_num
|
latentsync/trepa/utils/metric_utils.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/universome/stylegan-v/blob/master/src/metrics/metric_utils.py
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import pickle
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from typing import List, Tuple
|
9 |
+
|
10 |
+
def seed_everything(seed):
|
11 |
+
random.seed(seed)
|
12 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
13 |
+
np.random.seed(seed)
|
14 |
+
torch.manual_seed(seed)
|
15 |
+
torch.cuda.manual_seed(seed)
|
16 |
+
|
17 |
+
|
18 |
+
class FeatureStats:
|
19 |
+
'''
|
20 |
+
Class to store statistics of features, including all features and mean/covariance.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
capture_all: Whether to store all the features.
|
24 |
+
capture_mean_cov: Whether to store mean and covariance.
|
25 |
+
max_items: Maximum number of items to store.
|
26 |
+
'''
|
27 |
+
def __init__(self, capture_all: bool = False, capture_mean_cov: bool = False, max_items: int = None):
|
28 |
+
'''
|
29 |
+
'''
|
30 |
+
self.capture_all = capture_all
|
31 |
+
self.capture_mean_cov = capture_mean_cov
|
32 |
+
self.max_items = max_items
|
33 |
+
self.num_items = 0
|
34 |
+
self.num_features = None
|
35 |
+
self.all_features = None
|
36 |
+
self.raw_mean = None
|
37 |
+
self.raw_cov = None
|
38 |
+
|
39 |
+
def set_num_features(self, num_features: int):
|
40 |
+
'''
|
41 |
+
Set the number of features diminsions.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
num_features: Number of features diminsions.
|
45 |
+
'''
|
46 |
+
if self.num_features is not None:
|
47 |
+
assert num_features == self.num_features
|
48 |
+
else:
|
49 |
+
self.num_features = num_features
|
50 |
+
self.all_features = []
|
51 |
+
self.raw_mean = np.zeros([num_features], dtype=np.float64)
|
52 |
+
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
|
53 |
+
|
54 |
+
def is_full(self) -> bool:
|
55 |
+
'''
|
56 |
+
Check if the maximum number of samples is reached.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
True if the storage is full, False otherwise.
|
60 |
+
'''
|
61 |
+
return (self.max_items is not None) and (self.num_items >= self.max_items)
|
62 |
+
|
63 |
+
def append(self, x: np.ndarray):
|
64 |
+
'''
|
65 |
+
Add the newly computed features to the list. Update the mean and covariance.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
x: New features to record.
|
69 |
+
'''
|
70 |
+
x = np.asarray(x, dtype=np.float32)
|
71 |
+
assert x.ndim == 2
|
72 |
+
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
|
73 |
+
if self.num_items >= self.max_items:
|
74 |
+
return
|
75 |
+
x = x[:self.max_items - self.num_items]
|
76 |
+
|
77 |
+
self.set_num_features(x.shape[1])
|
78 |
+
self.num_items += x.shape[0]
|
79 |
+
if self.capture_all:
|
80 |
+
self.all_features.append(x)
|
81 |
+
if self.capture_mean_cov:
|
82 |
+
x64 = x.astype(np.float64)
|
83 |
+
self.raw_mean += x64.sum(axis=0)
|
84 |
+
self.raw_cov += x64.T @ x64
|
85 |
+
|
86 |
+
def append_torch(self, x: torch.Tensor, rank: int, num_gpus: int):
|
87 |
+
'''
|
88 |
+
Add the newly computed PyTorch features to the list. Update the mean and covariance.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
x: New features to record.
|
92 |
+
rank: Rank of the current GPU.
|
93 |
+
num_gpus: Total number of GPUs.
|
94 |
+
'''
|
95 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 2
|
96 |
+
assert 0 <= rank < num_gpus
|
97 |
+
if num_gpus > 1:
|
98 |
+
ys = []
|
99 |
+
for src in range(num_gpus):
|
100 |
+
y = x.clone()
|
101 |
+
torch.distributed.broadcast(y, src=src)
|
102 |
+
ys.append(y)
|
103 |
+
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
|
104 |
+
self.append(x.cpu().numpy())
|
105 |
+
|
106 |
+
def get_all(self) -> np.ndarray:
|
107 |
+
'''
|
108 |
+
Get all the stored features as NumPy Array.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
Concatenation of the stored features.
|
112 |
+
'''
|
113 |
+
assert self.capture_all
|
114 |
+
return np.concatenate(self.all_features, axis=0)
|
115 |
+
|
116 |
+
def get_all_torch(self) -> torch.Tensor:
|
117 |
+
'''
|
118 |
+
Get all the stored features as PyTorch Tensor.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Concatenation of the stored features.
|
122 |
+
'''
|
123 |
+
return torch.from_numpy(self.get_all())
|
124 |
+
|
125 |
+
def get_mean_cov(self) -> Tuple[np.ndarray, np.ndarray]:
|
126 |
+
'''
|
127 |
+
Get the mean and covariance of the stored features.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
Mean and covariance of the stored features.
|
131 |
+
'''
|
132 |
+
assert self.capture_mean_cov
|
133 |
+
mean = self.raw_mean / self.num_items
|
134 |
+
cov = self.raw_cov / self.num_items
|
135 |
+
cov = cov - np.outer(mean, mean)
|
136 |
+
return mean, cov
|
137 |
+
|
138 |
+
def save(self, pkl_file: str):
|
139 |
+
'''
|
140 |
+
Save the features and statistics to a pickle file.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
pkl_file: Path to the pickle file.
|
144 |
+
'''
|
145 |
+
with open(pkl_file, 'wb') as f:
|
146 |
+
pickle.dump(self.__dict__, f)
|
147 |
+
|
148 |
+
@staticmethod
|
149 |
+
def load(pkl_file: str) -> 'FeatureStats':
|
150 |
+
'''
|
151 |
+
Load the features and statistics from a pickle file.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
pkl_file: Path to the pickle file.
|
155 |
+
'''
|
156 |
+
with open(pkl_file, 'rb') as f:
|
157 |
+
s = pickle.load(f)
|
158 |
+
obj = FeatureStats(capture_all=s['capture_all'], max_items=s['max_items'])
|
159 |
+
obj.__dict__.update(s)
|
160 |
+
print('Loaded %d features from %s' % (obj.num_items, pkl_file))
|
161 |
+
return obj
|
latentsync/utils/affine_transform.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/guanjz20/StyleSync/blob/main/utils.py
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
def transformation_from_points(points1, points0, smooth=True, p_bias=None):
|
8 |
+
points2 = np.array(points0)
|
9 |
+
points2 = points2.astype(np.float64)
|
10 |
+
points1 = points1.astype(np.float64)
|
11 |
+
c1 = np.mean(points1, axis=0)
|
12 |
+
c2 = np.mean(points2, axis=0)
|
13 |
+
points1 -= c1
|
14 |
+
points2 -= c2
|
15 |
+
s1 = np.std(points1)
|
16 |
+
s2 = np.std(points2)
|
17 |
+
points1 /= s1
|
18 |
+
points2 /= s2
|
19 |
+
U, S, Vt = np.linalg.svd(np.matmul(points1.T, points2))
|
20 |
+
R = (np.matmul(U, Vt)).T
|
21 |
+
sR = (s2 / s1) * R
|
22 |
+
T = c2.reshape(2, 1) - (s2 / s1) * np.matmul(R, c1.reshape(2, 1))
|
23 |
+
M = np.concatenate((sR, T), axis=1)
|
24 |
+
if smooth:
|
25 |
+
bias = points2[2] - points1[2]
|
26 |
+
if p_bias is None:
|
27 |
+
p_bias = bias
|
28 |
+
else:
|
29 |
+
bias = p_bias * 0.2 + bias * 0.8
|
30 |
+
p_bias = bias
|
31 |
+
M[:, 2] = M[:, 2] + bias
|
32 |
+
return M, p_bias
|
33 |
+
|
34 |
+
|
35 |
+
class AlignRestore(object):
|
36 |
+
def __init__(self, align_points=3):
|
37 |
+
if align_points == 3:
|
38 |
+
self.upscale_factor = 1
|
39 |
+
self.crop_ratio = (2.8, 2.8)
|
40 |
+
self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]])
|
41 |
+
self.face_template = self.face_template * 2.8
|
42 |
+
# self.face_size = (int(100 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
|
43 |
+
self.face_size = (int(75 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
|
44 |
+
self.p_bias = None
|
45 |
+
|
46 |
+
def process(self, img, lmk_align=None, smooth=True, align_points=3):
|
47 |
+
aligned_face, affine_matrix = self.align_warp_face(img, lmk_align, smooth)
|
48 |
+
restored_img = self.restore_img(img, aligned_face, affine_matrix)
|
49 |
+
cv2.imwrite("restored.jpg", restored_img)
|
50 |
+
cv2.imwrite("aligned.jpg", aligned_face)
|
51 |
+
return aligned_face, restored_img
|
52 |
+
|
53 |
+
def align_warp_face(self, img, lmks3, smooth=True, border_mode="constant"):
|
54 |
+
affine_matrix, self.p_bias = transformation_from_points(lmks3, self.face_template, smooth, self.p_bias)
|
55 |
+
if border_mode == "constant":
|
56 |
+
border_mode = cv2.BORDER_CONSTANT
|
57 |
+
elif border_mode == "reflect101":
|
58 |
+
border_mode = cv2.BORDER_REFLECT101
|
59 |
+
elif border_mode == "reflect":
|
60 |
+
border_mode = cv2.BORDER_REFLECT
|
61 |
+
cropped_face = cv2.warpAffine(
|
62 |
+
img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=[127, 127, 127]
|
63 |
+
)
|
64 |
+
return cropped_face, affine_matrix
|
65 |
+
|
66 |
+
def align_warp_face2(self, img, landmark, border_mode="constant"):
|
67 |
+
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template)[0]
|
68 |
+
if border_mode == "constant":
|
69 |
+
border_mode = cv2.BORDER_CONSTANT
|
70 |
+
elif border_mode == "reflect101":
|
71 |
+
border_mode = cv2.BORDER_REFLECT101
|
72 |
+
elif border_mode == "reflect":
|
73 |
+
border_mode = cv2.BORDER_REFLECT
|
74 |
+
cropped_face = cv2.warpAffine(
|
75 |
+
img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)
|
76 |
+
)
|
77 |
+
return cropped_face, affine_matrix
|
78 |
+
|
79 |
+
def restore_img(self, input_img, face, affine_matrix):
|
80 |
+
h, w, _ = input_img.shape
|
81 |
+
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
82 |
+
upsample_img = cv2.resize(input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
83 |
+
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
84 |
+
inverse_affine *= self.upscale_factor
|
85 |
+
if self.upscale_factor > 1:
|
86 |
+
extra_offset = 0.5 * self.upscale_factor
|
87 |
+
else:
|
88 |
+
extra_offset = 0
|
89 |
+
inverse_affine[:, 2] += extra_offset
|
90 |
+
inv_restored = cv2.warpAffine(face, inverse_affine, (w_up, h_up))
|
91 |
+
mask = np.ones((self.face_size[1], self.face_size[0]), dtype=np.float32)
|
92 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
93 |
+
inv_mask_erosion = cv2.erode(
|
94 |
+
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)
|
95 |
+
)
|
96 |
+
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
97 |
+
total_face_area = np.sum(inv_mask_erosion)
|
98 |
+
w_edge = int(total_face_area**0.5) // 20
|
99 |
+
erosion_radius = w_edge * 2
|
100 |
+
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
101 |
+
blur_size = w_edge * 2
|
102 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
103 |
+
inv_soft_mask = inv_soft_mask[:, :, None]
|
104 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
105 |
+
if np.max(upsample_img) > 256:
|
106 |
+
upsample_img = upsample_img.astype(np.uint16)
|
107 |
+
else:
|
108 |
+
upsample_img = upsample_img.astype(np.uint8)
|
109 |
+
return upsample_img
|
110 |
+
|
111 |
+
|
112 |
+
class laplacianSmooth:
|
113 |
+
def __init__(self, smoothAlpha=0.3):
|
114 |
+
self.smoothAlpha = smoothAlpha
|
115 |
+
self.pts_last = None
|
116 |
+
|
117 |
+
def smooth(self, pts_cur):
|
118 |
+
if self.pts_last is None:
|
119 |
+
self.pts_last = pts_cur.copy()
|
120 |
+
return pts_cur.copy()
|
121 |
+
x1 = min(pts_cur[:, 0])
|
122 |
+
x2 = max(pts_cur[:, 0])
|
123 |
+
y1 = min(pts_cur[:, 1])
|
124 |
+
y2 = max(pts_cur[:, 1])
|
125 |
+
width = x2 - x1
|
126 |
+
pts_update = []
|
127 |
+
for i in range(len(pts_cur)):
|
128 |
+
x_new, y_new = pts_cur[i]
|
129 |
+
x_old, y_old = self.pts_last[i]
|
130 |
+
tmp = (x_new - x_old) ** 2 + (y_new - y_old) ** 2
|
131 |
+
w = np.exp(-tmp / (width * self.smoothAlpha))
|
132 |
+
x = x_old * w + x_new * (1 - w)
|
133 |
+
y = y_old * w + y_new * (1 - w)
|
134 |
+
pts_update.append([x, y])
|
135 |
+
pts_update = np.array(pts_update)
|
136 |
+
self.pts_last = pts_update.copy()
|
137 |
+
|
138 |
+
return pts_update
|
latentsync/utils/audio.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/Rudrabha/Wav2Lip/blob/master/audio.py
|
2 |
+
|
3 |
+
import librosa
|
4 |
+
import librosa.filters
|
5 |
+
import numpy as np
|
6 |
+
from scipy import signal
|
7 |
+
from scipy.io import wavfile
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
import torch
|
10 |
+
|
11 |
+
audio_config_path = "configs/audio.yaml"
|
12 |
+
|
13 |
+
config = OmegaConf.load(audio_config_path)
|
14 |
+
|
15 |
+
|
16 |
+
def load_wav(path, sr):
|
17 |
+
return librosa.core.load(path, sr=sr)[0]
|
18 |
+
|
19 |
+
|
20 |
+
def save_wav(wav, path, sr):
|
21 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
22 |
+
# proposed by @dsmiller
|
23 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
24 |
+
|
25 |
+
|
26 |
+
def save_wavenet_wav(wav, path, sr):
|
27 |
+
librosa.output.write_wav(path, wav, sr=sr)
|
28 |
+
|
29 |
+
|
30 |
+
def preemphasis(wav, k, preemphasize=True):
|
31 |
+
if preemphasize:
|
32 |
+
return signal.lfilter([1, -k], [1], wav)
|
33 |
+
return wav
|
34 |
+
|
35 |
+
|
36 |
+
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
37 |
+
if inv_preemphasize:
|
38 |
+
return signal.lfilter([1], [1, -k], wav)
|
39 |
+
return wav
|
40 |
+
|
41 |
+
|
42 |
+
def get_hop_size():
|
43 |
+
hop_size = config.audio.hop_size
|
44 |
+
if hop_size is None:
|
45 |
+
assert config.audio.frame_shift_ms is not None
|
46 |
+
hop_size = int(config.audio.frame_shift_ms / 1000 * config.audio.sample_rate)
|
47 |
+
return hop_size
|
48 |
+
|
49 |
+
|
50 |
+
def linearspectrogram(wav):
|
51 |
+
D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
|
52 |
+
S = _amp_to_db(np.abs(D)) - config.audio.ref_level_db
|
53 |
+
|
54 |
+
if config.audio.signal_normalization:
|
55 |
+
return _normalize(S)
|
56 |
+
return S
|
57 |
+
|
58 |
+
|
59 |
+
def melspectrogram(wav):
|
60 |
+
D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
|
61 |
+
S = _amp_to_db(_linear_to_mel(np.abs(D))) - config.audio.ref_level_db
|
62 |
+
|
63 |
+
if config.audio.signal_normalization:
|
64 |
+
return _normalize(S)
|
65 |
+
return S
|
66 |
+
|
67 |
+
|
68 |
+
def _lws_processor():
|
69 |
+
import lws
|
70 |
+
|
71 |
+
return lws.lws(config.audio.n_fft, get_hop_size(), fftsize=config.audio.win_size, mode="speech")
|
72 |
+
|
73 |
+
|
74 |
+
def _stft(y):
|
75 |
+
if config.audio.use_lws:
|
76 |
+
return _lws_processor(config.audio).stft(y).T
|
77 |
+
else:
|
78 |
+
return librosa.stft(y=y, n_fft=config.audio.n_fft, hop_length=get_hop_size(), win_length=config.audio.win_size)
|
79 |
+
|
80 |
+
|
81 |
+
##########################################################
|
82 |
+
# Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
83 |
+
def num_frames(length, fsize, fshift):
|
84 |
+
"""Compute number of time frames of spectrogram"""
|
85 |
+
pad = fsize - fshift
|
86 |
+
if length % fshift == 0:
|
87 |
+
M = (length + pad * 2 - fsize) // fshift + 1
|
88 |
+
else:
|
89 |
+
M = (length + pad * 2 - fsize) // fshift + 2
|
90 |
+
return M
|
91 |
+
|
92 |
+
|
93 |
+
def pad_lr(x, fsize, fshift):
|
94 |
+
"""Compute left and right padding"""
|
95 |
+
M = num_frames(len(x), fsize, fshift)
|
96 |
+
pad = fsize - fshift
|
97 |
+
T = len(x) + 2 * pad
|
98 |
+
r = (M - 1) * fshift + fsize - T
|
99 |
+
return pad, pad + r
|
100 |
+
|
101 |
+
|
102 |
+
##########################################################
|
103 |
+
# Librosa correct padding
|
104 |
+
def librosa_pad_lr(x, fsize, fshift):
|
105 |
+
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
106 |
+
|
107 |
+
|
108 |
+
# Conversions
|
109 |
+
_mel_basis = None
|
110 |
+
|
111 |
+
|
112 |
+
def _linear_to_mel(spectogram):
|
113 |
+
global _mel_basis
|
114 |
+
if _mel_basis is None:
|
115 |
+
_mel_basis = _build_mel_basis()
|
116 |
+
return np.dot(_mel_basis, spectogram)
|
117 |
+
|
118 |
+
|
119 |
+
def _build_mel_basis():
|
120 |
+
assert config.audio.fmax <= config.audio.sample_rate // 2
|
121 |
+
return librosa.filters.mel(
|
122 |
+
sr=config.audio.sample_rate,
|
123 |
+
n_fft=config.audio.n_fft,
|
124 |
+
n_mels=config.audio.num_mels,
|
125 |
+
fmin=config.audio.fmin,
|
126 |
+
fmax=config.audio.fmax,
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
def _amp_to_db(x):
|
131 |
+
min_level = np.exp(config.audio.min_level_db / 20 * np.log(10))
|
132 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
133 |
+
|
134 |
+
|
135 |
+
def _db_to_amp(x):
|
136 |
+
return np.power(10.0, (x) * 0.05)
|
137 |
+
|
138 |
+
|
139 |
+
def _normalize(S):
|
140 |
+
if config.audio.allow_clipping_in_normalization:
|
141 |
+
if config.audio.symmetric_mels:
|
142 |
+
return np.clip(
|
143 |
+
(2 * config.audio.max_abs_value) * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
|
144 |
+
- config.audio.max_abs_value,
|
145 |
+
-config.audio.max_abs_value,
|
146 |
+
config.audio.max_abs_value,
|
147 |
+
)
|
148 |
+
else:
|
149 |
+
return np.clip(
|
150 |
+
config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db)),
|
151 |
+
0,
|
152 |
+
config.audio.max_abs_value,
|
153 |
+
)
|
154 |
+
|
155 |
+
assert S.max() <= 0 and S.min() - config.audio.min_level_db >= 0
|
156 |
+
if config.audio.symmetric_mels:
|
157 |
+
return (2 * config.audio.max_abs_value) * (
|
158 |
+
(S - config.audio.min_level_db) / (-config.audio.min_level_db)
|
159 |
+
) - config.audio.max_abs_value
|
160 |
+
else:
|
161 |
+
return config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
|
162 |
+
|
163 |
+
|
164 |
+
def _denormalize(D):
|
165 |
+
if config.audio.allow_clipping_in_normalization:
|
166 |
+
if config.audio.symmetric_mels:
|
167 |
+
return (
|
168 |
+
(np.clip(D, -config.audio.max_abs_value, config.audio.max_abs_value) + config.audio.max_abs_value)
|
169 |
+
* -config.audio.min_level_db
|
170 |
+
/ (2 * config.audio.max_abs_value)
|
171 |
+
) + config.audio.min_level_db
|
172 |
+
else:
|
173 |
+
return (
|
174 |
+
np.clip(D, 0, config.audio.max_abs_value) * -config.audio.min_level_db / config.audio.max_abs_value
|
175 |
+
) + config.audio.min_level_db
|
176 |
+
|
177 |
+
if config.audio.symmetric_mels:
|
178 |
+
return (
|
179 |
+
(D + config.audio.max_abs_value) * -config.audio.min_level_db / (2 * config.audio.max_abs_value)
|
180 |
+
) + config.audio.min_level_db
|
181 |
+
else:
|
182 |
+
return (D * -config.audio.min_level_db / config.audio.max_abs_value) + config.audio.min_level_db
|
183 |
+
|
184 |
+
|
185 |
+
def get_melspec_overlap(audio_samples, melspec_length=52):
|
186 |
+
mel_spec_overlap = melspectrogram(audio_samples.numpy())
|
187 |
+
mel_spec_overlap = torch.from_numpy(mel_spec_overlap)
|
188 |
+
i = 0
|
189 |
+
mel_spec_overlap_list = []
|
190 |
+
while i + melspec_length < mel_spec_overlap.shape[1] - 3:
|
191 |
+
mel_spec_overlap_list.append(mel_spec_overlap[:, i : i + melspec_length].unsqueeze(0))
|
192 |
+
i += 3
|
193 |
+
mel_spec_overlap = torch.stack(mel_spec_overlap_list)
|
194 |
+
return mel_spec_overlap
|
latentsync/utils/av_reader.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# We modified the original AVReader class of decord to solve the problem of memory leak.
|
2 |
+
# For more details, refer to: https://github.com/dmlc/decord/issues/208
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from decord.video_reader import VideoReader
|
6 |
+
from decord.audio_reader import AudioReader
|
7 |
+
|
8 |
+
from decord.ndarray import cpu
|
9 |
+
from decord import ndarray as _nd
|
10 |
+
from decord.bridge import bridge_out
|
11 |
+
|
12 |
+
|
13 |
+
class AVReader(object):
|
14 |
+
"""Individual audio video reader with convenient indexing function.
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
uri: str
|
19 |
+
Path of file.
|
20 |
+
ctx: decord.Context
|
21 |
+
The context to decode the file, can be decord.cpu() or decord.gpu().
|
22 |
+
sample_rate: int, default is -1
|
23 |
+
Desired output sample rate of the audio, unchanged if `-1` is specified.
|
24 |
+
mono: bool, default is True
|
25 |
+
Desired output channel layout of the audio. `True` is mono layout. `False` is unchanged.
|
26 |
+
width : int, default is -1
|
27 |
+
Desired output width of the video, unchanged if `-1` is specified.
|
28 |
+
height : int, default is -1
|
29 |
+
Desired output height of the video, unchanged if `-1` is specified.
|
30 |
+
num_threads : int, default is 0
|
31 |
+
Number of decoding thread, auto if `0` is specified.
|
32 |
+
fault_tol : int, default is -1
|
33 |
+
The threshold of corupted and recovered frames. This is to prevent silent fault
|
34 |
+
tolerance when for example 50% frames of a video cannot be decoded and duplicate
|
35 |
+
frames are returned. You may find the fault tolerant feature sweet in many cases,
|
36 |
+
but not for training models. Say `N = # recovered frames`
|
37 |
+
If `fault_tol` < 0, nothing will happen.
|
38 |
+
If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`, raise `DECORDLimitReachedError`.
|
39 |
+
If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self, uri, ctx=cpu(0), sample_rate=44100, mono=True, width=-1, height=-1, num_threads=0, fault_tol=-1
|
44 |
+
):
|
45 |
+
self.__audio_reader = AudioReader(uri, ctx, sample_rate, mono)
|
46 |
+
self.__audio_reader.add_padding()
|
47 |
+
if hasattr(uri, "read"):
|
48 |
+
uri.seek(0)
|
49 |
+
self.__video_reader = VideoReader(uri, ctx, width, height, num_threads, fault_tol)
|
50 |
+
self.__video_reader.seek(0)
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
"""Get length of the video. Note that sometimes FFMPEG reports inaccurate number of frames,
|
54 |
+
we always follow what FFMPEG reports.
|
55 |
+
Returns
|
56 |
+
-------
|
57 |
+
int
|
58 |
+
The number of frames in the video file.
|
59 |
+
"""
|
60 |
+
return len(self.__video_reader)
|
61 |
+
|
62 |
+
def __getitem__(self, idx):
|
63 |
+
"""Get audio samples and video frame at `idx`.
|
64 |
+
|
65 |
+
Parameters
|
66 |
+
----------
|
67 |
+
idx : int or slice
|
68 |
+
The frame index, can be negative which means it will index backwards,
|
69 |
+
or slice of frame indices.
|
70 |
+
|
71 |
+
Returns
|
72 |
+
-------
|
73 |
+
(ndarray/list of ndarray, ndarray)
|
74 |
+
First element is samples of shape CxS or a list of length N containing samples of shape CxS,
|
75 |
+
where N is the number of frames, C is the number of channels,
|
76 |
+
S is the number of samples of the corresponding frame.
|
77 |
+
|
78 |
+
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
|
79 |
+
where N is the length of the slice.
|
80 |
+
"""
|
81 |
+
assert self.__video_reader is not None and self.__audio_reader is not None
|
82 |
+
if isinstance(idx, slice):
|
83 |
+
return self.get_batch(range(*idx.indices(len(self.__video_reader))))
|
84 |
+
if idx < 0:
|
85 |
+
idx += len(self.__video_reader)
|
86 |
+
if idx >= len(self.__video_reader) or idx < 0:
|
87 |
+
raise IndexError("Index: {} out of bound: {}".format(idx, len(self.__video_reader)))
|
88 |
+
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
|
89 |
+
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
|
90 |
+
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
|
91 |
+
results = (self.__audio_reader[audio_start_idx:audio_end_idx], self.__video_reader[idx])
|
92 |
+
self.__video_reader.seek(0)
|
93 |
+
return results
|
94 |
+
|
95 |
+
def get_batch(self, indices):
|
96 |
+
"""Get entire batch of audio samples and video frames.
|
97 |
+
|
98 |
+
Parameters
|
99 |
+
----------
|
100 |
+
indices : list of integers
|
101 |
+
A list of frame indices. If negative indices detected, the indices will be indexed from backward
|
102 |
+
Returns
|
103 |
+
-------
|
104 |
+
(list of ndarray, ndarray)
|
105 |
+
First element is a list of length N containing samples of shape CxS,
|
106 |
+
where N is the number of frames, C is the number of channels,
|
107 |
+
S is the number of samples of the corresponding frame.
|
108 |
+
|
109 |
+
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
|
110 |
+
where N is the length of the slice.
|
111 |
+
|
112 |
+
"""
|
113 |
+
assert self.__video_reader is not None and self.__audio_reader is not None
|
114 |
+
indices = self._validate_indices(indices)
|
115 |
+
audio_arr = []
|
116 |
+
prev_video_idx = None
|
117 |
+
prev_audio_end_idx = None
|
118 |
+
for idx in list(indices):
|
119 |
+
frame_start_time, frame_end_time = self.__video_reader.get_frame_timestamp(idx)
|
120 |
+
# timestamp and sample conversion could have some error that could cause non-continuous audio
|
121 |
+
# we detect if retrieving continuous frame and make the audio continuous
|
122 |
+
if prev_video_idx and idx == prev_video_idx + 1:
|
123 |
+
audio_start_idx = prev_audio_end_idx
|
124 |
+
else:
|
125 |
+
audio_start_idx = self.__audio_reader._time_to_sample(frame_start_time)
|
126 |
+
audio_end_idx = self.__audio_reader._time_to_sample(frame_end_time)
|
127 |
+
audio_arr.append(self.__audio_reader[audio_start_idx:audio_end_idx])
|
128 |
+
prev_video_idx = idx
|
129 |
+
prev_audio_end_idx = audio_end_idx
|
130 |
+
results = (audio_arr, self.__video_reader.get_batch(indices))
|
131 |
+
self.__video_reader.seek(0)
|
132 |
+
return results
|
133 |
+
|
134 |
+
def _get_slice(self, sl):
|
135 |
+
audio_arr = np.empty(shape=(self.__audio_reader.shape()[0], 0), dtype="float32")
|
136 |
+
for idx in list(sl):
|
137 |
+
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
|
138 |
+
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
|
139 |
+
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
|
140 |
+
audio_arr = np.concatenate(
|
141 |
+
(audio_arr, self.__audio_reader[audio_start_idx:audio_end_idx].asnumpy()), axis=1
|
142 |
+
)
|
143 |
+
results = (bridge_out(_nd.array(audio_arr)), self.__video_reader.get_batch(sl))
|
144 |
+
self.__video_reader.seek(0)
|
145 |
+
return results
|
146 |
+
|
147 |
+
def _validate_indices(self, indices):
|
148 |
+
"""Validate int64 integers and convert negative integers to positive by backward search"""
|
149 |
+
assert self.__video_reader is not None and self.__audio_reader is not None
|
150 |
+
indices = np.array(indices, dtype=np.int64)
|
151 |
+
# process negative indices
|
152 |
+
indices[indices < 0] += len(self.__video_reader)
|
153 |
+
if not (indices >= 0).all():
|
154 |
+
raise IndexError("Invalid negative indices: {}".format(indices[indices < 0] + len(self.__video_reader)))
|
155 |
+
if not (indices < len(self.__video_reader)).all():
|
156 |
+
raise IndexError("Out of bound indices: {}".format(indices[indices >= len(self.__video_reader)]))
|
157 |
+
return indices
|
latentsync/utils/image_processor.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from torchvision import transforms
|
16 |
+
import cv2
|
17 |
+
from einops import rearrange
|
18 |
+
import mediapipe as mp
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from typing import Union
|
22 |
+
from .affine_transform import AlignRestore, laplacianSmooth
|
23 |
+
import face_alignment
|
24 |
+
|
25 |
+
"""
|
26 |
+
If you are enlarging the image, you should prefer to use INTER_LINEAR or INTER_CUBIC interpolation. If you are shrinking the image, you should prefer to use INTER_AREA interpolation.
|
27 |
+
https://stackoverflow.com/questions/23853632/which-kind-of-interpolation-best-for-resizing-image
|
28 |
+
"""
|
29 |
+
|
30 |
+
|
31 |
+
def load_fixed_mask(resolution: int) -> torch.Tensor:
|
32 |
+
mask_image = cv2.imread("latentsync/utils/mask.png")
|
33 |
+
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
|
34 |
+
mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0
|
35 |
+
mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
|
36 |
+
return mask_image
|
37 |
+
|
38 |
+
|
39 |
+
class ImageProcessor:
|
40 |
+
def __init__(self, resolution: int = 512, mask: str = "fix_mask", device: str = "cpu", mask_image=None):
|
41 |
+
self.resolution = resolution
|
42 |
+
self.resize = transforms.Resize(
|
43 |
+
(resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True
|
44 |
+
)
|
45 |
+
self.normalize = transforms.Normalize([0.5], [0.5], inplace=True)
|
46 |
+
self.mask = mask
|
47 |
+
|
48 |
+
if mask in ["mouth", "face", "eye"]:
|
49 |
+
self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
|
50 |
+
if mask == "fix_mask":
|
51 |
+
self.face_mesh = None
|
52 |
+
self.smoother = laplacianSmooth()
|
53 |
+
self.restorer = AlignRestore()
|
54 |
+
|
55 |
+
if mask_image is None:
|
56 |
+
self.mask_image = load_fixed_mask(resolution)
|
57 |
+
else:
|
58 |
+
self.mask_image = mask_image
|
59 |
+
|
60 |
+
if device != "cpu":
|
61 |
+
self.fa = face_alignment.FaceAlignment(
|
62 |
+
face_alignment.LandmarksType.TWO_D, flip_input=False, device=device
|
63 |
+
)
|
64 |
+
self.face_mesh = None
|
65 |
+
else:
|
66 |
+
# self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
|
67 |
+
self.face_mesh = None
|
68 |
+
self.fa = None
|
69 |
+
|
70 |
+
def detect_facial_landmarks(self, image: np.ndarray):
|
71 |
+
height, width, _ = image.shape
|
72 |
+
results = self.face_mesh.process(image)
|
73 |
+
if not results.multi_face_landmarks: # Face not detected
|
74 |
+
raise RuntimeError("Face not detected")
|
75 |
+
face_landmarks = results.multi_face_landmarks[0] # Only use the first face in the image
|
76 |
+
landmark_coordinates = [
|
77 |
+
(int(landmark.x * width), int(landmark.y * height)) for landmark in face_landmarks.landmark
|
78 |
+
] # x means width, y means height
|
79 |
+
return landmark_coordinates
|
80 |
+
|
81 |
+
def preprocess_one_masked_image(self, image: torch.Tensor) -> np.ndarray:
|
82 |
+
image = self.resize(image)
|
83 |
+
|
84 |
+
if self.mask == "mouth" or self.mask == "face":
|
85 |
+
landmark_coordinates = self.detect_facial_landmarks(image)
|
86 |
+
if self.mask == "mouth":
|
87 |
+
surround_landmarks = mouth_surround_landmarks
|
88 |
+
else:
|
89 |
+
surround_landmarks = face_surround_landmarks
|
90 |
+
|
91 |
+
points = [landmark_coordinates[landmark] for landmark in surround_landmarks]
|
92 |
+
points = np.array(points)
|
93 |
+
mask = np.ones((self.resolution, self.resolution))
|
94 |
+
mask = cv2.fillPoly(mask, pts=[points], color=(0, 0, 0))
|
95 |
+
mask = torch.from_numpy(mask)
|
96 |
+
mask = mask.unsqueeze(0)
|
97 |
+
elif self.mask == "half":
|
98 |
+
mask = torch.ones((self.resolution, self.resolution))
|
99 |
+
height = mask.shape[0]
|
100 |
+
mask[height // 2 :, :] = 0
|
101 |
+
mask = mask.unsqueeze(0)
|
102 |
+
elif self.mask == "eye":
|
103 |
+
mask = torch.ones((self.resolution, self.resolution))
|
104 |
+
landmark_coordinates = self.detect_facial_landmarks(image)
|
105 |
+
y = landmark_coordinates[195][1]
|
106 |
+
mask[y:, :] = 0
|
107 |
+
mask = mask.unsqueeze(0)
|
108 |
+
else:
|
109 |
+
raise ValueError("Invalid mask type")
|
110 |
+
|
111 |
+
image = image.to(dtype=torch.float32)
|
112 |
+
pixel_values = self.normalize(image / 255.0)
|
113 |
+
masked_pixel_values = pixel_values * mask
|
114 |
+
mask = 1 - mask
|
115 |
+
|
116 |
+
return pixel_values, masked_pixel_values, mask
|
117 |
+
|
118 |
+
def affine_transform(self, image: torch.Tensor) -> np.ndarray:
|
119 |
+
# image = rearrange(image, "c h w-> h w c").numpy()
|
120 |
+
if self.fa is None:
|
121 |
+
landmark_coordinates = np.array(self.detect_facial_landmarks(image))
|
122 |
+
lm68 = mediapipe_lm478_to_face_alignment_lm68(landmark_coordinates)
|
123 |
+
else:
|
124 |
+
detected_faces = self.fa.get_landmarks(image)
|
125 |
+
if detected_faces is None:
|
126 |
+
raise RuntimeError("Face not detected")
|
127 |
+
lm68 = detected_faces[0]
|
128 |
+
|
129 |
+
points = self.smoother.smooth(lm68)
|
130 |
+
lmk3_ = np.zeros((3, 2))
|
131 |
+
lmk3_[0] = points[17:22].mean(0)
|
132 |
+
lmk3_[1] = points[22:27].mean(0)
|
133 |
+
lmk3_[2] = points[27:36].mean(0)
|
134 |
+
# print(lmk3_)
|
135 |
+
face, affine_matrix = self.restorer.align_warp_face(
|
136 |
+
image.copy(), lmks3=lmk3_, smooth=True, border_mode="constant"
|
137 |
+
)
|
138 |
+
box = [0, 0, face.shape[1], face.shape[0]] # x1, y1, x2, y2
|
139 |
+
face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_CUBIC)
|
140 |
+
face = rearrange(torch.from_numpy(face), "h w c -> c h w")
|
141 |
+
return face, box, affine_matrix
|
142 |
+
|
143 |
+
def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False):
|
144 |
+
if affine_transform:
|
145 |
+
image, _, _ = self.affine_transform(image)
|
146 |
+
else:
|
147 |
+
image = self.resize(image)
|
148 |
+
pixel_values = self.normalize(image / 255.0)
|
149 |
+
masked_pixel_values = pixel_values * self.mask_image
|
150 |
+
return pixel_values, masked_pixel_values, self.mask_image[0:1]
|
151 |
+
|
152 |
+
def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False):
|
153 |
+
if isinstance(images, np.ndarray):
|
154 |
+
images = torch.from_numpy(images)
|
155 |
+
if images.shape[3] == 3:
|
156 |
+
images = rearrange(images, "b h w c -> b c h w")
|
157 |
+
if self.mask == "fix_mask":
|
158 |
+
results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images]
|
159 |
+
else:
|
160 |
+
results = [self.preprocess_one_masked_image(image) for image in images]
|
161 |
+
|
162 |
+
pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results))
|
163 |
+
return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list)
|
164 |
+
|
165 |
+
def process_images(self, images: Union[torch.Tensor, np.ndarray]):
|
166 |
+
if isinstance(images, np.ndarray):
|
167 |
+
images = torch.from_numpy(images)
|
168 |
+
if images.shape[3] == 3:
|
169 |
+
images = rearrange(images, "b h w c -> b c h w")
|
170 |
+
images = self.resize(images)
|
171 |
+
pixel_values = self.normalize(images / 255.0)
|
172 |
+
return pixel_values
|
173 |
+
|
174 |
+
def close(self):
|
175 |
+
if self.face_mesh is not None:
|
176 |
+
self.face_mesh.close()
|
177 |
+
|
178 |
+
|
179 |
+
def mediapipe_lm478_to_face_alignment_lm68(lm478, return_2d=True):
|
180 |
+
"""
|
181 |
+
lm478: [B, 478, 3] or [478,3]
|
182 |
+
"""
|
183 |
+
# lm478[..., 0] *= W
|
184 |
+
# lm478[..., 1] *= H
|
185 |
+
landmarks_extracted = []
|
186 |
+
for index in landmark_points_68:
|
187 |
+
x = lm478[index][0]
|
188 |
+
y = lm478[index][1]
|
189 |
+
landmarks_extracted.append((x, y))
|
190 |
+
return np.array(landmarks_extracted)
|
191 |
+
|
192 |
+
|
193 |
+
landmark_points_68 = [
|
194 |
+
162,
|
195 |
+
234,
|
196 |
+
93,
|
197 |
+
58,
|
198 |
+
172,
|
199 |
+
136,
|
200 |
+
149,
|
201 |
+
148,
|
202 |
+
152,
|
203 |
+
377,
|
204 |
+
378,
|
205 |
+
365,
|
206 |
+
397,
|
207 |
+
288,
|
208 |
+
323,
|
209 |
+
454,
|
210 |
+
389,
|
211 |
+
71,
|
212 |
+
63,
|
213 |
+
105,
|
214 |
+
66,
|
215 |
+
107,
|
216 |
+
336,
|
217 |
+
296,
|
218 |
+
334,
|
219 |
+
293,
|
220 |
+
301,
|
221 |
+
168,
|
222 |
+
197,
|
223 |
+
5,
|
224 |
+
4,
|
225 |
+
75,
|
226 |
+
97,
|
227 |
+
2,
|
228 |
+
326,
|
229 |
+
305,
|
230 |
+
33,
|
231 |
+
160,
|
232 |
+
158,
|
233 |
+
133,
|
234 |
+
153,
|
235 |
+
144,
|
236 |
+
362,
|
237 |
+
385,
|
238 |
+
387,
|
239 |
+
263,
|
240 |
+
373,
|
241 |
+
380,
|
242 |
+
61,
|
243 |
+
39,
|
244 |
+
37,
|
245 |
+
0,
|
246 |
+
267,
|
247 |
+
269,
|
248 |
+
291,
|
249 |
+
405,
|
250 |
+
314,
|
251 |
+
17,
|
252 |
+
84,
|
253 |
+
181,
|
254 |
+
78,
|
255 |
+
82,
|
256 |
+
13,
|
257 |
+
312,
|
258 |
+
308,
|
259 |
+
317,
|
260 |
+
14,
|
261 |
+
87,
|
262 |
+
]
|
263 |
+
|
264 |
+
|
265 |
+
# Refer to https://storage.googleapis.com/mediapipe-assets/documentation/mediapipe_face_landmark_fullsize.png
|
266 |
+
mouth_surround_landmarks = [
|
267 |
+
164,
|
268 |
+
165,
|
269 |
+
167,
|
270 |
+
92,
|
271 |
+
186,
|
272 |
+
57,
|
273 |
+
43,
|
274 |
+
106,
|
275 |
+
182,
|
276 |
+
83,
|
277 |
+
18,
|
278 |
+
313,
|
279 |
+
406,
|
280 |
+
335,
|
281 |
+
273,
|
282 |
+
287,
|
283 |
+
410,
|
284 |
+
322,
|
285 |
+
391,
|
286 |
+
393,
|
287 |
+
]
|
288 |
+
|
289 |
+
face_surround_landmarks = [
|
290 |
+
152,
|
291 |
+
377,
|
292 |
+
400,
|
293 |
+
378,
|
294 |
+
379,
|
295 |
+
365,
|
296 |
+
397,
|
297 |
+
288,
|
298 |
+
435,
|
299 |
+
433,
|
300 |
+
411,
|
301 |
+
425,
|
302 |
+
423,
|
303 |
+
327,
|
304 |
+
326,
|
305 |
+
94,
|
306 |
+
97,
|
307 |
+
98,
|
308 |
+
203,
|
309 |
+
205,
|
310 |
+
187,
|
311 |
+
213,
|
312 |
+
215,
|
313 |
+
58,
|
314 |
+
172,
|
315 |
+
136,
|
316 |
+
150,
|
317 |
+
149,
|
318 |
+
176,
|
319 |
+
148,
|
320 |
+
]
|
321 |
+
|
322 |
+
if __name__ == "__main__":
|
323 |
+
image_processor = ImageProcessor(512, mask="fix_mask")
|
324 |
+
video = cv2.VideoCapture("/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/original/val/RD_Radio57_000.mp4")
|
325 |
+
while True:
|
326 |
+
ret, frame = video.read()
|
327 |
+
# if not ret:
|
328 |
+
# break
|
329 |
+
|
330 |
+
# cv2.imwrite("image.jpg", frame)
|
331 |
+
|
332 |
+
frame = rearrange(torch.Tensor(frame).type(torch.uint8), "h w c -> c h w")
|
333 |
+
# face, masked_face, _ = image_processor.preprocess_fixed_mask_image(frame, affine_transform=True)
|
334 |
+
face, _, _ = image_processor.affine_transform(frame)
|
335 |
+
|
336 |
+
break
|
337 |
+
|
338 |
+
face = (rearrange(face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8)
|
339 |
+
cv2.imwrite("face.jpg", face)
|
340 |
+
|
341 |
+
# masked_face = (rearrange(masked_face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8)
|
342 |
+
# cv2.imwrite("masked_face.jpg", masked_face)
|
latentsync/utils/mask.png
ADDED
latentsync/utils/util.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import imageio
|
17 |
+
import numpy as np
|
18 |
+
import json
|
19 |
+
from typing import Union
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import torchvision
|
26 |
+
import torch.distributed as dist
|
27 |
+
from torchvision import transforms
|
28 |
+
|
29 |
+
from tqdm import tqdm
|
30 |
+
from einops import rearrange
|
31 |
+
import cv2
|
32 |
+
from decord import AudioReader, VideoReader
|
33 |
+
import shutil
|
34 |
+
import subprocess
|
35 |
+
|
36 |
+
|
37 |
+
# Machine epsilon for a float32 (single precision)
|
38 |
+
eps = np.finfo(np.float32).eps
|
39 |
+
|
40 |
+
|
41 |
+
def read_json(filepath: str):
|
42 |
+
with open(filepath) as f:
|
43 |
+
json_dict = json.load(f)
|
44 |
+
return json_dict
|
45 |
+
|
46 |
+
|
47 |
+
def read_video(video_path: str, change_fps=True, use_decord=True):
|
48 |
+
if change_fps:
|
49 |
+
temp_dir = "temp"
|
50 |
+
if os.path.exists(temp_dir):
|
51 |
+
shutil.rmtree(temp_dir)
|
52 |
+
os.makedirs(temp_dir, exist_ok=True)
|
53 |
+
command = (
|
54 |
+
f"ffmpeg -loglevel error -y -nostdin -i {video_path} -r 25 -crf 18 {os.path.join(temp_dir, 'video.mp4')}"
|
55 |
+
)
|
56 |
+
subprocess.run(command, shell=True)
|
57 |
+
target_video_path = os.path.join(temp_dir, "video.mp4")
|
58 |
+
else:
|
59 |
+
target_video_path = video_path
|
60 |
+
|
61 |
+
if use_decord:
|
62 |
+
return read_video_decord(target_video_path)
|
63 |
+
else:
|
64 |
+
return read_video_cv2(target_video_path)
|
65 |
+
|
66 |
+
|
67 |
+
def read_video_decord(video_path: str):
|
68 |
+
vr = VideoReader(video_path)
|
69 |
+
video_frames = vr[:].asnumpy()
|
70 |
+
vr.seek(0)
|
71 |
+
return video_frames
|
72 |
+
|
73 |
+
|
74 |
+
def read_video_cv2(video_path: str):
|
75 |
+
# Open the video file
|
76 |
+
cap = cv2.VideoCapture(video_path)
|
77 |
+
|
78 |
+
# Check if the video was opened successfully
|
79 |
+
if not cap.isOpened():
|
80 |
+
print("Error: Could not open video.")
|
81 |
+
return np.array([])
|
82 |
+
|
83 |
+
frames = []
|
84 |
+
|
85 |
+
while True:
|
86 |
+
# Read a frame
|
87 |
+
ret, frame = cap.read()
|
88 |
+
|
89 |
+
# If frame is read correctly ret is True
|
90 |
+
if not ret:
|
91 |
+
break
|
92 |
+
|
93 |
+
# Convert BGR to RGB
|
94 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
95 |
+
|
96 |
+
frames.append(frame_rgb)
|
97 |
+
|
98 |
+
# Release the video capture object
|
99 |
+
cap.release()
|
100 |
+
|
101 |
+
return np.array(frames)
|
102 |
+
|
103 |
+
|
104 |
+
def read_audio(audio_path: str, audio_sample_rate: int = 16000):
|
105 |
+
if audio_path is None:
|
106 |
+
raise ValueError("Audio path is required.")
|
107 |
+
ar = AudioReader(audio_path, sample_rate=audio_sample_rate, mono=True)
|
108 |
+
|
109 |
+
# To access the audio samples
|
110 |
+
audio_samples = torch.from_numpy(ar[:].asnumpy())
|
111 |
+
audio_samples = audio_samples.squeeze(0)
|
112 |
+
|
113 |
+
return audio_samples
|
114 |
+
|
115 |
+
|
116 |
+
def write_video(video_output_path: str, video_frames: np.ndarray, fps: int):
|
117 |
+
height, width = video_frames[0].shape[:2]
|
118 |
+
out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
|
119 |
+
# out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"vp09"), fps, (width, height))
|
120 |
+
for frame in video_frames:
|
121 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
122 |
+
out.write(frame)
|
123 |
+
out.release()
|
124 |
+
|
125 |
+
|
126 |
+
def init_dist(backend="nccl", **kwargs):
|
127 |
+
"""Initializes distributed environment."""
|
128 |
+
rank = int(os.environ["RANK"])
|
129 |
+
num_gpus = torch.cuda.device_count()
|
130 |
+
if num_gpus == 0:
|
131 |
+
raise RuntimeError("No GPUs available for training.")
|
132 |
+
local_rank = rank % num_gpus
|
133 |
+
torch.cuda.set_device(local_rank)
|
134 |
+
dist.init_process_group(backend=backend, **kwargs)
|
135 |
+
|
136 |
+
return local_rank
|
137 |
+
|
138 |
+
|
139 |
+
def zero_rank_print(s):
|
140 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
141 |
+
print("### " + s)
|
142 |
+
|
143 |
+
|
144 |
+
def zero_rank_log(logger, message: str):
|
145 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
146 |
+
logger.info(message)
|
147 |
+
|
148 |
+
|
149 |
+
def make_audio_window(audio_embeddings: torch.Tensor, window_size: int):
|
150 |
+
audio_window = []
|
151 |
+
end_idx = audio_embeddings.shape[1] - window_size + 1
|
152 |
+
for i in range(end_idx):
|
153 |
+
audio_window.append(audio_embeddings[:, i : i + window_size, :])
|
154 |
+
audio_window = torch.stack(audio_window)
|
155 |
+
audio_window = rearrange(audio_window, "f b w d -> b f w d")
|
156 |
+
return audio_window
|
157 |
+
|
158 |
+
|
159 |
+
def check_video_fps(video_path: str):
|
160 |
+
cam = cv2.VideoCapture(video_path)
|
161 |
+
fps = cam.get(cv2.CAP_PROP_FPS)
|
162 |
+
if fps != 25:
|
163 |
+
raise ValueError(f"Video FPS is not 25, it is {fps}. Please convert the video to 25 FPS.")
|
164 |
+
|
165 |
+
|
166 |
+
def tailor_tensor_to_length(tensor: torch.Tensor, length: int):
|
167 |
+
if len(tensor) == length:
|
168 |
+
return tensor
|
169 |
+
elif len(tensor) > length:
|
170 |
+
return tensor[:length]
|
171 |
+
else:
|
172 |
+
return torch.cat([tensor, tensor[-1].repeat(length - len(tensor))])
|
173 |
+
|
174 |
+
|
175 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
176 |
+
videos = rearrange(videos, "b c f h w -> f b c h w")
|
177 |
+
outputs = []
|
178 |
+
for x in videos:
|
179 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
180 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
181 |
+
if rescale:
|
182 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
183 |
+
x = (x * 255).numpy().astype(np.uint8)
|
184 |
+
outputs.append(x)
|
185 |
+
|
186 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
187 |
+
imageio.mimsave(path, outputs, fps=fps)
|
188 |
+
|
189 |
+
|
190 |
+
def interpolate_features(features: torch.Tensor, output_len: int) -> torch.Tensor:
|
191 |
+
features = features.cpu().numpy()
|
192 |
+
input_len, num_features = features.shape
|
193 |
+
|
194 |
+
input_timesteps = np.linspace(0, 10, input_len)
|
195 |
+
output_timesteps = np.linspace(0, 10, output_len)
|
196 |
+
output_features = np.zeros((output_len, num_features))
|
197 |
+
for feat in range(num_features):
|
198 |
+
output_features[:, feat] = np.interp(output_timesteps, input_timesteps, features[:, feat])
|
199 |
+
return torch.from_numpy(output_features)
|
200 |
+
|
201 |
+
|
202 |
+
# DDIM Inversion
|
203 |
+
@torch.no_grad()
|
204 |
+
def init_prompt(prompt, pipeline):
|
205 |
+
uncond_input = pipeline.tokenizer(
|
206 |
+
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, return_tensors="pt"
|
207 |
+
)
|
208 |
+
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
209 |
+
text_input = pipeline.tokenizer(
|
210 |
+
[prompt],
|
211 |
+
padding="max_length",
|
212 |
+
max_length=pipeline.tokenizer.model_max_length,
|
213 |
+
truncation=True,
|
214 |
+
return_tensors="pt",
|
215 |
+
)
|
216 |
+
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
217 |
+
context = torch.cat([uncond_embeddings, text_embeddings])
|
218 |
+
|
219 |
+
return context
|
220 |
+
|
221 |
+
|
222 |
+
def reversed_forward(ddim_scheduler, pred_noise, timesteps, x_t):
|
223 |
+
# Compute alphas, betas
|
224 |
+
alpha_prod_t = ddim_scheduler.alphas_cumprod[timesteps]
|
225 |
+
beta_prod_t = 1 - alpha_prod_t
|
226 |
+
|
227 |
+
# 3. compute predicted original sample from predicted noise also called
|
228 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
229 |
+
if ddim_scheduler.config.prediction_type == "epsilon":
|
230 |
+
beta_prod_t = beta_prod_t[:, None, None, None, None]
|
231 |
+
alpha_prod_t = alpha_prod_t[:, None, None, None, None]
|
232 |
+
pred_original_sample = (x_t - beta_prod_t ** (0.5) * pred_noise) / alpha_prod_t ** (0.5)
|
233 |
+
else:
|
234 |
+
raise NotImplementedError("This prediction type is not implemented yet")
|
235 |
+
|
236 |
+
# Clip "predicted x_0"
|
237 |
+
if ddim_scheduler.config.clip_sample:
|
238 |
+
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
239 |
+
return pred_original_sample
|
240 |
+
|
241 |
+
|
242 |
+
def next_step(
|
243 |
+
model_output: Union[torch.FloatTensor, np.ndarray],
|
244 |
+
timestep: int,
|
245 |
+
sample: Union[torch.FloatTensor, np.ndarray],
|
246 |
+
ddim_scheduler,
|
247 |
+
):
|
248 |
+
timestep, next_timestep = (
|
249 |
+
min(timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999),
|
250 |
+
timestep,
|
251 |
+
)
|
252 |
+
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
253 |
+
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
254 |
+
beta_prod_t = 1 - alpha_prod_t
|
255 |
+
next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
256 |
+
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
257 |
+
next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction
|
258 |
+
return next_sample
|
259 |
+
|
260 |
+
|
261 |
+
def get_noise_pred_single(latents, t, context, unet):
|
262 |
+
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
|
263 |
+
return noise_pred
|
264 |
+
|
265 |
+
|
266 |
+
@torch.no_grad()
|
267 |
+
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
268 |
+
context = init_prompt(prompt, pipeline)
|
269 |
+
uncond_embeddings, cond_embeddings = context.chunk(2)
|
270 |
+
all_latent = [latent]
|
271 |
+
latent = latent.clone().detach()
|
272 |
+
for i in tqdm(range(num_inv_steps)):
|
273 |
+
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
274 |
+
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
|
275 |
+
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
276 |
+
all_latent.append(latent)
|
277 |
+
return all_latent
|
278 |
+
|
279 |
+
|
280 |
+
@torch.no_grad()
|
281 |
+
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
282 |
+
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
283 |
+
return ddim_latents
|
284 |
+
|
285 |
+
|
286 |
+
def plot_loss_chart(save_path: str, *args):
|
287 |
+
# Creating the plot
|
288 |
+
plt.figure()
|
289 |
+
for loss_line in args:
|
290 |
+
plt.plot(loss_line[1], loss_line[2], label=loss_line[0])
|
291 |
+
plt.xlabel("Step")
|
292 |
+
plt.ylabel("Loss")
|
293 |
+
plt.legend()
|
294 |
+
|
295 |
+
# Save the figure to a file
|
296 |
+
plt.savefig(save_path)
|
297 |
+
|
298 |
+
# Close the figure to free memory
|
299 |
+
plt.close()
|
300 |
+
|
301 |
+
|
302 |
+
CRED = "\033[91m"
|
303 |
+
CEND = "\033[0m"
|
304 |
+
|
305 |
+
|
306 |
+
def red_text(text: str):
|
307 |
+
return f"{CRED}{text}{CEND}"
|
308 |
+
|
309 |
+
|
310 |
+
log_loss = nn.BCELoss(reduction="none")
|
311 |
+
|
312 |
+
|
313 |
+
def cosine_loss(vision_embeds, audio_embeds, y):
|
314 |
+
sims = nn.functional.cosine_similarity(vision_embeds, audio_embeds)
|
315 |
+
# sims[sims!=sims] = 0 # remove nan
|
316 |
+
# sims = sims.clamp(0, 1)
|
317 |
+
loss = log_loss(sims.unsqueeze(1), y).squeeze()
|
318 |
+
return loss
|
319 |
+
|
320 |
+
|
321 |
+
def save_image(image, save_path):
|
322 |
+
# input size (C, H, W)
|
323 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
324 |
+
image = (image * 255).to(torch.uint8)
|
325 |
+
image = transforms.ToPILImage()(image)
|
326 |
+
# Save the image copy
|
327 |
+
image.save(save_path)
|
328 |
+
|
329 |
+
# Close the image file
|
330 |
+
image.close()
|
331 |
+
|
332 |
+
|
333 |
+
def gather_loss(loss, device):
|
334 |
+
# Sum the local loss across all processes
|
335 |
+
local_loss = loss.item()
|
336 |
+
global_loss = torch.tensor(local_loss, dtype=torch.float32).to(device)
|
337 |
+
dist.all_reduce(global_loss, op=dist.ReduceOp.SUM)
|
338 |
+
|
339 |
+
# Calculate the average loss across all processes
|
340 |
+
global_average_loss = global_loss.item() / dist.get_world_size()
|
341 |
+
return global_average_loss
|
342 |
+
|
343 |
+
|
344 |
+
def gather_video_paths_recursively(input_dir):
|
345 |
+
print(f"Recursively gathering video paths of {input_dir} ...")
|
346 |
+
paths = []
|
347 |
+
gather_video_paths(input_dir, paths)
|
348 |
+
return paths
|
349 |
+
|
350 |
+
|
351 |
+
def gather_video_paths(input_dir, paths):
|
352 |
+
for file in sorted(os.listdir(input_dir)):
|
353 |
+
if file.endswith(".mp4"):
|
354 |
+
filepath = os.path.join(input_dir, file)
|
355 |
+
paths.append(filepath)
|
356 |
+
elif os.path.isdir(os.path.join(input_dir, file)):
|
357 |
+
gather_video_paths(os.path.join(input_dir, file), paths)
|
358 |
+
|
359 |
+
|
360 |
+
def count_video_time(video_path):
|
361 |
+
video = cv2.VideoCapture(video_path)
|
362 |
+
|
363 |
+
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
|
364 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
365 |
+
return frame_count / fps
|
latentsync/whisper/audio2feature.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/TMElyralab/MuseTalk/blob/main/musetalk/whisper/audio2feature.py
|
2 |
+
|
3 |
+
from .whisper import load_model
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import os
|
7 |
+
|
8 |
+
|
9 |
+
class Audio2Feature:
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
model_path="checkpoints/whisper/tiny.pt",
|
13 |
+
device=None,
|
14 |
+
audio_embeds_cache_dir=None,
|
15 |
+
num_frames=16,
|
16 |
+
):
|
17 |
+
self.model = load_model(model_path, device)
|
18 |
+
self.audio_embeds_cache_dir = audio_embeds_cache_dir
|
19 |
+
self.num_frames = num_frames
|
20 |
+
self.embedding_dim = self.model.dims.n_audio_state
|
21 |
+
|
22 |
+
def get_sliced_feature(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25):
|
23 |
+
"""
|
24 |
+
Get sliced features based on a given index
|
25 |
+
:param feature_array:
|
26 |
+
:param start_idx: the start index of the feature
|
27 |
+
:param audio_feat_length:
|
28 |
+
:return:
|
29 |
+
"""
|
30 |
+
length = len(feature_array)
|
31 |
+
selected_feature = []
|
32 |
+
selected_idx = []
|
33 |
+
|
34 |
+
center_idx = int(vid_idx * 50 / fps)
|
35 |
+
left_idx = center_idx - audio_feat_length[0] * 2
|
36 |
+
right_idx = center_idx + (audio_feat_length[1] + 1) * 2
|
37 |
+
|
38 |
+
for idx in range(left_idx, right_idx):
|
39 |
+
idx = max(0, idx)
|
40 |
+
idx = min(length - 1, idx)
|
41 |
+
x = feature_array[idx]
|
42 |
+
selected_feature.append(x)
|
43 |
+
selected_idx.append(idx)
|
44 |
+
|
45 |
+
selected_feature = torch.cat(selected_feature, dim=0)
|
46 |
+
selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
|
47 |
+
return selected_feature, selected_idx
|
48 |
+
|
49 |
+
def get_sliced_feature_sparse(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25):
|
50 |
+
"""
|
51 |
+
Get sliced features based on a given index
|
52 |
+
:param feature_array:
|
53 |
+
:param start_idx: the start index of the feature
|
54 |
+
:param audio_feat_length:
|
55 |
+
:return:
|
56 |
+
"""
|
57 |
+
length = len(feature_array)
|
58 |
+
selected_feature = []
|
59 |
+
selected_idx = []
|
60 |
+
|
61 |
+
for dt in range(-audio_feat_length[0], audio_feat_length[1] + 1):
|
62 |
+
left_idx = int((vid_idx + dt) * 50 / fps)
|
63 |
+
if left_idx < 1 or left_idx > length - 1:
|
64 |
+
left_idx = max(0, left_idx)
|
65 |
+
left_idx = min(length - 1, left_idx)
|
66 |
+
|
67 |
+
x = feature_array[left_idx]
|
68 |
+
x = x[np.newaxis, :, :]
|
69 |
+
x = np.repeat(x, 2, axis=0)
|
70 |
+
selected_feature.append(x)
|
71 |
+
selected_idx.append(left_idx)
|
72 |
+
selected_idx.append(left_idx)
|
73 |
+
else:
|
74 |
+
x = feature_array[left_idx - 1 : left_idx + 1]
|
75 |
+
selected_feature.append(x)
|
76 |
+
selected_idx.append(left_idx - 1)
|
77 |
+
selected_idx.append(left_idx)
|
78 |
+
selected_feature = np.concatenate(selected_feature, axis=0)
|
79 |
+
selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
|
80 |
+
selected_feature = torch.from_numpy(selected_feature)
|
81 |
+
return selected_feature, selected_idx
|
82 |
+
|
83 |
+
def feature2chunks(self, feature_array, fps, audio_feat_length=[2, 2]):
|
84 |
+
whisper_chunks = []
|
85 |
+
whisper_idx_multiplier = 50.0 / fps
|
86 |
+
i = 0
|
87 |
+
print(f"video in {fps} FPS, audio idx in 50FPS")
|
88 |
+
|
89 |
+
while True:
|
90 |
+
start_idx = int(i * whisper_idx_multiplier)
|
91 |
+
selected_feature, selected_idx = self.get_sliced_feature(
|
92 |
+
feature_array=feature_array, vid_idx=i, audio_feat_length=audio_feat_length, fps=fps
|
93 |
+
)
|
94 |
+
# print(f"i:{i},selected_idx {selected_idx}")
|
95 |
+
whisper_chunks.append(selected_feature)
|
96 |
+
i += 1
|
97 |
+
if start_idx > len(feature_array):
|
98 |
+
break
|
99 |
+
|
100 |
+
return whisper_chunks
|
101 |
+
|
102 |
+
def _audio2feat(self, audio_path: str):
|
103 |
+
# get the sample rate of the audio
|
104 |
+
result = self.model.transcribe(audio_path)
|
105 |
+
embed_list = []
|
106 |
+
for emb in result["segments"]:
|
107 |
+
encoder_embeddings = emb["encoder_embeddings"]
|
108 |
+
encoder_embeddings = encoder_embeddings.transpose(0, 2, 1, 3)
|
109 |
+
encoder_embeddings = encoder_embeddings.squeeze(0)
|
110 |
+
start_idx = int(emb["start"])
|
111 |
+
end_idx = int(emb["end"])
|
112 |
+
emb_end_idx = int((end_idx - start_idx) / 2)
|
113 |
+
embed_list.append(encoder_embeddings[:emb_end_idx])
|
114 |
+
concatenated_array = torch.from_numpy(np.concatenate(embed_list, axis=0))
|
115 |
+
return concatenated_array
|
116 |
+
|
117 |
+
def audio2feat(self, audio_path):
|
118 |
+
if self.audio_embeds_cache_dir == "" or self.audio_embeds_cache_dir is None:
|
119 |
+
return self._audio2feat(audio_path)
|
120 |
+
|
121 |
+
audio_embeds_cache_path = os.path.join(self.audio_embeds_cache_dir, os.path.basename(audio_path) + ".pt")
|
122 |
+
|
123 |
+
if os.path.isfile(audio_embeds_cache_path):
|
124 |
+
try:
|
125 |
+
audio_feat = torch.load(audio_embeds_cache_path)
|
126 |
+
except Exception as e:
|
127 |
+
print(f"{type(e).__name__} - {e} - {audio_embeds_cache_path}")
|
128 |
+
os.remove(audio_embeds_cache_path)
|
129 |
+
audio_feat = self._audio2feat(audio_path)
|
130 |
+
torch.save(audio_feat, audio_embeds_cache_path)
|
131 |
+
else:
|
132 |
+
audio_feat = self._audio2feat(audio_path)
|
133 |
+
torch.save(audio_feat, audio_embeds_cache_path)
|
134 |
+
|
135 |
+
return audio_feat
|
136 |
+
|
137 |
+
def crop_overlap_audio_window(self, audio_feat, start_index):
|
138 |
+
selected_feature_list = []
|
139 |
+
for i in range(start_index, start_index + self.num_frames):
|
140 |
+
selected_feature, selected_idx = self.get_sliced_feature(
|
141 |
+
feature_array=audio_feat, vid_idx=i, audio_feat_length=[2, 2], fps=25
|
142 |
+
)
|
143 |
+
selected_feature_list.append(selected_feature)
|
144 |
+
mel_overlap = torch.stack(selected_feature_list)
|
145 |
+
return mel_overlap
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
audio_encoder = Audio2Feature(model_path="checkpoints/whisper/tiny.pt")
|
150 |
+
audio_path = "assets/demo1_audio.wav"
|
151 |
+
array = audio_encoder.audio2feat(audio_path)
|
152 |
+
print(array.shape)
|
153 |
+
fps = 25
|
154 |
+
whisper_idx_multiplier = 50.0 / fps
|
155 |
+
|
156 |
+
i = 0
|
157 |
+
print(f"video in {fps} FPS, audio idx in 50FPS")
|
158 |
+
while True:
|
159 |
+
start_idx = int(i * whisper_idx_multiplier)
|
160 |
+
selected_feature, selected_idx = audio_encoder.get_sliced_feature(
|
161 |
+
feature_array=array, vid_idx=i, audio_feat_length=[2, 2], fps=fps
|
162 |
+
)
|
163 |
+
print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
|
164 |
+
i += 1
|
165 |
+
if start_idx > len(array):
|
166 |
+
break
|
latentsync/whisper/whisper/__init__.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import urllib
|
5 |
+
import warnings
|
6 |
+
from typing import List, Optional, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
12 |
+
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
13 |
+
from .model import Whisper, ModelDimensions
|
14 |
+
from .transcribe import transcribe
|
15 |
+
|
16 |
+
|
17 |
+
_MODELS = {
|
18 |
+
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
19 |
+
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
20 |
+
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
21 |
+
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
22 |
+
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
23 |
+
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
24 |
+
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
25 |
+
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
26 |
+
"large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
|
27 |
+
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
28 |
+
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
29 |
+
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
34 |
+
os.makedirs(root, exist_ok=True)
|
35 |
+
|
36 |
+
expected_sha256 = url.split("/")[-2]
|
37 |
+
download_target = os.path.join(root, os.path.basename(url))
|
38 |
+
|
39 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
40 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
41 |
+
|
42 |
+
if os.path.isfile(download_target):
|
43 |
+
model_bytes = open(download_target, "rb").read()
|
44 |
+
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
45 |
+
return model_bytes if in_memory else download_target
|
46 |
+
else:
|
47 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
48 |
+
|
49 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
50 |
+
with tqdm(
|
51 |
+
total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024
|
52 |
+
) as loop:
|
53 |
+
while True:
|
54 |
+
buffer = source.read(8192)
|
55 |
+
if not buffer:
|
56 |
+
break
|
57 |
+
|
58 |
+
output.write(buffer)
|
59 |
+
loop.update(len(buffer))
|
60 |
+
|
61 |
+
model_bytes = open(download_target, "rb").read()
|
62 |
+
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
63 |
+
raise RuntimeError(
|
64 |
+
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
65 |
+
)
|
66 |
+
|
67 |
+
return model_bytes if in_memory else download_target
|
68 |
+
|
69 |
+
|
70 |
+
def available_models() -> List[str]:
|
71 |
+
"""Returns the names of available models"""
|
72 |
+
return list(_MODELS.keys())
|
73 |
+
|
74 |
+
|
75 |
+
def load_model(
|
76 |
+
name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False
|
77 |
+
) -> Whisper:
|
78 |
+
"""
|
79 |
+
Load a Whisper ASR model
|
80 |
+
|
81 |
+
Parameters
|
82 |
+
----------
|
83 |
+
name : str
|
84 |
+
one of the official model names listed by `whisper.available_models()`, or
|
85 |
+
path to a model checkpoint containing the model dimensions and the model state_dict.
|
86 |
+
device : Union[str, torch.device]
|
87 |
+
the PyTorch device to put the model into
|
88 |
+
download_root: str
|
89 |
+
path to download the model files; by default, it uses "~/.cache/whisper"
|
90 |
+
in_memory: bool
|
91 |
+
whether to preload the model weights into host memory
|
92 |
+
|
93 |
+
Returns
|
94 |
+
-------
|
95 |
+
model : Whisper
|
96 |
+
The Whisper ASR model instance
|
97 |
+
"""
|
98 |
+
|
99 |
+
if device is None:
|
100 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
101 |
+
if download_root is None:
|
102 |
+
download_root = os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
|
103 |
+
|
104 |
+
if name in _MODELS:
|
105 |
+
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
106 |
+
elif os.path.isfile(name):
|
107 |
+
checkpoint_file = open(name, "rb").read() if in_memory else name
|
108 |
+
else:
|
109 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
110 |
+
|
111 |
+
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
|
112 |
+
checkpoint = torch.load(fp, map_location=device)
|
113 |
+
del checkpoint_file
|
114 |
+
|
115 |
+
dims = ModelDimensions(**checkpoint["dims"])
|
116 |
+
model = Whisper(dims)
|
117 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
118 |
+
|
119 |
+
return model.to(device)
|
latentsync/whisper/whisper/__main__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .transcribe import cli
|
2 |
+
|
3 |
+
|
4 |
+
cli()
|
latentsync/whisper/whisper/assets/gpt2/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
latentsync/whisper/whisper/assets/gpt2/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
latentsync/whisper/whisper/assets/gpt2/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
|
latentsync/whisper/whisper/assets/gpt2/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
latentsync/whisper/whisper/assets/mel_filters.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd2cc75e70e36fcbdd8ffbc2499062f30094093e6bf2cbafa9859f59972b420b
|
3 |
+
size 2048
|
latentsync/whisper/whisper/assets/multilingual/added_tokens.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"<|endoftext|>": 50257}
|
latentsync/whisper/whisper/assets/multilingual/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
latentsync/whisper/whisper/assets/multilingual/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
latentsync/whisper/whisper/assets/multilingual/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
|
latentsync/whisper/whisper/assets/multilingual/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
latentsync/whisper/whisper/audio.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import lru_cache
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import ffmpeg
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from .utils import exact_div
|
11 |
+
|
12 |
+
# hard-coded audio hyperparameters
|
13 |
+
SAMPLE_RATE = 16000
|
14 |
+
N_FFT = 400
|
15 |
+
N_MELS = 80
|
16 |
+
HOP_LENGTH = 160
|
17 |
+
CHUNK_LENGTH = 30
|
18 |
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
19 |
+
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
|
20 |
+
|
21 |
+
|
22 |
+
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
23 |
+
"""
|
24 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
25 |
+
|
26 |
+
Parameters
|
27 |
+
----------
|
28 |
+
file: str
|
29 |
+
The audio file to open
|
30 |
+
|
31 |
+
sr: int
|
32 |
+
The sample rate to resample the audio if necessary
|
33 |
+
|
34 |
+
Returns
|
35 |
+
-------
|
36 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
37 |
+
"""
|
38 |
+
try:
|
39 |
+
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
40 |
+
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
41 |
+
out, _ = (
|
42 |
+
ffmpeg.input(file, threads=0)
|
43 |
+
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
|
44 |
+
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
45 |
+
)
|
46 |
+
except ffmpeg.Error as e:
|
47 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
48 |
+
|
49 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
50 |
+
|
51 |
+
|
52 |
+
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
53 |
+
"""
|
54 |
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
55 |
+
"""
|
56 |
+
if torch.is_tensor(array):
|
57 |
+
if array.shape[axis] > length:
|
58 |
+
array = array.index_select(dim=axis, index=torch.arange(length))
|
59 |
+
|
60 |
+
if array.shape[axis] < length:
|
61 |
+
pad_widths = [(0, 0)] * array.ndim
|
62 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
63 |
+
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
64 |
+
else:
|
65 |
+
if array.shape[axis] > length:
|
66 |
+
array = array.take(indices=range(length), axis=axis)
|
67 |
+
|
68 |
+
if array.shape[axis] < length:
|
69 |
+
pad_widths = [(0, 0)] * array.ndim
|
70 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
71 |
+
array = np.pad(array, pad_widths)
|
72 |
+
|
73 |
+
return array
|
74 |
+
|
75 |
+
|
76 |
+
@lru_cache(maxsize=None)
|
77 |
+
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
78 |
+
"""
|
79 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
80 |
+
Allows decoupling librosa dependency; saved using:
|
81 |
+
|
82 |
+
np.savez_compressed(
|
83 |
+
"mel_filters.npz",
|
84 |
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
85 |
+
)
|
86 |
+
"""
|
87 |
+
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
88 |
+
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
|
89 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
90 |
+
|
91 |
+
|
92 |
+
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
|
93 |
+
"""
|
94 |
+
Compute the log-Mel spectrogram of
|
95 |
+
|
96 |
+
Parameters
|
97 |
+
----------
|
98 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
99 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
100 |
+
|
101 |
+
n_mels: int
|
102 |
+
The number of Mel-frequency filters, only 80 is supported
|
103 |
+
|
104 |
+
Returns
|
105 |
+
-------
|
106 |
+
torch.Tensor, shape = (80, n_frames)
|
107 |
+
A Tensor that contains the Mel spectrogram
|
108 |
+
"""
|
109 |
+
if not torch.is_tensor(audio):
|
110 |
+
if isinstance(audio, str):
|
111 |
+
audio = load_audio(audio)
|
112 |
+
audio = torch.from_numpy(audio)
|
113 |
+
|
114 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
115 |
+
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
116 |
+
|
117 |
+
magnitudes = stft[:, :-1].abs() ** 2
|
118 |
+
|
119 |
+
filters = mel_filters(audio.device, n_mels)
|
120 |
+
mel_spec = filters @ magnitudes
|
121 |
+
|
122 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
123 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
124 |
+
log_spec = (log_spec + 4.0) / 4.0
|
125 |
+
return log_spec
|
latentsync/whisper/whisper/decoding.py
ADDED
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import Tensor
|
8 |
+
from torch.distributions import Categorical
|
9 |
+
|
10 |
+
from .audio import CHUNK_LENGTH
|
11 |
+
from .tokenizer import Tokenizer, get_tokenizer
|
12 |
+
from .utils import compression_ratio
|
13 |
+
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
from .model import Whisper
|
16 |
+
|
17 |
+
|
18 |
+
@torch.no_grad()
|
19 |
+
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
|
20 |
+
"""
|
21 |
+
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
22 |
+
of the most probable language tokens and the probability distribution over all language tokens.
|
23 |
+
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
24 |
+
|
25 |
+
Returns
|
26 |
+
-------
|
27 |
+
language_tokens : Tensor, shape = (n_audio,)
|
28 |
+
ids of the most probable language tokens, which appears after the startoftranscript token.
|
29 |
+
language_probs : List[Dict[str, float]], length = n_audio
|
30 |
+
list of dictionaries containing the probability distribution over all languages.
|
31 |
+
"""
|
32 |
+
if tokenizer is None:
|
33 |
+
tokenizer = get_tokenizer(model.is_multilingual)
|
34 |
+
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
|
35 |
+
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
|
36 |
+
|
37 |
+
single = mel.ndim == 2
|
38 |
+
if single:
|
39 |
+
mel = mel.unsqueeze(0)
|
40 |
+
|
41 |
+
# skip encoder forward pass if already-encoded audio features were given
|
42 |
+
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
43 |
+
mel = model.encoder(mel)
|
44 |
+
|
45 |
+
# forward pass using a single token, startoftranscript
|
46 |
+
n_audio = mel.shape[0]
|
47 |
+
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
48 |
+
logits = model.logits(x, mel)[:, 0]
|
49 |
+
|
50 |
+
# collect detected languages; suppress all non-language tokens
|
51 |
+
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
52 |
+
mask[list(tokenizer.all_language_tokens)] = False
|
53 |
+
logits[:, mask] = -np.inf
|
54 |
+
language_tokens = logits.argmax(dim=-1)
|
55 |
+
language_token_probs = logits.softmax(dim=-1).cpu()
|
56 |
+
language_probs = [
|
57 |
+
{
|
58 |
+
c: language_token_probs[i, j].item()
|
59 |
+
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
60 |
+
}
|
61 |
+
for i in range(n_audio)
|
62 |
+
]
|
63 |
+
|
64 |
+
if single:
|
65 |
+
language_tokens = language_tokens[0]
|
66 |
+
language_probs = language_probs[0]
|
67 |
+
|
68 |
+
return language_tokens, language_probs
|
69 |
+
|
70 |
+
|
71 |
+
@dataclass(frozen=True)
|
72 |
+
class DecodingOptions:
|
73 |
+
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
74 |
+
language: Optional[str] = None # language that the audio is in; uses detected language if None
|
75 |
+
|
76 |
+
# sampling-related options
|
77 |
+
temperature: float = 0.0
|
78 |
+
sample_len: Optional[int] = None # maximum number of tokens to sample
|
79 |
+
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
|
80 |
+
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
81 |
+
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
82 |
+
|
83 |
+
# options for ranking generations (either beams or best-of-N samples)
|
84 |
+
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
|
85 |
+
|
86 |
+
# prompt, prefix, and token suppression
|
87 |
+
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
|
88 |
+
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
|
89 |
+
suppress_blank: bool = True # this will suppress blank outputs
|
90 |
+
|
91 |
+
# list of tokens ids (or comma-separated token ids) to suppress
|
92 |
+
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
93 |
+
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
94 |
+
|
95 |
+
# timestamp sampling options
|
96 |
+
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
97 |
+
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
|
98 |
+
|
99 |
+
# implementation details
|
100 |
+
fp16: bool = True # use fp16 for most of the calculation
|
101 |
+
|
102 |
+
|
103 |
+
@dataclass(frozen=True)
|
104 |
+
class DecodingResult:
|
105 |
+
audio_features: Tensor
|
106 |
+
language: str
|
107 |
+
encoder_embeddings: np.ndarray
|
108 |
+
decoder_embeddings: np.ndarray
|
109 |
+
language_probs: Optional[Dict[str, float]] = None
|
110 |
+
tokens: List[int] = field(default_factory=list)
|
111 |
+
text: str = ""
|
112 |
+
avg_logprob: float = np.nan
|
113 |
+
no_speech_prob: float = np.nan
|
114 |
+
temperature: float = np.nan
|
115 |
+
compression_ratio: float = np.nan
|
116 |
+
|
117 |
+
|
118 |
+
class Inference:
|
119 |
+
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
120 |
+
"""Perform a forward pass on the decoder and return per-token logits"""
|
121 |
+
raise NotImplementedError
|
122 |
+
|
123 |
+
def rearrange_kv_cache(self, source_indices) -> None:
|
124 |
+
"""Update the key-value cache according to the updated beams"""
|
125 |
+
raise NotImplementedError
|
126 |
+
|
127 |
+
def cleanup_caching(self) -> None:
|
128 |
+
"""Clean up any resources or hooks after decoding is finished"""
|
129 |
+
pass
|
130 |
+
|
131 |
+
|
132 |
+
class PyTorchInference(Inference):
|
133 |
+
def __init__(self, model: "Whisper", initial_token_length: int):
|
134 |
+
self.model: "Whisper" = model
|
135 |
+
self.initial_token_length = initial_token_length
|
136 |
+
self.kv_cache = {}
|
137 |
+
self.hooks = []
|
138 |
+
|
139 |
+
def logits(self, tokens: Tensor, audio_features: Tensor, include_embeddings=False) -> Tensor:
|
140 |
+
if not self.kv_cache:
|
141 |
+
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
142 |
+
|
143 |
+
if tokens.shape[-1] > self.initial_token_length:
|
144 |
+
# only need to use the last token except in the first forward pass
|
145 |
+
tokens = tokens[:, -1:]
|
146 |
+
|
147 |
+
return_val = self.model.decoder(tokens, audio_features,
|
148 |
+
kv_cache=self.kv_cache, include_embeddings=include_embeddings)
|
149 |
+
return return_val
|
150 |
+
|
151 |
+
def cleanup_caching(self):
|
152 |
+
for hook in self.hooks:
|
153 |
+
hook.remove()
|
154 |
+
|
155 |
+
self.kv_cache = {}
|
156 |
+
self.hooks = []
|
157 |
+
|
158 |
+
def rearrange_kv_cache(self, source_indices):
|
159 |
+
for module, tensor in self.kv_cache.items():
|
160 |
+
# update the key/value cache to contain the selected sequences
|
161 |
+
self.kv_cache[module] = tensor[source_indices].detach()
|
162 |
+
|
163 |
+
|
164 |
+
class SequenceRanker:
|
165 |
+
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
|
166 |
+
"""
|
167 |
+
Given a list of groups of samples and their cumulative log probabilities,
|
168 |
+
return the indices of the samples in each group to select as the final result
|
169 |
+
"""
|
170 |
+
raise NotImplementedError
|
171 |
+
|
172 |
+
|
173 |
+
class MaximumLikelihoodRanker(SequenceRanker):
|
174 |
+
"""
|
175 |
+
Select the sample with the highest log probabilities, penalized using either
|
176 |
+
a simple length normalization or Google NMT paper's length penalty
|
177 |
+
"""
|
178 |
+
|
179 |
+
def __init__(self, length_penalty: Optional[float]):
|
180 |
+
self.length_penalty = length_penalty
|
181 |
+
|
182 |
+
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
183 |
+
def scores(logprobs, lengths):
|
184 |
+
result = []
|
185 |
+
for logprob, length in zip(logprobs, lengths):
|
186 |
+
if self.length_penalty is None:
|
187 |
+
penalty = length
|
188 |
+
else:
|
189 |
+
# from the Google NMT paper
|
190 |
+
penalty = ((5 + length) / 6) ** self.length_penalty
|
191 |
+
result.append(logprob / penalty)
|
192 |
+
return result
|
193 |
+
|
194 |
+
# get the sequence with the highest score
|
195 |
+
lengths = [[len(t) for t in s] for s in tokens]
|
196 |
+
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
197 |
+
|
198 |
+
|
199 |
+
class TokenDecoder:
|
200 |
+
def reset(self):
|
201 |
+
"""Initialize any stateful variables for decoding a new sequence"""
|
202 |
+
|
203 |
+
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
204 |
+
"""Specify how to select the next token, based on the current trace and logits
|
205 |
+
|
206 |
+
Parameters
|
207 |
+
----------
|
208 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
209 |
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
210 |
+
|
211 |
+
logits : Tensor, shape = (n_batch, vocab_size)
|
212 |
+
per-token logits of the probability distribution at the current step
|
213 |
+
|
214 |
+
sum_logprobs : Tensor, shape = (n_batch)
|
215 |
+
cumulative log probabilities for each sequence
|
216 |
+
|
217 |
+
Returns
|
218 |
+
-------
|
219 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
220 |
+
the tokens, appended with the selected next token
|
221 |
+
|
222 |
+
completed : bool
|
223 |
+
True if all sequences has reached the end of text
|
224 |
+
|
225 |
+
"""
|
226 |
+
raise NotImplementedError
|
227 |
+
|
228 |
+
def finalize(
|
229 |
+
self, tokens: Tensor, sum_logprobs: Tensor
|
230 |
+
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
231 |
+
"""Finalize search and return the final candidate sequences
|
232 |
+
|
233 |
+
Parameters
|
234 |
+
----------
|
235 |
+
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
236 |
+
all tokens in the context so far, including the prefix and sot_sequence
|
237 |
+
|
238 |
+
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
239 |
+
cumulative log probabilities for each sequence
|
240 |
+
|
241 |
+
Returns
|
242 |
+
-------
|
243 |
+
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
244 |
+
sequence of Tensors containing candidate token sequences, for each audio input
|
245 |
+
|
246 |
+
sum_logprobs : List[List[float]], length = n_audio
|
247 |
+
sequence of cumulative log probabilities corresponding to the above
|
248 |
+
|
249 |
+
"""
|
250 |
+
raise NotImplementedError
|
251 |
+
|
252 |
+
|
253 |
+
class GreedyDecoder(TokenDecoder):
|
254 |
+
def __init__(self, temperature: float, eot: int):
|
255 |
+
self.temperature = temperature
|
256 |
+
self.eot = eot
|
257 |
+
|
258 |
+
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
259 |
+
temperature = self.temperature
|
260 |
+
if temperature == 0:
|
261 |
+
next_tokens = logits.argmax(dim=-1)
|
262 |
+
else:
|
263 |
+
next_tokens = Categorical(logits=logits / temperature).sample()
|
264 |
+
|
265 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
266 |
+
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
267 |
+
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
268 |
+
|
269 |
+
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
270 |
+
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
271 |
+
|
272 |
+
completed = (tokens[:, -1] == self.eot).all()
|
273 |
+
return tokens, completed
|
274 |
+
|
275 |
+
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
276 |
+
# make sure each sequence has at least one EOT token at the end
|
277 |
+
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
278 |
+
return tokens, sum_logprobs.tolist()
|
279 |
+
|
280 |
+
|
281 |
+
class BeamSearchDecoder(TokenDecoder):
|
282 |
+
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
|
283 |
+
self.beam_size = beam_size
|
284 |
+
self.eot = eot
|
285 |
+
self.inference = inference
|
286 |
+
self.patience = patience or 1.0
|
287 |
+
self.max_candidates: int = round(beam_size * self.patience)
|
288 |
+
self.finished_sequences = None
|
289 |
+
|
290 |
+
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
|
291 |
+
|
292 |
+
def reset(self):
|
293 |
+
self.finished_sequences = None
|
294 |
+
|
295 |
+
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
296 |
+
if tokens.shape[0] % self.beam_size != 0:
|
297 |
+
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
298 |
+
|
299 |
+
n_audio = tokens.shape[0] // self.beam_size
|
300 |
+
if self.finished_sequences is None: # for the first update
|
301 |
+
self.finished_sequences = [{} for _ in range(n_audio)]
|
302 |
+
|
303 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
304 |
+
next_tokens, source_indices, finished_sequences = [], [], []
|
305 |
+
for i in range(n_audio):
|
306 |
+
scores, sources, finished = {}, {}, {}
|
307 |
+
|
308 |
+
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
309 |
+
for j in range(self.beam_size):
|
310 |
+
idx = i * self.beam_size + j
|
311 |
+
prefix = tokens[idx].tolist()
|
312 |
+
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
313 |
+
new_logprob = (sum_logprobs[idx] + logprob).item()
|
314 |
+
sequence = tuple(prefix + [token.item()])
|
315 |
+
scores[sequence] = new_logprob
|
316 |
+
sources[sequence] = idx
|
317 |
+
|
318 |
+
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
319 |
+
saved = 0
|
320 |
+
for sequence in sorted(scores, key=scores.get, reverse=True):
|
321 |
+
if sequence[-1] == self.eot:
|
322 |
+
finished[sequence] = scores[sequence]
|
323 |
+
else:
|
324 |
+
sum_logprobs[len(next_tokens)] = scores[sequence]
|
325 |
+
next_tokens.append(sequence)
|
326 |
+
source_indices.append(sources[sequence])
|
327 |
+
|
328 |
+
saved += 1
|
329 |
+
if saved == self.beam_size:
|
330 |
+
break
|
331 |
+
|
332 |
+
finished_sequences.append(finished)
|
333 |
+
|
334 |
+
tokens = torch.tensor(next_tokens, device=tokens.device)
|
335 |
+
self.inference.rearrange_kv_cache(source_indices)
|
336 |
+
|
337 |
+
# add newly finished sequences to self.finished_sequences
|
338 |
+
assert len(self.finished_sequences) == len(finished_sequences)
|
339 |
+
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
340 |
+
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
341 |
+
if len(previously_finished) >= self.max_candidates:
|
342 |
+
break # the candidate list is full
|
343 |
+
previously_finished[seq] = newly_finished[seq]
|
344 |
+
|
345 |
+
# mark as completed if all audio has enough number of samples
|
346 |
+
completed = all(
|
347 |
+
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
|
348 |
+
)
|
349 |
+
return tokens, completed
|
350 |
+
|
351 |
+
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
352 |
+
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
353 |
+
sum_logprobs = sum_logprobs.cpu()
|
354 |
+
for i, sequences in enumerate(self.finished_sequences):
|
355 |
+
if len(sequences) < self.beam_size: # when not enough sequences are finished
|
356 |
+
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
357 |
+
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
358 |
+
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
359 |
+
if len(sequences) >= self.beam_size:
|
360 |
+
break
|
361 |
+
|
362 |
+
tokens: List[List[Tensor]] = [
|
363 |
+
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
|
364 |
+
]
|
365 |
+
sum_logprobs: List[List[float]] = [
|
366 |
+
list(sequences.values()) for sequences in self.finished_sequences
|
367 |
+
]
|
368 |
+
return tokens, sum_logprobs
|
369 |
+
|
370 |
+
|
371 |
+
class LogitFilter:
|
372 |
+
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
373 |
+
"""Apply any filtering or masking to logits in-place
|
374 |
+
|
375 |
+
Parameters
|
376 |
+
----------
|
377 |
+
logits : Tensor, shape = (n_batch, vocab_size)
|
378 |
+
per-token logits of the probability distribution at the current step
|
379 |
+
|
380 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
381 |
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
382 |
+
|
383 |
+
"""
|
384 |
+
raise NotImplementedError
|
385 |
+
|
386 |
+
|
387 |
+
class SuppressBlank(LogitFilter):
|
388 |
+
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
389 |
+
self.tokenizer = tokenizer
|
390 |
+
self.sample_begin = sample_begin
|
391 |
+
|
392 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
393 |
+
if tokens.shape[1] == self.sample_begin:
|
394 |
+
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
395 |
+
|
396 |
+
|
397 |
+
class SuppressTokens(LogitFilter):
|
398 |
+
def __init__(self, suppress_tokens: Sequence[int]):
|
399 |
+
self.suppress_tokens = list(suppress_tokens)
|
400 |
+
|
401 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
402 |
+
logits[:, self.suppress_tokens] = -np.inf
|
403 |
+
|
404 |
+
|
405 |
+
class ApplyTimestampRules(LogitFilter):
|
406 |
+
def __init__(
|
407 |
+
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
|
408 |
+
):
|
409 |
+
self.tokenizer = tokenizer
|
410 |
+
self.sample_begin = sample_begin
|
411 |
+
self.max_initial_timestamp_index = max_initial_timestamp_index
|
412 |
+
|
413 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
414 |
+
# suppress <|notimestamps|> which is handled by without_timestamps
|
415 |
+
if self.tokenizer.no_timestamps is not None:
|
416 |
+
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
417 |
+
|
418 |
+
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
419 |
+
for k in range(tokens.shape[0]):
|
420 |
+
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
421 |
+
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
422 |
+
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
423 |
+
|
424 |
+
if last_was_timestamp:
|
425 |
+
if penultimate_was_timestamp: # has to be non-timestamp
|
426 |
+
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
427 |
+
else: # cannot be normal text tokens
|
428 |
+
logits[k, : self.tokenizer.eot] = -np.inf
|
429 |
+
|
430 |
+
# apply the `max_initial_timestamp` option
|
431 |
+
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
|
432 |
+
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
433 |
+
logits[:, last_allowed + 1 :] = -np.inf
|
434 |
+
|
435 |
+
# if sum of probability over timestamps is above any other token, sample timestamp
|
436 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
437 |
+
for k in range(tokens.shape[0]):
|
438 |
+
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
|
439 |
+
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
440 |
+
if timestamp_logprob > max_text_token_logprob:
|
441 |
+
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
442 |
+
|
443 |
+
|
444 |
+
class DecodingTask:
|
445 |
+
inference: Inference
|
446 |
+
sequence_ranker: SequenceRanker
|
447 |
+
decoder: TokenDecoder
|
448 |
+
logit_filters: List[LogitFilter]
|
449 |
+
|
450 |
+
def __init__(self, model: "Whisper", options: DecodingOptions):
|
451 |
+
self.model = model
|
452 |
+
|
453 |
+
language = options.language or "en"
|
454 |
+
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
|
455 |
+
self.tokenizer: Tokenizer = tokenizer
|
456 |
+
self.options: DecodingOptions = self._verify_options(options)
|
457 |
+
|
458 |
+
self.n_group: int = options.beam_size or options.best_of or 1
|
459 |
+
self.n_ctx: int = model.dims.n_text_ctx
|
460 |
+
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
461 |
+
|
462 |
+
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
463 |
+
if self.options.without_timestamps:
|
464 |
+
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
465 |
+
|
466 |
+
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
467 |
+
self.sample_begin: int = len(self.initial_tokens)
|
468 |
+
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
469 |
+
|
470 |
+
# inference: implements the forward pass through the decoder, including kv caching
|
471 |
+
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
472 |
+
|
473 |
+
# sequence ranker: implements how to rank a group of sampled sequences
|
474 |
+
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
475 |
+
|
476 |
+
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
477 |
+
if options.beam_size is not None:
|
478 |
+
self.decoder = BeamSearchDecoder(
|
479 |
+
options.beam_size, tokenizer.eot, self.inference, options.patience
|
480 |
+
)
|
481 |
+
else:
|
482 |
+
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
483 |
+
|
484 |
+
# logit filters: applies various rules to suppress or penalize certain tokens
|
485 |
+
self.logit_filters = []
|
486 |
+
if self.options.suppress_blank:
|
487 |
+
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
488 |
+
if self.options.suppress_tokens:
|
489 |
+
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
490 |
+
if not options.without_timestamps:
|
491 |
+
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
492 |
+
max_initial_timestamp_index = None
|
493 |
+
if options.max_initial_timestamp:
|
494 |
+
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
|
495 |
+
self.logit_filters.append(
|
496 |
+
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
|
497 |
+
)
|
498 |
+
|
499 |
+
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
500 |
+
if options.beam_size is not None and options.best_of is not None:
|
501 |
+
raise ValueError("beam_size and best_of can't be given together")
|
502 |
+
if options.temperature == 0:
|
503 |
+
if options.best_of is not None:
|
504 |
+
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
505 |
+
if options.patience is not None and options.beam_size is None:
|
506 |
+
raise ValueError("patience requires beam_size to be given")
|
507 |
+
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
|
508 |
+
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
509 |
+
|
510 |
+
return options
|
511 |
+
|
512 |
+
def _get_initial_tokens(self) -> Tuple[int]:
|
513 |
+
tokens = list(self.sot_sequence)
|
514 |
+
prefix = self.options.prefix
|
515 |
+
prompt = self.options.prompt
|
516 |
+
|
517 |
+
if prefix:
|
518 |
+
prefix_tokens = (
|
519 |
+
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
|
520 |
+
)
|
521 |
+
if self.sample_len is not None:
|
522 |
+
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
523 |
+
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
524 |
+
tokens = tokens + prefix_tokens
|
525 |
+
|
526 |
+
if prompt:
|
527 |
+
prompt_tokens = (
|
528 |
+
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
|
529 |
+
)
|
530 |
+
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
|
531 |
+
|
532 |
+
return tuple(tokens)
|
533 |
+
|
534 |
+
def _get_suppress_tokens(self) -> Tuple[int]:
|
535 |
+
suppress_tokens = self.options.suppress_tokens
|
536 |
+
|
537 |
+
if isinstance(suppress_tokens, str):
|
538 |
+
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
539 |
+
|
540 |
+
if -1 in suppress_tokens:
|
541 |
+
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
542 |
+
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
543 |
+
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
544 |
+
suppress_tokens = [] # interpret empty string as an empty list
|
545 |
+
else:
|
546 |
+
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
547 |
+
|
548 |
+
suppress_tokens.extend(
|
549 |
+
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
550 |
+
)
|
551 |
+
if self.tokenizer.no_speech is not None:
|
552 |
+
# no-speech probability is collected separately
|
553 |
+
suppress_tokens.append(self.tokenizer.no_speech)
|
554 |
+
|
555 |
+
return tuple(sorted(set(suppress_tokens)))
|
556 |
+
|
557 |
+
def _get_audio_features(self, mel: Tensor, include_embeddings: bool = False):
|
558 |
+
if self.options.fp16:
|
559 |
+
mel = mel.half()
|
560 |
+
|
561 |
+
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
562 |
+
# encoded audio features are given; skip audio encoding
|
563 |
+
audio_features = mel
|
564 |
+
else:
|
565 |
+
result = self.model.encoder(mel, include_embeddings)
|
566 |
+
if include_embeddings:
|
567 |
+
audio_features, embeddings = result
|
568 |
+
else:
|
569 |
+
audio_features = result
|
570 |
+
|
571 |
+
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
|
572 |
+
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
573 |
+
|
574 |
+
if include_embeddings:
|
575 |
+
return audio_features, embeddings
|
576 |
+
else:
|
577 |
+
return audio_features
|
578 |
+
|
579 |
+
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
580 |
+
languages = [self.options.language] * audio_features.shape[0]
|
581 |
+
lang_probs = None
|
582 |
+
|
583 |
+
if self.options.language is None or self.options.task == "lang_id":
|
584 |
+
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
585 |
+
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
586 |
+
if self.options.language is None:
|
587 |
+
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
588 |
+
|
589 |
+
return languages, lang_probs
|
590 |
+
|
591 |
+
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
592 |
+
assert audio_features.shape[0] == tokens.shape[0]
|
593 |
+
n_batch = tokens.shape[0]
|
594 |
+
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
595 |
+
no_speech_probs = [np.nan] * n_batch
|
596 |
+
|
597 |
+
try:
|
598 |
+
embeddings = []
|
599 |
+
for i in range(self.sample_len):
|
600 |
+
logits, token_embeddings = self.inference.logits(tokens, audio_features, include_embeddings=True)
|
601 |
+
|
602 |
+
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
603 |
+
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
604 |
+
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
605 |
+
|
606 |
+
# now we need to consider the logits at the last token only
|
607 |
+
logits = logits[:, -1]
|
608 |
+
token_embeddings = token_embeddings[:, :, -1]
|
609 |
+
|
610 |
+
# Append embeddings together
|
611 |
+
embeddings.append(token_embeddings)
|
612 |
+
|
613 |
+
# apply the logit filters, e.g. for suppressing or applying penalty to
|
614 |
+
for logit_filter in self.logit_filters:
|
615 |
+
logit_filter.apply(logits, tokens)
|
616 |
+
|
617 |
+
# expand the tokens tensor with the selected next tokens
|
618 |
+
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
619 |
+
|
620 |
+
if completed or tokens.shape[-1] > self.n_ctx:
|
621 |
+
break
|
622 |
+
finally:
|
623 |
+
if completed:
|
624 |
+
embeddings = embeddings[:-1]
|
625 |
+
embeddings = np.stack(embeddings, 2)
|
626 |
+
self.inference.cleanup_caching()
|
627 |
+
|
628 |
+
return tokens, sum_logprobs, no_speech_probs, embeddings
|
629 |
+
|
630 |
+
@torch.no_grad()
|
631 |
+
def run(self, mel: Tensor) -> List[DecodingResult]:
|
632 |
+
self.decoder.reset()
|
633 |
+
tokenizer: Tokenizer = self.tokenizer
|
634 |
+
n_audio: int = mel.shape[0]
|
635 |
+
|
636 |
+
# encoder forward pass
|
637 |
+
forward_pass: Tuple[Tensor, np.ndarray] = self._get_audio_features(mel, include_embeddings=True)
|
638 |
+
audio_features, encoder_embeddings = forward_pass
|
639 |
+
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
640 |
+
|
641 |
+
# detect language if requested, overwriting the language token
|
642 |
+
languages, language_probs = self._detect_language(audio_features, tokens)
|
643 |
+
if self.options.task == "lang_id":
|
644 |
+
return [
|
645 |
+
DecodingResult(audio_features=features, language=language, language_probs=probs)
|
646 |
+
for features, language, probs in zip(audio_features, languages, language_probs)
|
647 |
+
]
|
648 |
+
|
649 |
+
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
650 |
+
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
|
651 |
+
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
652 |
+
|
653 |
+
# call the main sampling loop
|
654 |
+
tokens, sum_logprobs, no_speech_probs, decoder_embeddings = self._main_loop(audio_features, tokens)
|
655 |
+
|
656 |
+
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
657 |
+
audio_features = audio_features[:: self.n_group]
|
658 |
+
no_speech_probs = no_speech_probs[:: self.n_group]
|
659 |
+
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
660 |
+
|
661 |
+
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
662 |
+
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
663 |
+
|
664 |
+
# get the final candidates for each group, and slice between the first sampled token and EOT
|
665 |
+
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
666 |
+
tokens: List[List[Tensor]] = [
|
667 |
+
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
|
668 |
+
]
|
669 |
+
|
670 |
+
# select the top-ranked sample in each group
|
671 |
+
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
672 |
+
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
673 |
+
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
674 |
+
|
675 |
+
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
676 |
+
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
677 |
+
|
678 |
+
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
679 |
+
if len(set(map(len, fields))) != 1:
|
680 |
+
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
681 |
+
|
682 |
+
return [
|
683 |
+
DecodingResult(
|
684 |
+
audio_features=features,
|
685 |
+
language=language,
|
686 |
+
tokens=tokens,
|
687 |
+
text=text,
|
688 |
+
avg_logprob=avg_logprob,
|
689 |
+
no_speech_prob=no_speech_prob,
|
690 |
+
temperature=self.options.temperature,
|
691 |
+
compression_ratio=compression_ratio(text),
|
692 |
+
encoder_embeddings=encoder_embeddings,
|
693 |
+
decoder_embeddings=decoder_embeddings
|
694 |
+
)
|
695 |
+
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
696 |
+
]
|
697 |
+
|
698 |
+
|
699 |
+
@torch.no_grad()
|
700 |
+
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
|
701 |
+
"""
|
702 |
+
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
703 |
+
|
704 |
+
Parameters
|
705 |
+
----------
|
706 |
+
model: Whisper
|
707 |
+
the Whisper model instance
|
708 |
+
|
709 |
+
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
710 |
+
A tensor containing the Mel spectrogram(s)
|
711 |
+
|
712 |
+
options: DecodingOptions
|
713 |
+
A dataclass that contains all necessary options for decoding 30-second segments
|
714 |
+
|
715 |
+
Returns
|
716 |
+
-------
|
717 |
+
result: Union[DecodingResult, List[DecodingResult]]
|
718 |
+
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
719 |
+
"""
|
720 |
+
single = mel.ndim == 2
|
721 |
+
if single:
|
722 |
+
mel = mel.unsqueeze(0)
|
723 |
+
|
724 |
+
result = DecodingTask(model, options).run(mel)
|
725 |
+
|
726 |
+
if single:
|
727 |
+
result = result[0]
|
728 |
+
|
729 |
+
return result
|
latentsync/whisper/whisper/model.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Dict
|
3 |
+
from typing import Iterable, Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import Tensor
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from .transcribe import transcribe as transcribe_function
|
12 |
+
from .decoding import detect_language as detect_language_function, decode as decode_function
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class ModelDimensions:
|
17 |
+
n_mels: int
|
18 |
+
n_audio_ctx: int
|
19 |
+
n_audio_state: int
|
20 |
+
n_audio_head: int
|
21 |
+
n_audio_layer: int
|
22 |
+
n_vocab: int
|
23 |
+
n_text_ctx: int
|
24 |
+
n_text_state: int
|
25 |
+
n_text_head: int
|
26 |
+
n_text_layer: int
|
27 |
+
|
28 |
+
|
29 |
+
class LayerNorm(nn.LayerNorm):
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
return super().forward(x.float()).type(x.dtype)
|
32 |
+
|
33 |
+
|
34 |
+
class Linear(nn.Linear):
|
35 |
+
def forward(self, x: Tensor) -> Tensor:
|
36 |
+
return F.linear(
|
37 |
+
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
class Conv1d(nn.Conv1d):
|
42 |
+
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
|
43 |
+
return super()._conv_forward(
|
44 |
+
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
def sinusoids(length, channels, max_timescale=10000):
|
49 |
+
"""Returns sinusoids for positional embedding"""
|
50 |
+
assert channels % 2 == 0
|
51 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
52 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
53 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
54 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
55 |
+
|
56 |
+
|
57 |
+
class MultiHeadAttention(nn.Module):
|
58 |
+
def __init__(self, n_state: int, n_head: int):
|
59 |
+
super().__init__()
|
60 |
+
self.n_head = n_head
|
61 |
+
self.query = Linear(n_state, n_state)
|
62 |
+
self.key = Linear(n_state, n_state, bias=False)
|
63 |
+
self.value = Linear(n_state, n_state)
|
64 |
+
self.out = Linear(n_state, n_state)
|
65 |
+
|
66 |
+
def forward(
|
67 |
+
self,
|
68 |
+
x: Tensor,
|
69 |
+
xa: Optional[Tensor] = None,
|
70 |
+
mask: Optional[Tensor] = None,
|
71 |
+
kv_cache: Optional[dict] = None,
|
72 |
+
):
|
73 |
+
q = self.query(x)
|
74 |
+
|
75 |
+
if kv_cache is None or xa is None:
|
76 |
+
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
77 |
+
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
78 |
+
k = self.key(x if xa is None else xa)
|
79 |
+
v = self.value(x if xa is None else xa)
|
80 |
+
else:
|
81 |
+
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
82 |
+
k = kv_cache.get(self.key, self.key(xa))
|
83 |
+
v = kv_cache.get(self.value, self.value(xa))
|
84 |
+
|
85 |
+
wv = self.qkv_attention(q, k, v, mask)
|
86 |
+
return self.out(wv)
|
87 |
+
|
88 |
+
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
89 |
+
n_batch, n_ctx, n_state = q.shape
|
90 |
+
scale = (n_state // self.n_head) ** -0.25
|
91 |
+
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
92 |
+
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
93 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
94 |
+
|
95 |
+
qk = q @ k
|
96 |
+
if mask is not None:
|
97 |
+
qk = qk + mask[:n_ctx, :n_ctx]
|
98 |
+
|
99 |
+
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
|
100 |
+
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
101 |
+
|
102 |
+
|
103 |
+
class ResidualAttentionBlock(nn.Module):
|
104 |
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
105 |
+
super().__init__()
|
106 |
+
|
107 |
+
self.attn = MultiHeadAttention(n_state, n_head)
|
108 |
+
self.attn_ln = LayerNorm(n_state)
|
109 |
+
|
110 |
+
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
|
111 |
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
112 |
+
|
113 |
+
n_mlp = n_state * 4
|
114 |
+
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
|
115 |
+
self.mlp_ln = LayerNorm(n_state)
|
116 |
+
|
117 |
+
def forward(
|
118 |
+
self,
|
119 |
+
x: Tensor,
|
120 |
+
xa: Optional[Tensor] = None,
|
121 |
+
mask: Optional[Tensor] = None,
|
122 |
+
kv_cache: Optional[dict] = None,
|
123 |
+
):
|
124 |
+
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
125 |
+
if self.cross_attn:
|
126 |
+
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
127 |
+
x = x + self.mlp(self.mlp_ln(x))
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
class AudioEncoder(nn.Module):
|
132 |
+
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
133 |
+
super().__init__()
|
134 |
+
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
135 |
+
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
136 |
+
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
137 |
+
|
138 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
139 |
+
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
140 |
+
)
|
141 |
+
self.ln_post = LayerNorm(n_state)
|
142 |
+
|
143 |
+
def forward(self, x: Tensor, include_embeddings: bool = False):
|
144 |
+
"""
|
145 |
+
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
146 |
+
the mel spectrogram of the audio
|
147 |
+
include_embeddings: bool
|
148 |
+
whether to include intermediate steps in the output
|
149 |
+
"""
|
150 |
+
x = F.gelu(self.conv1(x))
|
151 |
+
x = F.gelu(self.conv2(x))
|
152 |
+
x = x.permute(0, 2, 1)
|
153 |
+
|
154 |
+
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
155 |
+
x = (x + self.positional_embedding).to(x.dtype)
|
156 |
+
|
157 |
+
if include_embeddings:
|
158 |
+
embeddings = [x.cpu().detach().numpy()]
|
159 |
+
|
160 |
+
for block in self.blocks:
|
161 |
+
x = block(x)
|
162 |
+
if include_embeddings:
|
163 |
+
embeddings.append(x.cpu().detach().numpy())
|
164 |
+
|
165 |
+
x = self.ln_post(x)
|
166 |
+
|
167 |
+
if include_embeddings:
|
168 |
+
embeddings = np.stack(embeddings, axis=1)
|
169 |
+
return x, embeddings
|
170 |
+
else:
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
class TextDecoder(nn.Module):
|
175 |
+
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
176 |
+
super().__init__()
|
177 |
+
|
178 |
+
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
179 |
+
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
180 |
+
|
181 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
182 |
+
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
183 |
+
)
|
184 |
+
self.ln = LayerNorm(n_state)
|
185 |
+
|
186 |
+
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
187 |
+
self.register_buffer("mask", mask, persistent=False)
|
188 |
+
|
189 |
+
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None, include_embeddings: bool = False):
|
190 |
+
"""
|
191 |
+
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
192 |
+
the text tokens
|
193 |
+
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
194 |
+
the encoded audio features to be attended on
|
195 |
+
include_embeddings : bool
|
196 |
+
Whether to include intermediate values in the output to this function
|
197 |
+
"""
|
198 |
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
199 |
+
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
200 |
+
x = x.to(xa.dtype)
|
201 |
+
|
202 |
+
if include_embeddings:
|
203 |
+
embeddings = [x.cpu().detach().numpy()]
|
204 |
+
|
205 |
+
for block in self.blocks:
|
206 |
+
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
207 |
+
if include_embeddings:
|
208 |
+
embeddings.append(x.cpu().detach().numpy())
|
209 |
+
|
210 |
+
x = self.ln(x)
|
211 |
+
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
212 |
+
|
213 |
+
if include_embeddings:
|
214 |
+
embeddings = np.stack(embeddings, axis=1)
|
215 |
+
return logits, embeddings
|
216 |
+
else:
|
217 |
+
return logits
|
218 |
+
|
219 |
+
|
220 |
+
class Whisper(nn.Module):
|
221 |
+
def __init__(self, dims: ModelDimensions):
|
222 |
+
super().__init__()
|
223 |
+
self.dims = dims
|
224 |
+
self.encoder = AudioEncoder(
|
225 |
+
self.dims.n_mels,
|
226 |
+
self.dims.n_audio_ctx,
|
227 |
+
self.dims.n_audio_state,
|
228 |
+
self.dims.n_audio_head,
|
229 |
+
self.dims.n_audio_layer,
|
230 |
+
)
|
231 |
+
self.decoder = TextDecoder(
|
232 |
+
self.dims.n_vocab,
|
233 |
+
self.dims.n_text_ctx,
|
234 |
+
self.dims.n_text_state,
|
235 |
+
self.dims.n_text_head,
|
236 |
+
self.dims.n_text_layer,
|
237 |
+
)
|
238 |
+
|
239 |
+
def embed_audio(self, mel: torch.Tensor):
|
240 |
+
return self.encoder.forward(mel)
|
241 |
+
|
242 |
+
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
243 |
+
return self.decoder.forward(tokens, audio_features)
|
244 |
+
|
245 |
+
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
246 |
+
return self.decoder(tokens, self.encoder(mel))
|
247 |
+
|
248 |
+
@property
|
249 |
+
def device(self):
|
250 |
+
return next(self.parameters()).device
|
251 |
+
|
252 |
+
@property
|
253 |
+
def is_multilingual(self):
|
254 |
+
return self.dims.n_vocab == 51865
|
255 |
+
|
256 |
+
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
257 |
+
"""
|
258 |
+
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
259 |
+
tensors calculated for the previous positions. This method returns a dictionary that stores
|
260 |
+
all caches, and the necessary hooks for the key and value projection modules that save the
|
261 |
+
intermediate tensors to be reused during later calculations.
|
262 |
+
|
263 |
+
Returns
|
264 |
+
-------
|
265 |
+
cache : Dict[nn.Module, torch.Tensor]
|
266 |
+
A dictionary object mapping the key/value projection modules to its cache
|
267 |
+
hooks : List[RemovableHandle]
|
268 |
+
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
269 |
+
"""
|
270 |
+
cache = {**cache} if cache is not None else {}
|
271 |
+
hooks = []
|
272 |
+
|
273 |
+
def save_to_cache(module, _, output):
|
274 |
+
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
|
275 |
+
cache[module] = output # save as-is, for the first token or cross attention
|
276 |
+
else:
|
277 |
+
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
278 |
+
return cache[module]
|
279 |
+
|
280 |
+
def install_hooks(layer: nn.Module):
|
281 |
+
if isinstance(layer, MultiHeadAttention):
|
282 |
+
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
283 |
+
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
284 |
+
|
285 |
+
self.decoder.apply(install_hooks)
|
286 |
+
return cache, hooks
|
287 |
+
|
288 |
+
detect_language = detect_language_function
|
289 |
+
transcribe = transcribe_function
|
290 |
+
decode = decode_function
|
latentsync/whisper/whisper/normalizers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .basic import BasicTextNormalizer
|
2 |
+
from .english import EnglishTextNormalizer
|
latentsync/whisper/whisper/normalizers/basic.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import unicodedata
|
3 |
+
|
4 |
+
import regex
|
5 |
+
|
6 |
+
# non-ASCII letters that are not separated by "NFKD" normalization
|
7 |
+
ADDITIONAL_DIACRITICS = {
|
8 |
+
"œ": "oe",
|
9 |
+
"Œ": "OE",
|
10 |
+
"ø": "o",
|
11 |
+
"Ø": "O",
|
12 |
+
"æ": "ae",
|
13 |
+
"Æ": "AE",
|
14 |
+
"ß": "ss",
|
15 |
+
"ẞ": "SS",
|
16 |
+
"đ": "d",
|
17 |
+
"Đ": "D",
|
18 |
+
"ð": "d",
|
19 |
+
"Ð": "D",
|
20 |
+
"þ": "th",
|
21 |
+
"Þ": "th",
|
22 |
+
"ł": "l",
|
23 |
+
"Ł": "L",
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def remove_symbols_and_diacritics(s: str, keep=""):
|
28 |
+
"""
|
29 |
+
Replace any other markers, symbols, and punctuations with a space,
|
30 |
+
and drop any diacritics (category 'Mn' and some manual mappings)
|
31 |
+
"""
|
32 |
+
return "".join(
|
33 |
+
c
|
34 |
+
if c in keep
|
35 |
+
else ADDITIONAL_DIACRITICS[c]
|
36 |
+
if c in ADDITIONAL_DIACRITICS
|
37 |
+
else ""
|
38 |
+
if unicodedata.category(c) == "Mn"
|
39 |
+
else " "
|
40 |
+
if unicodedata.category(c)[0] in "MSP"
|
41 |
+
else c
|
42 |
+
for c in unicodedata.normalize("NFKD", s)
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def remove_symbols(s: str):
|
47 |
+
"""
|
48 |
+
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
49 |
+
"""
|
50 |
+
return "".join(
|
51 |
+
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
class BasicTextNormalizer:
|
56 |
+
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
57 |
+
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
58 |
+
self.split_letters = split_letters
|
59 |
+
|
60 |
+
def __call__(self, s: str):
|
61 |
+
s = s.lower()
|
62 |
+
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
63 |
+
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
64 |
+
s = self.clean(s).lower()
|
65 |
+
|
66 |
+
if self.split_letters:
|
67 |
+
s = " ".join(regex.findall(r"\X", s, regex.U))
|
68 |
+
|
69 |
+
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
70 |
+
|
71 |
+
return s
|
latentsync/whisper/whisper/normalizers/english.json
ADDED
@@ -0,0 +1,1742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"accessorise": "accessorize",
|
3 |
+
"accessorised": "accessorized",
|
4 |
+
"accessorises": "accessorizes",
|
5 |
+
"accessorising": "accessorizing",
|
6 |
+
"acclimatisation": "acclimatization",
|
7 |
+
"acclimatise": "acclimatize",
|
8 |
+
"acclimatised": "acclimatized",
|
9 |
+
"acclimatises": "acclimatizes",
|
10 |
+
"acclimatising": "acclimatizing",
|
11 |
+
"accoutrements": "accouterments",
|
12 |
+
"aeon": "eon",
|
13 |
+
"aeons": "eons",
|
14 |
+
"aerogramme": "aerogram",
|
15 |
+
"aerogrammes": "aerograms",
|
16 |
+
"aeroplane": "airplane",
|
17 |
+
"aeroplanes": "airplanes",
|
18 |
+
"aesthete": "esthete",
|
19 |
+
"aesthetes": "esthetes",
|
20 |
+
"aesthetic": "esthetic",
|
21 |
+
"aesthetically": "esthetically",
|
22 |
+
"aesthetics": "esthetics",
|
23 |
+
"aetiology": "etiology",
|
24 |
+
"ageing": "aging",
|
25 |
+
"aggrandisement": "aggrandizement",
|
26 |
+
"agonise": "agonize",
|
27 |
+
"agonised": "agonized",
|
28 |
+
"agonises": "agonizes",
|
29 |
+
"agonising": "agonizing",
|
30 |
+
"agonisingly": "agonizingly",
|
31 |
+
"almanack": "almanac",
|
32 |
+
"almanacks": "almanacs",
|
33 |
+
"aluminium": "aluminum",
|
34 |
+
"amortisable": "amortizable",
|
35 |
+
"amortisation": "amortization",
|
36 |
+
"amortisations": "amortizations",
|
37 |
+
"amortise": "amortize",
|
38 |
+
"amortised": "amortized",
|
39 |
+
"amortises": "amortizes",
|
40 |
+
"amortising": "amortizing",
|
41 |
+
"amphitheatre": "amphitheater",
|
42 |
+
"amphitheatres": "amphitheaters",
|
43 |
+
"anaemia": "anemia",
|
44 |
+
"anaemic": "anemic",
|
45 |
+
"anaesthesia": "anesthesia",
|
46 |
+
"anaesthetic": "anesthetic",
|
47 |
+
"anaesthetics": "anesthetics",
|
48 |
+
"anaesthetise": "anesthetize",
|
49 |
+
"anaesthetised": "anesthetized",
|
50 |
+
"anaesthetises": "anesthetizes",
|
51 |
+
"anaesthetising": "anesthetizing",
|
52 |
+
"anaesthetist": "anesthetist",
|
53 |
+
"anaesthetists": "anesthetists",
|
54 |
+
"anaesthetize": "anesthetize",
|
55 |
+
"anaesthetized": "anesthetized",
|
56 |
+
"anaesthetizes": "anesthetizes",
|
57 |
+
"anaesthetizing": "anesthetizing",
|
58 |
+
"analogue": "analog",
|
59 |
+
"analogues": "analogs",
|
60 |
+
"analyse": "analyze",
|
61 |
+
"analysed": "analyzed",
|
62 |
+
"analyses": "analyzes",
|
63 |
+
"analysing": "analyzing",
|
64 |
+
"anglicise": "anglicize",
|
65 |
+
"anglicised": "anglicized",
|
66 |
+
"anglicises": "anglicizes",
|
67 |
+
"anglicising": "anglicizing",
|
68 |
+
"annualised": "annualized",
|
69 |
+
"antagonise": "antagonize",
|
70 |
+
"antagonised": "antagonized",
|
71 |
+
"antagonises": "antagonizes",
|
72 |
+
"antagonising": "antagonizing",
|
73 |
+
"apologise": "apologize",
|
74 |
+
"apologised": "apologized",
|
75 |
+
"apologises": "apologizes",
|
76 |
+
"apologising": "apologizing",
|
77 |
+
"appal": "appall",
|
78 |
+
"appals": "appalls",
|
79 |
+
"appetiser": "appetizer",
|
80 |
+
"appetisers": "appetizers",
|
81 |
+
"appetising": "appetizing",
|
82 |
+
"appetisingly": "appetizingly",
|
83 |
+
"arbour": "arbor",
|
84 |
+
"arbours": "arbors",
|
85 |
+
"archeological": "archaeological",
|
86 |
+
"archaeologically": "archeologically",
|
87 |
+
"archaeologist": "archeologist",
|
88 |
+
"archaeologists": "archeologists",
|
89 |
+
"archaeology": "archeology</span>",
|
90 |
+
"ardour": "ardor",
|
91 |
+
"armour": "armor",
|
92 |
+
"armoured": "armored",
|
93 |
+
"armourer": "armorer",
|
94 |
+
"armourers": "armorers",
|
95 |
+
"armouries": "armories",
|
96 |
+
"armoury": "armory",
|
97 |
+
"artefact": "artifact",
|
98 |
+
"artefacts": "artifacts",
|
99 |
+
"authorise": "authorize",
|
100 |
+
"authorised": "authorized",
|
101 |
+
"authorises": "authorizes",
|
102 |
+
"authorising": "authorizing",
|
103 |
+
"axe": "ax",
|
104 |
+
"backpedalled": "backpedaled",
|
105 |
+
"backpedalling": "backpedaling",
|
106 |
+
"bannister": "banister",
|
107 |
+
"bannisters": "banisters",
|
108 |
+
"baptise": "baptize",
|
109 |
+
"baptised": "baptized",
|
110 |
+
"baptises": "baptizes",
|
111 |
+
"baptising": "baptizing",
|
112 |
+
"bastardise": "bastardize",
|
113 |
+
"bastardised": "bastardized",
|
114 |
+
"bastardises": "bastardizes",
|
115 |
+
"bastardising": "bastardizing",
|
116 |
+
"battleax": "battleaxe",
|
117 |
+
"baulk": "balk",
|
118 |
+
"baulked": "balked",
|
119 |
+
"baulking": "balking",
|
120 |
+
"baulks": "balks",
|
121 |
+
"bedevilled": "bedeviled",
|
122 |
+
"bedevilling": "bedeviling",
|
123 |
+
"behaviour": "behavior",
|
124 |
+
"behavioural": "behavioral",
|
125 |
+
"behaviourism": "behaviorism",
|
126 |
+
"behaviourist": "behaviorist",
|
127 |
+
"behaviourists": "behaviorists",
|
128 |
+
"behaviours": "behaviors",
|
129 |
+
"behove": "behoove",
|
130 |
+
"behoved": "behooved",
|
131 |
+
"behoves": "behooves",
|
132 |
+
"bejewelled": "bejeweled",
|
133 |
+
"belabour": "belabor",
|
134 |
+
"belaboured": "belabored",
|
135 |
+
"belabouring": "belaboring",
|
136 |
+
"belabours": "belabors",
|
137 |
+
"bevelled": "beveled",
|
138 |
+
"bevvies": "bevies",
|
139 |
+
"bevvy": "bevy",
|
140 |
+
"biassed": "biased",
|
141 |
+
"biassing": "biasing",
|
142 |
+
"bingeing": "binging",
|
143 |
+
"bougainvillaea": "bougainvillea",
|
144 |
+
"bougainvillaeas": "bougainvilleas",
|
145 |
+
"bowdlerise": "bowdlerize",
|
146 |
+
"bowdlerised": "bowdlerized",
|
147 |
+
"bowdlerises": "bowdlerizes",
|
148 |
+
"bowdlerising": "bowdlerizing",
|
149 |
+
"breathalyse": "breathalyze",
|
150 |
+
"breathalysed": "breathalyzed",
|
151 |
+
"breathalyser": "breathalyzer",
|
152 |
+
"breathalysers": "breathalyzers",
|
153 |
+
"breathalyses": "breathalyzes",
|
154 |
+
"breathalysing": "breathalyzing",
|
155 |
+
"brutalise": "brutalize",
|
156 |
+
"brutalised": "brutalized",
|
157 |
+
"brutalises": "brutalizes",
|
158 |
+
"brutalising": "brutalizing",
|
159 |
+
"busses": "buses",
|
160 |
+
"bussing": "busing",
|
161 |
+
"caesarean": "cesarean",
|
162 |
+
"caesareans": "cesareans",
|
163 |
+
"calibre": "caliber",
|
164 |
+
"calibres": "calibers",
|
165 |
+
"calliper": "caliper",
|
166 |
+
"callipers": "calipers",
|
167 |
+
"callisthenics": "calisthenics",
|
168 |
+
"canalise": "canalize",
|
169 |
+
"canalised": "canalized",
|
170 |
+
"canalises": "canalizes",
|
171 |
+
"canalising": "canalizing",
|
172 |
+
"cancelation": "cancellation",
|
173 |
+
"cancelations": "cancellations",
|
174 |
+
"cancelled": "canceled",
|
175 |
+
"cancelling": "canceling",
|
176 |
+
"candour": "candor",
|
177 |
+
"cannibalise": "cannibalize",
|
178 |
+
"cannibalised": "cannibalized",
|
179 |
+
"cannibalises": "cannibalizes",
|
180 |
+
"cannibalising": "cannibalizing",
|
181 |
+
"canonise": "canonize",
|
182 |
+
"canonised": "canonized",
|
183 |
+
"canonises": "canonizes",
|
184 |
+
"canonising": "canonizing",
|
185 |
+
"capitalise": "capitalize",
|
186 |
+
"capitalised": "capitalized",
|
187 |
+
"capitalises": "capitalizes",
|
188 |
+
"capitalising": "capitalizing",
|
189 |
+
"caramelise": "caramelize",
|
190 |
+
"caramelised": "caramelized",
|
191 |
+
"caramelises": "caramelizes",
|
192 |
+
"caramelising": "caramelizing",
|
193 |
+
"carbonise": "carbonize",
|
194 |
+
"carbonised": "carbonized",
|
195 |
+
"carbonises": "carbonizes",
|
196 |
+
"carbonising": "carbonizing",
|
197 |
+
"carolled": "caroled",
|
198 |
+
"carolling": "caroling",
|
199 |
+
"catalogue": "catalog",
|
200 |
+
"catalogued": "cataloged",
|
201 |
+
"catalogues": "catalogs",
|
202 |
+
"cataloguing": "cataloging",
|
203 |
+
"catalyse": "catalyze",
|
204 |
+
"catalysed": "catalyzed",
|
205 |
+
"catalyses": "catalyzes",
|
206 |
+
"catalysing": "catalyzing",
|
207 |
+
"categorise": "categorize",
|
208 |
+
"categorised": "categorized",
|
209 |
+
"categorises": "categorizes",
|
210 |
+
"categorising": "categorizing",
|
211 |
+
"cauterise": "cauterize",
|
212 |
+
"cauterised": "cauterized",
|
213 |
+
"cauterises": "cauterizes",
|
214 |
+
"cauterising": "cauterizing",
|
215 |
+
"cavilled": "caviled",
|
216 |
+
"cavilling": "caviling",
|
217 |
+
"centigramme": "centigram",
|
218 |
+
"centigrammes": "centigrams",
|
219 |
+
"centilitre": "centiliter",
|
220 |
+
"centilitres": "centiliters",
|
221 |
+
"centimetre": "centimeter",
|
222 |
+
"centimetres": "centimeters",
|
223 |
+
"centralise": "centralize",
|
224 |
+
"centralised": "centralized",
|
225 |
+
"centralises": "centralizes",
|
226 |
+
"centralising": "centralizing",
|
227 |
+
"centre": "center",
|
228 |
+
"centred": "centered",
|
229 |
+
"centrefold": "centerfold",
|
230 |
+
"centrefolds": "centerfolds",
|
231 |
+
"centrepiece": "centerpiece",
|
232 |
+
"centrepieces": "centerpieces",
|
233 |
+
"centres": "centers",
|
234 |
+
"channelled": "channeled",
|
235 |
+
"channelling": "channeling",
|
236 |
+
"characterise": "characterize",
|
237 |
+
"characterised": "characterized",
|
238 |
+
"characterises": "characterizes",
|
239 |
+
"characterising": "characterizing",
|
240 |
+
"cheque": "check",
|
241 |
+
"chequebook": "checkbook",
|
242 |
+
"chequebooks": "checkbooks",
|
243 |
+
"chequered": "checkered",
|
244 |
+
"cheques": "checks",
|
245 |
+
"chilli": "chili",
|
246 |
+
"chimaera": "chimera",
|
247 |
+
"chimaeras": "chimeras",
|
248 |
+
"chiselled": "chiseled",
|
249 |
+
"chiselling": "chiseling",
|
250 |
+
"circularise": "circularize",
|
251 |
+
"circularised": "circularized",
|
252 |
+
"circularises": "circularizes",
|
253 |
+
"circularising": "circularizing",
|
254 |
+
"civilise": "civilize",
|
255 |
+
"civilised": "civilized",
|
256 |
+
"civilises": "civilizes",
|
257 |
+
"civilising": "civilizing",
|
258 |
+
"clamour": "clamor",
|
259 |
+
"clamoured": "clamored",
|
260 |
+
"clamouring": "clamoring",
|
261 |
+
"clamours": "clamors",
|
262 |
+
"clangour": "clangor",
|
263 |
+
"clarinettist": "clarinetist",
|
264 |
+
"clarinettists": "clarinetists",
|
265 |
+
"collectivise": "collectivize",
|
266 |
+
"collectivised": "collectivized",
|
267 |
+
"collectivises": "collectivizes",
|
268 |
+
"collectivising": "collectivizing",
|
269 |
+
"colonisation": "colonization",
|
270 |
+
"colonise": "colonize",
|
271 |
+
"colonised": "colonized",
|
272 |
+
"coloniser": "colonizer",
|
273 |
+
"colonisers": "colonizers",
|
274 |
+
"colonises": "colonizes",
|
275 |
+
"colonising": "colonizing",
|
276 |
+
"colour": "color",
|
277 |
+
"colourant": "colorant",
|
278 |
+
"colourants": "colorants",
|
279 |
+
"coloured": "colored",
|
280 |
+
"coloureds": "coloreds",
|
281 |
+
"colourful": "colorful",
|
282 |
+
"colourfully": "colorfully",
|
283 |
+
"colouring": "coloring",
|
284 |
+
"colourize": "colorize",
|
285 |
+
"colourized": "colorized",
|
286 |
+
"colourizes": "colorizes",
|
287 |
+
"colourizing": "colorizing",
|
288 |
+
"colourless": "colorless",
|
289 |
+
"colours": "colors",
|
290 |
+
"commercialise": "commercialize",
|
291 |
+
"commercialised": "commercialized",
|
292 |
+
"commercialises": "commercializes",
|
293 |
+
"commercialising": "commercializing",
|
294 |
+
"compartmentalise": "compartmentalize",
|
295 |
+
"compartmentalised": "compartmentalized",
|
296 |
+
"compartmentalises": "compartmentalizes",
|
297 |
+
"compartmentalising": "compartmentalizing",
|
298 |
+
"computerise": "computerize",
|
299 |
+
"computerised": "computerized",
|
300 |
+
"computerises": "computerizes",
|
301 |
+
"computerising": "computerizing",
|
302 |
+
"conceptualise": "conceptualize",
|
303 |
+
"conceptualised": "conceptualized",
|
304 |
+
"conceptualises": "conceptualizes",
|
305 |
+
"conceptualising": "conceptualizing",
|
306 |
+
"connexion": "connection",
|
307 |
+
"connexions": "connections",
|
308 |
+
"contextualise": "contextualize",
|
309 |
+
"contextualised": "contextualized",
|
310 |
+
"contextualises": "contextualizes",
|
311 |
+
"contextualising": "contextualizing",
|
312 |
+
"cosier": "cozier",
|
313 |
+
"cosies": "cozies",
|
314 |
+
"cosiest": "coziest",
|
315 |
+
"cosily": "cozily",
|
316 |
+
"cosiness": "coziness",
|
317 |
+
"cosy": "cozy",
|
318 |
+
"councillor": "councilor",
|
319 |
+
"councillors": "councilors",
|
320 |
+
"counselled": "counseled",
|
321 |
+
"counselling": "counseling",
|
322 |
+
"counsellor": "counselor",
|
323 |
+
"counsellors": "counselors",
|
324 |
+
"crenelated": "crenellated",
|
325 |
+
"criminalise": "criminalize",
|
326 |
+
"criminalised": "criminalized",
|
327 |
+
"criminalises": "criminalizes",
|
328 |
+
"criminalising": "criminalizing",
|
329 |
+
"criticise": "criticize",
|
330 |
+
"criticised": "criticized",
|
331 |
+
"criticises": "criticizes",
|
332 |
+
"criticising": "criticizing",
|
333 |
+
"crueller": "crueler",
|
334 |
+
"cruellest": "cruelest",
|
335 |
+
"crystallisation": "crystallization",
|
336 |
+
"crystallise": "crystallize",
|
337 |
+
"crystallised": "crystallized",
|
338 |
+
"crystallises": "crystallizes",
|
339 |
+
"crystallising": "crystallizing",
|
340 |
+
"cudgelled": "cudgeled",
|
341 |
+
"cudgelling": "cudgeling",
|
342 |
+
"customise": "customize",
|
343 |
+
"customised": "customized",
|
344 |
+
"customises": "customizes",
|
345 |
+
"customising": "customizing",
|
346 |
+
"cypher": "cipher",
|
347 |
+
"cyphers": "ciphers",
|
348 |
+
"decentralisation": "decentralization",
|
349 |
+
"decentralise": "decentralize",
|
350 |
+
"decentralised": "decentralized",
|
351 |
+
"decentralises": "decentralizes",
|
352 |
+
"decentralising": "decentralizing",
|
353 |
+
"decriminalisation": "decriminalization",
|
354 |
+
"decriminalise": "decriminalize",
|
355 |
+
"decriminalised": "decriminalized",
|
356 |
+
"decriminalises": "decriminalizes",
|
357 |
+
"decriminalising": "decriminalizing",
|
358 |
+
"defence": "defense",
|
359 |
+
"defenceless": "defenseless",
|
360 |
+
"defences": "defenses",
|
361 |
+
"dehumanisation": "dehumanization",
|
362 |
+
"dehumanise": "dehumanize",
|
363 |
+
"dehumanised": "dehumanized",
|
364 |
+
"dehumanises": "dehumanizes",
|
365 |
+
"dehumanising": "dehumanizing",
|
366 |
+
"demeanour": "demeanor",
|
367 |
+
"demilitarisation": "demilitarization",
|
368 |
+
"demilitarise": "demilitarize",
|
369 |
+
"demilitarised": "demilitarized",
|
370 |
+
"demilitarises": "demilitarizes",
|
371 |
+
"demilitarising": "demilitarizing",
|
372 |
+
"demobilisation": "demobilization",
|
373 |
+
"demobilise": "demobilize",
|
374 |
+
"demobilised": "demobilized",
|
375 |
+
"demobilises": "demobilizes",
|
376 |
+
"demobilising": "demobilizing",
|
377 |
+
"democratisation": "democratization",
|
378 |
+
"democratise": "democratize",
|
379 |
+
"democratised": "democratized",
|
380 |
+
"democratises": "democratizes",
|
381 |
+
"democratising": "democratizing",
|
382 |
+
"demonise": "demonize",
|
383 |
+
"demonised": "demonized",
|
384 |
+
"demonises": "demonizes",
|
385 |
+
"demonising": "demonizing",
|
386 |
+
"demoralisation": "demoralization",
|
387 |
+
"demoralise": "demoralize",
|
388 |
+
"demoralised": "demoralized",
|
389 |
+
"demoralises": "demoralizes",
|
390 |
+
"demoralising": "demoralizing",
|
391 |
+
"denationalisation": "denationalization",
|
392 |
+
"denationalise": "denationalize",
|
393 |
+
"denationalised": "denationalized",
|
394 |
+
"denationalises": "denationalizes",
|
395 |
+
"denationalising": "denationalizing",
|
396 |
+
"deodorise": "deodorize",
|
397 |
+
"deodorised": "deodorized",
|
398 |
+
"deodorises": "deodorizes",
|
399 |
+
"deodorising": "deodorizing",
|
400 |
+
"depersonalise": "depersonalize",
|
401 |
+
"depersonalised": "depersonalized",
|
402 |
+
"depersonalises": "depersonalizes",
|
403 |
+
"depersonalising": "depersonalizing",
|
404 |
+
"deputise": "deputize",
|
405 |
+
"deputised": "deputized",
|
406 |
+
"deputises": "deputizes",
|
407 |
+
"deputising": "deputizing",
|
408 |
+
"desensitisation": "desensitization",
|
409 |
+
"desensitise": "desensitize",
|
410 |
+
"desensitised": "desensitized",
|
411 |
+
"desensitises": "desensitizes",
|
412 |
+
"desensitising": "desensitizing",
|
413 |
+
"destabilisation": "destabilization",
|
414 |
+
"destabilise": "destabilize",
|
415 |
+
"destabilised": "destabilized",
|
416 |
+
"destabilises": "destabilizes",
|
417 |
+
"destabilising": "destabilizing",
|
418 |
+
"dialled": "dialed",
|
419 |
+
"dialling": "dialing",
|
420 |
+
"dialogue": "dialog",
|
421 |
+
"dialogues": "dialogs",
|
422 |
+
"diarrhoea": "diarrhea",
|
423 |
+
"digitise": "digitize",
|
424 |
+
"digitised": "digitized",
|
425 |
+
"digitises": "digitizes",
|
426 |
+
"digitising": "digitizing",
|
427 |
+
"disc": "disk",
|
428 |
+
"discolour": "discolor",
|
429 |
+
"discoloured": "discolored",
|
430 |
+
"discolouring": "discoloring",
|
431 |
+
"discolours": "discolors",
|
432 |
+
"discs": "disks",
|
433 |
+
"disembowelled": "disemboweled",
|
434 |
+
"disembowelling": "disemboweling",
|
435 |
+
"disfavour": "disfavor",
|
436 |
+
"dishevelled": "disheveled",
|
437 |
+
"dishonour": "dishonor",
|
438 |
+
"dishonourable": "dishonorable",
|
439 |
+
"dishonourably": "dishonorably",
|
440 |
+
"dishonoured": "dishonored",
|
441 |
+
"dishonouring": "dishonoring",
|
442 |
+
"dishonours": "dishonors",
|
443 |
+
"disorganisation": "disorganization",
|
444 |
+
"disorganised": "disorganized",
|
445 |
+
"distil": "distill",
|
446 |
+
"distils": "distills",
|
447 |
+
"dramatisation": "dramatization",
|
448 |
+
"dramatisations": "dramatizations",
|
449 |
+
"dramatise": "dramatize",
|
450 |
+
"dramatised": "dramatized",
|
451 |
+
"dramatises": "dramatizes",
|
452 |
+
"dramatising": "dramatizing",
|
453 |
+
"draught": "draft",
|
454 |
+
"draughtboard": "draftboard",
|
455 |
+
"draughtboards": "draftboards",
|
456 |
+
"draughtier": "draftier",
|
457 |
+
"draughtiest": "draftiest",
|
458 |
+
"draughts": "drafts",
|
459 |
+
"draughtsman": "draftsman",
|
460 |
+
"draughtsmanship": "draftsmanship",
|
461 |
+
"draughtsmen": "draftsmen",
|
462 |
+
"draughtswoman": "draftswoman",
|
463 |
+
"draughtswomen": "draftswomen",
|
464 |
+
"draughty": "drafty",
|
465 |
+
"drivelled": "driveled",
|
466 |
+
"drivelling": "driveling",
|
467 |
+
"duelled": "dueled",
|
468 |
+
"duelling": "dueling",
|
469 |
+
"economise": "economize",
|
470 |
+
"economised": "economized",
|
471 |
+
"economises": "economizes",
|
472 |
+
"economising": "economizing",
|
473 |
+
"edoema": "edema",
|
474 |
+
"editorialise": "editorialize",
|
475 |
+
"editorialised": "editorialized",
|
476 |
+
"editorialises": "editorializes",
|
477 |
+
"editorialising": "editorializing",
|
478 |
+
"empathise": "empathize",
|
479 |
+
"empathised": "empathized",
|
480 |
+
"empathises": "empathizes",
|
481 |
+
"empathising": "empathizing",
|
482 |
+
"emphasise": "emphasize",
|
483 |
+
"emphasised": "emphasized",
|
484 |
+
"emphasises": "emphasizes",
|
485 |
+
"emphasising": "emphasizing",
|
486 |
+
"enamelled": "enameled",
|
487 |
+
"enamelling": "enameling",
|
488 |
+
"enamoured": "enamored",
|
489 |
+
"encyclopaedia": "encyclopedia",
|
490 |
+
"encyclopaedias": "encyclopedias",
|
491 |
+
"encyclopaedic": "encyclopedic",
|
492 |
+
"endeavour": "endeavor",
|
493 |
+
"endeavoured": "endeavored",
|
494 |
+
"endeavouring": "endeavoring",
|
495 |
+
"endeavours": "endeavors",
|
496 |
+
"energise": "energize",
|
497 |
+
"energised": "energized",
|
498 |
+
"energises": "energizes",
|
499 |
+
"energising": "energizing",
|
500 |
+
"enrol": "enroll",
|
501 |
+
"enrols": "enrolls",
|
502 |
+
"enthral": "enthrall",
|
503 |
+
"enthrals": "enthralls",
|
504 |
+
"epaulette": "epaulet",
|
505 |
+
"epaulettes": "epaulets",
|
506 |
+
"epicentre": "epicenter",
|
507 |
+
"epicentres": "epicenters",
|
508 |
+
"epilogue": "epilog",
|
509 |
+
"epilogues": "epilogs",
|
510 |
+
"epitomise": "epitomize",
|
511 |
+
"epitomised": "epitomized",
|
512 |
+
"epitomises": "epitomizes",
|
513 |
+
"epitomising": "epitomizing",
|
514 |
+
"equalisation": "equalization",
|
515 |
+
"equalise": "equalize",
|
516 |
+
"equalised": "equalized",
|
517 |
+
"equaliser": "equalizer",
|
518 |
+
"equalisers": "equalizers",
|
519 |
+
"equalises": "equalizes",
|
520 |
+
"equalising": "equalizing",
|
521 |
+
"eulogise": "eulogize",
|
522 |
+
"eulogised": "eulogized",
|
523 |
+
"eulogises": "eulogizes",
|
524 |
+
"eulogising": "eulogizing",
|
525 |
+
"evangelise": "evangelize",
|
526 |
+
"evangelised": "evangelized",
|
527 |
+
"evangelises": "evangelizes",
|
528 |
+
"evangelising": "evangelizing",
|
529 |
+
"exorcise": "exorcize",
|
530 |
+
"exorcised": "exorcized",
|
531 |
+
"exorcises": "exorcizes",
|
532 |
+
"exorcising": "exorcizing",
|
533 |
+
"extemporisation": "extemporization",
|
534 |
+
"extemporise": "extemporize",
|
535 |
+
"extemporised": "extemporized",
|
536 |
+
"extemporises": "extemporizes",
|
537 |
+
"extemporising": "extemporizing",
|
538 |
+
"externalisation": "externalization",
|
539 |
+
"externalisations": "externalizations",
|
540 |
+
"externalise": "externalize",
|
541 |
+
"externalised": "externalized",
|
542 |
+
"externalises": "externalizes",
|
543 |
+
"externalising": "externalizing",
|
544 |
+
"factorise": "factorize",
|
545 |
+
"factorised": "factorized",
|
546 |
+
"factorises": "factorizes",
|
547 |
+
"factorising": "factorizing",
|
548 |
+
"faecal": "fecal",
|
549 |
+
"faeces": "feces",
|
550 |
+
"familiarisation": "familiarization",
|
551 |
+
"familiarise": "familiarize",
|
552 |
+
"familiarised": "familiarized",
|
553 |
+
"familiarises": "familiarizes",
|
554 |
+
"familiarising": "familiarizing",
|
555 |
+
"fantasise": "fantasize",
|
556 |
+
"fantasised": "fantasized",
|
557 |
+
"fantasises": "fantasizes",
|
558 |
+
"fantasising": "fantasizing",
|
559 |
+
"favour": "favor",
|
560 |
+
"favourable": "favorable",
|
561 |
+
"favourably": "favorably",
|
562 |
+
"favoured": "favored",
|
563 |
+
"favouring": "favoring",
|
564 |
+
"favourite": "favorite",
|
565 |
+
"favourites": "favorites",
|
566 |
+
"favouritism": "favoritism",
|
567 |
+
"favours": "favors",
|
568 |
+
"feminise": "feminize",
|
569 |
+
"feminised": "feminized",
|
570 |
+
"feminises": "feminizes",
|
571 |
+
"feminising": "feminizing",
|
572 |
+
"fertilisation": "fertilization",
|
573 |
+
"fertilise": "fertilize",
|
574 |
+
"fertilised": "fertilized",
|
575 |
+
"fertiliser": "fertilizer",
|
576 |
+
"fertilisers": "fertilizers",
|
577 |
+
"fertilises": "fertilizes",
|
578 |
+
"fertilising": "fertilizing",
|
579 |
+
"fervour": "fervor",
|
580 |
+
"fibre": "fiber",
|
581 |
+
"fibreglass": "fiberglass",
|
582 |
+
"fibres": "fibers",
|
583 |
+
"fictionalisation": "fictionalization",
|
584 |
+
"fictionalisations": "fictionalizations",
|
585 |
+
"fictionalise": "fictionalize",
|
586 |
+
"fictionalised": "fictionalized",
|
587 |
+
"fictionalises": "fictionalizes",
|
588 |
+
"fictionalising": "fictionalizing",
|
589 |
+
"fillet": "filet",
|
590 |
+
"filleted": "fileted",
|
591 |
+
"filleting": "fileting",
|
592 |
+
"fillets": "filets",
|
593 |
+
"finalisation": "finalization",
|
594 |
+
"finalise": "finalize",
|
595 |
+
"finalised": "finalized",
|
596 |
+
"finalises": "finalizes",
|
597 |
+
"finalising": "finalizing",
|
598 |
+
"flautist": "flutist",
|
599 |
+
"flautists": "flutists",
|
600 |
+
"flavour": "flavor",
|
601 |
+
"flavoured": "flavored",
|
602 |
+
"flavouring": "flavoring",
|
603 |
+
"flavourings": "flavorings",
|
604 |
+
"flavourless": "flavorless",
|
605 |
+
"flavours": "flavors",
|
606 |
+
"flavoursome": "flavorsome",
|
607 |
+
"flyer / flier": "flier / flyer",
|
608 |
+
"foetal": "fetal",
|
609 |
+
"foetid": "fetid",
|
610 |
+
"foetus": "fetus",
|
611 |
+
"foetuses": "fetuses",
|
612 |
+
"formalisation": "formalization",
|
613 |
+
"formalise": "formalize",
|
614 |
+
"formalised": "formalized",
|
615 |
+
"formalises": "formalizes",
|
616 |
+
"formalising": "formalizing",
|
617 |
+
"fossilisation": "fossilization",
|
618 |
+
"fossilise": "fossilize",
|
619 |
+
"fossilised": "fossilized",
|
620 |
+
"fossilises": "fossilizes",
|
621 |
+
"fossilising": "fossilizing",
|
622 |
+
"fraternisation": "fraternization",
|
623 |
+
"fraternise": "fraternize",
|
624 |
+
"fraternised": "fraternized",
|
625 |
+
"fraternises": "fraternizes",
|
626 |
+
"fraternising": "fraternizing",
|
627 |
+
"fulfil": "fulfill",
|
628 |
+
"fulfilment": "fulfillment",
|
629 |
+
"fulfils": "fulfills",
|
630 |
+
"funnelled": "funneled",
|
631 |
+
"funnelling": "funneling",
|
632 |
+
"galvanise": "galvanize",
|
633 |
+
"galvanised": "galvanized",
|
634 |
+
"galvanises": "galvanizes",
|
635 |
+
"galvanising": "galvanizing",
|
636 |
+
"gambolled": "gamboled",
|
637 |
+
"gambolling": "gamboling",
|
638 |
+
"gaol": "jail",
|
639 |
+
"gaolbird": "jailbird",
|
640 |
+
"gaolbirds": "jailbirds",
|
641 |
+
"gaolbreak": "jailbreak",
|
642 |
+
"gaolbreaks": "jailbreaks",
|
643 |
+
"gaoled": "jailed",
|
644 |
+
"gaoler": "jailer",
|
645 |
+
"gaolers": "jailers",
|
646 |
+
"gaoling": "jailing",
|
647 |
+
"gaols": "jails",
|
648 |
+
"gasses": "gases",
|
649 |
+
"gage": "gauge",
|
650 |
+
"gaged": "gauged",
|
651 |
+
"gages": "gauges",
|
652 |
+
"gaging": "gauging",
|
653 |
+
"generalisation": "generalization",
|
654 |
+
"generalisations": "generalizations",
|
655 |
+
"generalise": "generalize",
|
656 |
+
"generalised": "generalized",
|
657 |
+
"generalises": "generalizes",
|
658 |
+
"generalising": "generalizing",
|
659 |
+
"ghettoise": "ghettoize",
|
660 |
+
"ghettoised": "ghettoized",
|
661 |
+
"ghettoises": "ghettoizes",
|
662 |
+
"ghettoising": "ghettoizing",
|
663 |
+
"gipsies": "gypsies",
|
664 |
+
"glamorise": "glamorize",
|
665 |
+
"glamorised": "glamorized",
|
666 |
+
"glamorises": "glamorizes",
|
667 |
+
"glamorising": "glamorizing",
|
668 |
+
"glamor": "glamour",
|
669 |
+
"globalisation": "globalization",
|
670 |
+
"globalise": "globalize",
|
671 |
+
"globalised": "globalized",
|
672 |
+
"globalises": "globalizes",
|
673 |
+
"globalising": "globalizing",
|
674 |
+
"glueing": "gluing",
|
675 |
+
"goitre": "goiter",
|
676 |
+
"goitres": "goiters",
|
677 |
+
"gonorrhoea": "gonorrhea",
|
678 |
+
"gramme": "gram",
|
679 |
+
"grammes": "grams",
|
680 |
+
"gravelled": "graveled",
|
681 |
+
"grey": "gray",
|
682 |
+
"greyed": "grayed",
|
683 |
+
"greying": "graying",
|
684 |
+
"greyish": "grayish",
|
685 |
+
"greyness": "grayness",
|
686 |
+
"greys": "grays",
|
687 |
+
"grovelled": "groveled",
|
688 |
+
"grovelling": "groveling",
|
689 |
+
"groyne": "groin",
|
690 |
+
"groynes": "groins",
|
691 |
+
"gruelling": "grueling",
|
692 |
+
"gruellingly": "gruelingly",
|
693 |
+
"gryphon": "griffin",
|
694 |
+
"gryphons": "griffins",
|
695 |
+
"gynaecological": "gynecological",
|
696 |
+
"gynaecologist": "gynecologist",
|
697 |
+
"gynaecologists": "gynecologists",
|
698 |
+
"gynaecology": "gynecology",
|
699 |
+
"haematological": "hematological",
|
700 |
+
"haematologist": "hematologist",
|
701 |
+
"haematologists": "hematologists",
|
702 |
+
"haematology": "hematology",
|
703 |
+
"haemoglobin": "hemoglobin",
|
704 |
+
"haemophilia": "hemophilia",
|
705 |
+
"haemophiliac": "hemophiliac",
|
706 |
+
"haemophiliacs": "hemophiliacs",
|
707 |
+
"haemorrhage": "hemorrhage",
|
708 |
+
"haemorrhaged": "hemorrhaged",
|
709 |
+
"haemorrhages": "hemorrhages",
|
710 |
+
"haemorrhaging": "hemorrhaging",
|
711 |
+
"haemorrhoids": "hemorrhoids",
|
712 |
+
"harbour": "harbor",
|
713 |
+
"harboured": "harbored",
|
714 |
+
"harbouring": "harboring",
|
715 |
+
"harbours": "harbors",
|
716 |
+
"harmonisation": "harmonization",
|
717 |
+
"harmonise": "harmonize",
|
718 |
+
"harmonised": "harmonized",
|
719 |
+
"harmonises": "harmonizes",
|
720 |
+
"harmonising": "harmonizing",
|
721 |
+
"homoeopath": "homeopath",
|
722 |
+
"homoeopathic": "homeopathic",
|
723 |
+
"homoeopaths": "homeopaths",
|
724 |
+
"homoeopathy": "homeopathy",
|
725 |
+
"homogenise": "homogenize",
|
726 |
+
"homogenised": "homogenized",
|
727 |
+
"homogenises": "homogenizes",
|
728 |
+
"homogenising": "homogenizing",
|
729 |
+
"honour": "honor",
|
730 |
+
"honourable": "honorable",
|
731 |
+
"honourably": "honorably",
|
732 |
+
"honoured": "honored",
|
733 |
+
"honouring": "honoring",
|
734 |
+
"honours": "honors",
|
735 |
+
"hospitalisation": "hospitalization",
|
736 |
+
"hospitalise": "hospitalize",
|
737 |
+
"hospitalised": "hospitalized",
|
738 |
+
"hospitalises": "hospitalizes",
|
739 |
+
"hospitalising": "hospitalizing",
|
740 |
+
"humanise": "humanize",
|
741 |
+
"humanised": "humanized",
|
742 |
+
"humanises": "humanizes",
|
743 |
+
"humanising": "humanizing",
|
744 |
+
"humour": "humor",
|
745 |
+
"humoured": "humored",
|
746 |
+
"humouring": "humoring",
|
747 |
+
"humourless": "humorless",
|
748 |
+
"humours": "humors",
|
749 |
+
"hybridise": "hybridize",
|
750 |
+
"hybridised": "hybridized",
|
751 |
+
"hybridises": "hybridizes",
|
752 |
+
"hybridising": "hybridizing",
|
753 |
+
"hypnotise": "hypnotize",
|
754 |
+
"hypnotised": "hypnotized",
|
755 |
+
"hypnotises": "hypnotizes",
|
756 |
+
"hypnotising": "hypnotizing",
|
757 |
+
"hypothesise": "hypothesize",
|
758 |
+
"hypothesised": "hypothesized",
|
759 |
+
"hypothesises": "hypothesizes",
|
760 |
+
"hypothesising": "hypothesizing",
|
761 |
+
"idealisation": "idealization",
|
762 |
+
"idealise": "idealize",
|
763 |
+
"idealised": "idealized",
|
764 |
+
"idealises": "idealizes",
|
765 |
+
"idealising": "idealizing",
|
766 |
+
"idolise": "idolize",
|
767 |
+
"idolised": "idolized",
|
768 |
+
"idolises": "idolizes",
|
769 |
+
"idolising": "idolizing",
|
770 |
+
"immobilisation": "immobilization",
|
771 |
+
"immobilise": "immobilize",
|
772 |
+
"immobilised": "immobilized",
|
773 |
+
"immobiliser": "immobilizer",
|
774 |
+
"immobilisers": "immobilizers",
|
775 |
+
"immobilises": "immobilizes",
|
776 |
+
"immobilising": "immobilizing",
|
777 |
+
"immortalise": "immortalize",
|
778 |
+
"immortalised": "immortalized",
|
779 |
+
"immortalises": "immortalizes",
|
780 |
+
"immortalising": "immortalizing",
|
781 |
+
"immunisation": "immunization",
|
782 |
+
"immunise": "immunize",
|
783 |
+
"immunised": "immunized",
|
784 |
+
"immunises": "immunizes",
|
785 |
+
"immunising": "immunizing",
|
786 |
+
"impanelled": "impaneled",
|
787 |
+
"impanelling": "impaneling",
|
788 |
+
"imperilled": "imperiled",
|
789 |
+
"imperilling": "imperiling",
|
790 |
+
"individualise": "individualize",
|
791 |
+
"individualised": "individualized",
|
792 |
+
"individualises": "individualizes",
|
793 |
+
"individualising": "individualizing",
|
794 |
+
"industrialise": "industrialize",
|
795 |
+
"industrialised": "industrialized",
|
796 |
+
"industrialises": "industrializes",
|
797 |
+
"industrialising": "industrializing",
|
798 |
+
"inflexion": "inflection",
|
799 |
+
"inflexions": "inflections",
|
800 |
+
"initialise": "initialize",
|
801 |
+
"initialised": "initialized",
|
802 |
+
"initialises": "initializes",
|
803 |
+
"initialising": "initializing",
|
804 |
+
"initialled": "initialed",
|
805 |
+
"initialling": "initialing",
|
806 |
+
"instal": "install",
|
807 |
+
"instalment": "installment",
|
808 |
+
"instalments": "installments",
|
809 |
+
"instals": "installs",
|
810 |
+
"instil": "instill",
|
811 |
+
"instils": "instills",
|
812 |
+
"institutionalisation": "institutionalization",
|
813 |
+
"institutionalise": "institutionalize",
|
814 |
+
"institutionalised": "institutionalized",
|
815 |
+
"institutionalises": "institutionalizes",
|
816 |
+
"institutionalising": "institutionalizing",
|
817 |
+
"intellectualise": "intellectualize",
|
818 |
+
"intellectualised": "intellectualized",
|
819 |
+
"intellectualises": "intellectualizes",
|
820 |
+
"intellectualising": "intellectualizing",
|
821 |
+
"internalisation": "internalization",
|
822 |
+
"internalise": "internalize",
|
823 |
+
"internalised": "internalized",
|
824 |
+
"internalises": "internalizes",
|
825 |
+
"internalising": "internalizing",
|
826 |
+
"internationalisation": "internationalization",
|
827 |
+
"internationalise": "internationalize",
|
828 |
+
"internationalised": "internationalized",
|
829 |
+
"internationalises": "internationalizes",
|
830 |
+
"internationalising": "internationalizing",
|
831 |
+
"ionisation": "ionization",
|
832 |
+
"ionise": "ionize",
|
833 |
+
"ionised": "ionized",
|
834 |
+
"ioniser": "ionizer",
|
835 |
+
"ionisers": "ionizers",
|
836 |
+
"ionises": "ionizes",
|
837 |
+
"ionising": "ionizing",
|
838 |
+
"italicise": "italicize",
|
839 |
+
"italicised": "italicized",
|
840 |
+
"italicises": "italicizes",
|
841 |
+
"italicising": "italicizing",
|
842 |
+
"itemise": "itemize",
|
843 |
+
"itemised": "itemized",
|
844 |
+
"itemises": "itemizes",
|
845 |
+
"itemising": "itemizing",
|
846 |
+
"jeopardise": "jeopardize",
|
847 |
+
"jeopardised": "jeopardized",
|
848 |
+
"jeopardises": "jeopardizes",
|
849 |
+
"jeopardising": "jeopardizing",
|
850 |
+
"jewelled": "jeweled",
|
851 |
+
"jeweller": "jeweler",
|
852 |
+
"jewellers": "jewelers",
|
853 |
+
"jewellery": "jewelry",
|
854 |
+
"judgement": "judgment",
|
855 |
+
"kilogramme": "kilogram",
|
856 |
+
"kilogrammes": "kilograms",
|
857 |
+
"kilometre": "kilometer",
|
858 |
+
"kilometres": "kilometers",
|
859 |
+
"labelled": "labeled",
|
860 |
+
"labelling": "labeling",
|
861 |
+
"labour": "labor",
|
862 |
+
"laboured": "labored",
|
863 |
+
"labourer": "laborer",
|
864 |
+
"labourers": "laborers",
|
865 |
+
"labouring": "laboring",
|
866 |
+
"labours": "labors",
|
867 |
+
"lacklustre": "lackluster",
|
868 |
+
"legalisation": "legalization",
|
869 |
+
"legalise": "legalize",
|
870 |
+
"legalised": "legalized",
|
871 |
+
"legalises": "legalizes",
|
872 |
+
"legalising": "legalizing",
|
873 |
+
"legitimise": "legitimize",
|
874 |
+
"legitimised": "legitimized",
|
875 |
+
"legitimises": "legitimizes",
|
876 |
+
"legitimising": "legitimizing",
|
877 |
+
"leukaemia": "leukemia",
|
878 |
+
"levelled": "leveled",
|
879 |
+
"leveller": "leveler",
|
880 |
+
"levellers": "levelers",
|
881 |
+
"levelling": "leveling",
|
882 |
+
"libelled": "libeled",
|
883 |
+
"libelling": "libeling",
|
884 |
+
"libellous": "libelous",
|
885 |
+
"liberalisation": "liberalization",
|
886 |
+
"liberalise": "liberalize",
|
887 |
+
"liberalised": "liberalized",
|
888 |
+
"liberalises": "liberalizes",
|
889 |
+
"liberalising": "liberalizing",
|
890 |
+
"licence": "license",
|
891 |
+
"licenced": "licensed",
|
892 |
+
"licences": "licenses",
|
893 |
+
"licencing": "licensing",
|
894 |
+
"likeable": "likable",
|
895 |
+
"lionisation": "lionization",
|
896 |
+
"lionise": "lionize",
|
897 |
+
"lionised": "lionized",
|
898 |
+
"lionises": "lionizes",
|
899 |
+
"lionising": "lionizing",
|
900 |
+
"liquidise": "liquidize",
|
901 |
+
"liquidised": "liquidized",
|
902 |
+
"liquidiser": "liquidizer",
|
903 |
+
"liquidisers": "liquidizers",
|
904 |
+
"liquidises": "liquidizes",
|
905 |
+
"liquidising": "liquidizing",
|
906 |
+
"litre": "liter",
|
907 |
+
"litres": "liters",
|
908 |
+
"localise": "localize",
|
909 |
+
"localised": "localized",
|
910 |
+
"localises": "localizes",
|
911 |
+
"localising": "localizing",
|
912 |
+
"louvre": "louver",
|
913 |
+
"louvred": "louvered",
|
914 |
+
"louvres": "louvers",
|
915 |
+
"lustre": "luster",
|
916 |
+
"magnetise": "magnetize",
|
917 |
+
"magnetised": "magnetized",
|
918 |
+
"magnetises": "magnetizes",
|
919 |
+
"magnetising": "magnetizing",
|
920 |
+
"manoeuvrability": "maneuverability",
|
921 |
+
"manoeuvrable": "maneuverable",
|
922 |
+
"manoeuvre": "maneuver",
|
923 |
+
"manoeuvred": "maneuvered",
|
924 |
+
"manoeuvres": "maneuvers",
|
925 |
+
"manoeuvring": "maneuvering",
|
926 |
+
"manoeuvrings": "maneuverings",
|
927 |
+
"marginalisation": "marginalization",
|
928 |
+
"marginalise": "marginalize",
|
929 |
+
"marginalised": "marginalized",
|
930 |
+
"marginalises": "marginalizes",
|
931 |
+
"marginalising": "marginalizing",
|
932 |
+
"marshalled": "marshaled",
|
933 |
+
"marshalling": "marshaling",
|
934 |
+
"marvelled": "marveled",
|
935 |
+
"marvelling": "marveling",
|
936 |
+
"marvellous": "marvelous",
|
937 |
+
"marvellously": "marvelously",
|
938 |
+
"materialisation": "materialization",
|
939 |
+
"materialise": "materialize",
|
940 |
+
"materialised": "materialized",
|
941 |
+
"materialises": "materializes",
|
942 |
+
"materialising": "materializing",
|
943 |
+
"maximisation": "maximization",
|
944 |
+
"maximise": "maximize",
|
945 |
+
"maximised": "maximized",
|
946 |
+
"maximises": "maximizes",
|
947 |
+
"maximising": "maximizing",
|
948 |
+
"meagre": "meager",
|
949 |
+
"mechanisation": "mechanization",
|
950 |
+
"mechanise": "mechanize",
|
951 |
+
"mechanised": "mechanized",
|
952 |
+
"mechanises": "mechanizes",
|
953 |
+
"mechanising": "mechanizing",
|
954 |
+
"mediaeval": "medieval",
|
955 |
+
"memorialise": "memorialize",
|
956 |
+
"memorialised": "memorialized",
|
957 |
+
"memorialises": "memorializes",
|
958 |
+
"memorialising": "memorializing",
|
959 |
+
"memorise": "memorize",
|
960 |
+
"memorised": "memorized",
|
961 |
+
"memorises": "memorizes",
|
962 |
+
"memorising": "memorizing",
|
963 |
+
"mesmerise": "mesmerize",
|
964 |
+
"mesmerised": "mesmerized",
|
965 |
+
"mesmerises": "mesmerizes",
|
966 |
+
"mesmerising": "mesmerizing",
|
967 |
+
"metabolise": "metabolize",
|
968 |
+
"metabolised": "metabolized",
|
969 |
+
"metabolises": "metabolizes",
|
970 |
+
"metabolising": "metabolizing",
|
971 |
+
"metre": "meter",
|
972 |
+
"metres": "meters",
|
973 |
+
"micrometre": "micrometer",
|
974 |
+
"micrometres": "micrometers",
|
975 |
+
"militarise": "militarize",
|
976 |
+
"militarised": "militarized",
|
977 |
+
"militarises": "militarizes",
|
978 |
+
"militarising": "militarizing",
|
979 |
+
"milligramme": "milligram",
|
980 |
+
"milligrammes": "milligrams",
|
981 |
+
"millilitre": "milliliter",
|
982 |
+
"millilitres": "milliliters",
|
983 |
+
"millimetre": "millimeter",
|
984 |
+
"millimetres": "millimeters",
|
985 |
+
"miniaturisation": "miniaturization",
|
986 |
+
"miniaturise": "miniaturize",
|
987 |
+
"miniaturised": "miniaturized",
|
988 |
+
"miniaturises": "miniaturizes",
|
989 |
+
"miniaturising": "miniaturizing",
|
990 |
+
"minibusses": "minibuses",
|
991 |
+
"minimise": "minimize",
|
992 |
+
"minimised": "minimized",
|
993 |
+
"minimises": "minimizes",
|
994 |
+
"minimising": "minimizing",
|
995 |
+
"misbehaviour": "misbehavior",
|
996 |
+
"misdemeanour": "misdemeanor",
|
997 |
+
"misdemeanours": "misdemeanors",
|
998 |
+
"misspelt": "misspelled",
|
999 |
+
"mitre": "miter",
|
1000 |
+
"mitres": "miters",
|
1001 |
+
"mobilisation": "mobilization",
|
1002 |
+
"mobilise": "mobilize",
|
1003 |
+
"mobilised": "mobilized",
|
1004 |
+
"mobilises": "mobilizes",
|
1005 |
+
"mobilising": "mobilizing",
|
1006 |
+
"modelled": "modeled",
|
1007 |
+
"modeller": "modeler",
|
1008 |
+
"modellers": "modelers",
|
1009 |
+
"modelling": "modeling",
|
1010 |
+
"modernise": "modernize",
|
1011 |
+
"modernised": "modernized",
|
1012 |
+
"modernises": "modernizes",
|
1013 |
+
"modernising": "modernizing",
|
1014 |
+
"moisturise": "moisturize",
|
1015 |
+
"moisturised": "moisturized",
|
1016 |
+
"moisturiser": "moisturizer",
|
1017 |
+
"moisturisers": "moisturizers",
|
1018 |
+
"moisturises": "moisturizes",
|
1019 |
+
"moisturising": "moisturizing",
|
1020 |
+
"monologue": "monolog",
|
1021 |
+
"monologues": "monologs",
|
1022 |
+
"monopolisation": "monopolization",
|
1023 |
+
"monopolise": "monopolize",
|
1024 |
+
"monopolised": "monopolized",
|
1025 |
+
"monopolises": "monopolizes",
|
1026 |
+
"monopolising": "monopolizing",
|
1027 |
+
"moralise": "moralize",
|
1028 |
+
"moralised": "moralized",
|
1029 |
+
"moralises": "moralizes",
|
1030 |
+
"moralising": "moralizing",
|
1031 |
+
"motorised": "motorized",
|
1032 |
+
"mould": "mold",
|
1033 |
+
"moulded": "molded",
|
1034 |
+
"moulder": "molder",
|
1035 |
+
"mouldered": "moldered",
|
1036 |
+
"mouldering": "moldering",
|
1037 |
+
"moulders": "molders",
|
1038 |
+
"mouldier": "moldier",
|
1039 |
+
"mouldiest": "moldiest",
|
1040 |
+
"moulding": "molding",
|
1041 |
+
"mouldings": "moldings",
|
1042 |
+
"moulds": "molds",
|
1043 |
+
"mouldy": "moldy",
|
1044 |
+
"moult": "molt",
|
1045 |
+
"moulted": "molted",
|
1046 |
+
"moulting": "molting",
|
1047 |
+
"moults": "molts",
|
1048 |
+
"moustache": "mustache",
|
1049 |
+
"moustached": "mustached",
|
1050 |
+
"moustaches": "mustaches",
|
1051 |
+
"moustachioed": "mustachioed",
|
1052 |
+
"multicoloured": "multicolored",
|
1053 |
+
"nationalisation": "nationalization",
|
1054 |
+
"nationalisations": "nationalizations",
|
1055 |
+
"nationalise": "nationalize",
|
1056 |
+
"nationalised": "nationalized",
|
1057 |
+
"nationalises": "nationalizes",
|
1058 |
+
"nationalising": "nationalizing",
|
1059 |
+
"naturalisation": "naturalization",
|
1060 |
+
"naturalise": "naturalize",
|
1061 |
+
"naturalised": "naturalized",
|
1062 |
+
"naturalises": "naturalizes",
|
1063 |
+
"naturalising": "naturalizing",
|
1064 |
+
"neighbour": "neighbor",
|
1065 |
+
"neighbourhood": "neighborhood",
|
1066 |
+
"neighbourhoods": "neighborhoods",
|
1067 |
+
"neighbouring": "neighboring",
|
1068 |
+
"neighbourliness": "neighborliness",
|
1069 |
+
"neighbourly": "neighborly",
|
1070 |
+
"neighbours": "neighbors",
|
1071 |
+
"neutralisation": "neutralization",
|
1072 |
+
"neutralise": "neutralize",
|
1073 |
+
"neutralised": "neutralized",
|
1074 |
+
"neutralises": "neutralizes",
|
1075 |
+
"neutralising": "neutralizing",
|
1076 |
+
"normalisation": "normalization",
|
1077 |
+
"normalise": "normalize",
|
1078 |
+
"normalised": "normalized",
|
1079 |
+
"normalises": "normalizes",
|
1080 |
+
"normalising": "normalizing",
|
1081 |
+
"odour": "odor",
|
1082 |
+
"odourless": "odorless",
|
1083 |
+
"odours": "odors",
|
1084 |
+
"oesophagus": "esophagus",
|
1085 |
+
"oesophaguses": "esophaguses",
|
1086 |
+
"oestrogen": "estrogen",
|
1087 |
+
"offence": "offense",
|
1088 |
+
"offences": "offenses",
|
1089 |
+
"omelette": "omelet",
|
1090 |
+
"omelettes": "omelets",
|
1091 |
+
"optimise": "optimize",
|
1092 |
+
"optimised": "optimized",
|
1093 |
+
"optimises": "optimizes",
|
1094 |
+
"optimising": "optimizing",
|
1095 |
+
"organisation": "organization",
|
1096 |
+
"organisational": "organizational",
|
1097 |
+
"organisations": "organizations",
|
1098 |
+
"organise": "organize",
|
1099 |
+
"organised": "organized",
|
1100 |
+
"organiser": "organizer",
|
1101 |
+
"organisers": "organizers",
|
1102 |
+
"organises": "organizes",
|
1103 |
+
"organising": "organizing",
|
1104 |
+
"orthopaedic": "orthopedic",
|
1105 |
+
"orthopaedics": "orthopedics",
|
1106 |
+
"ostracise": "ostracize",
|
1107 |
+
"ostracised": "ostracized",
|
1108 |
+
"ostracises": "ostracizes",
|
1109 |
+
"ostracising": "ostracizing",
|
1110 |
+
"outmanoeuvre": "outmaneuver",
|
1111 |
+
"outmanoeuvred": "outmaneuvered",
|
1112 |
+
"outmanoeuvres": "outmaneuvers",
|
1113 |
+
"outmanoeuvring": "outmaneuvering",
|
1114 |
+
"overemphasise": "overemphasize",
|
1115 |
+
"overemphasised": "overemphasized",
|
1116 |
+
"overemphasises": "overemphasizes",
|
1117 |
+
"overemphasising": "overemphasizing",
|
1118 |
+
"oxidisation": "oxidization",
|
1119 |
+
"oxidise": "oxidize",
|
1120 |
+
"oxidised": "oxidized",
|
1121 |
+
"oxidises": "oxidizes",
|
1122 |
+
"oxidising": "oxidizing",
|
1123 |
+
"paederast": "pederast",
|
1124 |
+
"paederasts": "pederasts",
|
1125 |
+
"paediatric": "pediatric",
|
1126 |
+
"paediatrician": "pediatrician",
|
1127 |
+
"paediatricians": "pediatricians",
|
1128 |
+
"paediatrics": "pediatrics",
|
1129 |
+
"paedophile": "pedophile",
|
1130 |
+
"paedophiles": "pedophiles",
|
1131 |
+
"paedophilia": "pedophilia",
|
1132 |
+
"palaeolithic": "paleolithic",
|
1133 |
+
"palaeontologist": "paleontologist",
|
1134 |
+
"palaeontologists": "paleontologists",
|
1135 |
+
"palaeontology": "paleontology",
|
1136 |
+
"panelled": "paneled",
|
1137 |
+
"panelling": "paneling",
|
1138 |
+
"panellist": "panelist",
|
1139 |
+
"panellists": "panelists",
|
1140 |
+
"paralyse": "paralyze",
|
1141 |
+
"paralysed": "paralyzed",
|
1142 |
+
"paralyses": "paralyzes",
|
1143 |
+
"paralysing": "paralyzing",
|
1144 |
+
"parcelled": "parceled",
|
1145 |
+
"parcelling": "parceling",
|
1146 |
+
"parlour": "parlor",
|
1147 |
+
"parlours": "parlors",
|
1148 |
+
"particularise": "particularize",
|
1149 |
+
"particularised": "particularized",
|
1150 |
+
"particularises": "particularizes",
|
1151 |
+
"particularising": "particularizing",
|
1152 |
+
"passivisation": "passivization",
|
1153 |
+
"passivise": "passivize",
|
1154 |
+
"passivised": "passivized",
|
1155 |
+
"passivises": "passivizes",
|
1156 |
+
"passivising": "passivizing",
|
1157 |
+
"pasteurisation": "pasteurization",
|
1158 |
+
"pasteurise": "pasteurize",
|
1159 |
+
"pasteurised": "pasteurized",
|
1160 |
+
"pasteurises": "pasteurizes",
|
1161 |
+
"pasteurising": "pasteurizing",
|
1162 |
+
"patronise": "patronize",
|
1163 |
+
"patronised": "patronized",
|
1164 |
+
"patronises": "patronizes",
|
1165 |
+
"patronising": "patronizing",
|
1166 |
+
"patronisingly": "patronizingly",
|
1167 |
+
"pedalled": "pedaled",
|
1168 |
+
"pedalling": "pedaling",
|
1169 |
+
"pedestrianisation": "pedestrianization",
|
1170 |
+
"pedestrianise": "pedestrianize",
|
1171 |
+
"pedestrianised": "pedestrianized",
|
1172 |
+
"pedestrianises": "pedestrianizes",
|
1173 |
+
"pedestrianising": "pedestrianizing",
|
1174 |
+
"penalise": "penalize",
|
1175 |
+
"penalised": "penalized",
|
1176 |
+
"penalises": "penalizes",
|
1177 |
+
"penalising": "penalizing",
|
1178 |
+
"pencilled": "penciled",
|
1179 |
+
"pencilling": "penciling",
|
1180 |
+
"personalise": "personalize",
|
1181 |
+
"personalised": "personalized",
|
1182 |
+
"personalises": "personalizes",
|
1183 |
+
"personalising": "personalizing",
|
1184 |
+
"pharmacopoeia": "pharmacopeia",
|
1185 |
+
"pharmacopoeias": "pharmacopeias",
|
1186 |
+
"philosophise": "philosophize",
|
1187 |
+
"philosophised": "philosophized",
|
1188 |
+
"philosophises": "philosophizes",
|
1189 |
+
"philosophising": "philosophizing",
|
1190 |
+
"philtre": "filter",
|
1191 |
+
"philtres": "filters",
|
1192 |
+
"phoney": "phony",
|
1193 |
+
"plagiarise": "plagiarize",
|
1194 |
+
"plagiarised": "plagiarized",
|
1195 |
+
"plagiarises": "plagiarizes",
|
1196 |
+
"plagiarising": "plagiarizing",
|
1197 |
+
"plough": "plow",
|
1198 |
+
"ploughed": "plowed",
|
1199 |
+
"ploughing": "plowing",
|
1200 |
+
"ploughman": "plowman",
|
1201 |
+
"ploughmen": "plowmen",
|
1202 |
+
"ploughs": "plows",
|
1203 |
+
"ploughshare": "plowshare",
|
1204 |
+
"ploughshares": "plowshares",
|
1205 |
+
"polarisation": "polarization",
|
1206 |
+
"polarise": "polarize",
|
1207 |
+
"polarised": "polarized",
|
1208 |
+
"polarises": "polarizes",
|
1209 |
+
"polarising": "polarizing",
|
1210 |
+
"politicisation": "politicization",
|
1211 |
+
"politicise": "politicize",
|
1212 |
+
"politicised": "politicized",
|
1213 |
+
"politicises": "politicizes",
|
1214 |
+
"politicising": "politicizing",
|
1215 |
+
"popularisation": "popularization",
|
1216 |
+
"popularise": "popularize",
|
1217 |
+
"popularised": "popularized",
|
1218 |
+
"popularises": "popularizes",
|
1219 |
+
"popularising": "popularizing",
|
1220 |
+
"pouffe": "pouf",
|
1221 |
+
"pouffes": "poufs",
|
1222 |
+
"practise": "practice",
|
1223 |
+
"practised": "practiced",
|
1224 |
+
"practises": "practices",
|
1225 |
+
"practising": "practicing",
|
1226 |
+
"praesidium": "presidium",
|
1227 |
+
"praesidiums": "presidiums",
|
1228 |
+
"pressurisation": "pressurization",
|
1229 |
+
"pressurise": "pressurize",
|
1230 |
+
"pressurised": "pressurized",
|
1231 |
+
"pressurises": "pressurizes",
|
1232 |
+
"pressurising": "pressurizing",
|
1233 |
+
"pretence": "pretense",
|
1234 |
+
"pretences": "pretenses",
|
1235 |
+
"primaeval": "primeval",
|
1236 |
+
"prioritisation": "prioritization",
|
1237 |
+
"prioritise": "prioritize",
|
1238 |
+
"prioritised": "prioritized",
|
1239 |
+
"prioritises": "prioritizes",
|
1240 |
+
"prioritising": "prioritizing",
|
1241 |
+
"privatisation": "privatization",
|
1242 |
+
"privatisations": "privatizations",
|
1243 |
+
"privatise": "privatize",
|
1244 |
+
"privatised": "privatized",
|
1245 |
+
"privatises": "privatizes",
|
1246 |
+
"privatising": "privatizing",
|
1247 |
+
"professionalisation": "professionalization",
|
1248 |
+
"professionalise": "professionalize",
|
1249 |
+
"professionalised": "professionalized",
|
1250 |
+
"professionalises": "professionalizes",
|
1251 |
+
"professionalising": "professionalizing",
|
1252 |
+
"programme": "program",
|
1253 |
+
"programmes": "programs",
|
1254 |
+
"prologue": "prolog",
|
1255 |
+
"prologues": "prologs",
|
1256 |
+
"propagandise": "propagandize",
|
1257 |
+
"propagandised": "propagandized",
|
1258 |
+
"propagandises": "propagandizes",
|
1259 |
+
"propagandising": "propagandizing",
|
1260 |
+
"proselytise": "proselytize",
|
1261 |
+
"proselytised": "proselytized",
|
1262 |
+
"proselytiser": "proselytizer",
|
1263 |
+
"proselytisers": "proselytizers",
|
1264 |
+
"proselytises": "proselytizes",
|
1265 |
+
"proselytising": "proselytizing",
|
1266 |
+
"psychoanalyse": "psychoanalyze",
|
1267 |
+
"psychoanalysed": "psychoanalyzed",
|
1268 |
+
"psychoanalyses": "psychoanalyzes",
|
1269 |
+
"psychoanalysing": "psychoanalyzing",
|
1270 |
+
"publicise": "publicize",
|
1271 |
+
"publicised": "publicized",
|
1272 |
+
"publicises": "publicizes",
|
1273 |
+
"publicising": "publicizing",
|
1274 |
+
"pulverisation": "pulverization",
|
1275 |
+
"pulverise": "pulverize",
|
1276 |
+
"pulverised": "pulverized",
|
1277 |
+
"pulverises": "pulverizes",
|
1278 |
+
"pulverising": "pulverizing",
|
1279 |
+
"pummelled": "pummel",
|
1280 |
+
"pummelling": "pummeled",
|
1281 |
+
"pyjama": "pajama",
|
1282 |
+
"pyjamas": "pajamas",
|
1283 |
+
"pzazz": "pizzazz",
|
1284 |
+
"quarrelled": "quarreled",
|
1285 |
+
"quarrelling": "quarreling",
|
1286 |
+
"radicalise": "radicalize",
|
1287 |
+
"radicalised": "radicalized",
|
1288 |
+
"radicalises": "radicalizes",
|
1289 |
+
"radicalising": "radicalizing",
|
1290 |
+
"rancour": "rancor",
|
1291 |
+
"randomise": "randomize",
|
1292 |
+
"randomised": "randomized",
|
1293 |
+
"randomises": "randomizes",
|
1294 |
+
"randomising": "randomizing",
|
1295 |
+
"rationalisation": "rationalization",
|
1296 |
+
"rationalisations": "rationalizations",
|
1297 |
+
"rationalise": "rationalize",
|
1298 |
+
"rationalised": "rationalized",
|
1299 |
+
"rationalises": "rationalizes",
|
1300 |
+
"rationalising": "rationalizing",
|
1301 |
+
"ravelled": "raveled",
|
1302 |
+
"ravelling": "raveling",
|
1303 |
+
"realisable": "realizable",
|
1304 |
+
"realisation": "realization",
|
1305 |
+
"realisations": "realizations",
|
1306 |
+
"realise": "realize",
|
1307 |
+
"realised": "realized",
|
1308 |
+
"realises": "realizes",
|
1309 |
+
"realising": "realizing",
|
1310 |
+
"recognisable": "recognizable",
|
1311 |
+
"recognisably": "recognizably",
|
1312 |
+
"recognisance": "recognizance",
|
1313 |
+
"recognise": "recognize",
|
1314 |
+
"recognised": "recognized",
|
1315 |
+
"recognises": "recognizes",
|
1316 |
+
"recognising": "recognizing",
|
1317 |
+
"reconnoitre": "reconnoiter",
|
1318 |
+
"reconnoitred": "reconnoitered",
|
1319 |
+
"reconnoitres": "reconnoiters",
|
1320 |
+
"reconnoitring": "reconnoitering",
|
1321 |
+
"refuelled": "refueled",
|
1322 |
+
"refuelling": "refueling",
|
1323 |
+
"regularisation": "regularization",
|
1324 |
+
"regularise": "regularize",
|
1325 |
+
"regularised": "regularized",
|
1326 |
+
"regularises": "regularizes",
|
1327 |
+
"regularising": "regularizing",
|
1328 |
+
"remodelled": "remodeled",
|
1329 |
+
"remodelling": "remodeling",
|
1330 |
+
"remould": "remold",
|
1331 |
+
"remoulded": "remolded",
|
1332 |
+
"remoulding": "remolding",
|
1333 |
+
"remoulds": "remolds",
|
1334 |
+
"reorganisation": "reorganization",
|
1335 |
+
"reorganisations": "reorganizations",
|
1336 |
+
"reorganise": "reorganize",
|
1337 |
+
"reorganised": "reorganized",
|
1338 |
+
"reorganises": "reorganizes",
|
1339 |
+
"reorganising": "reorganizing",
|
1340 |
+
"revelled": "reveled",
|
1341 |
+
"reveller": "reveler",
|
1342 |
+
"revellers": "revelers",
|
1343 |
+
"revelling": "reveling",
|
1344 |
+
"revitalise": "revitalize",
|
1345 |
+
"revitalised": "revitalized",
|
1346 |
+
"revitalises": "revitalizes",
|
1347 |
+
"revitalising": "revitalizing",
|
1348 |
+
"revolutionise": "revolutionize",
|
1349 |
+
"revolutionised": "revolutionized",
|
1350 |
+
"revolutionises": "revolutionizes",
|
1351 |
+
"revolutionising": "revolutionizing",
|
1352 |
+
"rhapsodise": "rhapsodize",
|
1353 |
+
"rhapsodised": "rhapsodized",
|
1354 |
+
"rhapsodises": "rhapsodizes",
|
1355 |
+
"rhapsodising": "rhapsodizing",
|
1356 |
+
"rigour": "rigor",
|
1357 |
+
"rigours": "rigors",
|
1358 |
+
"ritualised": "ritualized",
|
1359 |
+
"rivalled": "rivaled",
|
1360 |
+
"rivalling": "rivaling",
|
1361 |
+
"romanticise": "romanticize",
|
1362 |
+
"romanticised": "romanticized",
|
1363 |
+
"romanticises": "romanticizes",
|
1364 |
+
"romanticising": "romanticizing",
|
1365 |
+
"rumour": "rumor",
|
1366 |
+
"rumoured": "rumored",
|
1367 |
+
"rumours": "rumors",
|
1368 |
+
"sabre": "saber",
|
1369 |
+
"sabres": "sabers",
|
1370 |
+
"saltpetre": "saltpeter",
|
1371 |
+
"sanitise": "sanitize",
|
1372 |
+
"sanitised": "sanitized",
|
1373 |
+
"sanitises": "sanitizes",
|
1374 |
+
"sanitising": "sanitizing",
|
1375 |
+
"satirise": "satirize",
|
1376 |
+
"satirised": "satirized",
|
1377 |
+
"satirises": "satirizes",
|
1378 |
+
"satirising": "satirizing",
|
1379 |
+
"saviour": "savior",
|
1380 |
+
"saviours": "saviors",
|
1381 |
+
"savour": "savor",
|
1382 |
+
"savoured": "savored",
|
1383 |
+
"savouries": "savories",
|
1384 |
+
"savouring": "savoring",
|
1385 |
+
"savours": "savors",
|
1386 |
+
"savoury": "savory",
|
1387 |
+
"scandalise": "scandalize",
|
1388 |
+
"scandalised": "scandalized",
|
1389 |
+
"scandalises": "scandalizes",
|
1390 |
+
"scandalising": "scandalizing",
|
1391 |
+
"sceptic": "skeptic",
|
1392 |
+
"sceptical": "skeptical",
|
1393 |
+
"sceptically": "skeptically",
|
1394 |
+
"scepticism": "skepticism",
|
1395 |
+
"sceptics": "skeptics",
|
1396 |
+
"sceptre": "scepter",
|
1397 |
+
"sceptres": "scepters",
|
1398 |
+
"scrutinise": "scrutinize",
|
1399 |
+
"scrutinised": "scrutinized",
|
1400 |
+
"scrutinises": "scrutinizes",
|
1401 |
+
"scrutinising": "scrutinizing",
|
1402 |
+
"secularisation": "secularization",
|
1403 |
+
"secularise": "secularize",
|
1404 |
+
"secularised": "secularized",
|
1405 |
+
"secularises": "secularizes",
|
1406 |
+
"secularising": "secularizing",
|
1407 |
+
"sensationalise": "sensationalize",
|
1408 |
+
"sensationalised": "sensationalized",
|
1409 |
+
"sensationalises": "sensationalizes",
|
1410 |
+
"sensationalising": "sensationalizing",
|
1411 |
+
"sensitise": "sensitize",
|
1412 |
+
"sensitised": "sensitized",
|
1413 |
+
"sensitises": "sensitizes",
|
1414 |
+
"sensitising": "sensitizing",
|
1415 |
+
"sentimentalise": "sentimentalize",
|
1416 |
+
"sentimentalised": "sentimentalized",
|
1417 |
+
"sentimentalises": "sentimentalizes",
|
1418 |
+
"sentimentalising": "sentimentalizing",
|
1419 |
+
"sepulchre": "sepulcher",
|
1420 |
+
"sepulchres": "sepulchers",
|
1421 |
+
"serialisation": "serialization",
|
1422 |
+
"serialisations": "serializations",
|
1423 |
+
"serialise": "serialize",
|
1424 |
+
"serialised": "serialized",
|
1425 |
+
"serialises": "serializes",
|
1426 |
+
"serialising": "serializing",
|
1427 |
+
"sermonise": "sermonize",
|
1428 |
+
"sermonised": "sermonized",
|
1429 |
+
"sermonises": "sermonizes",
|
1430 |
+
"sermonising": "sermonizing",
|
1431 |
+
"sheikh": "sheik",
|
1432 |
+
"shovelled": "shoveled",
|
1433 |
+
"shovelling": "shoveling",
|
1434 |
+
"shrivelled": "shriveled",
|
1435 |
+
"shrivelling": "shriveling",
|
1436 |
+
"signalise": "signalize",
|
1437 |
+
"signalised": "signalized",
|
1438 |
+
"signalises": "signalizes",
|
1439 |
+
"signalising": "signalizing",
|
1440 |
+
"signalled": "signaled",
|
1441 |
+
"signalling": "signaling",
|
1442 |
+
"smoulder": "smolder",
|
1443 |
+
"smouldered": "smoldered",
|
1444 |
+
"smouldering": "smoldering",
|
1445 |
+
"smoulders": "smolders",
|
1446 |
+
"snivelled": "sniveled",
|
1447 |
+
"snivelling": "sniveling",
|
1448 |
+
"snorkelled": "snorkeled",
|
1449 |
+
"snorkelling": "snorkeling",
|
1450 |
+
"snowplough": "snowplow",
|
1451 |
+
"snowploughs": "snowplow",
|
1452 |
+
"socialisation": "socialization",
|
1453 |
+
"socialise": "socialize",
|
1454 |
+
"socialised": "socialized",
|
1455 |
+
"socialises": "socializes",
|
1456 |
+
"socialising": "socializing",
|
1457 |
+
"sodomise": "sodomize",
|
1458 |
+
"sodomised": "sodomized",
|
1459 |
+
"sodomises": "sodomizes",
|
1460 |
+
"sodomising": "sodomizing",
|
1461 |
+
"solemnise": "solemnize",
|
1462 |
+
"solemnised": "solemnized",
|
1463 |
+
"solemnises": "solemnizes",
|
1464 |
+
"solemnising": "solemnizing",
|
1465 |
+
"sombre": "somber",
|
1466 |
+
"specialisation": "specialization",
|
1467 |
+
"specialisations": "specializations",
|
1468 |
+
"specialise": "specialize",
|
1469 |
+
"specialised": "specialized",
|
1470 |
+
"specialises": "specializes",
|
1471 |
+
"specialising": "specializing",
|
1472 |
+
"spectre": "specter",
|
1473 |
+
"spectres": "specters",
|
1474 |
+
"spiralled": "spiraled",
|
1475 |
+
"spiralling": "spiraling",
|
1476 |
+
"splendour": "splendor",
|
1477 |
+
"splendours": "splendors",
|
1478 |
+
"squirrelled": "squirreled",
|
1479 |
+
"squirrelling": "squirreling",
|
1480 |
+
"stabilisation": "stabilization",
|
1481 |
+
"stabilise": "stabilize",
|
1482 |
+
"stabilised": "stabilized",
|
1483 |
+
"stabiliser": "stabilizer",
|
1484 |
+
"stabilisers": "stabilizers",
|
1485 |
+
"stabilises": "stabilizes",
|
1486 |
+
"stabilising": "stabilizing",
|
1487 |
+
"standardisation": "standardization",
|
1488 |
+
"standardise": "standardize",
|
1489 |
+
"standardised": "standardized",
|
1490 |
+
"standardises": "standardizes",
|
1491 |
+
"standardising": "standardizing",
|
1492 |
+
"stencilled": "stenciled",
|
1493 |
+
"stencilling": "stenciling",
|
1494 |
+
"sterilisation": "sterilization",
|
1495 |
+
"sterilisations": "sterilizations",
|
1496 |
+
"sterilise": "sterilize",
|
1497 |
+
"sterilised": "sterilized",
|
1498 |
+
"steriliser": "sterilizer",
|
1499 |
+
"sterilisers": "sterilizers",
|
1500 |
+
"sterilises": "sterilizes",
|
1501 |
+
"sterilising": "sterilizing",
|
1502 |
+
"stigmatisation": "stigmatization",
|
1503 |
+
"stigmatise": "stigmatize",
|
1504 |
+
"stigmatised": "stigmatized",
|
1505 |
+
"stigmatises": "stigmatizes",
|
1506 |
+
"stigmatising": "stigmatizing",
|
1507 |
+
"storey": "story",
|
1508 |
+
"storeys": "stories",
|
1509 |
+
"subsidisation": "subsidization",
|
1510 |
+
"subsidise": "subsidize",
|
1511 |
+
"subsidised": "subsidized",
|
1512 |
+
"subsidiser": "subsidizer",
|
1513 |
+
"subsidisers": "subsidizers",
|
1514 |
+
"subsidises": "subsidizes",
|
1515 |
+
"subsidising": "subsidizing",
|
1516 |
+
"succour": "succor",
|
1517 |
+
"succoured": "succored",
|
1518 |
+
"succouring": "succoring",
|
1519 |
+
"succours": "succors",
|
1520 |
+
"sulphate": "sulfate",
|
1521 |
+
"sulphates": "sulfates",
|
1522 |
+
"sulphide": "sulfide",
|
1523 |
+
"sulphides": "sulfides",
|
1524 |
+
"sulphur": "sulfur",
|
1525 |
+
"sulphurous": "sulfurous",
|
1526 |
+
"summarise": "summarize",
|
1527 |
+
"summarised": "summarized",
|
1528 |
+
"summarises": "summarizes",
|
1529 |
+
"summarising": "summarizing",
|
1530 |
+
"swivelled": "swiveled",
|
1531 |
+
"swivelling": "swiveling",
|
1532 |
+
"symbolise": "symbolize",
|
1533 |
+
"symbolised": "symbolized",
|
1534 |
+
"symbolises": "symbolizes",
|
1535 |
+
"symbolising": "symbolizing",
|
1536 |
+
"sympathise": "sympathize",
|
1537 |
+
"sympathised": "sympathized",
|
1538 |
+
"sympathiser": "sympathizer",
|
1539 |
+
"sympathisers": "sympathizers",
|
1540 |
+
"sympathises": "sympathizes",
|
1541 |
+
"sympathising": "sympathizing",
|
1542 |
+
"synchronisation": "synchronization",
|
1543 |
+
"synchronise": "synchronize",
|
1544 |
+
"synchronised": "synchronized",
|
1545 |
+
"synchronises": "synchronizes",
|
1546 |
+
"synchronising": "synchronizing",
|
1547 |
+
"synthesise": "synthesize",
|
1548 |
+
"synthesised": "synthesized",
|
1549 |
+
"synthesiser": "synthesizer",
|
1550 |
+
"synthesisers": "synthesizers",
|
1551 |
+
"synthesises": "synthesizes",
|
1552 |
+
"synthesising": "synthesizing",
|
1553 |
+
"syphon": "siphon",
|
1554 |
+
"syphoned": "siphoned",
|
1555 |
+
"syphoning": "siphoning",
|
1556 |
+
"syphons": "siphons",
|
1557 |
+
"systematisation": "systematization",
|
1558 |
+
"systematise": "systematize",
|
1559 |
+
"systematised": "systematized",
|
1560 |
+
"systematises": "systematizes",
|
1561 |
+
"systematising": "systematizing",
|
1562 |
+
"tantalise": "tantalize",
|
1563 |
+
"tantalised": "tantalized",
|
1564 |
+
"tantalises": "tantalizes",
|
1565 |
+
"tantalising": "tantalizing",
|
1566 |
+
"tantalisingly": "tantalizingly",
|
1567 |
+
"tasselled": "tasseled",
|
1568 |
+
"technicolour": "technicolor",
|
1569 |
+
"temporise": "temporize",
|
1570 |
+
"temporised": "temporized",
|
1571 |
+
"temporises": "temporizes",
|
1572 |
+
"temporising": "temporizing",
|
1573 |
+
"tenderise": "tenderize",
|
1574 |
+
"tenderised": "tenderized",
|
1575 |
+
"tenderises": "tenderizes",
|
1576 |
+
"tenderising": "tenderizing",
|
1577 |
+
"terrorise": "terrorize",
|
1578 |
+
"terrorised": "terrorized",
|
1579 |
+
"terrorises": "terrorizes",
|
1580 |
+
"terrorising": "terrorizing",
|
1581 |
+
"theatre": "theater",
|
1582 |
+
"theatregoer": "theatergoer",
|
1583 |
+
"theatregoers": "theatergoers",
|
1584 |
+
"theatres": "theaters",
|
1585 |
+
"theorise": "theorize",
|
1586 |
+
"theorised": "theorized",
|
1587 |
+
"theorises": "theorizes",
|
1588 |
+
"theorising": "theorizing",
|
1589 |
+
"tonne": "ton",
|
1590 |
+
"tonnes": "tons",
|
1591 |
+
"towelled": "toweled",
|
1592 |
+
"towelling": "toweling",
|
1593 |
+
"toxaemia": "toxemia",
|
1594 |
+
"tranquillise": "tranquilize",
|
1595 |
+
"tranquillised": "tranquilized",
|
1596 |
+
"tranquilliser": "tranquilizer",
|
1597 |
+
"tranquillisers": "tranquilizers",
|
1598 |
+
"tranquillises": "tranquilizes",
|
1599 |
+
"tranquillising": "tranquilizing",
|
1600 |
+
"tranquillity": "tranquility",
|
1601 |
+
"tranquillize": "tranquilize",
|
1602 |
+
"tranquillized": "tranquilized",
|
1603 |
+
"tranquillizer": "tranquilizer",
|
1604 |
+
"tranquillizers": "tranquilizers",
|
1605 |
+
"tranquillizes": "tranquilizes",
|
1606 |
+
"tranquillizing": "tranquilizing",
|
1607 |
+
"tranquilly": "tranquility",
|
1608 |
+
"transistorised": "transistorized",
|
1609 |
+
"traumatise": "traumatize",
|
1610 |
+
"traumatised": "traumatized",
|
1611 |
+
"traumatises": "traumatizes",
|
1612 |
+
"traumatising": "traumatizing",
|
1613 |
+
"travelled": "traveled",
|
1614 |
+
"traveller": "traveler",
|
1615 |
+
"travellers": "travelers",
|
1616 |
+
"travelling": "traveling",
|
1617 |
+
"travelog": "travelogue",
|
1618 |
+
"travelogs": "travelogues",
|
1619 |
+
"trialled": "trialed",
|
1620 |
+
"trialling": "trialing",
|
1621 |
+
"tricolour": "tricolor",
|
1622 |
+
"tricolours": "tricolors",
|
1623 |
+
"trivialise": "trivialize",
|
1624 |
+
"trivialised": "trivialized",
|
1625 |
+
"trivialises": "trivializes",
|
1626 |
+
"trivialising": "trivializing",
|
1627 |
+
"tumour": "tumor",
|
1628 |
+
"tumours": "tumors",
|
1629 |
+
"tunnelled": "tunneled",
|
1630 |
+
"tunnelling": "tunneling",
|
1631 |
+
"tyrannise": "tyrannize",
|
1632 |
+
"tyrannised": "tyrannized",
|
1633 |
+
"tyrannises": "tyrannizes",
|
1634 |
+
"tyrannising": "tyrannizing",
|
1635 |
+
"tyre": "tire",
|
1636 |
+
"tyres": "tires",
|
1637 |
+
"unauthorised": "unauthorized",
|
1638 |
+
"uncivilised": "uncivilized",
|
1639 |
+
"underutilised": "underutilized",
|
1640 |
+
"unequalled": "unequaled",
|
1641 |
+
"unfavourable": "unfavorable",
|
1642 |
+
"unfavourably": "unfavorably",
|
1643 |
+
"unionisation": "unionization",
|
1644 |
+
"unionise": "unionize",
|
1645 |
+
"unionised": "unionized",
|
1646 |
+
"unionises": "unionizes",
|
1647 |
+
"unionising": "unionizing",
|
1648 |
+
"unorganised": "unorganized",
|
1649 |
+
"unravelled": "unraveled",
|
1650 |
+
"unravelling": "unraveling",
|
1651 |
+
"unrecognisable": "unrecognizable",
|
1652 |
+
"unrecognised": "unrecognized",
|
1653 |
+
"unrivalled": "unrivaled",
|
1654 |
+
"unsavoury": "unsavory",
|
1655 |
+
"untrammelled": "untrammeled",
|
1656 |
+
"urbanisation": "urbanization",
|
1657 |
+
"urbanise": "urbanize",
|
1658 |
+
"urbanised": "urbanized",
|
1659 |
+
"urbanises": "urbanizes",
|
1660 |
+
"urbanising": "urbanizing",
|
1661 |
+
"utilisable": "utilizable",
|
1662 |
+
"utilisation": "utilization",
|
1663 |
+
"utilise": "utilize",
|
1664 |
+
"utilised": "utilized",
|
1665 |
+
"utilises": "utilizes",
|
1666 |
+
"utilising": "utilizing",
|
1667 |
+
"valour": "valor",
|
1668 |
+
"vandalise": "vandalize",
|
1669 |
+
"vandalised": "vandalized",
|
1670 |
+
"vandalises": "vandalizes",
|
1671 |
+
"vandalising": "vandalizing",
|
1672 |
+
"vaporisation": "vaporization",
|
1673 |
+
"vaporise": "vaporize",
|
1674 |
+
"vaporised": "vaporized",
|
1675 |
+
"vaporises": "vaporizes",
|
1676 |
+
"vaporising": "vaporizing",
|
1677 |
+
"vapour": "vapor",
|
1678 |
+
"vapours": "vapors",
|
1679 |
+
"verbalise": "verbalize",
|
1680 |
+
"verbalised": "verbalized",
|
1681 |
+
"verbalises": "verbalizes",
|
1682 |
+
"verbalising": "verbalizing",
|
1683 |
+
"victimisation": "victimization",
|
1684 |
+
"victimise": "victimize",
|
1685 |
+
"victimised": "victimized",
|
1686 |
+
"victimises": "victimizes",
|
1687 |
+
"victimising": "victimizing",
|
1688 |
+
"videodisc": "videodisk",
|
1689 |
+
"videodiscs": "videodisks",
|
1690 |
+
"vigour": "vigor",
|
1691 |
+
"visualisation": "visualization",
|
1692 |
+
"visualisations": "visualizations",
|
1693 |
+
"visualise": "visualize",
|
1694 |
+
"visualised": "visualized",
|
1695 |
+
"visualises": "visualizes",
|
1696 |
+
"visualising": "visualizing",
|
1697 |
+
"vocalisation": "vocalization",
|
1698 |
+
"vocalisations": "vocalizations",
|
1699 |
+
"vocalise": "vocalize",
|
1700 |
+
"vocalised": "vocalized",
|
1701 |
+
"vocalises": "vocalizes",
|
1702 |
+
"vocalising": "vocalizing",
|
1703 |
+
"vulcanised": "vulcanized",
|
1704 |
+
"vulgarisation": "vulgarization",
|
1705 |
+
"vulgarise": "vulgarize",
|
1706 |
+
"vulgarised": "vulgarized",
|
1707 |
+
"vulgarises": "vulgarizes",
|
1708 |
+
"vulgarising": "vulgarizing",
|
1709 |
+
"waggon": "wagon",
|
1710 |
+
"waggons": "wagons",
|
1711 |
+
"watercolour": "watercolor",
|
1712 |
+
"watercolours": "watercolors",
|
1713 |
+
"weaselled": "weaseled",
|
1714 |
+
"weaselling": "weaseling",
|
1715 |
+
"westernisation": "westernization",
|
1716 |
+
"westernise": "westernize",
|
1717 |
+
"westernised": "westernized",
|
1718 |
+
"westernises": "westernizes",
|
1719 |
+
"westernising": "westernizing",
|
1720 |
+
"womanise": "womanize",
|
1721 |
+
"womanised": "womanized",
|
1722 |
+
"womaniser": "womanizer",
|
1723 |
+
"womanisers": "womanizers",
|
1724 |
+
"womanises": "womanizes",
|
1725 |
+
"womanising": "womanizing",
|
1726 |
+
"woollen": "woolen",
|
1727 |
+
"woollens": "woolens",
|
1728 |
+
"woollies": "woolies",
|
1729 |
+
"woolly": "wooly",
|
1730 |
+
"worshipped": "worshiped",
|
1731 |
+
"worshipping": "worshiping",
|
1732 |
+
"worshipper": "worshiper",
|
1733 |
+
"yodelled": "yodeled",
|
1734 |
+
"yodelling": "yodeling",
|
1735 |
+
"yoghourt": "yogurt",
|
1736 |
+
"yoghourts": "yogurts",
|
1737 |
+
"yoghurt": "yogurt",
|
1738 |
+
"yoghurts": "yogurts",
|
1739 |
+
"mhm": "hmm",
|
1740 |
+
"mm": "hmm",
|
1741 |
+
"mmm": "hmm"
|
1742 |
+
}
|
latentsync/whisper/whisper/normalizers/english.py
ADDED
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from fractions import Fraction
|
5 |
+
from typing import Iterator, List, Match, Optional, Union
|
6 |
+
|
7 |
+
from more_itertools import windowed
|
8 |
+
|
9 |
+
from .basic import remove_symbols_and_diacritics
|
10 |
+
|
11 |
+
|
12 |
+
class EnglishNumberNormalizer:
|
13 |
+
"""
|
14 |
+
Convert any spelled-out numbers into arabic numbers, while handling:
|
15 |
+
|
16 |
+
- remove any commas
|
17 |
+
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
18 |
+
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
19 |
+
- spell out `one` and `ones`
|
20 |
+
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.zeros = {"o", "oh", "zero"}
|
27 |
+
self.ones = {
|
28 |
+
name: i
|
29 |
+
for i, name in enumerate(
|
30 |
+
[
|
31 |
+
"one",
|
32 |
+
"two",
|
33 |
+
"three",
|
34 |
+
"four",
|
35 |
+
"five",
|
36 |
+
"six",
|
37 |
+
"seven",
|
38 |
+
"eight",
|
39 |
+
"nine",
|
40 |
+
"ten",
|
41 |
+
"eleven",
|
42 |
+
"twelve",
|
43 |
+
"thirteen",
|
44 |
+
"fourteen",
|
45 |
+
"fifteen",
|
46 |
+
"sixteen",
|
47 |
+
"seventeen",
|
48 |
+
"eighteen",
|
49 |
+
"nineteen",
|
50 |
+
],
|
51 |
+
start=1,
|
52 |
+
)
|
53 |
+
}
|
54 |
+
self.ones_plural = {
|
55 |
+
"sixes" if name == "six" else name + "s": (value, "s")
|
56 |
+
for name, value in self.ones.items()
|
57 |
+
}
|
58 |
+
self.ones_ordinal = {
|
59 |
+
"zeroth": (0, "th"),
|
60 |
+
"first": (1, "st"),
|
61 |
+
"second": (2, "nd"),
|
62 |
+
"third": (3, "rd"),
|
63 |
+
"fifth": (5, "th"),
|
64 |
+
"twelfth": (12, "th"),
|
65 |
+
**{
|
66 |
+
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
67 |
+
for name, value in self.ones.items()
|
68 |
+
if value > 3 and value != 5 and value != 12
|
69 |
+
},
|
70 |
+
}
|
71 |
+
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
72 |
+
|
73 |
+
self.tens = {
|
74 |
+
"twenty": 20,
|
75 |
+
"thirty": 30,
|
76 |
+
"forty": 40,
|
77 |
+
"fifty": 50,
|
78 |
+
"sixty": 60,
|
79 |
+
"seventy": 70,
|
80 |
+
"eighty": 80,
|
81 |
+
"ninety": 90,
|
82 |
+
}
|
83 |
+
self.tens_plural = {
|
84 |
+
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
85 |
+
}
|
86 |
+
self.tens_ordinal = {
|
87 |
+
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
|
88 |
+
}
|
89 |
+
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
90 |
+
|
91 |
+
self.multipliers = {
|
92 |
+
"hundred": 100,
|
93 |
+
"thousand": 1_000,
|
94 |
+
"million": 1_000_000,
|
95 |
+
"billion": 1_000_000_000,
|
96 |
+
"trillion": 1_000_000_000_000,
|
97 |
+
"quadrillion": 1_000_000_000_000_000,
|
98 |
+
"quintillion": 1_000_000_000_000_000_000,
|
99 |
+
"sextillion": 1_000_000_000_000_000_000_000,
|
100 |
+
"septillion": 1_000_000_000_000_000_000_000_000,
|
101 |
+
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
102 |
+
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
103 |
+
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
104 |
+
}
|
105 |
+
self.multipliers_plural = {
|
106 |
+
name + "s": (value, "s") for name, value in self.multipliers.items()
|
107 |
+
}
|
108 |
+
self.multipliers_ordinal = {
|
109 |
+
name + "th": (value, "th") for name, value in self.multipliers.items()
|
110 |
+
}
|
111 |
+
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
|
112 |
+
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
113 |
+
|
114 |
+
self.preceding_prefixers = {
|
115 |
+
"minus": "-",
|
116 |
+
"negative": "-",
|
117 |
+
"plus": "+",
|
118 |
+
"positive": "+",
|
119 |
+
}
|
120 |
+
self.following_prefixers = {
|
121 |
+
"pound": "£",
|
122 |
+
"pounds": "£",
|
123 |
+
"euro": "€",
|
124 |
+
"euros": "€",
|
125 |
+
"dollar": "$",
|
126 |
+
"dollars": "$",
|
127 |
+
"cent": "¢",
|
128 |
+
"cents": "¢",
|
129 |
+
}
|
130 |
+
self.prefixes = set(
|
131 |
+
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
|
132 |
+
)
|
133 |
+
self.suffixers = {
|
134 |
+
"per": {"cent": "%"},
|
135 |
+
"percent": "%",
|
136 |
+
}
|
137 |
+
self.specials = {"and", "double", "triple", "point"}
|
138 |
+
|
139 |
+
self.words = set(
|
140 |
+
[
|
141 |
+
key
|
142 |
+
for mapping in [
|
143 |
+
self.zeros,
|
144 |
+
self.ones,
|
145 |
+
self.ones_suffixed,
|
146 |
+
self.tens,
|
147 |
+
self.tens_suffixed,
|
148 |
+
self.multipliers,
|
149 |
+
self.multipliers_suffixed,
|
150 |
+
self.preceding_prefixers,
|
151 |
+
self.following_prefixers,
|
152 |
+
self.suffixers,
|
153 |
+
self.specials,
|
154 |
+
]
|
155 |
+
for key in mapping
|
156 |
+
]
|
157 |
+
)
|
158 |
+
self.literal_words = {"one", "ones"}
|
159 |
+
|
160 |
+
def process_words(self, words: List[str]) -> Iterator[str]:
|
161 |
+
prefix: Optional[str] = None
|
162 |
+
value: Optional[Union[str, int]] = None
|
163 |
+
skip = False
|
164 |
+
|
165 |
+
def to_fraction(s: str):
|
166 |
+
try:
|
167 |
+
return Fraction(s)
|
168 |
+
except ValueError:
|
169 |
+
return None
|
170 |
+
|
171 |
+
def output(result: Union[str, int]):
|
172 |
+
nonlocal prefix, value
|
173 |
+
result = str(result)
|
174 |
+
if prefix is not None:
|
175 |
+
result = prefix + result
|
176 |
+
value = None
|
177 |
+
prefix = None
|
178 |
+
return result
|
179 |
+
|
180 |
+
if len(words) == 0:
|
181 |
+
return
|
182 |
+
|
183 |
+
for prev, current, next in windowed([None] + words + [None], 3):
|
184 |
+
if skip:
|
185 |
+
skip = False
|
186 |
+
continue
|
187 |
+
|
188 |
+
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
189 |
+
has_prefix = current[0] in self.prefixes
|
190 |
+
current_without_prefix = current[1:] if has_prefix else current
|
191 |
+
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
192 |
+
# arabic numbers (potentially with signs and fractions)
|
193 |
+
f = to_fraction(current_without_prefix)
|
194 |
+
assert f is not None
|
195 |
+
if value is not None:
|
196 |
+
if isinstance(value, str) and value.endswith("."):
|
197 |
+
# concatenate decimals / ip address components
|
198 |
+
value = str(value) + str(current)
|
199 |
+
continue
|
200 |
+
else:
|
201 |
+
yield output(value)
|
202 |
+
|
203 |
+
prefix = current[0] if has_prefix else prefix
|
204 |
+
if f.denominator == 1:
|
205 |
+
value = f.numerator # store integers as int
|
206 |
+
else:
|
207 |
+
value = current_without_prefix
|
208 |
+
elif current not in self.words:
|
209 |
+
# non-numeric words
|
210 |
+
if value is not None:
|
211 |
+
yield output(value)
|
212 |
+
yield output(current)
|
213 |
+
elif current in self.zeros:
|
214 |
+
value = str(value or "") + "0"
|
215 |
+
elif current in self.ones:
|
216 |
+
ones = self.ones[current]
|
217 |
+
|
218 |
+
if value is None:
|
219 |
+
value = ones
|
220 |
+
elif isinstance(value, str) or prev in self.ones:
|
221 |
+
if prev in self.tens and ones < 10: # replace the last zero with the digit
|
222 |
+
assert value[-1] == "0"
|
223 |
+
value = value[:-1] + str(ones)
|
224 |
+
else:
|
225 |
+
value = str(value) + str(ones)
|
226 |
+
elif ones < 10:
|
227 |
+
if value % 10 == 0:
|
228 |
+
value += ones
|
229 |
+
else:
|
230 |
+
value = str(value) + str(ones)
|
231 |
+
else: # eleven to nineteen
|
232 |
+
if value % 100 == 0:
|
233 |
+
value += ones
|
234 |
+
else:
|
235 |
+
value = str(value) + str(ones)
|
236 |
+
elif current in self.ones_suffixed:
|
237 |
+
# ordinal or cardinal; yield the number right away
|
238 |
+
ones, suffix = self.ones_suffixed[current]
|
239 |
+
if value is None:
|
240 |
+
yield output(str(ones) + suffix)
|
241 |
+
elif isinstance(value, str) or prev in self.ones:
|
242 |
+
if prev in self.tens and ones < 10:
|
243 |
+
assert value[-1] == "0"
|
244 |
+
yield output(value[:-1] + str(ones) + suffix)
|
245 |
+
else:
|
246 |
+
yield output(str(value) + str(ones) + suffix)
|
247 |
+
elif ones < 10:
|
248 |
+
if value % 10 == 0:
|
249 |
+
yield output(str(value + ones) + suffix)
|
250 |
+
else:
|
251 |
+
yield output(str(value) + str(ones) + suffix)
|
252 |
+
else: # eleven to nineteen
|
253 |
+
if value % 100 == 0:
|
254 |
+
yield output(str(value + ones) + suffix)
|
255 |
+
else:
|
256 |
+
yield output(str(value) + str(ones) + suffix)
|
257 |
+
value = None
|
258 |
+
elif current in self.tens:
|
259 |
+
tens = self.tens[current]
|
260 |
+
if value is None:
|
261 |
+
value = tens
|
262 |
+
elif isinstance(value, str):
|
263 |
+
value = str(value) + str(tens)
|
264 |
+
else:
|
265 |
+
if value % 100 == 0:
|
266 |
+
value += tens
|
267 |
+
else:
|
268 |
+
value = str(value) + str(tens)
|
269 |
+
elif current in self.tens_suffixed:
|
270 |
+
# ordinal or cardinal; yield the number right away
|
271 |
+
tens, suffix = self.tens_suffixed[current]
|
272 |
+
if value is None:
|
273 |
+
yield output(str(tens) + suffix)
|
274 |
+
elif isinstance(value, str):
|
275 |
+
yield output(str(value) + str(tens) + suffix)
|
276 |
+
else:
|
277 |
+
if value % 100 == 0:
|
278 |
+
yield output(str(value + tens) + suffix)
|
279 |
+
else:
|
280 |
+
yield output(str(value) + str(tens) + suffix)
|
281 |
+
elif current in self.multipliers:
|
282 |
+
multiplier = self.multipliers[current]
|
283 |
+
if value is None:
|
284 |
+
value = multiplier
|
285 |
+
elif isinstance(value, str) or value == 0:
|
286 |
+
f = to_fraction(value)
|
287 |
+
p = f * multiplier if f is not None else None
|
288 |
+
if f is not None and p.denominator == 1:
|
289 |
+
value = p.numerator
|
290 |
+
else:
|
291 |
+
yield output(value)
|
292 |
+
value = multiplier
|
293 |
+
else:
|
294 |
+
before = value // 1000 * 1000
|
295 |
+
residual = value % 1000
|
296 |
+
value = before + residual * multiplier
|
297 |
+
elif current in self.multipliers_suffixed:
|
298 |
+
multiplier, suffix = self.multipliers_suffixed[current]
|
299 |
+
if value is None:
|
300 |
+
yield output(str(multiplier) + suffix)
|
301 |
+
elif isinstance(value, str):
|
302 |
+
f = to_fraction(value)
|
303 |
+
p = f * multiplier if f is not None else None
|
304 |
+
if f is not None and p.denominator == 1:
|
305 |
+
yield output(str(p.numerator) + suffix)
|
306 |
+
else:
|
307 |
+
yield output(value)
|
308 |
+
yield output(str(multiplier) + suffix)
|
309 |
+
else: # int
|
310 |
+
before = value // 1000 * 1000
|
311 |
+
residual = value % 1000
|
312 |
+
value = before + residual * multiplier
|
313 |
+
yield output(str(value) + suffix)
|
314 |
+
value = None
|
315 |
+
elif current in self.preceding_prefixers:
|
316 |
+
# apply prefix (positive, minus, etc.) if it precedes a number
|
317 |
+
if value is not None:
|
318 |
+
yield output(value)
|
319 |
+
|
320 |
+
if next in self.words or next_is_numeric:
|
321 |
+
prefix = self.preceding_prefixers[current]
|
322 |
+
else:
|
323 |
+
yield output(current)
|
324 |
+
elif current in self.following_prefixers:
|
325 |
+
# apply prefix (dollars, cents, etc.) only after a number
|
326 |
+
if value is not None:
|
327 |
+
prefix = self.following_prefixers[current]
|
328 |
+
yield output(value)
|
329 |
+
else:
|
330 |
+
yield output(current)
|
331 |
+
elif current in self.suffixers:
|
332 |
+
# apply suffix symbols (percent -> '%')
|
333 |
+
if value is not None:
|
334 |
+
suffix = self.suffixers[current]
|
335 |
+
if isinstance(suffix, dict):
|
336 |
+
if next in suffix:
|
337 |
+
yield output(str(value) + suffix[next])
|
338 |
+
skip = True
|
339 |
+
else:
|
340 |
+
yield output(value)
|
341 |
+
yield output(current)
|
342 |
+
else:
|
343 |
+
yield output(str(value) + suffix)
|
344 |
+
else:
|
345 |
+
yield output(current)
|
346 |
+
elif current in self.specials:
|
347 |
+
if next not in self.words and not next_is_numeric:
|
348 |
+
# apply special handling only if the next word can be numeric
|
349 |
+
if value is not None:
|
350 |
+
yield output(value)
|
351 |
+
yield output(current)
|
352 |
+
elif current == "and":
|
353 |
+
# ignore "and" after hundreds, thousands, etc.
|
354 |
+
if prev not in self.multipliers:
|
355 |
+
if value is not None:
|
356 |
+
yield output(value)
|
357 |
+
yield output(current)
|
358 |
+
elif current == "double" or current == "triple":
|
359 |
+
if next in self.ones or next in self.zeros:
|
360 |
+
repeats = 2 if current == "double" else 3
|
361 |
+
ones = self.ones.get(next, 0)
|
362 |
+
value = str(value or "") + str(ones) * repeats
|
363 |
+
skip = True
|
364 |
+
else:
|
365 |
+
if value is not None:
|
366 |
+
yield output(value)
|
367 |
+
yield output(current)
|
368 |
+
elif current == "point":
|
369 |
+
if next in self.decimals or next_is_numeric:
|
370 |
+
value = str(value or "") + "."
|
371 |
+
else:
|
372 |
+
# should all have been covered at this point
|
373 |
+
raise ValueError(f"Unexpected token: {current}")
|
374 |
+
else:
|
375 |
+
# all should have been covered at this point
|
376 |
+
raise ValueError(f"Unexpected token: {current}")
|
377 |
+
|
378 |
+
if value is not None:
|
379 |
+
yield output(value)
|
380 |
+
|
381 |
+
def preprocess(self, s: str):
|
382 |
+
# replace "<number> and a half" with "<number> point five"
|
383 |
+
results = []
|
384 |
+
|
385 |
+
segments = re.split(r"\band\s+a\s+half\b", s)
|
386 |
+
for i, segment in enumerate(segments):
|
387 |
+
if len(segment.strip()) == 0:
|
388 |
+
continue
|
389 |
+
if i == len(segments) - 1:
|
390 |
+
results.append(segment)
|
391 |
+
else:
|
392 |
+
results.append(segment)
|
393 |
+
last_word = segment.rsplit(maxsplit=2)[-1]
|
394 |
+
if last_word in self.decimals or last_word in self.multipliers:
|
395 |
+
results.append("point five")
|
396 |
+
else:
|
397 |
+
results.append("and a half")
|
398 |
+
|
399 |
+
s = " ".join(results)
|
400 |
+
|
401 |
+
# put a space at number/letter boundary
|
402 |
+
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
403 |
+
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
404 |
+
|
405 |
+
# but remove spaces which could be a suffix
|
406 |
+
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
407 |
+
|
408 |
+
return s
|
409 |
+
|
410 |
+
def postprocess(self, s: str):
|
411 |
+
def combine_cents(m: Match):
|
412 |
+
try:
|
413 |
+
currency = m.group(1)
|
414 |
+
integer = m.group(2)
|
415 |
+
cents = int(m.group(3))
|
416 |
+
return f"{currency}{integer}.{cents:02d}"
|
417 |
+
except ValueError:
|
418 |
+
return m.string
|
419 |
+
|
420 |
+
def extract_cents(m: Match):
|
421 |
+
try:
|
422 |
+
return f"¢{int(m.group(1))}"
|
423 |
+
except ValueError:
|
424 |
+
return m.string
|
425 |
+
|
426 |
+
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
427 |
+
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
428 |
+
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
429 |
+
|
430 |
+
# write "one(s)" instead of "1(s)", just for the readability
|
431 |
+
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
432 |
+
|
433 |
+
return s
|
434 |
+
|
435 |
+
def __call__(self, s: str):
|
436 |
+
s = self.preprocess(s)
|
437 |
+
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
438 |
+
s = self.postprocess(s)
|
439 |
+
|
440 |
+
return s
|
441 |
+
|
442 |
+
|
443 |
+
class EnglishSpellingNormalizer:
|
444 |
+
"""
|
445 |
+
Applies British-American spelling mappings as listed in [1].
|
446 |
+
|
447 |
+
[1] https://www.tysto.com/uk-us-spelling-list.html
|
448 |
+
"""
|
449 |
+
|
450 |
+
def __init__(self):
|
451 |
+
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
452 |
+
self.mapping = json.load(open(mapping_path))
|
453 |
+
|
454 |
+
def __call__(self, s: str):
|
455 |
+
return " ".join(self.mapping.get(word, word) for word in s.split())
|
456 |
+
|
457 |
+
|
458 |
+
class EnglishTextNormalizer:
|
459 |
+
def __init__(self):
|
460 |
+
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
461 |
+
self.replacers = {
|
462 |
+
# common contractions
|
463 |
+
r"\bwon't\b": "will not",
|
464 |
+
r"\bcan't\b": "can not",
|
465 |
+
r"\blet's\b": "let us",
|
466 |
+
r"\bain't\b": "aint",
|
467 |
+
r"\by'all\b": "you all",
|
468 |
+
r"\bwanna\b": "want to",
|
469 |
+
r"\bgotta\b": "got to",
|
470 |
+
r"\bgonna\b": "going to",
|
471 |
+
r"\bi'ma\b": "i am going to",
|
472 |
+
r"\bimma\b": "i am going to",
|
473 |
+
r"\bwoulda\b": "would have",
|
474 |
+
r"\bcoulda\b": "could have",
|
475 |
+
r"\bshoulda\b": "should have",
|
476 |
+
r"\bma'am\b": "madam",
|
477 |
+
# contractions in titles/prefixes
|
478 |
+
r"\bmr\b": "mister ",
|
479 |
+
r"\bmrs\b": "missus ",
|
480 |
+
r"\bst\b": "saint ",
|
481 |
+
r"\bdr\b": "doctor ",
|
482 |
+
r"\bprof\b": "professor ",
|
483 |
+
r"\bcapt\b": "captain ",
|
484 |
+
r"\bgov\b": "governor ",
|
485 |
+
r"\bald\b": "alderman ",
|
486 |
+
r"\bgen\b": "general ",
|
487 |
+
r"\bsen\b": "senator ",
|
488 |
+
r"\brep\b": "representative ",
|
489 |
+
r"\bpres\b": "president ",
|
490 |
+
r"\brev\b": "reverend ",
|
491 |
+
r"\bhon\b": "honorable ",
|
492 |
+
r"\basst\b": "assistant ",
|
493 |
+
r"\bassoc\b": "associate ",
|
494 |
+
r"\blt\b": "lieutenant ",
|
495 |
+
r"\bcol\b": "colonel ",
|
496 |
+
r"\bjr\b": "junior ",
|
497 |
+
r"\bsr\b": "senior ",
|
498 |
+
r"\besq\b": "esquire ",
|
499 |
+
# prefect tenses, ideally it should be any past participles, but it's harder..
|
500 |
+
r"'d been\b": " had been",
|
501 |
+
r"'s been\b": " has been",
|
502 |
+
r"'d gone\b": " had gone",
|
503 |
+
r"'s gone\b": " has gone",
|
504 |
+
r"'d done\b": " had done", # "'s done" is ambiguous
|
505 |
+
r"'s got\b": " has got",
|
506 |
+
# general contractions
|
507 |
+
r"n't\b": " not",
|
508 |
+
r"'re\b": " are",
|
509 |
+
r"'s\b": " is",
|
510 |
+
r"'d\b": " would",
|
511 |
+
r"'ll\b": " will",
|
512 |
+
r"'t\b": " not",
|
513 |
+
r"'ve\b": " have",
|
514 |
+
r"'m\b": " am",
|
515 |
+
}
|
516 |
+
self.standardize_numbers = EnglishNumberNormalizer()
|
517 |
+
self.standardize_spellings = EnglishSpellingNormalizer()
|
518 |
+
|
519 |
+
def __call__(self, s: str):
|
520 |
+
s = s.lower()
|
521 |
+
|
522 |
+
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
523 |
+
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
524 |
+
s = re.sub(self.ignore_patterns, "", s)
|
525 |
+
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
|
526 |
+
|
527 |
+
for pattern, replacement in self.replacers.items():
|
528 |
+
s = re.sub(pattern, replacement, s)
|
529 |
+
|
530 |
+
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
531 |
+
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
532 |
+
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
|
533 |
+
|
534 |
+
s = self.standardize_numbers(s)
|
535 |
+
s = self.standardize_spellings(s)
|
536 |
+
|
537 |
+
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
538 |
+
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
539 |
+
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
540 |
+
|
541 |
+
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
542 |
+
|
543 |
+
return s
|
latentsync/whisper/whisper/tokenizer.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from functools import lru_cache
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from transformers import GPT2TokenizerFast
|
9 |
+
|
10 |
+
LANGUAGES = {
|
11 |
+
"en": "english",
|
12 |
+
"zh": "chinese",
|
13 |
+
"de": "german",
|
14 |
+
"es": "spanish",
|
15 |
+
"ru": "russian",
|
16 |
+
"ko": "korean",
|
17 |
+
"fr": "french",
|
18 |
+
"ja": "japanese",
|
19 |
+
"pt": "portuguese",
|
20 |
+
"tr": "turkish",
|
21 |
+
"pl": "polish",
|
22 |
+
"ca": "catalan",
|
23 |
+
"nl": "dutch",
|
24 |
+
"ar": "arabic",
|
25 |
+
"sv": "swedish",
|
26 |
+
"it": "italian",
|
27 |
+
"id": "indonesian",
|
28 |
+
"hi": "hindi",
|
29 |
+
"fi": "finnish",
|
30 |
+
"vi": "vietnamese",
|
31 |
+
"iw": "hebrew",
|
32 |
+
"uk": "ukrainian",
|
33 |
+
"el": "greek",
|
34 |
+
"ms": "malay",
|
35 |
+
"cs": "czech",
|
36 |
+
"ro": "romanian",
|
37 |
+
"da": "danish",
|
38 |
+
"hu": "hungarian",
|
39 |
+
"ta": "tamil",
|
40 |
+
"no": "norwegian",
|
41 |
+
"th": "thai",
|
42 |
+
"ur": "urdu",
|
43 |
+
"hr": "croatian",
|
44 |
+
"bg": "bulgarian",
|
45 |
+
"lt": "lithuanian",
|
46 |
+
"la": "latin",
|
47 |
+
"mi": "maori",
|
48 |
+
"ml": "malayalam",
|
49 |
+
"cy": "welsh",
|
50 |
+
"sk": "slovak",
|
51 |
+
"te": "telugu",
|
52 |
+
"fa": "persian",
|
53 |
+
"lv": "latvian",
|
54 |
+
"bn": "bengali",
|
55 |
+
"sr": "serbian",
|
56 |
+
"az": "azerbaijani",
|
57 |
+
"sl": "slovenian",
|
58 |
+
"kn": "kannada",
|
59 |
+
"et": "estonian",
|
60 |
+
"mk": "macedonian",
|
61 |
+
"br": "breton",
|
62 |
+
"eu": "basque",
|
63 |
+
"is": "icelandic",
|
64 |
+
"hy": "armenian",
|
65 |
+
"ne": "nepali",
|
66 |
+
"mn": "mongolian",
|
67 |
+
"bs": "bosnian",
|
68 |
+
"kk": "kazakh",
|
69 |
+
"sq": "albanian",
|
70 |
+
"sw": "swahili",
|
71 |
+
"gl": "galician",
|
72 |
+
"mr": "marathi",
|
73 |
+
"pa": "punjabi",
|
74 |
+
"si": "sinhala",
|
75 |
+
"km": "khmer",
|
76 |
+
"sn": "shona",
|
77 |
+
"yo": "yoruba",
|
78 |
+
"so": "somali",
|
79 |
+
"af": "afrikaans",
|
80 |
+
"oc": "occitan",
|
81 |
+
"ka": "georgian",
|
82 |
+
"be": "belarusian",
|
83 |
+
"tg": "tajik",
|
84 |
+
"sd": "sindhi",
|
85 |
+
"gu": "gujarati",
|
86 |
+
"am": "amharic",
|
87 |
+
"yi": "yiddish",
|
88 |
+
"lo": "lao",
|
89 |
+
"uz": "uzbek",
|
90 |
+
"fo": "faroese",
|
91 |
+
"ht": "haitian creole",
|
92 |
+
"ps": "pashto",
|
93 |
+
"tk": "turkmen",
|
94 |
+
"nn": "nynorsk",
|
95 |
+
"mt": "maltese",
|
96 |
+
"sa": "sanskrit",
|
97 |
+
"lb": "luxembourgish",
|
98 |
+
"my": "myanmar",
|
99 |
+
"bo": "tibetan",
|
100 |
+
"tl": "tagalog",
|
101 |
+
"mg": "malagasy",
|
102 |
+
"as": "assamese",
|
103 |
+
"tt": "tatar",
|
104 |
+
"haw": "hawaiian",
|
105 |
+
"ln": "lingala",
|
106 |
+
"ha": "hausa",
|
107 |
+
"ba": "bashkir",
|
108 |
+
"jw": "javanese",
|
109 |
+
"su": "sundanese",
|
110 |
+
}
|
111 |
+
|
112 |
+
# language code lookup by name, with a few language aliases
|
113 |
+
TO_LANGUAGE_CODE = {
|
114 |
+
**{language: code for code, language in LANGUAGES.items()},
|
115 |
+
"burmese": "my",
|
116 |
+
"valencian": "ca",
|
117 |
+
"flemish": "nl",
|
118 |
+
"haitian": "ht",
|
119 |
+
"letzeburgesch": "lb",
|
120 |
+
"pushto": "ps",
|
121 |
+
"panjabi": "pa",
|
122 |
+
"moldavian": "ro",
|
123 |
+
"moldovan": "ro",
|
124 |
+
"sinhalese": "si",
|
125 |
+
"castilian": "es",
|
126 |
+
}
|
127 |
+
|
128 |
+
|
129 |
+
@dataclass(frozen=True)
|
130 |
+
class Tokenizer:
|
131 |
+
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
|
132 |
+
|
133 |
+
tokenizer: "GPT2TokenizerFast"
|
134 |
+
language: Optional[str]
|
135 |
+
sot_sequence: Tuple[int]
|
136 |
+
|
137 |
+
def encode(self, text, **kwargs):
|
138 |
+
return self.tokenizer.encode(text, **kwargs)
|
139 |
+
|
140 |
+
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
|
141 |
+
return self.tokenizer.decode(token_ids, **kwargs)
|
142 |
+
|
143 |
+
def decode_with_timestamps(self, tokens) -> str:
|
144 |
+
"""
|
145 |
+
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
146 |
+
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
147 |
+
"""
|
148 |
+
outputs = [[]]
|
149 |
+
for token in tokens:
|
150 |
+
if token >= self.timestamp_begin:
|
151 |
+
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
152 |
+
outputs.append(timestamp)
|
153 |
+
outputs.append([])
|
154 |
+
else:
|
155 |
+
outputs[-1].append(token)
|
156 |
+
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
157 |
+
return "".join(outputs)
|
158 |
+
|
159 |
+
@property
|
160 |
+
@lru_cache()
|
161 |
+
def eot(self) -> int:
|
162 |
+
return self.tokenizer.eos_token_id
|
163 |
+
|
164 |
+
@property
|
165 |
+
@lru_cache()
|
166 |
+
def sot(self) -> int:
|
167 |
+
return self._get_single_token_id("<|startoftranscript|>")
|
168 |
+
|
169 |
+
@property
|
170 |
+
@lru_cache()
|
171 |
+
def sot_lm(self) -> int:
|
172 |
+
return self._get_single_token_id("<|startoflm|>")
|
173 |
+
|
174 |
+
@property
|
175 |
+
@lru_cache()
|
176 |
+
def sot_prev(self) -> int:
|
177 |
+
return self._get_single_token_id("<|startofprev|>")
|
178 |
+
|
179 |
+
@property
|
180 |
+
@lru_cache()
|
181 |
+
def no_speech(self) -> int:
|
182 |
+
return self._get_single_token_id("<|nospeech|>")
|
183 |
+
|
184 |
+
@property
|
185 |
+
@lru_cache()
|
186 |
+
def no_timestamps(self) -> int:
|
187 |
+
return self._get_single_token_id("<|notimestamps|>")
|
188 |
+
|
189 |
+
@property
|
190 |
+
@lru_cache()
|
191 |
+
def timestamp_begin(self) -> int:
|
192 |
+
return self.tokenizer.all_special_ids[-1] + 1
|
193 |
+
|
194 |
+
@property
|
195 |
+
@lru_cache()
|
196 |
+
def language_token(self) -> int:
|
197 |
+
"""Returns the token id corresponding to the value of the `language` field"""
|
198 |
+
if self.language is None:
|
199 |
+
raise ValueError(f"This tokenizer does not have language token configured")
|
200 |
+
|
201 |
+
additional_tokens = dict(
|
202 |
+
zip(
|
203 |
+
self.tokenizer.additional_special_tokens,
|
204 |
+
self.tokenizer.additional_special_tokens_ids,
|
205 |
+
)
|
206 |
+
)
|
207 |
+
candidate = f"<|{self.language}|>"
|
208 |
+
if candidate in additional_tokens:
|
209 |
+
return additional_tokens[candidate]
|
210 |
+
|
211 |
+
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
212 |
+
|
213 |
+
@property
|
214 |
+
@lru_cache()
|
215 |
+
def all_language_tokens(self) -> Tuple[int]:
|
216 |
+
result = []
|
217 |
+
for token, token_id in zip(
|
218 |
+
self.tokenizer.additional_special_tokens,
|
219 |
+
self.tokenizer.additional_special_tokens_ids,
|
220 |
+
):
|
221 |
+
if token.strip("<|>") in LANGUAGES:
|
222 |
+
result.append(token_id)
|
223 |
+
return tuple(result)
|
224 |
+
|
225 |
+
@property
|
226 |
+
@lru_cache()
|
227 |
+
def all_language_codes(self) -> Tuple[str]:
|
228 |
+
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
229 |
+
|
230 |
+
@property
|
231 |
+
@lru_cache()
|
232 |
+
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
233 |
+
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
234 |
+
|
235 |
+
@property
|
236 |
+
@lru_cache()
|
237 |
+
def non_speech_tokens(self) -> Tuple[int]:
|
238 |
+
"""
|
239 |
+
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
240 |
+
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
241 |
+
|
242 |
+
- ♪♪♪
|
243 |
+
- ( SPEAKING FOREIGN LANGUAGE )
|
244 |
+
- [DAVID] Hey there,
|
245 |
+
|
246 |
+
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
247 |
+
"""
|
248 |
+
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
|
249 |
+
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
250 |
+
|
251 |
+
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
252 |
+
# In case they're multiple tokens, suppress the first token, which is safe because:
|
253 |
+
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
254 |
+
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
255 |
+
miscellaneous = set("♩♪♫♬♭♮♯")
|
256 |
+
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
257 |
+
|
258 |
+
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
259 |
+
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
|
260 |
+
for symbol in symbols + list(miscellaneous):
|
261 |
+
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
|
262 |
+
if len(tokens) == 1 or symbol in miscellaneous:
|
263 |
+
result.add(tokens[0])
|
264 |
+
|
265 |
+
return tuple(sorted(result))
|
266 |
+
|
267 |
+
def _get_single_token_id(self, text) -> int:
|
268 |
+
tokens = self.tokenizer.encode(text)
|
269 |
+
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
270 |
+
return tokens[0]
|
271 |
+
|
272 |
+
|
273 |
+
@lru_cache(maxsize=None)
|
274 |
+
def build_tokenizer(name: str = "gpt2"):
|
275 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
276 |
+
path = os.path.join(os.path.dirname(__file__), "assets", name)
|
277 |
+
tokenizer = GPT2TokenizerFast.from_pretrained(path)
|
278 |
+
|
279 |
+
specials = [
|
280 |
+
"<|startoftranscript|>",
|
281 |
+
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
282 |
+
"<|translate|>",
|
283 |
+
"<|transcribe|>",
|
284 |
+
"<|startoflm|>",
|
285 |
+
"<|startofprev|>",
|
286 |
+
"<|nospeech|>",
|
287 |
+
"<|notimestamps|>",
|
288 |
+
]
|
289 |
+
|
290 |
+
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
291 |
+
return tokenizer
|
292 |
+
|
293 |
+
|
294 |
+
@lru_cache(maxsize=None)
|
295 |
+
def get_tokenizer(
|
296 |
+
multilingual: bool,
|
297 |
+
*,
|
298 |
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
299 |
+
language: Optional[str] = None,
|
300 |
+
) -> Tokenizer:
|
301 |
+
if language is not None:
|
302 |
+
language = language.lower()
|
303 |
+
if language not in LANGUAGES:
|
304 |
+
if language in TO_LANGUAGE_CODE:
|
305 |
+
language = TO_LANGUAGE_CODE[language]
|
306 |
+
else:
|
307 |
+
raise ValueError(f"Unsupported language: {language}")
|
308 |
+
|
309 |
+
if multilingual:
|
310 |
+
tokenizer_name = "multilingual"
|
311 |
+
task = task or "transcribe"
|
312 |
+
language = language or "en"
|
313 |
+
else:
|
314 |
+
tokenizer_name = "gpt2"
|
315 |
+
task = None
|
316 |
+
language = None
|
317 |
+
|
318 |
+
tokenizer = build_tokenizer(name=tokenizer_name)
|
319 |
+
all_special_ids: List[int] = tokenizer.all_special_ids
|
320 |
+
sot: int = all_special_ids[1]
|
321 |
+
translate: int = all_special_ids[-6]
|
322 |
+
transcribe: int = all_special_ids[-5]
|
323 |
+
|
324 |
+
langs = tuple(LANGUAGES.keys())
|
325 |
+
sot_sequence = [sot]
|
326 |
+
if language is not None:
|
327 |
+
sot_sequence.append(sot + 1 + langs.index(language))
|
328 |
+
if task is not None:
|
329 |
+
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
330 |
+
|
331 |
+
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
|
latentsync/whisper/whisper/transcribe.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import warnings
|
4 |
+
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import tqdm
|
9 |
+
|
10 |
+
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
11 |
+
from .decoding import DecodingOptions, DecodingResult
|
12 |
+
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
13 |
+
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
|
14 |
+
|
15 |
+
if TYPE_CHECKING:
|
16 |
+
from .model import Whisper
|
17 |
+
|
18 |
+
|
19 |
+
def transcribe(
|
20 |
+
model: "Whisper",
|
21 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
22 |
+
*,
|
23 |
+
verbose: Optional[bool] = None,
|
24 |
+
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
25 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
26 |
+
logprob_threshold: Optional[float] = -1.0,
|
27 |
+
no_speech_threshold: Optional[float] = 0.6,
|
28 |
+
condition_on_previous_text: bool = True,
|
29 |
+
force_extraction: bool = False,
|
30 |
+
**decode_options,
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Transcribe an audio file using Whisper
|
34 |
+
|
35 |
+
Parameters
|
36 |
+
----------
|
37 |
+
model: Whisper
|
38 |
+
The Whisper model instance
|
39 |
+
|
40 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
41 |
+
The path to the audio file to open, or the audio waveform
|
42 |
+
|
43 |
+
verbose: bool
|
44 |
+
Whether to display the text being decoded to the console. If True, displays all the details,
|
45 |
+
If False, displays minimal details. If None, does not display anything
|
46 |
+
|
47 |
+
temperature: Union[float, Tuple[float, ...]]
|
48 |
+
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
49 |
+
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
50 |
+
|
51 |
+
compression_ratio_threshold: float
|
52 |
+
If the gzip compression ratio is above this value, treat as failed
|
53 |
+
|
54 |
+
logprob_threshold: float
|
55 |
+
If the average log probability over sampled tokens is below this value, treat as failed
|
56 |
+
|
57 |
+
no_speech_threshold: float
|
58 |
+
If the no_speech probability is higher than this value AND the average log probability
|
59 |
+
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
60 |
+
|
61 |
+
condition_on_previous_text: bool
|
62 |
+
if True, the previous output of the model is provided as a prompt for the next window;
|
63 |
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
64 |
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
65 |
+
|
66 |
+
decode_options: dict
|
67 |
+
Keyword arguments to construct `DecodingOptions` instances
|
68 |
+
|
69 |
+
Returns
|
70 |
+
-------
|
71 |
+
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
72 |
+
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
73 |
+
"""
|
74 |
+
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
75 |
+
if model.device == torch.device("cpu"):
|
76 |
+
if torch.cuda.is_available():
|
77 |
+
warnings.warn("Performing inference on CPU when CUDA is available")
|
78 |
+
if dtype == torch.float16:
|
79 |
+
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
80 |
+
dtype = torch.float32
|
81 |
+
|
82 |
+
if dtype == torch.float32:
|
83 |
+
decode_options["fp16"] = False
|
84 |
+
|
85 |
+
mel = log_mel_spectrogram(audio)
|
86 |
+
|
87 |
+
all_segments = []
|
88 |
+
def add_segment(
|
89 |
+
*, start: float, end: float, encoder_embeddings
|
90 |
+
):
|
91 |
+
|
92 |
+
all_segments.append(
|
93 |
+
{
|
94 |
+
"start": start,
|
95 |
+
"end": end,
|
96 |
+
"encoder_embeddings":encoder_embeddings,
|
97 |
+
}
|
98 |
+
)
|
99 |
+
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
100 |
+
num_frames = mel.shape[-1]
|
101 |
+
seek = 0
|
102 |
+
previous_seek_value = seek
|
103 |
+
sample_skip = 3000 #
|
104 |
+
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
105 |
+
while seek < num_frames:
|
106 |
+
# seek是开始的帧数
|
107 |
+
end_seek = min(seek + sample_skip, num_frames)
|
108 |
+
segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype)
|
109 |
+
|
110 |
+
single = segment.ndim == 2
|
111 |
+
if single:
|
112 |
+
segment = segment.unsqueeze(0)
|
113 |
+
if dtype == torch.float16:
|
114 |
+
segment = segment.half()
|
115 |
+
audio_features, embeddings = model.encoder(segment, include_embeddings = True)
|
116 |
+
|
117 |
+
encoder_embeddings = embeddings
|
118 |
+
#print(f"encoder_embeddings shape {encoder_embeddings.shape}")
|
119 |
+
add_segment(
|
120 |
+
start=seek,
|
121 |
+
end=end_seek,
|
122 |
+
#text_tokens=tokens,
|
123 |
+
#result=result,
|
124 |
+
encoder_embeddings=encoder_embeddings,
|
125 |
+
)
|
126 |
+
seek+=sample_skip
|
127 |
+
|
128 |
+
return dict(segments=all_segments)
|
129 |
+
|
130 |
+
|
131 |
+
def cli():
|
132 |
+
from . import available_models
|
133 |
+
|
134 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
135 |
+
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
136 |
+
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
137 |
+
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
138 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
139 |
+
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
140 |
+
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
141 |
+
|
142 |
+
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
143 |
+
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
144 |
+
|
145 |
+
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
146 |
+
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
147 |
+
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
148 |
+
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
149 |
+
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
150 |
+
|
151 |
+
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
152 |
+
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
153 |
+
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
154 |
+
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
155 |
+
|
156 |
+
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
157 |
+
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
158 |
+
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
159 |
+
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
160 |
+
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
161 |
+
|
162 |
+
args = parser.parse_args().__dict__
|
163 |
+
model_name: str = args.pop("model")
|
164 |
+
model_dir: str = args.pop("model_dir")
|
165 |
+
output_dir: str = args.pop("output_dir")
|
166 |
+
device: str = args.pop("device")
|
167 |
+
os.makedirs(output_dir, exist_ok=True)
|
168 |
+
|
169 |
+
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
170 |
+
if args["language"] is not None:
|
171 |
+
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
172 |
+
args["language"] = "en"
|
173 |
+
|
174 |
+
temperature = args.pop("temperature")
|
175 |
+
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
176 |
+
if temperature_increment_on_fallback is not None:
|
177 |
+
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
178 |
+
else:
|
179 |
+
temperature = [temperature]
|
180 |
+
|
181 |
+
threads = args.pop("threads")
|
182 |
+
if threads > 0:
|
183 |
+
torch.set_num_threads(threads)
|
184 |
+
|
185 |
+
from . import load_model
|
186 |
+
model = load_model(model_name, device=device, download_root=model_dir)
|
187 |
+
|
188 |
+
for audio_path in args.pop("audio"):
|
189 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
190 |
+
|
191 |
+
audio_basename = os.path.basename(audio_path)
|
192 |
+
|
193 |
+
# save TXT
|
194 |
+
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
|
195 |
+
write_txt(result["segments"], file=txt)
|
196 |
+
|
197 |
+
# save VTT
|
198 |
+
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
|
199 |
+
write_vtt(result["segments"], file=vtt)
|
200 |
+
|
201 |
+
# save SRT
|
202 |
+
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
203 |
+
write_srt(result["segments"], file=srt)
|
204 |
+
|
205 |
+
|
206 |
+
if __name__ == '__main__':
|
207 |
+
cli()
|
latentsync/whisper/whisper/utils.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import zlib
|
2 |
+
from typing import Iterator, TextIO
|
3 |
+
|
4 |
+
|
5 |
+
def exact_div(x, y):
|
6 |
+
assert x % y == 0
|
7 |
+
return x // y
|
8 |
+
|
9 |
+
|
10 |
+
def str2bool(string):
|
11 |
+
str2val = {"True": True, "False": False}
|
12 |
+
if string in str2val:
|
13 |
+
return str2val[string]
|
14 |
+
else:
|
15 |
+
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
16 |
+
|
17 |
+
|
18 |
+
def optional_int(string):
|
19 |
+
return None if string == "None" else int(string)
|
20 |
+
|
21 |
+
|
22 |
+
def optional_float(string):
|
23 |
+
return None if string == "None" else float(string)
|
24 |
+
|
25 |
+
|
26 |
+
def compression_ratio(text) -> float:
|
27 |
+
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
28 |
+
|
29 |
+
|
30 |
+
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
|
31 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
32 |
+
milliseconds = round(seconds * 1000.0)
|
33 |
+
|
34 |
+
hours = milliseconds // 3_600_000
|
35 |
+
milliseconds -= hours * 3_600_000
|
36 |
+
|
37 |
+
minutes = milliseconds // 60_000
|
38 |
+
milliseconds -= minutes * 60_000
|
39 |
+
|
40 |
+
seconds = milliseconds // 1_000
|
41 |
+
milliseconds -= seconds * 1_000
|
42 |
+
|
43 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
44 |
+
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
45 |
+
|
46 |
+
|
47 |
+
def write_txt(transcript: Iterator[dict], file: TextIO):
|
48 |
+
for segment in transcript:
|
49 |
+
print(segment['text'].strip(), file=file, flush=True)
|
50 |
+
|
51 |
+
|
52 |
+
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
53 |
+
print("WEBVTT\n", file=file)
|
54 |
+
for segment in transcript:
|
55 |
+
print(
|
56 |
+
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
57 |
+
f"{segment['text'].strip().replace('-->', '->')}\n",
|
58 |
+
file=file,
|
59 |
+
flush=True,
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def write_srt(transcript: Iterator[dict], file: TextIO):
|
64 |
+
"""
|
65 |
+
Write a transcript to a file in SRT format.
|
66 |
+
|
67 |
+
Example usage:
|
68 |
+
from pathlib import Path
|
69 |
+
from whisper.utils import write_srt
|
70 |
+
|
71 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
72 |
+
|
73 |
+
# save SRT
|
74 |
+
audio_basename = Path(audio_path).stem
|
75 |
+
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
76 |
+
write_srt(result["segments"], file=srt)
|
77 |
+
"""
|
78 |
+
for i, segment in enumerate(transcript, start=1):
|
79 |
+
# write srt lines
|
80 |
+
print(
|
81 |
+
f"{i}\n"
|
82 |
+
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
83 |
+
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
84 |
+
f"{segment['text'].strip().replace('-->', '->')}\n",
|
85 |
+
file=file,
|
86 |
+
flush=True,
|
87 |
+
)
|
preprocess/affine_transform.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from latentsync.utils.util import read_video, write_video
|
16 |
+
from latentsync.utils.image_processor import ImageProcessor
|
17 |
+
import torch
|
18 |
+
from einops import rearrange
|
19 |
+
import os
|
20 |
+
import tqdm
|
21 |
+
import subprocess
|
22 |
+
from multiprocessing import Process
|
23 |
+
import shutil
|
24 |
+
|
25 |
+
paths = []
|
26 |
+
|
27 |
+
|
28 |
+
def gather_video_paths(input_dir, output_dir):
|
29 |
+
for video in sorted(os.listdir(input_dir)):
|
30 |
+
if video.endswith(".mp4"):
|
31 |
+
video_input = os.path.join(input_dir, video)
|
32 |
+
video_output = os.path.join(output_dir, video)
|
33 |
+
if os.path.isfile(video_output):
|
34 |
+
continue
|
35 |
+
paths.append((video_input, video_output))
|
36 |
+
elif os.path.isdir(os.path.join(input_dir, video)):
|
37 |
+
gather_video_paths(os.path.join(input_dir, video), os.path.join(output_dir, video))
|
38 |
+
|
39 |
+
|
40 |
+
class FaceDetector:
|
41 |
+
def __init__(self, resolution: int = 512, device: str = "cpu"):
|
42 |
+
self.image_processor = ImageProcessor(resolution, "fix_mask", device)
|
43 |
+
|
44 |
+
def affine_transform_video(self, video_path):
|
45 |
+
video_frames = read_video(video_path, change_fps=False)
|
46 |
+
results = []
|
47 |
+
for frame in video_frames:
|
48 |
+
frame, _, _ = self.image_processor.affine_transform(frame)
|
49 |
+
results.append(frame)
|
50 |
+
results = torch.stack(results)
|
51 |
+
|
52 |
+
results = rearrange(results, "f c h w -> f h w c").numpy()
|
53 |
+
return results
|
54 |
+
|
55 |
+
def close(self):
|
56 |
+
self.image_processor.close()
|
57 |
+
|
58 |
+
|
59 |
+
def combine_video_audio(video_frames, video_input_path, video_output_path, process_temp_dir):
|
60 |
+
video_name = os.path.basename(video_input_path)[:-4]
|
61 |
+
audio_temp = os.path.join(process_temp_dir, f"{video_name}_temp.wav")
|
62 |
+
video_temp = os.path.join(process_temp_dir, f"{video_name}_temp.mp4")
|
63 |
+
|
64 |
+
write_video(video_temp, video_frames, fps=25)
|
65 |
+
|
66 |
+
command = f"ffmpeg -y -loglevel error -i {video_input_path} -q:a 0 -map a {audio_temp}"
|
67 |
+
subprocess.run(command, shell=True)
|
68 |
+
|
69 |
+
os.makedirs(os.path.dirname(video_output_path), exist_ok=True)
|
70 |
+
command = f"ffmpeg -y -loglevel error -i {video_temp} -i {audio_temp} -c:v libx264 -c:a aac -map 0:v -map 1:a -q:v 0 -q:a 0 {video_output_path}"
|
71 |
+
subprocess.run(command, shell=True)
|
72 |
+
|
73 |
+
os.remove(audio_temp)
|
74 |
+
os.remove(video_temp)
|
75 |
+
|
76 |
+
|
77 |
+
def func(paths, process_temp_dir, device_id, resolution):
|
78 |
+
os.makedirs(process_temp_dir, exist_ok=True)
|
79 |
+
face_detector = FaceDetector(resolution, f"cuda:{device_id}")
|
80 |
+
|
81 |
+
for video_input, video_output in paths:
|
82 |
+
if os.path.isfile(video_output):
|
83 |
+
continue
|
84 |
+
try:
|
85 |
+
video_frames = face_detector.affine_transform_video(video_input)
|
86 |
+
except Exception as e: # Handle the exception of face not detcted
|
87 |
+
print(f"Exception: {e} - {video_input}")
|
88 |
+
continue
|
89 |
+
|
90 |
+
os.makedirs(os.path.dirname(video_output), exist_ok=True)
|
91 |
+
combine_video_audio(video_frames, video_input, video_output, process_temp_dir)
|
92 |
+
print(f"Saved: {video_output}")
|
93 |
+
|
94 |
+
face_detector.close()
|
95 |
+
|
96 |
+
|
97 |
+
def split(a, n):
|
98 |
+
k, m = divmod(len(a), n)
|
99 |
+
return (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n))
|
100 |
+
|
101 |
+
|
102 |
+
def affine_transform_multi_gpus(input_dir, output_dir, temp_dir, resolution, num_workers):
|
103 |
+
print(f"Recursively gathering video paths of {input_dir} ...")
|
104 |
+
gather_video_paths(input_dir, output_dir)
|
105 |
+
num_devices = torch.cuda.device_count()
|
106 |
+
if num_devices == 0:
|
107 |
+
raise RuntimeError("No GPUs found")
|
108 |
+
|
109 |
+
if os.path.exists(temp_dir):
|
110 |
+
shutil.rmtree(temp_dir)
|
111 |
+
os.makedirs(temp_dir, exist_ok=True)
|
112 |
+
|
113 |
+
split_paths = list(split(paths, num_workers * num_devices))
|
114 |
+
|
115 |
+
processes = []
|
116 |
+
|
117 |
+
for i in range(num_devices):
|
118 |
+
for j in range(num_workers):
|
119 |
+
process_index = i * num_workers + j
|
120 |
+
process = Process(
|
121 |
+
target=func, args=(split_paths[process_index], os.path.join(temp_dir, f"process_{i}"), i, resolution)
|
122 |
+
)
|
123 |
+
process.start()
|
124 |
+
processes.append(process)
|
125 |
+
|
126 |
+
for process in processes:
|
127 |
+
process.join()
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/avatars/resampled/train"
|
132 |
+
output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/avatars/affine_transformed/train"
|
133 |
+
temp_dir = "temp"
|
134 |
+
resolution = 256
|
135 |
+
num_workers = 10 # How many processes per device
|
136 |
+
|
137 |
+
affine_transform_multi_gpus(input_dir, output_dir, temp_dir, resolution, num_workers)
|