kevinwang676's picture
Upload folder using huggingface_hub
6755a2d verified
raw
history blame contribute delete
622 Bytes
import torch
import torch.nn as nn
from transformers import AutoModel
class Bert(nn.Module):
def __init__(self, args):
super(Bert, self).__init__()
self.args = args
self.bert = AutoModel.from_pretrained('./hfl/chinese-bert-wwm-ext')
# self.bert = AutoModel.from_pretrained(args.ENCODER)
#self.bert = AutoModel.from_pretrained('bert-base-chinese')
def forward(self, x):
# y = torch.ones((int(self.args.batch_size/4), self.args.max_textLen, self.args.textFea_dim),device=x.device)
y = self.bert(x, return_dict=True).last_hidden_state
return y