File size: 1,071 Bytes
9bb0389
 
5045193
9bb0389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from torch import nn
from MyModel.partViT import patchNPositionalEmbeddingMaker,transformerEncoderBlock

class ViT(nn.Module):
  def __init__(self,inChannels,outChannels,patchSize,imgSize, hiddenLayer,numHeads,MLPdropOut,numTransformLayers,numClasses,embeddingDropOut=0.1,attnDropOut=0):
    super().__init__()
    self.EmbeddingMaker = patchNPositionalEmbeddingMaker(inChannels,outChannels,patchSize,imgSize)
    # self.transformerEncodingBlock = transformerEncoderBlock(outChannels,hiddenLayer,numHeads,MLPdropOut,attnDropOut)
    self.embeddingDrop = nn.Dropout(embeddingDropOut)
    self.TransformEncoder = nn.Sequential(*[transformerEncoderBlock(outChannels,hiddenLayer,numHeads,MLPdropOut,attnDropOut) for _ in range(numTransformLayers)])
    self.Classifier = nn.Sequential(nn.LayerNorm(normalized_shape=outChannels),
                                    nn.Linear(outChannels,numClasses))
  def forward(self,x):
    x = self.EmbeddingMaker(x)
    x = self.embeddingDrop(x)
    x = self.TransformEncoder(x)
    x = self.Classifier(x[:,0])
    return x