Spaces:
Running
Running
# -*- encoding: utf-8 -*- | |
''' | |
@File : sr_group.py | |
@Time : 2022/04/02 01:17:21 | |
@Author : Ming Ding | |
@Contact : [email protected] | |
''' | |
# here put the import lib | |
import os | |
import sys | |
import math | |
import random | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from SwissArmyTransformer.resources import auto_create | |
from .direct_sr import DirectSuperResolution | |
from .iterative_sr import IterativeSuperResolution | |
class SRGroup: | |
def __init__(self, args, home_path=None,): | |
dsr_path = auto_create('cogview2-dsr', path=home_path) | |
itersr_path = auto_create('cogview2-itersr', path=home_path) | |
dsr = DirectSuperResolution(args, dsr_path) | |
itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer) | |
self.dsr = dsr | |
self.itersr = itersr | |
def sr_base(self, img_tokens, txt_tokens): | |
assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2 | |
batch_size = img_tokens.shape[0] | |
txt_len = txt_tokens.shape[-1] | |
if len(txt_tokens.shape) == 1: | |
txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len) | |
sred_tokens = self.dsr(txt_tokens, img_tokens) | |
iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone()) | |
return iter_tokens[-batch_size:] | |
# def sr_patch(self, img_tokens, txt_tokens): | |
# assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2 | |
# batch_size = img_tokens.shape[0] * 9 | |
# txt_len = txt_tokens.shape[-1] | |
# if len(txt_tokens.shape) == 1: | |
# txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len) | |
# img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400) | |
# iter_tokens = self.sr_base(img_tokens, txt_tokens) | |
# return iter_tokens |