duzx16
commited on
Commit
·
0deb1dd
1
Parent(s):
12c8049
Fix default dtype for classification head
Browse files- modeling_chatglm.py +1 -1
modeling_chatglm.py
CHANGED
@@ -1139,7 +1139,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1139 |
self.num_labels = config.num_labels
|
1140 |
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1141 |
|
1142 |
-
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=
|
1143 |
if config.classifier_dropout is not None:
|
1144 |
self.dropout = nn.Dropout(config.classifier_dropout)
|
1145 |
else:
|
|
|
1139 |
self.num_labels = config.num_labels
|
1140 |
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1141 |
|
1142 |
+
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=config.torch_dtype)
|
1143 |
if config.classifier_dropout is not None:
|
1144 |
self.dropout = nn.Dropout(config.classifier_dropout)
|
1145 |
else:
|