Spaces:
Runtime error
Runtime error
first commit
Browse files- lib/attention.py +2 -2
- lib/ddpm_multi.py +2 -2
- lib/openaimodel.py +2 -2
- lib/util.py +2 -2
lib/attention.py
CHANGED
@@ -16,7 +16,7 @@ from torch import nn, einsum
|
|
16 |
from einops import rearrange, repeat
|
17 |
from typing import Optional, Any
|
18 |
|
19 |
-
from utils import checkpoint
|
20 |
|
21 |
try:
|
22 |
import xformers
|
@@ -351,4 +351,4 @@ class SpatialTransformer(nn.Module):
|
|
351 |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
352 |
if not self.use_linear:
|
353 |
x = self.proj_out(x)
|
354 |
-
return x + x_in
|
|
|
16 |
from einops import rearrange, repeat
|
17 |
from typing import Optional, Any
|
18 |
|
19 |
+
from ..utils import checkpoint
|
20 |
|
21 |
try:
|
22 |
import xformers
|
|
|
351 |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
352 |
if not self.use_linear:
|
353 |
x = self.proj_out(x)
|
354 |
+
return x + x_in
|
lib/ddpm_multi.py
CHANGED
@@ -30,7 +30,7 @@ from torchvision.utils import make_grid
|
|
30 |
from pytorch_lightning.utilities.distributed import rank_zero_only
|
31 |
from omegaconf import ListConfig
|
32 |
|
33 |
-
from utils import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
34 |
from lib.distributions import normal_kl, DiagonalGaussianDistribution
|
35 |
from lib.autoencoder import IdentityFirstStage, AutoencoderKL
|
36 |
from lib.util import make_beta_schedule, extract_into_tensor, noise_like
|
@@ -1798,4 +1798,4 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
|
1798 |
def log_images(self, *args, **kwargs):
|
1799 |
log = super().log_images(*args, **kwargs)
|
1800 |
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
1801 |
-
return log
|
|
|
30 |
from pytorch_lightning.utilities.distributed import rank_zero_only
|
31 |
from omegaconf import ListConfig
|
32 |
|
33 |
+
from ..utils import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
34 |
from lib.distributions import normal_kl, DiagonalGaussianDistribution
|
35 |
from lib.autoencoder import IdentityFirstStage, AutoencoderKL
|
36 |
from lib.util import make_beta_schedule, extract_into_tensor, noise_like
|
|
|
1798 |
def log_images(self, *args, **kwargs):
|
1799 |
log = super().log_images(*args, **kwargs)
|
1800 |
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
1801 |
+
return log
|
lib/openaimodel.py
CHANGED
@@ -26,7 +26,7 @@ from lib.util import (
|
|
26 |
timestep_embedding,
|
27 |
)
|
28 |
from attention import SpatialTransformer
|
29 |
-
from utils import exists
|
30 |
|
31 |
|
32 |
# dummy replace
|
@@ -793,4 +793,4 @@ class UNetModel(nn.Module):
|
|
793 |
if self.predict_codebook_ids:
|
794 |
return self.id_predictor(h)
|
795 |
else:
|
796 |
-
return self.out(h)
|
|
|
26 |
timestep_embedding,
|
27 |
)
|
28 |
from attention import SpatialTransformer
|
29 |
+
from ..utils import exists
|
30 |
|
31 |
|
32 |
# dummy replace
|
|
|
793 |
if self.predict_codebook_ids:
|
794 |
return self.id_predictor(h)
|
795 |
else:
|
796 |
+
return self.out(h)
|
lib/util.py
CHANGED
@@ -25,7 +25,7 @@ import torch.nn as nn
|
|
25 |
import numpy as np
|
26 |
from einops import repeat
|
27 |
|
28 |
-
from utils import instantiate_from_config
|
29 |
|
30 |
|
31 |
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
@@ -277,4 +277,4 @@ class HybridConditioner(nn.Module):
|
|
277 |
def noise_like(shape, device, repeat=False):
|
278 |
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
279 |
noise = lambda: torch.randn(shape, device=device)
|
280 |
-
return repeat_noise() if repeat else noise()
|
|
|
25 |
import numpy as np
|
26 |
from einops import repeat
|
27 |
|
28 |
+
from ..utils import instantiate_from_config
|
29 |
|
30 |
|
31 |
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
|
|
277 |
def noise_like(shape, device, repeat=False):
|
278 |
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
279 |
noise = lambda: torch.randn(shape, device=device)
|
280 |
+
return repeat_noise() if repeat else noise()
|