Spaces:
Runtime error
Runtime error
File size: 1,071 Bytes
9bb0389 5045193 9bb0389 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from torch import nn
from MyModel.partViT import patchNPositionalEmbeddingMaker,transformerEncoderBlock
class ViT(nn.Module):
def __init__(self,inChannels,outChannels,patchSize,imgSize, hiddenLayer,numHeads,MLPdropOut,numTransformLayers,numClasses,embeddingDropOut=0.1,attnDropOut=0):
super().__init__()
self.EmbeddingMaker = patchNPositionalEmbeddingMaker(inChannels,outChannels,patchSize,imgSize)
# self.transformerEncodingBlock = transformerEncoderBlock(outChannels,hiddenLayer,numHeads,MLPdropOut,attnDropOut)
self.embeddingDrop = nn.Dropout(embeddingDropOut)
self.TransformEncoder = nn.Sequential(*[transformerEncoderBlock(outChannels,hiddenLayer,numHeads,MLPdropOut,attnDropOut) for _ in range(numTransformLayers)])
self.Classifier = nn.Sequential(nn.LayerNorm(normalized_shape=outChannels),
nn.Linear(outChannels,numClasses))
def forward(self,x):
x = self.EmbeddingMaker(x)
x = self.embeddingDrop(x)
x = self.TransformEncoder(x)
x = self.Classifier(x[:,0])
return x
|