How to actually run the model without getting run-time errors?
I tried as per readme first.
from transformers import AutoModelForMaskedLM, BertTokenizer, pipeline
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base', trust_remote_code=True,
revision='24512df') # I tried with or without revision
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer)
classifier("I [MASK] to the store yesterday.")
The example is not working.
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/flash_attn_triton.py:781, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale)
778 assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
779 assert q.dtype in [torch.float16,
780 torch.bfloat16], 'Only support fp16 and bf16'
--> 781 assert q.is_cuda and k.is_cuda and v.is_cuda
This is trivial to fix:
mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base', trust_remote_code=True, revision='24512df').cuda()
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer,device="cuda:0")
classifier("I [MASK] to the store yesterday.")
And ...
KeyError Traceback (most recent call last)
File <string>:21, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)
KeyError: ('2-.-0-.-0-83ca8b715a9dc5f32dc1110973485f64-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-975a5a907f067e8e36a802ec0cd5bc10-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float16, torch.float16, torch.float16, torch.float16, torch.float16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('matrix', False, 64, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False), (False, False), (True, False), (True, False), (True, False), (False, False), (False, False), (False, False), (True, False), (True, False), (True, False), (True, False)))
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:937, in build_triton_ir(fn, signature, specialization, constants)
936 try:
--> 937 generator.visit(fn.parse())
938 except Exception as e:
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:183, in CodeGenerator.visit_Module(self, node)
182 def visit_Module(self, node):
--> 183 ast.NodeVisitor.generic_visit(self, node)
File /usr/lib/python3.11/ast.py:426, in NodeVisitor.generic_visit(self, node)
425 if isinstance(item, AST):
--> 426 self.visit(item)
427 elif isinstance(value, AST):
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:252, in CodeGenerator.visit_FunctionDef(self, node)
251 # visit function body
--> 252 has_ret = self.visit_compound_statement(node.body)
253 # finalize function
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:177, in CodeGenerator.visit_compound_statement(self, stmts)
176 for stmt in stmts:
--> 177 self.last_ret_type = self.visit(stmt)
178 if isinstance(stmt, ast.Return):
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:678, in CodeGenerator.visit_For(self, node)
677 self.scf_stack.append(node)
--> 678 self.visit_compound_statement(node.body)
679 self.scf_stack.pop()
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:177, in CodeGenerator.visit_compound_statement(self, stmts)
176 for stmt in stmts:
--> 177 self.last_ret_type = self.visit(stmt)
178 if isinstance(stmt, ast.Return):
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:319, in CodeGenerator.visit_AugAssign(self, node)
318 assign = ast.Assign(targets=[node.target], value=rhs)
--> 319 self.visit(assign)
320 return self.get_value(name)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:301, in CodeGenerator.visit_Assign(self, node)
300 names = _names[0]
--> 301 values = self.visit(node.value)
302 if not isinstance(names, tuple):
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:339, in CodeGenerator.visit_BinOp(self, node)
338 lhs = self.visit(node.left)
--> 339 rhs = self.visit(node.right)
340 fn = {
341 ast.Add: '__add__',
342 ast.Sub: '__sub__',
(...)
352 ast.BitXor: '__xor__',
353 }[type(node.op)]
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:855, in CodeGenerator.visit(self, node)
854 warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8
--> 855 return super().visit(node)
File /usr/lib/python3.11/ast.py:418, in NodeVisitor.visit(self, node)
417 visitor = getattr(self, method, self.generic_visit)
--> 418 return visitor(node)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:797, in CodeGenerator.visit_Call(self, node)
795 if (hasattr(fn, '__self__') and self.is_triton_tensor(fn.__self__)) \
796 or impl.is_builtin(fn):
--> 797 return fn(*args, _builder=self.builder, **kws)
798 if fn in self.builtins.values():
File ~/src/sd/sd/lib/python3.11/site-packages/triton/impl/base.py:22, in builtin.<locals>.wrapper(*args, **kwargs)
18 raise ValueError(
19 "Did you forget to add
@triton
.jit ? "
20 "(`_builder` argument must be provided outside of JIT functions.)"
21 )
---> 22 return fn(*args, **kwargs)
TypeError: dot() got an unexpected keyword argument 'trans_b'
The above exception was the direct cause of the following exception:
CompilationError Traceback (most recent call last)
Cell In[3], line 8
4 mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base', trust_remote_code=True, revision='24512df').cuda()
6 classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer,device="cuda:0")
----> 8 classifier("I [MASK] to the store yesterday.")
File ~/src/sd/sd/lib/python3.11/site-packages/transformers/pipelines/fill_mask.py:239, in FillMaskPipeline.__call__(self, inputs, *args, **kwargs)
217 def __call__(self, inputs, *args, **kwargs):
218 """
219 Fill the masked token in the text(s) given as inputs.
220
(...)
237 - **token_str** (`str`) -- The predicted token (to replace the masked one).
238 """
--> 239 outputs = super().__call__(inputs, **kwargs)
240 if isinstance(inputs, list) and len(inputs) == 1:
241 return outputs[0]
File ~/src/sd/sd/lib/python3.11/site-packages/transformers/pipelines/base.py:1118, in Pipeline.__call__(self, inputs, num_workers, batch_size, *args, **kwargs)
1110 return next(
1111 iter(
1112 self.get_iterator(
(...)
1115 )
1116 )
1117 else:
-> 1118 return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
File ~/src/sd/sd/lib/python3.11/site-packages/transformers/pipelines/base.py:1125, in Pipeline.run_single(self, inputs, preprocess_params, forward_params, postprocess_params)
1123 def run_single(self, inputs, preprocess_params, forward_params, postprocess_params):
1124 model_inputs = self.preprocess(inputs, **preprocess_params)
-> 1125 model_outputs = self.forward(model_inputs, **forward_params)
1126 outputs = self.postprocess(model_outputs, **postprocess_params)
1127 return outputs
File ~/src/sd/sd/lib/python3.11/site-packages/transformers/pipelines/base.py:1024, in Pipeline.forward(self, model_inputs, **forward_params)
1022 with inference_context():
1023 model_inputs = self._ensure_tensor_on_device(model_inputs, device=self.device)
-> 1024 model_outputs = self._forward(model_inputs, **forward_params)
1025 model_outputs = self._ensure_tensor_on_device(model_outputs, device=torch.device("cpu"))
1026 else:
File ~/src/sd/sd/lib/python3.11/site-packages/transformers/pipelines/fill_mask.py:101, in FillMaskPipeline._forward(self, model_inputs)
100 def _forward(self, model_inputs):
--> 101 model_outputs = self.model(**model_inputs)
102 model_outputs["input_ids"] = model_inputs["input_ids"]
103 return model_outputs
File ~/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/bert_layers.py:850, in BertForMaskedLM.forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, output_attentions, output_hidden_states, return_dict)
846 masked_tokens_mask = labels > 0
848 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
--> 850 outputs = self.bert(
851 input_ids,
852 attention_mask=attention_mask,
853 token_type_ids=token_type_ids,
854 position_ids=position_ids,
855 head_mask=head_mask,
856 inputs_embeds=inputs_embeds,
857 encoder_hidden_states=encoder_hidden_states,
858 encoder_attention_mask=encoder_attention_mask,
859 output_attentions=output_attentions,
860 output_hidden_states=output_hidden_states,
861 return_dict=return_dict,
862 masked_tokens_mask=masked_tokens_mask,
863 )
865 sequence_output = outputs[0]
866 prediction_scores = self.cls(sequence_output)
File ~/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/bert_layers.py:669, in BertModel.forward(self, input_ids, token_type_ids, attention_mask, position_ids, output_all_encoded_layers, masked_tokens_mask, **kwargs)
666 first_col_mask[:, 0] = True
667 subset_mask = masked_tokens_mask | first_col_mask
--> 669 encoder_outputs = self.encoder(
670 embedding_output,
671 attention_mask,
672 output_all_encoded_layers=output_all_encoded_layers,
673 subset_mask=subset_mask)
675 if masked_tokens_mask is None:
676 sequence_output = encoder_outputs[-1]
File ~/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/bert_layers.py:507, in BertEncoder.forward(self, hidden_states, attention_mask, output_all_encoded_layers, subset_mask)
505 if subset_mask is None:
506 for layer_module in self.layer:
--> 507 hidden_states = layer_module(hidden_states,
508 cu_seqlens,
509 seqlen,
510 None,
511 indices,
512 attn_mask=attention_mask,
513 bias=alibi_attn_mask)
514 if output_all_encoded_layers:
515 all_encoder_layers.append(hidden_states)
File ~/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/bert_layers.py:388, in BertLayer.forward(self, hidden_states, cu_seqlens, seqlen, subset_idx, indices, attn_mask, bias)
366 def forward(
367 self,
368 hidden_states: torch.Tensor,
(...)
374 bias: Optional[torch.Tensor] = None,
375 ) -> torch.Tensor:
376 """Forward pass for a BERT layer, including both attention and MLP.
377
378 Args:
(...)
386 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
387 """
--> 388 attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
389 subset_idx, indices, attn_mask, bias)
390 layer_output = self.mlp(attention_output)
391 return layer_output
File ~/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/bert_layers.py:301, in BertUnpadAttention.forward(self, input_tensor, cu_seqlens, max_s, subset_idx, indices, attn_mask, bias)
279 def forward(
280 self,
281 input_tensor: torch.Tensor,
(...)
287 bias: Optional[torch.Tensor] = None,
288 ) -> torch.Tensor:
289 """Forward pass for scaled self-attention without padding.
290
291 Arguments:
(...)
299 bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
300 """
--> 301 self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
302 attn_mask, bias)
303 if subset_idx is not None:
304 return self.output(index_first_axis(self_output, subset_idx),
305 index_first_axis(input_tensor, subset_idx))
File ~/src/sd/sd/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/bert_layers.py:233, in BertUnpadSelfAttention.forward(self, hidden_states, cu_seqlens, max_seqlen_in_batch, indices, attn_mask, bias)
231 bias_dtype = bias.dtype
232 bias = bias.to(torch.float16)
--> 233 attention = flash_attn_qkvpacked_func(qkv, bias)
234 attention = attention.to(orig_dtype)
235 bias = bias.to(bias_dtype)
File ~/src/sd/sd/lib/python3.11/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs)
503 if not torch._C._are_functorch_transforms_active():
504 # See NOTE: [functorch vjp and autograd interaction]
505 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 506 return super().apply(*args, **kwargs) # type: ignore[misc]
508 if cls.setup_context == _SingleLevelFunction.setup_context:
509 raise RuntimeError(
510 'In order to use an autograd.Function with functorch transforms '
511 '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
512 'staticmethod. For more details, please see '
513 'https://pytorch.org/docs/master/notes/extending.func.html')
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/flash_attn_triton.py:1021, in _FlashAttnQKVPackedFunc.forward(ctx, qkv, bias, causal, softmax_scale)
1019 if qkv.stride(-1) != 1:
1020 qkv = qkv.contiguous()
-> 1021 o, lse, ctx.softmax_scale = _flash_attn_forward(
1022 qkv[:, :, 0],
1023 qkv[:, :, 1],
1024 qkv[:, :, 2],
1025 bias=bias,
1026 causal=causal,
1027 softmax_scale=softmax_scale)
1028 ctx.save_for_backward(qkv, o, lse, bias)
1029 ctx.causal = causal
File ~/.cache/huggingface/modules/transformers_modules/mosaicml/mosaic-bert-base/fcc434c97e2d475d5dd1a69fca9f734af7a41772/flash_attn_triton.py:826, in _flash_attn_forward(q, k, v, bias, causal, softmax_scale)
823 # BLOCK = 128
824 # num_warps = 4 if d <= 64 else 8
825 grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
--> 826 _fwd_kernel[grid]( # type: ignore
827 q,
828 k,
829 v,
830 bias,
831 o,
832 lse,
833 tmp,
834 softmax_scale,
835 q.stride(0),
836 q.stride(2),
837 q.stride(1),
838 k.stride(0),
839 k.stride(2),
840 k.stride(1),
841 v.stride(0),
842 v.stride(2),
843 v.stride(1),
844 *bias_strides,
845 o.stride(0),
846 o.stride(2),
847 o.stride(1),
848 nheads,
849 seqlen_q,
850 seqlen_k,
851 seqlen_q_rounded,
852 d,
853 seqlen_q // 32,
854 seqlen_k // 32, # key for triton cache (limit number of compilations)
855 # Can't use kwargs here because triton autotune expects key to be args, not kwargs
856 # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
857 bias_type,
858 causal,
859 BLOCK_HEADDIM,
860 # BLOCK_M=BLOCK, BLOCK_N=BLOCK,
861 # num_warps=num_warps,
862 # num_stages=1,
863 )
864 return o, lse, softmax_scale
File ~/src/sd/sd/lib/python3.11/site-packages/triton/runtime/autotuner.py:90, in Autotuner.run(self, *args, **kwargs)
88 if config.pre_hook is not None:
89 config.pre_hook(self.nargs)
---> 90 return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/runtime/autotuner.py:199, in Heuristics.run(self, *args, **kwargs)
197 for v, heur in self.values.items():
198 kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs})
--> 199 return self.fn.run(*args, **kwargs)
File <string>:41, in _fwd_kernel(Q, K, V, Bias, Out, Lse, TMP, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_bb, stride_bh, stride_bm, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, BIAS_TYPE, IS_CAUSAL, BLOCK_HEADDIM, EVEN_M, EVEN_N, EVEN_HEADDIM, BLOCK_M, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:1621, in compile(fn, **kwargs)
1619 next_module = parse(path)
1620 else:
-> 1621 next_module = compile(module)
1622 fn_cache_manager.put(next_module, f"{name}.{ir}")
1623 if os.path.exists(path):
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:1550, in compile.<locals>.<lambda>(src)
1545 extern_libs = kwargs.get("extern_libs", dict())
1546 # build compilation stages
1547 stages = {
1548 "ast": (lambda path: fn, None),
1549 "ttir": (lambda path: parse_mlir_module(path, context),
-> 1550 lambda src: ast_to_ttir(src, signature, configs[0], constants)),
1551 "ttgir": (lambda path: parse_mlir_module(path, context),
1552 lambda src: ttir_to_ttgir(src, num_warps, num_stages, capability)),
1553 "llir": (lambda path: Path(path).read_text(),
1554 lambda src: ttgir_to_llir(src, extern_libs, capability)),
1555 "ptx": (lambda path: Path(path).read_text(),
1556 lambda src: llir_to_ptx(src, capability)),
1557 "cubin": (lambda path: Path(path).read_bytes(),
1558 lambda src: ptx_to_cubin(src, capability))
1559 }
1560 # find out the signature of the function
1561 if isinstance(fn, triton.runtime.JITFunction):
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:962, in ast_to_ttir(fn, signature, specialization, constants)
961 def ast_to_ttir(fn, signature, specialization, constants):
--> 962 mod, _ = build_triton_ir(fn, signature, specialization, constants)
963 return optimize_triton_ir(mod)
File ~/src/sd/sd/lib/python3.11/site-packages/triton/compiler.py:942, in build_triton_ir(fn, signature, specialization, constants)
940 if node is None or isinstance(e, (NotImplementedError, CompilationError)):
941 raise e
--> 942 raise CompilationError(fn.src, node) from e
943 ret = generator.module
944 # module takes ownership of the context
CompilationError: at 114:24:
def _fwd_kernel(
Q,
K,
V,
Bias,
Out,
Lse,
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
softmax_scale,
stride_qb,
stride_qh,
stride_qm,
stride_kb,
stride_kh,
stride_kn,
stride_vb,
stride_vh,
stride_vn,
stride_bb,
stride_bh,
stride_bm,
stride_ob,
stride_oh,
stride_om,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
# off_b = tl.program_id(1)
# off_h = tl.program_id(2)
# off_hb = off_b * nheads + off_h
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# https://github.com/openai/triton/issues/741
# I'm seeing a tiny bit of difference (5-7us)
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (
offs_m[:, None] * stride_qm + offs_d[None, :])
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (
offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (
offs_n[:, None] * stride_vn + offs_d[None, :])
if BIAS_TYPE == 'vector':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == 'matrix':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (
offs_m[:, None] * stride_bm + offs_n[None, :])
else:
raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
# initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(q_ptrs,
mask=(offs_m[:, None] < seqlen_q) &
(offs_d[None, :] < headdim),
other=0.0)
# loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum(
(start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=offs_d[None, :] < headdim,
other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0)
else:
k = tl.load(k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) &
(offs_d[None, :] < headdim),
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
I use
In [8]: sys.version_info
Out[8]: sys.version_info(major=3, minor=11, micro=3, releaselevel='final', serial=0)
In [9]: torch.__version__
Out[9]: '2.0.1+cu117'
In [10]: import triton
In [11]: triton.__version__
Out[12]: '2.0.0'
In [13]: import triton.language as tl
In [14]: tl.__version__
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[14], line 1
----> 1 tl.__version__
AttributeError: module 'triton.language' has no attribute '__version__'
In [15]: tl.dot
Out[15]: <function triton.language.core.dot(input, other, allow_tf32=True, _builder=None)>
Do I need special revision= string to make it work? Did triton language have breaking changes?
This is a great question. Has anyone managed to successfully resolve it yet?
Being able to run the example code without errors would certainly increase confidence in the model immensely.
This is a great question. Has anyone managed to successfully resolve it yet?
Being able to run the example code without errors would certainly increase confidence in the model immensely.
I've managed to run it after changing triton versin
pip uninstall triton
pip install --no-deps triton==2.0.0.dev20221202
I've used --no-deps as otherwise it wanted to downgrade torch from 2.0.1 to 2.0.0. (No, thank you very much)
Here's a fully working example to run from directory of downloaded model (hence os.getcwd()
- you can't use from_pretrained('.') in this case as it causes weird errors down the line)
$ cat runme.py
import os
import torch
from transformers import AutoModelForMaskedLM, BertTokenizer, pipeline
mlm = AutoModelForMaskedLM.from_pretrained(os.getcwd(), trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer, device=0)
print(classifier("I [MASK] to the store yesterday."))
$ python runme.py
[{'score': 0.8977681398391724, 'token': 2253, 'token_str': 'went', 'sequence': 'i went to the store yesterday.'}, {'score': 0.02546772174537182, 'token': 2234, 'token_str': 'came', 'sequence': 'i came to the store yesterday.'}, {'score': 0.021113483235239983, 'token': 2939, 'token_str': 'walked', 'sequence': 'i walked to the store yesterday.'}, {'score': 0.013631888665258884, 'token': 2288, 'token_str': 'got', 'sequence': 'i got to the store yesterday.'}, {'score': 0.00997330341488123, 'token': 5225, 'token_str': 'drove', 'sequence': 'i drove to the store yesterday.'}]
Things I also tried:
replacing all
tl.dot(A, B, trans_a=True)
withtl.dot(tl.trans(A), B)
, but either I was not accurate or it's too compute-extensive: python either hanged or I lost patience.throwing away flash attention and using torch's scaled_dot_product_attention. I couldn't figure out how to massage parameters into correct shape
remove local flash_attention_triton and import one from the flash_attention package. It dumped a giant error log, but that's where I noticed that it was using not triton 2.0.0, but 2.0.0dev
After replacing triton version everything works.
Magical version string was taken from the python's flash-attention package
This worked beautifully! Thank you so much for sharing this solution.
Just out of curiosity, if you did not download the model and use os.getcwd(), were you actually receiving errors or was the model simply producing nonsensical inferences?