Xianbin
Update instruct model to latest weights
881b143
raw
history blame
170 Bytes
from torch import nn
FC_CLASS_REGISTRY = {"torch": nn.Linear}
try:
import transformer_engine.pytorch as te
FC_CLASS_REGISTRY["te"] = te.Linear
except:
pass