Spaces:
Running
Running
File size: 1,891 Bytes
89dc200 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# -*- encoding: utf-8 -*-
'''
@File : sr_group.py
@Time : 2022/04/02 01:17:21
@Author : Ming Ding
@Contact : [email protected]
'''
# here put the import lib
import os
import sys
import math
import random
import numpy as np
import torch
import torch.nn.functional as F
from SwissArmyTransformer.resources import auto_create
from .direct_sr import DirectSuperResolution
from .iterative_sr import IterativeSuperResolution
class SRGroup:
def __init__(self, args, home_path=None,):
dsr_path = auto_create('cogview2-dsr', path=home_path)
itersr_path = auto_create('cogview2-itersr', path=home_path)
dsr = DirectSuperResolution(args, dsr_path)
itersr = IterativeSuperResolution(args, itersr_path, shared_transformer=dsr.model.transformer)
self.dsr = dsr
self.itersr = itersr
def sr_base(self, img_tokens, txt_tokens):
assert img_tokens.shape[-1] == 400 and len(img_tokens.shape) == 2
batch_size = img_tokens.shape[0]
txt_len = txt_tokens.shape[-1]
if len(txt_tokens.shape) == 1:
txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
sred_tokens = self.dsr(txt_tokens, img_tokens)
iter_tokens = self.itersr(txt_tokens, sred_tokens[:, -3600:].clone())
return iter_tokens[-batch_size:]
# def sr_patch(self, img_tokens, txt_tokens):
# assert img_tokens.shape[-1] == 3600 and len(img_tokens.shape) == 2
# batch_size = img_tokens.shape[0] * 9
# txt_len = txt_tokens.shape[-1]
# if len(txt_tokens.shape) == 1:
# txt_tokens = txt_tokens.unsqueeze(0).expand(batch_size, txt_len)
# img_tokens = img_tokens.view(img_tokens.shape[0], 3, 20, 3, 20).permute(0, 1, 3, 2, 4).reshape(batch_size, 400)
# iter_tokens = self.sr_base(img_tokens, txt_tokens)
# return iter_tokens |