import numpy as np import torch import torch.nn as nn from transformers import MusicgenForConditionalGeneration, AutoModel, PretrainedConfig, PreTrainedModel class Im2Mu(nn.Module): def __init__(self, embed_dims=768, seq_len=64): super(Im2Mu, self).__init__() self.musicgen = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") self.muvis = AutoModel.from_pretrained("juliagsy/muvis", trust_remote_code=True).model.vit self.loss_ce = nn.CrossEntropyLoss(label_smoothing=0.1, ignore_index=-100) self.img_lin = nn.Linear(197, 256) def shift_right(self, input_ids): shifted_input_ids = torch.zeros_like(input_ids) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = 0 return shifted_input_ids def forward(self, img, wav): img_e = self.muvis(**img)["last_hidden_state"] img_embeds = self.musicgen.get_encoder()( inputs_embeds=img_e )["last_hidden_state"] img_embeds = img_embeds.permute(0, 2, 1) img_embeds = self.img_lin(img_embeds) img_embeds = img_embeds.permute(0, 2, 1) wav_tokens = self.musicgen.get_audio_encoder().encode( **wav, )["audio_codes"] wav_size = wav_tokens.size() wav_tokens = wav_tokens.view((wav_size[1] * wav_size[2], wav_size[-1])) wav_tokens = self.shift_right(wav_tokens) ret = self.musicgen( decoder_input_ids=wav_tokens, encoder_outputs=(img_embeds,), ) loss = self.loss_ce(ret.logits.view(-1, self.musicgen.config.audio_encoder.codebook_size), wav_tokens.view(-1)) return loss def generate(self, img, wav=None, guidance_scale=3, max_new_tokens=256, device="cpu"): img_embeds = self.muvis(**img)["last_hidden_state"] img_embeds = img_embeds.permute(0, 2, 1) img_embeds = self.img_lin(img_embeds) img_embeds = img_embeds.permute(0, 2, 1) img_embeds = self.musicgen.get_encoder()( inputs_embeds=img_embeds )["last_hidden_state"] if wav is not None: input_ids = self.musicgen.get_audio_encoder().encode( **wav, )["audio_codes"] wav_size = input_ids.size() input_ids = input_ids.view((wav_size[1] * wav_size[2], wav_size[-1])) input_ids = self.shift_right(input_ids) ret = self.musicgen.generate( decoder_input_ids=input_ids, # decoder_attention_mask=decoder_attention_mask, encoder_outputs=(img_embeds,), do_sample=True, guidance_scale=guidance_scale, max_new_tokens=256, ) else: input_ids = torch.zeros((4, 1)).long().to(device) decoder_attention_mask = torch.ones((img_embeds.size(0), 1)).long().to(device) ret = self.musicgen.generate( decoder_input_ids=input_ids, decoder_attention_mask=decoder_attention_mask, encoder_outputs=(img_embeds,), do_sample=True, guidance_scale=guidance_scale, max_new_tokens=max_new_tokens, ) return ret class ImagicConfig(PretrainedConfig): model_type = "imagic" def __init__( self, embed_dims=768, seq_len=64, **kwargs, ): self.embed_dims = embed_dims self.seq_len = seq_len super().__init__(**kwargs) class ImagicModel(PreTrainedModel): config_class = ImagicConfig def __init__(self, config): super().__init__(config) self.model = Im2Mu( embed_dims=config.embed_dims, seq_len=config.seq_len, ) def forward(self, img, wav): return self.model.forward(img, wav) def generate(self, img, wav=None, guidance_scale=3, device="cpu"): return self.model.generate(img, wav=wav, guidance_scale=guidance_scale, device=device)