error while generating
When generating with:
inputs = tokenizer(prompt, return_tensors="pt").to(device='cuda:0').input_ids
res = model.generate(inputs, max_new_tokens=20, do_sample=True, temperature=0.7, top_k=0, top_p=40)
I get:
Traceback (most recent call last):
File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0--d6252949da17ceb5f3a278a70250af13-1af5134066c618146d2cd009138944a0-6fb21260a873f1a3458b67752ca56f63-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', True, 128, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False), (True, False), (True, False), (False, False), (False, False)))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 937, in build_triton_ir
generator.visit(fn.parse())
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 183, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 426, in generic_visit
self.visit(item)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 252, in visit_FunctionDef
has_ret = self.visit_compound_statement(node.body)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
self.last_ret_type = self.visit(stmt)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 678, in visit_For
self.visit_compound_statement(node.body)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 177, in visit_compound_statement
self.last_ret_type = self.visit(stmt)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 319, in visit_AugAssign
self.visit(assign)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 301, in visit_Assign
values = self.visit(node.value)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 339, in visit_BinOp
rhs = self.visit(node.right)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 855, in visit
return super().visit(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/ast.py", line 418, in visit
return visitor(node)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 797, in visit_Call
return fn(*args, _builder=self.builder, **kws)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/impl/base.py", line 22, in wrapper
return fn(*args, **kwargs)
TypeError: dot() got an unexpected keyword argument 'trans_b'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_791836/924310910.py", line 2, in <module>
res = model.generate(inputs, max_new_tokens=20) #, do_sample=True, temperature=0.7, top_k=0, top_p=40) #, return_full_text=False)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/generation/utils.py", line 1437, in generate
return self.greedy_search(
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/transformers/generation/utils.py", line 2248, in greedy_search
outputs = self(
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/u/user983/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-storywriter/26f3be4e1d8bfe3d4313408401582f960f79ade7/modeling_mpt.py", line 237, in forward
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/u/user983/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-storywriter/26f3be4e1d8bfe3d4313408401582f960f79ade7/modeling_mpt.py", line 183, in forward
(x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/u/user983/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-storywriter/26f3be4e1d8bfe3d4313408401582f960f79ade7/blocks.py", line 36, in forward
(b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/u/user983/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-storywriter/26f3be4e1d8bfe3d4313408401582f960f79ade7/attention.py", line 171, in forward
(context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
File "/u/user983/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b-storywriter/26f3be4e1d8bfe3d4313408401582f960f79ade7/attention.py", line 111, in triton_flash_attn_fn
attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/flash_attn/flash_attn_triton.py", line 810, in forward
o, lse, ctx.softmax_scale = _flash_attn_forward(
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/flash_attn/flash_attn_triton.py", line 623, in _flash_attn_forward
_fwd_kernel[grid](
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 199, in run
return self.fn.run(*args, **kwargs)
File "<string>", line 41, in _fwd_kernel
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 1620, in compile
next_module = compile(module)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 1549, in <lambda>
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 962, in ast_to_ttir
mod, _ = build_triton_ir(fn, signature, specialization, constants)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/triton/compiler.py", line 942, in build_triton_ir
raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 78:24:
def _fwd_kernel(
Q, K, V, Bias, Out,
Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
softmax_scale,
stride_qb, stride_qh, stride_qm,
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_bb, stride_bh, stride_bm,
stride_ob, stride_oh, stride_om,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
# off_b = tl.program_id(1)
# off_h = tl.program_id(2)
# off_hb = off_b * nheads + off_h
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# https://github.com/openai/triton/issues/741
# I'm seeing a tiny bit of difference (5-7us)
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
if BIAS_TYPE == 'vector':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == 'matrix':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
# initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0)
# loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
^
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2057, in showtraceback
stb = self.InteractiveTB.structured_traceback(
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1288, in structured_traceback
return FormattedTB.structured_traceback(
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1177, in structured_traceback
return VerboseTB.structured_traceback(
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1049, in structured_traceback
formatted_exceptions += self.format_exception_as_a_whole(etype, evalue, etb, lines_of_context,
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/ultratb.py", line 935, in format_exception_as_a_whole
self.get_records(etb, number_of_lines_of_context, tb_offset) if etb else []
File "/u/user983/miniconda3/envs/llama/lib/python3.10/site-packages/IPython/core/ultratb.py", line 1003, in get_records
lines, first = inspect.getsourcelines(etb.tb_frame)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/inspect.py", line 1129, in getsourcelines
lines, lnum = findsource(object)
File "/u/user983/miniconda3/envs/llama/lib/python3.10/inspect.py", line 958, in findsource
raise OSError('could not get source code')
OSError: could not get source code
It happens only when using triton attention (torch attention works fine), using:
flash-attn==1.0.4
triton==2.0.0
Any ideas?
The versions we use are listed here: https://github.com/mosaicml/llm-foundry/blob/3959eaccba53c444c5705d600d333cc3d47bc06c/setup.py#L74-L78. I'd suggest trying those.
Even after I align to the above setup I still get the same error. See my freeze below.
Also - note this issue from triton: https://github.com/openai/triton/issues/1054 where one of the maintainers hints that it might not work on a100...?
aiofiles==22.1.0
aiosqlite==0.19.0
anyio==3.6.2
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
attrs==23.1.0
Babel==2.12.1
backcall==0.2.0
beautifulsoup4==4.12.2
bleach==6.0.0
certifi==2023.5.7
cffi==1.15.1
charset-normalizer==3.1.0
cmake==3.26.3
coloredlogs==15.0.1
comm==0.1.3
datasets==2.10.1
debugpy==1.6.7
decorator==5.1.1
defusedxml==0.7.1
einops==0.5.0
executing==1.2.0
fastjsonschema==2.16.3
filelock==3.12.0
flash-attn==1.0.3.post0
flatbuffers==23.3.3
fqdn==1.5.1
fsspec==2023.5.0
huggingface-hub==0.14.1
idna==3.4
install==1.3.5
ipykernel==6.23.0
ipython==8.13.2
ipython-genutils==0.2.0
ipywidgets==8.0.6
isoduration==20.11.0
jedi==0.18.2
Jinja2==3.1.2
json5==0.9.11
jsonpointer==2.3
jsonschema==4.17.3
jupyter-events==0.6.3
jupyter-ydoc==0.2.4
jupyter_client==8.2.0
jupyter_core==5.3.0
jupyter_server==2.5.0
jupyter_server_fileid==0.9.0
jupyter_server_terminals==0.4.4
jupyter_server_ydoc==0.8.0
jupyterlab==3.6.3
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.7
jupyterlab_server==2.22.1
llm-foundry==1.0
MarkupSafe==2.1.2
matplotlib-inline==0.1.6
mistune==2.0.5
mosaicml==0.14.1
mosaicml-cli==0.4.0a5
nbclassic==1.0.0
nbclient==0.7.4
nbconvert==7.4.0
nbformat==5.8.0
nest-asyncio==1.5.6
notebook==6.5.4
notebook_shim==0.2.3
numpy==1.24.3
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
omegaconf==2.3.0
onnx==1.13.1
onnxruntime==1.14.1
packaging==23.1
pandocfilters==1.5.0
parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
platformdirs==3.5.0
prometheus-client==0.16.0
prompt-toolkit==3.0.38
protobuf==4.23.0
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser==2.21
Pygments==2.15.1
pynvml==11.5.0
pyrsistent==0.19.3
python-dateutil==2.8.2
python-json-logger==2.0.7
PyYAML==6.0
pyzmq==25.0.2
regex==2023.5.5
requests==2.30.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
Send2Trash==1.8.2
sentencepiece==0.1.97
six==1.16.0
slack-sdk==3.21.3
sniffio==1.3.0
soupsieve==2.4.1
stack-data==0.6.2
sympy==1.12rc1
terminado==0.17.1
tinycss2==1.2.1
tokenizers==0.13.3
tomli==2.0.1
torch==1.13.1
tornado==6.3.1
tqdm==4.65.0
traitlets==5.9.0
transformers==4.28.1
triton==2.0.0.dev20221202
typing_extensions==4.5.0
uri-template==1.2.0
urllib3==2.0.2
wcwidth==0.2.6
webcolors==1.13
webencodings==0.5.1
websocket-client==1.5.1
widgetsnbextension==4.0.7
xentropy-cuda-lib @ git+https://github.com/HazyResearch/flash-attention.git@33e0860c9c5667fded5af674882e731909096a7f#subdirectory=csrc/xentropy
y-py==0.5.9
ypy-websocket==0.8.2
We have done all of our testing on a100s, so it is possible it only works there. You can always go back to using the torch
attention implementation if you aren't able to run it with triton without a100s.
I don't spot any obvious issues with your installed packages, but you might have better luck starting from our recommended docker image, if possible: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
Thanks, I'll try that.
FWIW - with the torch attention it's functioning fine (though somewhat slow)
Managed to get it to run locally (CPU only on a sytem with 64 GB RAM, on torch attention with cpu_low_mem=True), can we apply transformers accelerators to shard to to CPU if we run out of VRAM? I honestly wish to torch export as onyx so I can run CPU only but I currently struggle to load the model in my local jupyter notebook.
I meet the same Error with this problem when I Run it on the A100 with triton, It seems that it is an confilct between triton and CUDA. Because I tried many version of flash-attn and triton.
If anyone have progress, Please leave a comment. Triton is faster than torch and it is worth trying.
I meet the same Error with this problem when I Run it on the A100 with triton, It seems that it is an confilct between triton and CUDA. Because I tried many version of flash-attn and triton.
If anyone have progress, Please leave a comment. Triton is faster than torch and it is worth trying.
+1, the same error, i tried torch 1.13
and torch 2.0
with CUDA 11.7
and RTX 3090
GPU. Without using triton
everything works great
I don't spot any obvious issues with your installed packages, but you might have better luck starting from our recommended docker image, if possible:
mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
Tried this docker from scratch. The same error on RTX 3090
I don't spot any obvious issues with your installed packages, but you might have better luck starting from our recommended docker image, if possible:
mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
Tried this docker from scratch. The same error on
RTX 3090
Finally it works with RTX 3090
this mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
docker container. I re-installed some python packages in container. Final subset of pip freeze:
triton==2.0.0.dev20221202
torch==1.13.1+cu117
transformers==4.29.2
flash-attn==1.0.3.post0
Torch 2.0 uses a version of Triton that uses a version of MLIR which breaks ALiBi, so those of you who are seeing issues with Torch2... we are aware of this and are working on a fix. Closing this as completed because it does work with the correct versions of the dependencies, which can be found in LLM-Foundry
Thanks everyone, I have fix this error by install transformers==4.29.2
. It's a conflict between transformer and triton, finally I did not use mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
docker. Just make sure the folling package version can solve the problem.
triton==2.0.0.dev20221202 torch==1.13.1+cu117 transformers==4.29.2 flash-attn==1.0.3.post0
Torch 2.0 uses a version of Triton that uses a version of MLIR which breaks ALiBi, so those of you who are seeing issues with Torch2... we are aware of this and are working on a fix. Closing this as completed because it does work with the correct versions of the dependencies, which can be found in LLM-Foundry
Not exactly, If you install flash-attn package using pip, the torch 2.0 will be automatically uninstalled, and torch 1.13 will be installed automatically to match the flash-attn version. So this is a problem between transformer and triton, not torch 2.0.
You can specify the transformer version in the model card to avoid other people fail into this error.
I had the same issue when installing on my local machine --- the tricky part for me is that my card is running on CUDA 12.1, so the flash-attn throws error for Torch not compiled in the correct version when I tried to run the training in torch==1.13.1+cu117 ... And only the nightly version of torch>2.0 offers CUDA 12.1 support.
The cause is basically the issue mentioned here:
Torch 2.0 uses a version of Triton that uses a version of MLIR which breaks ALiBi, so those of you who are seeing issues with Torch2... we are aware of this and are working on a fix. Closing this as completed because it does work with the correct versions of the dependencies, which can be found in LLM-Foundry
My workaround is simply topip uninstall pytorch-triton
and thenpip install triton==2.0.0.dev20221202 --no-deps
otherwise pip forces reinstallation of torch.
Hope this is helpful for whoever comes later...