Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: FlashAttention only support fp16 and bf16 data type #259

Open
sankexin opened this issue May 20, 2024 · 0 comments
Open

RuntimeError: FlashAttention only support fp16 and bf16 data type #259

sankexin opened this issue May 20, 2024 · 0 comments

Comments

@sankexin
Copy link

python train.py --train_args_file train_args/sft/qlora/llama3-8b-sft-qlora.json

GPU: NVIDIA A800 80GB PCIe. Max memory: 79.151 GB. Platform = Linux.
Pytorch: 2.3.0+cu121. CUDA = 8.0. CUDA Toolkit = 12.1.
Bfloat16 = TRUE. Xformers = 0.0.26.post1. FA = True.

Traceback (most recent call last):
File "/home/sotanv/Firefly/train.py", line 439, in
main()
File "/home/sotanv/Firefly/train.py", line 427, in main
train_result = trainer.train()
File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 1624, in train
return inner_training_loop(
File "", line 354, in _fast_inner_training_loop
File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 2902, in training_step
loss = self.compute_loss(model, inputs)
File "/usr/local/lib/python3.10/site-packages/transformers/trainer.py", line 2925, in compute_loss
outputs = model(**inputs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py", line 817, in forward
return model_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/accelerate/utils/operations.py", line 805, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/unsloth/models/llama.py", line 882, in PeftModelForCausalLM_fast_forward
return self.base_model(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 161, in forward
return self.model.forward(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/unsloth/models/llama.py", line 813, in _CausalLM_fast_forward
outputs = self.model(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/unsloth/models/llama.py", line 668, in LlamaModel_fast_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
File "/usr/local/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 487, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/usr/local/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 262, in forward
outputs = run_function(*args)
File "/usr/local/lib/python3.10/site-packages/unsloth/models/llama.py", line 664, in custom_forward
return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/unsloth/models/llama.py", line 433, in LlamaDecoderLayer_fast_forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/site-packages/unsloth/models/llama.py", line 359, in LlamaAttention_fast_forward
A = flash_attn_func(Q, K, V, causal = True)
File "/usr/local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 831, in flash_attn_func
return FlashAttnFunc.apply(
File "/usr/local/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 511, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
File "/usr/local/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type
0%|

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant