File size: 7,484 Bytes
6045345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8746b70
 
 
 
6045345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a472e1
 
 
 
6045345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import logging
import os
from pathlib import Path
from typing import Optional, Tuple, TYPE_CHECKING

import torch
import transformers
from transformers import (
    AutoModelForCausalLM,
    LlamaForCausalLM,
    LlamaTokenizer,
    AutoTokenizer,
    PreTrainedModel,
)

from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN

if TYPE_CHECKING:
    from peft import PeftModel, PeftConfig
    from attrdict import AttrDefault
    from transformers import PreTrainedTokenizer


def load_model(
    base_model,
    base_model_config,
    model_type,
    tokenizer_type,
    cfg,
    adapter="lora",
    inference=False,
):
    # type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]

    # TODO refactor as a kwarg
    load_in_8bit = cfg.load_in_8bit
    tokenizer = None
    is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower()

    if is_llama_derived_model and cfg.flash_attention:
        if cfg.device not in ["mps", "cpu"] and inference is False:
            from axolotl.flash_attn import replace_llama_attn_with_flash_attn

            logging.info("patching with flash attention")
            replace_llama_attn_with_flash_attn()
    elif is_llama_derived_model and cfg.xformers_attention:
        from alpaca_lora_4bit.monkeypatch.llama_attn_hijack_xformers import hijack_llama_attention
        logging.info("patching with xformers attention")
        hijack_llama_attention()

    torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
    try:
        if cfg.load_4bit:
            from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
                replace_peft_model_with_int4_lora_model,
            )

            replace_peft_model_with_int4_lora_model()
        from peft import prepare_model_for_int8_training
    except Exception as e:
        logging.exception(e)
        raise e

    try:
        if cfg.load_4bit and is_llama_derived_model:
            from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
            from huggingface_hub import snapshot_download

            snapshot_download_kwargs = {}
            if cfg.base_model_ignore_patterns:
                snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
            cache_model_path = Path(snapshot_download(base_model, ** snapshot_download_kwargs))
            files = (
                list(cache_model_path.glob("*.pt"))
                + list(cache_model_path.glob("*.safetensors"))
                + list(cache_model_path.glob("*.bin"))
            )
            if len(files) > 0:
                model_path = str(files[0])
            else:
                logging.warning(
                    "unable to find a cached model file, this will likely fail..."
                )
                model_path = str(cache_model_path)
            model, tokenizer = load_llama_model_4bit_low_ram(
                base_model_config if base_model_config else base_model,
                model_path,
                device_map=cfg.device_map,
                groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
                is_v1_model=cfg.gptq_model_v1
                if cfg.gptq_model_v1 is not None
                else True,
            )
            load_in_8bit = False
        elif is_llama_derived_model:
            model = LlamaForCausalLM.from_pretrained(
                base_model,
                load_in_8bit=cfg.load_in_8bit,
                torch_dtype=torch_dtype,
                device_map=cfg.device_map,
            )
        else:
            model = getattr(transformers, model_type).from_pretrained(
                base_model,
                load_in_8bit=cfg.load_in_8bit,
                torch_dtype=torch_dtype,
                device_map=cfg.device_map,
            )
    except Exception as e:
        logging.error(
            "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
        )
        logging.exception(e)
        model = AutoModelForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=cfg.load_in_8bit,
            torch_dtype=torch_dtype,
            device_map=cfg.device_map,
        )

    if not tokenizer:
        try:
            if is_llama_derived_model:
                tokenizer = LlamaTokenizer.from_pretrained(model)
            else:
                tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
        except:
            tokenizer = AutoTokenizer.from_pretrained(base_model)

    logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
    logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
    logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
    logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")

    if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
        tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN

    if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        os.environ["TOKENIZERS_PARALLELISM"] = "false"

    if load_in_8bit and not cfg.load_4bit:
        logging.info("converting model w/ prepare_model_for_int8_training")
        model = prepare_model_for_int8_training(model)

    model, lora_config = load_adapter(model, cfg, adapter)

    if cfg.ddp:
        model.to(f"cuda:{cfg.local_rank}")

    if cfg.load_4bit:
        # Scales to half
        logging.info("Fitting 4bit scales and zeros to half")
        for n, m in model.named_modules():
            if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(
                type(m)
            ):
                if hasattr(m, "is_v1_model") and m.is_v1_model:
                    m.zeros = m.zeros.half()
                m.scales = m.scales.half()
                m.bias = m.bias.half()

    # TODO resume_from_checkpoint handling
    return model, tokenizer, lora_config


def load_adapter(model, cfg, adapter):
    # type: (PreTrainedModel, AttrDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]

    if adapter is None:
        return model, None
    if adapter == "lora":
        return load_lora(model, cfg)
    # TODO support Llama-Adapter once merged into peft https://github.com/huggingface/peft/pulls

    raise NotImplementedError(f"{adapter} peft adapter not available")


def load_lora(model, cfg):
    # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]

    from peft import (
        LoraConfig,
        get_peft_model,
        PeftModel,
    )

    lora_config = None

    if cfg.adapter == "lora":
        lora_config = LoraConfig(
            r=cfg.lora_r,
            lora_alpha=cfg.lora_alpha,
            target_modules=cfg.lora_target_modules,
            lora_dropout=cfg.lora_dropout,
            fan_in_fan_out=cfg.lora_fan_in_fan_out,
            bias="none",
            task_type="CAUSAL_LM",
        )

        if cfg.lora_model_dir:
            model = PeftModel.from_pretrained(
                model,
                cfg.lora_model_dir,
                device_map=cfg.device_map,
                torch_dtype=torch.float16,
            )
        else:
            model = get_peft_model(model, lora_config)

        model.print_trainable_parameters()

    return model, lora_config