Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,735 Bytes
81d8e7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from typing import Tuple
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from diffusers.models.modeling_utils import ModelMixin
import torch
class Conv2d(nn.Conv2d):
def forward(self, x):
x = super().forward(x)
return x
class DepthGuider(ModelMixin):
def __init__(
self,
conditioning_embedding_channels: int=4,
conditioning_channels: int = 1,
block_out_channels: Tuple[int] = (16, 32, 64, 128),
):
super().__init__()
self.conv_in = Conv2d(
conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(
Conv2d(channel_in, channel_in, kernel_size=3, padding=1)
)
self.blocks.append(
Conv2d(
channel_in, channel_out, kernel_size=3, padding=1, stride=2
)
)
self.conv_out = Conv2d(
block_out_channels[-1],
conditioning_embedding_channels,
kernel_size=3,
padding=1,
)
def forward(self, conditioning):
conditioning = F.interpolate(conditioning, size=(512,512), mode = 'bilinear', align_corners=True)
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding |