Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bidirectional attention or casual attention for embedding? #15

Open
yonxie opened this issue Mar 15, 2024 · 5 comments
Open

bidirectional attention or casual attention for embedding? #15

yonxie opened this issue Mar 15, 2024 · 5 comments

Comments

@yonxie
Copy link

yonxie commented Mar 15, 2024

You mention that bidirectional attention is used for embedding task. But it appears that you only use the last hidden states from the pretrained LLM to generate embeddings. Is the final projection is the only bidirectional part?

@Muennighoff
Copy link
Contributor

The last hidden state is produced via bidirectional attention in the model itself

@Hisarlik
Copy link
Contributor

Hisarlik commented Apr 9, 2024

Hi, I'm currently trying to train gritlm using Gemma2b to generate embeddings. While reviewing the training script for Mistral7b, I noticed the use of bidirectional attention with attn='bbcc'. In the context of embeddings, would it be more advantageous to train with 'bbcc' or 'cccc'?

However, when I tried to use attn='bbcc' with Gemma, I encountered an error: TypeError: GemmaModel.forward() received an unexpected keyword argument 'is_causal'. To fix this, I commented out the following line in gritlm.py:

if (self.attn is not None) and (self.attn[:2] == 'bb'): inputs["is_causal"] = False

is this correct ?

@Muennighoff
Copy link
Contributor

bbcc is better & commenting out that line will make it equivalent to cccc so it's not a good idea, also see #24

@Vincent-Li-9701
Copy link

Vincent-Li-9701 commented Apr 17, 2024

Hi @Muennighoff, amazing work! I have a similar confusing as @yonxie. I can see here that you did a final pooling.
You mentioned that "The last hidden state is produced via bidirectional attention in the model itself". Would you mind pointing out where this is done?

I was also looking at the query-doc cacheing example at page 63. In order to reuse the key-value cache (if I understand correctly the key values are producing during forward pass using bidirectional attention), that means GRIT GRITLM functions as a prefixLM with two independent prefixes during RAG?

@Muennighoff
Copy link
Contributor

Sorry for the confusion. I mean that inside of the model bidirectional attention is applied in every transformer layer. The attention mask for that is created here

attention_mask = _prepare_4d_attention_mask_for_sdpa(

The pooling that you point to is then applied to the final hidden state returned from the model to remove the sequence length dimension.

if I understand correctly the key values are producing during forward pass using bidirectional attention

Yes

that means GRIT GRITLM functions as a prefixLM with two independent prefixes during RAG?

The two caches (or prefixes if you will) are concatenated and have not paid attention to one another (maybe this is what you mean by independent). You may find it helpful to look at this code example: https://github.com/ContextualAI/gritlm?tab=readme-ov-file#caching

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants