Update modeling_zhinao.py
#2
by
neofung
- opened
- modeling_zhinao.py +12 -13
modeling_zhinao.py
CHANGED
@@ -748,6 +748,17 @@ class ZhinaoForCausalLM(ZhinaoPreTrainedModel):
|
|
748 |
|
749 |
def __init__(self, config):
|
750 |
super().__init__(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
751 |
self.model = ZhinaoModel(config)
|
752 |
self.vocab_size = config.vocab_size
|
753 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
@@ -761,19 +772,7 @@ class ZhinaoForCausalLM(ZhinaoPreTrainedModel):
|
|
761 |
if config.fp16:
|
762 |
self.model.half()
|
763 |
self.lm_head.half()
|
764 |
-
self.linear.half()
|
765 |
-
|
766 |
-
if config.use_flash_attn == "auto":
|
767 |
-
if flash_attn_varlen_func:
|
768 |
-
if config.bf16 or config.fp16:
|
769 |
-
logger.warn("Try importing flash-attention.")
|
770 |
-
config.use_flash_attn = True
|
771 |
-
else:
|
772 |
-
config.use_flash_attn = False
|
773 |
-
logger.warn("Flash attention will be disabled because it does NOT support fp32.")
|
774 |
-
else:
|
775 |
-
config.use_flash_attn = False
|
776 |
-
logger.warn("Please install FlashAttention first, " "e.g., with pip install flash-attn")
|
777 |
|
778 |
self.post_init()
|
779 |
|
|
|
748 |
|
749 |
def __init__(self, config):
|
750 |
super().__init__(config)
|
751 |
+
if config.use_flash_attn == "auto":
|
752 |
+
if flash_attn_varlen_func:
|
753 |
+
if config.bf16 or config.fp16:
|
754 |
+
logger.warn("Try importing flash-attention.")
|
755 |
+
config.use_flash_attn = True
|
756 |
+
else:
|
757 |
+
config.use_flash_attn = False
|
758 |
+
logger.warn("Flash attention will be disabled because it does NOT support fp32.")
|
759 |
+
else:
|
760 |
+
config.use_flash_attn = False
|
761 |
+
logger.warn("Please install FlashAttention first, " "e.g., with pip install flash-attn")
|
762 |
self.model = ZhinaoModel(config)
|
763 |
self.vocab_size = config.vocab_size
|
764 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
772 |
if config.fp16:
|
773 |
self.model.half()
|
774 |
self.lm_head.half()
|
775 |
+
self.linear.half()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
776 |
|
777 |
self.post_init()
|
778 |
|