Move flash_attn assert from __init__ into calling func

#32
by rogerxfeng8 - opened

When enabling phi-3-small on non-cuda devices, flash_attn package is not available. The assert of flash_attn in init will force the exit. The patch changes the assert into warning, so that we can use customized implementation of flash attention in users' own modeling.

Hi @nguyenbh , is it possible to review and merge this PR? We found this assertion is blocking non-CUDA device from running this model. If this assertion is removed it won't break CUDA device but will allow non-CUDA device to run this model. Thanks!

Microsoft org

Mostly LGTM. One ask: can you move the assert warning into _apply_dense_attn, where (flash_attn_varlen_kvpacked_func is being called ?
Otherwise might be hard to understand that the issue is that flash attention is not available.

rogerxfeng8 changed pull request title from Change the assert to warning in __init__ to Move flash_attn assert from __init__ into calling func

@bapatra hi Barun, thanks for the feedback. PR updated, pls check.

Microsoft org

Thanks @rogerxfeng8 !

nguyenbh changed pull request status to merged

Sign up or log in to comment