File size: 6,978 Bytes
92740f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64fc4c7
 
 
 
 
92740f3
 
 
 
 
 
 
 
 
 
 
64fc4c7
92740f3
 
 
 
 
 
 
 
 
 
 
 
 
 
64fc4c7
92740f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Copyright (c) 2024 NVIDIA CORPORATION. 
#   Licensed under the MIT license.

# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
#   LICENSE is in incl_licenses directory.

import sys 
sys.path.append('../')

from typing import Optional
from copy import deepcopy

from transformers import AutoModelForCausalLM, AutoTokenizer
from ms_clap.src.CLAPWrapper import CLAPWrapper

import torch
from torch import nn

try:
    from .flamingo import Flamingo
    from .flamingo_lm import FlamingoLMMixin
    from .utils import extend_instance
except:
    from flamingo import Flamingo
    from flamingo_lm import FlamingoLMMixin
    from utils import extend_instance


class CLAP(nn.Module):
    def __init__(self, clap_config):
        super(CLAP, self).__init__()
        self.method = clap_config["method"]

        if torch.cuda.is_available():
            device = 'cuda:0'
        else:
            device = 'cpu'

        if self.method == 'laion-clap':
            # https://github.com/LAION-AI/CLAP
            if clap_config["model_name"] in ['630k-audioset-best', '630k-best', '630k-audioset-fusion-best', '630k-fusion-best']:
                amodel = 'HTSAT-tiny'
            elif clap_config["model_name"] in ['music_speech_audioset_epoch_15_esc_89.98']:
                amodel = 'HTSAT-base'
            else:
                raise NotImplementedError
        
            enable_fusion = 'fusion' in clap_config["model_name"].lower()
            self.laion_clap = CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=device)
            self.laion_clap.load_ckpt(ckpt=clap_config["checkpoint"])
            
            for param in self.laion_clap.parameters():
                param.requires_grad = False
            self.laion_clap.eval()

            print('loaded laion-clap model: {}'.format(clap_config["checkpoint"]))
    
        elif self.method == 'microsoft-clap':
            # https://github.com/microsoft/CLAP
            self.ms_clap = CLAPWrapper(
                clap_config["checkpoint"], 
                config_root=clap_config["config_root"],
                version=clap_config['model_name'], 
                use_cuda=torch.cuda.is_available()
            )
            
            if clap_config['model_name'] in ['2022', '2023']:
                for param in self.ms_clap.clap.parameters():
                    param.requires_grad = False
                self.ms_clap.clap.eval()
            else:
                for param in self.ms_clap.clapcap.parameters():
                    param.requires_grad = False
                self.ms_clap.clapcap.eval()

            print('loaded microsoft-clap model: {}'.format(clap_config["checkpoint"]))
        
        else:
            raise NotImplementedError

    def forward(self, audio_clips):
        
        if len(audio_clips.shape) == 2:
            audio_clips = audio_clips.unsqueeze(0)
        assert len(audio_clips.shape) == 3

        audio_embeds = []
        for x in audio_clips:
            if self.method == 'laion-clap':
                audio_embed = self.laion_clap.get_audio_embedding_from_data(x=x, use_tensor=True)
            elif self.method == 'microsoft-clap':
                audio_embed = self.ms_clap.get_audio_embeddings_from_clips(x)
                
            audio_embeds.append(audio_embed)

        audio_embeds = torch.stack(audio_embeds, dim=0)
        audio_embeds.requires_grad = False

        return audio_embeds


def create_model_and_transforms(
    clap_config: dict,
    lang_encoder_path: str,
    tokenizer_path: str,
    audio_transformer_kwargs: dict,
    cross_attn_every_n_layers: int = 1,
    use_local_files: bool = False,
    decoder_layers_attr_name: str = None,
    freeze_lm_embeddings: bool = False,
    unfreeze_full_lm: bool = False,
    cache_dir: Optional[str] = None,
    **flamingo_kwargs,
):
    clap = CLAP(clap_config)

    text_tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_path,
        local_files_only=use_local_files,
        trust_remote_code=True,
        cache_dir=cache_dir,
    )
    text_tokenizer.add_special_tokens(
        {"additional_special_tokens": ["<audio>", "<|endofchunk|>"]}
    )
    if text_tokenizer.pad_token is None:
        text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
    if text_tokenizer.sep_token is None:
        text_tokenizer.add_special_tokens({"sep_token": "<SEP>"})

    lang_encoder = AutoModelForCausalLM.from_pretrained(
        lang_encoder_path,
        local_files_only=use_local_files,
        trust_remote_code=True,
        cache_dir=cache_dir,
    )

    extend_instance(lang_encoder, FlamingoLMMixin)

    if decoder_layers_attr_name is None:
        decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
    lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
    lang_encoder.resize_token_embeddings(len(text_tokenizer))

    unfreeze_clap = False 
        
    model = Flamingo(
        clap,
        unfreeze_clap,
        lang_encoder,
        text_tokenizer.encode("<|endofchunk|>")[-1],
        text_tokenizer.encode("<audio>")[-1],
        text_tokenizer.sep_token_id,
        audio_embed_dim=clap_config["audio_embed_dim"],
        audio_transformer_kwargs=audio_transformer_kwargs, 
        cross_attn_every_n_layers=cross_attn_every_n_layers,
        **flamingo_kwargs,
    )

    model.requires_grad_(False)
    assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0

    model.audio_transformer.requires_grad_(True)
    model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
    if not freeze_lm_embeddings:
        model.lang_encoder.get_input_embeddings().requires_grad_(True)
    
    if unfreeze_full_lm:
        model.lang_encoder.requires_grad_(True)
    
    if unfreeze_clap:
        model.clap.requires_grad_(True)

    print("Flamingo model initialized with {:,} trainable parameters (audio transformer has {:,}, LM has {:,})".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad),
        sum(p.numel() for p in model.audio_transformer.parameters() if p.requires_grad),
        sum(p.numel() for p in model.lang_encoder.parameters() if p.requires_grad)
    ))

    return model, text_tokenizer


def _infer_decoder_layers_attr_name(model):
    for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
        if k.lower() in model.__class__.__name__.lower():
            return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]

    raise ValueError(
        f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
    )


__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
    "opt": "model.decoder.layers",
    "gptj": "transformer.h",
    "gpt-j": "transformer.h",
    "pythia": "gpt_neox.layers",
    "llama": "model.layers",
    "gptneoxforcausallm": "gpt_neox.layers",
    "mpt": "transformer.blocks",
    "mosaicgpt": "transformer.blocks",
}