import torch.nn as nn from einops import rearrange class FiLMLayer(nn.Module): def __init__(self,channels,conditional_dim=256, apply_dim = 1): super().__init__() self.alpha = nn.Linear(conditional_dim,channels) self.beta = nn.Linear(conditional_dim,channels) self.apply_dim = apply_dim def forward(self,x,condition): alpha = self.alpha(condition) beta = self.beta(condition) input = x if self.apply_dim != 1: input = input.transpose(1,-1) alpha = rearrange(alpha,"b d -> b d"+" 1"*(x.dim()-alpha.dim())) beta = rearrange(beta,"b d -> b d"+" 1"*(x.dim()-beta.dim())) out = alpha*input+beta if self.apply_dim != 1: out = out.transpose(1,-1) return out