File size: 983 Bytes
a5bf838
 
 
1d7da3b
dd00657
 
 
1d7da3b
48f4c05
 
4c500f5
147241c
 
48f4c05
4c500f5
147241c
48f4c05
dd00657
fe0e69f
 
2824423
 
 
a5bf838
1d7da3b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import logging


def validate_config(cfg):
    if cfg.load_4bit:
        raise ValueError("cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq")

    if cfg.adapter == "qlora":
        if cfg.merge_lora:
            # can't merge qlora if loaded in 8bit or 4bit
            assert cfg.load_in_8bit is not True
            assert cfg.gptq is not True
            assert cfg.load_in_4bit is not True
        else:
            assert cfg.load_in_8bit is not True
            assert cfg.gptq is not True
            assert cfg.load_in_4bit is True

    if not cfg.load_in_8bit and cfg.adapter == "lora":
        logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
        
    if cfg.trust_remote_code:
        logging.warning("`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model.")

    # TODO
    # MPT 7b
    # https://github.com/facebookresearch/bitsandbytes/issues/25
    # no 8bit adamw w bf16