File size: 786 Bytes
7596274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch.nn as nn
from einops import rearrange

class FiLMLayer(nn.Module):
    def __init__(self,channels,conditional_dim=256, apply_dim = 1):
        super().__init__()
        self.alpha = nn.Linear(conditional_dim,channels)
        self.beta = nn.Linear(conditional_dim,channels)
        self.apply_dim = apply_dim
    def forward(self,x,condition):
        alpha = self.alpha(condition)
        beta = self.beta(condition)
        input = x
        if self.apply_dim != 1:
            input = input.transpose(1,-1)
        alpha = rearrange(alpha,"b d -> b d"+" 1"*(x.dim()-alpha.dim()))
        beta = rearrange(beta,"b d -> b d"+" 1"*(x.dim()-beta.dim()))
        out = alpha*input+beta
        if self.apply_dim != 1:
            out = out.transpose(1,-1)
        return out