File size: 364 Bytes
4b8b024
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from transformers import PreTrainedModel
from .configuration_resnet import ResnetConfig


class ResnetModel(PreTrainedModel):
    config_class = ResnetConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = torch.nn.Linear(5, 10)

    def forward(self, tensor):
        return self.model.forward_features(tensor)