Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly initialize the text model (Mistral) of Idefics2 with Flash Attention #30395

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

zafstojano
Copy link
Contributor

This PR attempts to resolve the issue of the text model not being loaded with Flash Attention 2

Relevant issue: #30394


Currently, whatever combination of parameters I pass to the instantiation of the Idefics2 models, the text model is not being loaded with Flash Attention 2. Here are several examples:

  1. Pass _attn_implementation to from_pretrained
import torch
from transformers import Idefics2ForConditionalGeneration

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2",
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)

Output:

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.21it/s]
Mistral model attention implementation:  sdpa
  1. Passattn_implementation to from_pretrained
import torch
from transformers import Idefics2ForConditionalGeneration

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)

Output:

Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.08it/s]
Mistral model attention implementation:  sdpa
  1. Pass config object with property _attn_implementation:
import torch
from transformers import AutoConfig, Idefics2ForConditionalGeneration

config = AutoConfig.from_pretrained("HuggingFaceM4/idefics2-8b")
config._attn_implementation = "flash_attention_2"
config.torch_dtype = torch.bfloat16

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    config=config,
    torch_dtype=torch.bfloat16,
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)

Output:

Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.09it/s]
Mistral model attention implementation:  sdpa
  1. Pass both config and attn_implementation to from_pretrained:
import torch
from transformers import AutoConfig, Idefics2ForConditionalGeneration

config = AutoConfig.from_pretrained("HuggingFaceM4/idefics2-8b")
config._attn_implementation = "flash_attention_2"
config.torch_dtype = torch.bfloat16

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    config=config,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)

Output:

Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.04it/s]
Mistral model attention implementation:  sdpa

This PR contains a simple patch which would allow the text model to be loaded with Flash Attention. Here is the output with the changes included:

import torch
from transformers import AutoConfig, Idefics2ForConditionalGeneration

config = AutoConfig.from_pretrained("HuggingFaceM4/idefics2-8b")
config._attn_implementation = "flash_attention_2"
config.torch_dtype = torch.bfloat16

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    config=config,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
print("Mistral model attention implementation: ", model.model.text_model._attn_implementation)

Output

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.13it/s]
Mistral model attention implementation:  flash_attention_2

It is not an ideal fix, since it requires both passing a config object and a attn_implementation parameter. Moreover, it relies on the use_flash_attention_2 parameter which might be deprecated soon.

Criticism, feedback and requests for changes are welcomed.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Ping: @amyeroberts

@amyeroberts
Copy link
Collaborator

Hi @zafstojano, thanks for opening this PR and addressing this issue!

At the moment in the diff and commit history there's lots of changes which are unrelated to this PR which should be resolved before merge. It looks like what happens after rebasing and pushing without force pushing. If this is the case, simply force pushing should resolve

@zafstojano zafstojano force-pushed the idefics2-init-mistral-with-flash-attention branch from bb9c2b4 to 669f7b1 Compare April 22, 2024 19:16
@zafstojano
Copy link
Contributor Author

@amyeroberts I have now force-pushed only my changes 👍

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 22, 2024

(curious) how to push without force after rebasing ... 👀 ?

@amyeroberts
Copy link
Collaborator

(curious) how to push without force after rebasing ... 👀 ?

I've done it before but can't remember exactly the steps I took to achieve it! I think it rejects the push, you can pull and then push again.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this fix!

We should add tests to make sure the attention implementation and torch dtype properly get set from the configs.

@@ -1473,16 +1473,23 @@ def __init__(self, config: Idefics2Config):
super().__init__(config)
self.padding_idx = self.config.text_config.pad_token_id
self.vocab_size = self.config.text_config.vocab_size
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we directly pass the config's attention implementation we don't need this

Suggested change
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"

Comment on lines 1480 to 1489
text_model_kwargs = {}
if self._use_flash_attention_2:
text_model_kwargs["use_flash_attention_2"] = True
torch_dtype = None
if config.text_config.torch_dtype is not None:
torch_dtype = config.text_config.torch_dtype
elif config.torch_dtype is not None:
torch_dtype = config.torch_dtype
text_model_kwargs["torch_dtype"] = torch_dtype
self.text_model = AutoModel.from_config(config.text_config, **text_model_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Let's use the same pattern as in llava-next. This is more robust to future attention implementations

  • It's a bit funny to have two possible options for torch_dtype here, and taking the one from the text_config as precedence. If a user specified model = IdeficsForConditionalGeneration(checkpoint, torch_dtype=torch.float16) I'd expect torch.float16 to be used.

Suggested change
text_model_kwargs = {}
if self._use_flash_attention_2:
text_model_kwargs["use_flash_attention_2"] = True
torch_dtype = None
if config.text_config.torch_dtype is not None:
torch_dtype = config.text_config.torch_dtype
elif config.torch_dtype is not None:
torch_dtype = config.torch_dtype
text_model_kwargs["torch_dtype"] = torch_dtype
self.text_model = AutoModel.from_config(config.text_config, **text_model_kwargs)
torch_dtype = config.text_config.torch_dtype
if config.torch_dtype is not None:
torch_dtype = config.torch_dtype
self.text_model = AutoModel.from_config(
config.text_config,
attn_implementation=config.._attn_implementation,
torch_dtype=torch_dtype
)

@zafstojano
Copy link
Contributor Author

@amyeroberts thank you for the constructive feedback.

I am currently experiencing some weird behavior when I integrate those changes, perhaps I am not 100% familiar with the internals of the transformers library.

For the following implementation of the init method in Idefics2Model:

class Idefics2Model(Idefics2PreTrainedModel):
    def __init__(self, config: Idefics2Config):
        super().__init__(config)
        self.padding_idx = self.config.text_config.pad_token_id
        self.vocab_size = self.config.text_config.vocab_size

        self.vision_model = Idefics2VisionTransformer(config.vision_config)
        self.connector = Idefics2Connector(config)
        torch_dtype = config.text_config.torch_dtype
        if config.torch_dtype is not None:
            torch_dtype = config.torch_dtype
        attn_implementation = config.text_config._attn_implementation
        if config._attn_implementation is not None:
            attn_implementation = config._attn_implementation
        print("=================")
        print("torch_dtype being passed to text_model in Idefics2Model.__init__():", torch_dtype)
        print("=================")
        self.text_model = AutoModel.from_config(
            config.text_config, 
            attn_implementation=attn_implementation,
            torch_dtype=torch_dtype,    
        )
        self.image_seq_len = config.perceiver_config.resampler_n_latents
        self.image_token_id = self.config.image_token_id

        self.post_init()

and the following code sample:

import torch
from transformers import Idefics2ForConditionalGeneration

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

print("Perceiver model flash attention: ", model.model.connector.perceiver_resampler._use_flash_attention_2)
print("Vision model flash attention: ", model.model.vision_model._use_flash_attention_2)
print("Text model flash attention: ", model.model.text_model._attn_implementation == "flash_attention_2")
print('-----------------')
print("Model dtype: ", model.dtype)

I get the output:

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in MistralModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
=================
torch_dtype being passed to text_model in Idefics2Model.__init__(): torch.float32
=================
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  2.68it/s]
Perceiver model flash attention:  True
Vision model flash attention:  True
Text model flash attention:  True
-----------------
Model dtype:  torch.bfloat16

So, the flash attention is correctly propagate to all submodules when the user specifies attn_implementation="flash_attention_2" in Idefics2ForConditionalGeneration.from_pretrained, but the torch_dtype is for some reason not: look at the torch_dtype being passed to the text model.

Do you have any idea why this is happening?

Unrelated to the above issue, I have another suggestion: Since the vision model is initialized from config.vision_config, and in turn it uses this sub-config file to infer the attention implementation, it would be a good idea to override the config.vision_config._attn_implementation property with the one inferred above. What do you think?

@zafstojano zafstojano marked this pull request as draft April 23, 2024 14:00
@amyeroberts
Copy link
Collaborator

Hi @zafstojano, thanks for sharing this script!

OK, so the behaviour of torch_dtype is quite complex and not the area of the code I'm most familiar with. In terms of what's happening in the script, I think:

  • torch_dtype isn't set in the idefics2config, so defaults to torch.float32
  • this is what is passed along when constructing the model. However, when loading the model from a checkpoint, what actually happens is we create an empty model, and then fill in the values when loading the weights. The torch.float32 you're seeing is when this empty model is made.
  • When loading the weights, the passed in torch_dtype value in the from_pretrained call determines how load the weights. If it's "auto" then we use the value in the config. If it's set to torch.xxx then it uses this value. If unset, it defaults to torch.float32. This will determine the model's dtype.
  • In this case, we don't want to pass torch_dtype from the config, as this just describes the format of the weights as they were saved. Instead, we should just pass the params for setting the attention value and skip the torch_dtype logic.

cc @younesbelkada To confirm if this is right and if there's anything else to be aware of. To understand more, if in composite models like this and llava have their language model saved in e.g. float16, and their vision tower in float32; what will happen when we use torch_dtype="auto"?

Since the vision model is initialized from config.vision_config, and in turn it uses this sub-config file to infer the attention implementation, it would be a good idea to override the config.vision_config._attn_implementation property with the one inferred above. What do you think?

Yes! Good idea

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !
To answer @amyeroberts 's question, I think it is fine to not pass any torch_dtype since cls._set_default_torch_dtype() is called here:

dtype_orig = cls._set_default_torch_dtype(torch_dtype)
- making sure the model that gets initialized here is initialized with the correct dtype (i.e. either the one passed by torch_dtype or from the config if one passes "auto")).

Once the model is initialized the original dtype is set again here.

@zafstojano,
Note torch_dtype always dictates the dtype of the whole model, even if idefics2 is in fact a combination of models, it should be seen as a standalone model. If one wants to try out complex combinations such as loading the vision part in fp32 and the text model in fp16, they should first load the entire model in fp16 and upcast the vision part in fp32 (or the other way around).
See also the solution in llava we've been doing this with that architecture and everything looks fine so far. I think the fix I propose here should be sufficient, can you double check that?

Comment on lines 1479 to 1489
torch_dtype = config.text_config.torch_dtype
if config.torch_dtype is not None:
torch_dtype = config.torch_dtype
attn_implementation = config.text_config._attn_implementation
if config._attn_implementation is not None:
attn_implementation = config._attn_implementation
self.text_model = AutoModel.from_config(
config.text_config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch_dtype = config.text_config.torch_dtype
if config.torch_dtype is not None:
torch_dtype = config.torch_dtype
attn_implementation = config.text_config._attn_implementation
if config._attn_implementation is not None:
attn_implementation = config._attn_implementation
self.text_model = AutoModel.from_config(
config.text_config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
self.text_model = AutoModel.from_config(
config.text_config, attn_implementation=config._attn_implementation
)

@zafstojano
Copy link
Contributor Author

Hi @amyeroberts @younesbelkada

With the implementation you proposed, for the following sample code:

import torch
from transformers import Idefics2ForConditionalGeneration

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

print("Perceiver model flash attention: ", model.model.connector.perceiver_resampler._use_flash_attention_2)
print("Vision model flash attention: ", model.model.vision_model._use_flash_attention_2)
print("Text model flash attention: ", model.model.text_model._attn_implementation == "flash_attention_2")
print('-----------------')
print("Model dtype: ", model.dtype)

I get the following output:

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
Loading checkpoint shards: 100%|██████████| 7/7 [00:02<00:00,  3.23it/s]
Perceiver model flash attention:  True
Vision model flash attention:  True
Text model flash attention:  True
-----------------
Model dtype:  torch.bfloat16

The reason why I wanted to explicitly pass torch_dtype is because of the warning You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour.

Is this acceptable?

@zafstojano
Copy link
Contributor Author

Moreover, when using the vision tower with Flash Attention, I get this exception:

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
Traceback (most recent call last):
  File "/home/z/.pyenv/versions/3.11.5/lib/python3.11/runpy.py", line 198, in _run_module_as_main
    return _run_code(code, main_globals, None,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/lib/python3.11/runpy.py", line 88, in _run_code
    exec(code, run_globals)
  File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/home/z/.vscode-server/extensions/ms-python.debugpy-2024.4.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "uxo_ml/idefics/train.py", line 162, in <module>
    trainer.train()
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/trainer.py", line 1859, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/trainer.py", line 2203, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/trainer.py", line 3138, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/trainer.py", line 3161, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/utils/operations.py", line 825, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/utils/operations.py", line 813, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/peft/peft_model.py", line 563, in forward
    return self.get_base_model()(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1823, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1643, in forward
    image_hidden_states = self.connector(
                          ^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1317, in forward
    image_hidden_states = self.perceiver_resampler(context=image_hidden_states, attention_mask=attention_mask)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1287, in forward
    layer_outputs = perceiver_layer(
                    ^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1220, in forward
    latents, self_attn_weights, present_key_value = self.self_attn(
                                                    ^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1004, in forward
    attn_output = self._flash_attention_forward(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/transformers/models/idefics2/modeling_idefics2.py", line 1071, in _flash_attention_forward
    attn_output_unpad = flash_attn_varlen_func(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 1066, in flash_attn_varlen_func
    return FlashAttnVarlenFunc.apply(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 581, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
                                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/z/.pyenv/versions/3.11.5/envs/idefics/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 86, in _flash_attn_varlen_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
                                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: query and key must have the same dtype

The above error can get fixed with casting the input image hidden states to the same dtype as the input tokens going into the Mistral model:

            # Get sequence from the vision encoder
            image_hidden_states = self.vision_model(
                pixel_values=pixel_values,
                patch_attention_mask=patch_attention_mask,
            ).last_hidden_state.to(dtype=self.dtype, device=input_ids.device)

            # Modality projection & resampling
            image_hidden_states = self.connector(
                image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
            ).to(dtype=self.dtype, device=input_ids.device)

Although, I still get the warning about upcasting:

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.

Weirdly, this happens even if I cast all params to the same dtype (e.g. bfloat16).

@amyeroberts
Copy link
Collaborator

@zafstojano I see. So this is an issue, and a tricky one at that.

@younesbelkada it doesn't seem to be the case that passing torch_dtype correctly propogates the specified weights to the other classes. Although it seems to set it globally -- the parent model has bfloat16 se -- if I query the weights in the loaded mistral model they're all in float32.

@younesbelkada
Copy link
Contributor

hmm interesting ok, I will have a deeper look then !

@amyeroberts
Copy link
Collaborator

@younesbelkada @zafstojano Just to follow up on the dtype investigation, I suspect there might be a difference between the torch_dtype being passed in the model inits during instantiation, and the torch dtype used when the pretrained weights are loaded in. I just ran a quick test, and the weights do seem to be loaded in as expected:

import torch
from transformers import Idefics2ForConditionalGeneration

print("Loading in as torch.bfloat16")
model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

model.dtype
print(model.model.text_model.dtype)
print(model.model.vision_model.embeddings.position_embedding.weight.dtype)
print(model.model.connector.perceiver_resampler.latents.dtype)

print("\nLoading in as torch.float32")
model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    attn_implementation="flash_attention_2",
)

model.dtype
print(model.model.text_model.dtype)
print(model.model.vision_model.embeddings.position_embedding.weight.dtype)
print(model.model.connector.perceiver_resampler.latents.dtype)

Produces output:

Loading in as torch.bfloat16
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:03<00:00,  2.06it/s]
torch.bfloat16
torch.bfloat16
torch.bfloat16

Loading in as torch.float32
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:03<00:00,  1.82it/s]
torch.float32
torch.float32
torch.float32

@younesbelkada
Copy link
Contributor

thanks for investigating @amyeroberts and apologies for not investigating ! Seems all is good then ? 🙏

@amyeroberts
Copy link
Collaborator

@younesbelkada Yep! I think so

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants