Text Generation
Transformers
PyTorch
mpt
Composer
MosaicML
llm-foundry
custom_code
text-generation-inference

error while generating

#10
by roy650 - opened

When generating with:

inputs = tokenizer(prompt, return_tensors="pt").to(device='cuda:0').input_ids
res = model.generate(inputs, max_new_tokens=20, do_sample=True, temperature=0.7, top_k=0, top_p=40) 

I get:

Traceback (most recent call last):
  File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0--d6252949da17ceb5f3a278a70250af13-1af5134066c618146d2cd009138944a0-6fb21260a873f1a3458b67752ca56f63-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', True, 128, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False), (True, False), (True, False), (False, False), (False, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 937, in build_triton_ir
    generator.visit(fn.parse())
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 183, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 252, in visit_FunctionDef
    has_ret = self.visit_compound_statement(node.body)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
    self.last_ret_type = self.visit(stmt)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 678, in visit_For
    self.visit_compound_statement(node.body)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
    self.last_ret_type = self.visit(stmt)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 319, in visit_AugAssign
    self.visit(assign)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 301, in visit_Assign
    values = self.visit(node.value)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 339, in visit_BinOp
    rhs = self.visit(node.right)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
    return super().visit(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 797, in visit_Call
    return fn(*args, _builder=self.builder, **kws)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/impl/base.py", line 22, in wrapper
    return fn(*args, **kwargs)
TypeError: dot() got an unexpected keyword argument 'trans_b'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_791836/924310910.py", line 2, in <module>
    res = model.generate(inputs, max_new_tokens=20) #, do_sample=True, temperature=0.7, top_k=0, top_p=40) #, return_full_text=False)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/generation/utils.py", line 1437, in generate
    return self.greedy_search(
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/generation/utils.py", line 2248, in greedy_search
    outputs = self(
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/u/user983/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-storywriter/26f3be4e1d8bfe3d4313408401582f960f79ade7/modeling_mpt.py", line 237, in forward
    outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/u/user983/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-storywriter/26f3be4e1d8bfe3d4313408401582f960f79ade7/modeling_mpt.py", line 183, in forward
    (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/u/user983/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-storywriter/26f3be4e1d8bfe3d4313408401582f960f79ade7/blocks.py", line 36, in forward
    (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/u/user983/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-storywriter/26f3be4e1d8bfe3d4313408401582f960f79ade7/attention.py", line 171, in forward
    (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
  File "/u/user983/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-storywriter/26f3be4e1d8bfe3d4313408401582f960f79ade7/attention.py", line 111, in triton_flash_attn_fn
    attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/flash_attn/flash_attn_triton.py", line 810, in forward
    o, lse, ctx.softmax_scale = _flash_attn_forward(
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/flash_attn/flash_attn_triton.py", line 623, in _flash_attn_forward
    _fwd_kernel[grid](
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 199, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 41, in _fwd_kernel
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 1620, in compile
    next_module = compile(module)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 1549, in <lambda>
    lambda src: ast_to_ttir(src, signature, configs[0], constants)),
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 962, in ast_to_ttir
    mod, _ = build_triton_ir(fn, signature, specialization, constants)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 942, in build_triton_ir
    raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 78:24:
def _fwd_kernel(
    Q, K, V, Bias, Out,
    Lse, TMP,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
    softmax_scale,
    stride_qb, stride_qh, stride_qm,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
    stride_bb, stride_bh, stride_bm,
    stride_ob, stride_oh, stride_om,
    nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
    CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
    BIAS_TYPE: tl.constexpr,
    IS_CAUSAL: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hb = tl.program_id(1)
    off_b = off_hb // nheads
    off_h = off_hb % nheads
    # off_b = tl.program_id(1)
    # off_h = tl.program_id(2)
    # off_hb = off_b * nheads + off_h
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # Initialize pointers to Q, K, V
    # Adding parenthesis around indexing might use int32 math instead of int64 math?
    # https://github.com/openai/triton/issues/741
    # I'm seeing a tiny bit of difference (5-7us)
    q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
    k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
    v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
    if BIAS_TYPE == 'vector':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
    elif BIAS_TYPE == 'matrix':
        b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
    # initialize pointer to m and l
    t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
    lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
    # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
    # tl.load(q_ptrs), we get the wrong output!
    if EVEN_M & EVEN_N:
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs)
        else:
            q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
    else:
        if EVEN_HEADDIM:
            q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
        else:
            q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                        other=0.0)
    # loop over k, v and update accumulator
    end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
        else:
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                            other=0.0)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)
                        ^

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2057, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1288, in structured_traceback
    return FormattedTB.structured_traceback(
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1177, in structured_traceback
    return VerboseTB.structured_traceback(
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1049, in structured_traceback
    formatted_exceptions += self.format_exception_as_a_whole(etype, evalue, etb, lines_of_context,
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/ultratb.py", line 935, in format_exception_as_a_whole
    self.get_records(etb, number_of_lines_of_context, tb_offset) if etb else []
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1003, in get_records
    lines, first = inspect.getsourcelines(etb.tb_frame)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/inspect.py", line 1129, in getsourcelines
    lines, lnum = findsource(object)
  File "/u/user983/miniconda3/envs/llama/lib/python3.10/inspect.py", line 958, in findsource
    raise OSError('could not get source code')
OSError: could not get source code

It happens only when using triton attention (torch attention works fine), using:

flash-attn==1.0.4
triton==2.0.0

Any ideas?

Even after I align to the above setup I still get the same error. See my freeze below.
Also - note this issue from triton: https://github.com/openai/triton/issues/1054 where one of the maintainers hints that it might not work on a100...?

aiofiles==22.1.0
aiosqlite==0.19.0
anyio==3.6.2
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
attrs==23.1.0
Babel==2.12.1
backcall==0.2.0
beautifulsoup4==4.12.2
bleach==6.0.0
certifi==2023.5.7
cffi==1.15.1
charset-normalizer==3.1.0
cmake==3.26.3
coloredlogs==15.0.1
comm==0.1.3
datasets==2.10.1
debugpy==1.6.7
decorator==5.1.1
defusedxml==0.7.1
einops==0.5.0
executing==1.2.0
fastjsonschema==2.16.3
filelock==3.12.0
flash-attn==1.0.3.post0
flatbuffers==23.3.3
fqdn==1.5.1
fsspec==2023.5.0
huggingface-hub==0.14.1
idna==3.4
install==1.3.5
ipykernel==6.23.0
ipython==8.13.2
ipython-genutils==0.2.0
ipywidgets==8.0.6
isoduration==20.11.0
jedi==0.18.2
Jinja2==3.1.2
json5==0.9.11
jsonpointer==2.3
jsonschema==4.17.3
jupyter-events==0.6.3
jupyter-ydoc==0.2.4
jupyter_client==8.2.0
jupyter_core==5.3.0
jupyter_server==2.5.0
jupyter_server_fileid==0.9.0
jupyter_server_terminals==0.4.4
jupyter_server_ydoc==0.8.0
jupyterlab==3.6.3
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.7
jupyterlab_server==2.22.1
llm-foundry==1.0
MarkupSafe==2.1.2
matplotlib-inline==0.1.6
mistune==2.0.5
mosaicml==0.14.1
mosaicml-cli==0.4.0a5
nbclassic==1.0.0
nbclient==0.7.4
nbconvert==7.4.0
nbformat==5.8.0
nest-asyncio==1.5.6
notebook==6.5.4
notebook_shim==0.2.3
numpy==1.24.3
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
omegaconf==2.3.0
onnx==1.13.1
onnxruntime==1.14.1
packaging==23.1
pandocfilters==1.5.0
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
platformdirs==3.5.0
prometheus-client==0.16.0
prompt-toolkit==3.0.38
protobuf==4.23.0
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser==2.21
Pygments==2.15.1
pynvml==11.5.0
pyrsistent==0.19.3
python-dateutil==2.8.2
python-json-logger==2.0.7
PyYAML==6.0
pyzmq==25.0.2
regex==2023.5.5
requests==2.30.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
Send2Trash==1.8.2
sentencepiece==0.1.97
six==1.16.0
slack-sdk==3.21.3
sniffio==1.3.0
soupsieve==2.4.1
stack-data==0.6.2
sympy==1.12rc1
terminado==0.17.1
tinycss2==1.2.1
tokenizers==0.13.3
tomli==2.0.1
torch==1.13.1
tornado==6.3.1
tqdm==4.65.0
traitlets==5.9.0
transformers==4.28.1
triton==2.0.0.dev20221202
typing_extensions==4.5.0
uri-template==1.2.0
urllib3==2.0.2
wcwidth==0.2.6
webcolors==1.13
webencodings==0.5.1
websocket-client==1.5.1
widgetsnbextension==4.0.7
xentropy-cuda-lib @ git+https://github.com/HazyResearch/flash-attention.git@33e0860c9c5667fded5af674882e731909096a7f#subdirectory=csrc/xentropy
y-py==0.5.9
ypy-websocket==0.8.2

We have done all of our testing on a100s, so it is possible it only works there. You can always go back to using the torch attention implementation if you aren't able to run it with triton without a100s.

I don't spot any obvious issues with your installed packages, but you might have better luck starting from our recommended docker image, if possible: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04

Thanks, I'll try that.
FWIW - with the torch attention it's functioning fine (though somewhat slow)

Managed to get it to run locally (CPU only on a sytem with 64 GB RAM, on torch attention with cpu_low_mem=True), can we apply transformers accelerators to shard to to CPU if we run out of VRAM? I honestly wish to torch export as onyx so I can run CPU only but I currently struggle to load the model in my local jupyter notebook.

I meet the same Error with this problem when I Run it on the A100 with triton, It seems that it is an confilct between triton and CUDA. Because I tried many version of flash-attn and triton.
If anyone have progress, Please leave a comment. Triton is faster than torch and it is worth trying.

I meet the same Error with this problem when I Run it on the A100 with triton, It seems that it is an confilct between triton and CUDA. Because I tried many version of flash-attn and triton.
If anyone have progress, Please leave a comment. Triton is faster than torch and it is worth trying.

+1, the same error, i tried torch 1.13 and torch 2.0 with CUDA 11.7 and RTX 3090 GPU. Without using triton everything works great

I don't spot any obvious issues with your installed packages, but you might have better luck starting from our recommended docker image, if possible: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04

Tried this docker from scratch. The same error on RTX 3090

I don't spot any obvious issues with your installed packages, but you might have better luck starting from our recommended docker image, if possible: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04

Tried this docker from scratch. The same error on RTX 3090

Finally it works with RTX 3090 this mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04 docker container. I re-installed some python packages in container. Final subset of pip freeze:

triton==2.0.0.dev20221202
torch==1.13.1+cu117
transformers==4.29.2
flash-attn==1.0.3.post0

Torch 2.0 uses a version of Triton that uses a version of MLIR which breaks ALiBi, so those of you who are seeing issues with Torch2... we are aware of this and are working on a fix. Closing this as completed because it does work with the correct versions of the dependencies, which can be found in LLM-Foundry

sam-mosaic changed discussion status to closed

Thanks everyone, I have fix this error by install transformers==4.29.2. It's a conflict between transformer and triton, finally I did not use mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04 docker. Just make sure the folling package version can solve the problem.

triton==2.0.0.dev20221202 torch==1.13.1+cu117 transformers==4.29.2 flash-attn==1.0.3.post0

Torch 2.0 uses a version of Triton that uses a version of MLIR which breaks ALiBi, so those of you who are seeing issues with Torch2... we are aware of this and are working on a fix. Closing this as completed because it does work with the correct versions of the dependencies, which can be found in LLM-Foundry

Not exactly, If you install flash-attn package using pip, the torch 2.0 will be automatically uninstalled, and torch 1.13 will be installed automatically to match the flash-attn version. So this is a problem between transformer and triton, not torch 2.0.
You can specify the transformer version in the model card to avoid other people fail into this error.

I had the same issue when installing on my local machine --- the tricky part for me is that my card is running on CUDA 12.1, so the flash-attn throws error for Torch not compiled in the correct version when I tried to run the training in torch==1.13.1+cu117 ... And only the nightly version of torch>2.0 offers CUDA 12.1 support.

The cause is basically the issue mentioned here:

Torch 2.0 uses a version of Triton that uses a version of MLIR which breaks ALiBi, so those of you who are seeing issues with Torch2... we are aware of this and are working on a fix. Closing this as completed because it does work with the correct versions of the dependencies, which can be found in LLM-Foundry

My workaround is simply to
pip uninstall pytorch-triton
and then
pip install triton==2.0.0.dev20221202 --no-deps
otherwise pip forces reinstallation of torch.

Hope this is helpful for whoever comes later...

Sign up or log in to comment