Bayesian Attention Mechanism Impact

By Arthur S. Bianchessi

Recently, I published "Bayesian Attention Mechanism: A Probabilistic Framework for Positional Encoding and Context Length Extrapolation" on arXiv. The paper introduces the Bayesian Attention Mechanism Theoretical Framework that enhances the capabilities of transformer models enabling models to extrapolate context lengths. This article discusses the potential impact of this framework on the future of large language models, particularly in terms of compute savings and the implications for hardware investments.

Resources

Introduction

Current models have scaled a lot due to the increase in capability of investing more computation in model training. However, there a two quadratic costs that are associated in training transformer models:

  1. To train a model with twice the number of parameters, you also need twice the amount of data, therefore resulting in a 4x increase in computational cost.
  2. The attention mechanism in transformers has a quadratic cost in the sequence length, which means that doubling the sequence length results in a 4x increase in computational cost.

Altough current model sizes have stagnated, the sequence length has not. Companies such as OpenAI, Google, and Meta have invested a lot of resources in increasing the sequence length of their models, resulting in models that can handle millions of tokens. However, in the current literature, there is no way to extrapolate the context length of a model beyond the training data without any modification to the underlying model1. This means that if a model is trained on 128K tokens, it cannot handle more than 128K tokens at inference time.

In our paper, we managed to train a small model (~120M) in a sequence length of 512 tokens, and extrapolate the context length of more than 128K tokens2, beating the context length of models such as Llama-3. To estimate the saving that would be achieved by using this method, I am going to use the Llama-3 herd of models as reference.

Estimating Savings

Llama-3 main model Llama-3.1-405B was trained on around 15 Trillion tokens. These tokens were used in different stages of training, increasing the context length trained on at each stage, the specific scalings used in the training can be found at section 3.4 of Llama-3 herd of models. For the sake of simplicity, I am considering:

Considering these variables, we can estimate the proportion of compute that is used in the training of the model at 128K context length following the formula:

$$ \text{Proportion} = \frac{\text{Compute(Long Context)}}{\text{Compute(8K)} + \text{Compute(Long Context)}} $$

Where the compute is approximated as the number of tokens multiplied by the square of the context length. This means that we can calculate the proportion of compute used in training the model at 128K context length as follows:

$$ \frac{(1 \times 10^{11}) \times 128\ 000^2}{\Big((15 \times 10^{12}) \times 8\ 000^2\Big) + \Big((1 \times 10^{11}) \times 128\ 000^2\Big)} \approx 63.054\% $$

This means that around 63% of the compute used in training Llama-3.1-405B was used in the training of the model for 128K context length. This is a very rough estimate, but it gives us a lower bound for the savings that could be achieved. In this specific case, training the same model would require almost $3\times$ less compute. If we consider context lengths of 1M tokens, the savings are exponentially greater, as the proportion of compute used in training the model at 1M context length would be even lower. As it follows

$$ \frac{(1 \times 10^{11}) \times (10^6)^2}{\Big((15 \times 10^{12}) \times 8\ 000^2\Big) + \Big((1 \times 10^{11}) \times (10^6)^2\Big)} \approx 99.049\% $$

in this case, the savings would be around $100\times$ less compute. Even though the neighter the number of tokens nor the context length were aproximated, this estimate is still an underestimation. This is due to the extrapolation of context length, which allows for an increase of $500\times$ context at $\text{~100%}$ accuracy in passkey retrieval. Considering this setup:

$$ \frac{(1 \times 10^{11}) \times (10^6)^2}{\Big((15 \times 10^{12}) \times 2\ 000^2\Big) + \Big((1 \times 10^{11}) \times (10^6)^2\Big)} \approx 99.94\% $$

In this case, the savings would be around $1600\times$ less compute, which is an absurd reduction in compute and training cost.

Conclusion

As shown, Bayesian Attention Mechanism allows for a significant reduction in the compute required to train models with large context lengths. Although the estimates done in this post are meant to be a lower bound, they shouldn't reflect the actual savings that can be achieved. This is due to the fact that companies such as Meta and OpenAI have invested a lot of resources in increasing their compute capabilities, and, as such, they are likely to continue to invest such savings in other improvements, such as training larger models. In other words, if the results of Bayesian Attention Mechanism are demonstrated to be consistent, this should result in a significant reduction in the cost of training models as well as an increse in the model capabilities.

Notes

1. There are some methods that allow for extrapolation of context length, that modify the underlying model, but do not need any additional training. Scalable Softmax is the one that has the best performance that I am aware of, being published in January 31st, 2025. It allows for around $8\times$ extrapolation of context length with minimal degradation in performance. Llama-4 herd of models are examples of models that use this method to extrapolate context length. Meta in this case seems to have used the method to train models from 128K to 1M tokens (10M in the case of Llama 4 Scout).

2. We did train models with larger context lengths, the model trained on sequence length of 2048 managed to get 100% accuracy in the passkey retrieval (10 samples) of 1 Million tokens (this test was done after publishing the paper). Initial testing of larger models ($\text{~500M}$ parameters) also show that these models can continue to extrapolate more than 2 orders of magnitude in context length($>100\times$).

3. It was used less than 15 Trillion tokens in this context length, but I rounded it up to 15 Trillion so it would underestimate the savings.

4. Acording to the report, Meta used 800B tokens divided in 6 stages for long context pre-training. Due to this, I believe that the 100B tokens trained in 128K context length is a good approximation, also being an underestimation. This also is corraborated by aproximating the sequence length trained on 100B tokens through exponentially increasing sequence lengths at each batch. The equivalent sequence length trained on 100B tokens is around 153K tokens of sequence length. This value was calculated using the following python code:

    import torch
    import math

    initial_sequence_length      = 8_000
    final_sequence_length        = 128_000
    total_number_of_tokens       = 800_000_000_000
    tokens_per_batch             = 16_000_000 
    number_of_batches            = total_number_of_tokens // tokens_per_batch 
    sequence_length_per_batch    = torch.logspace(math.log(final_sequence_length), 
                                                  math.log(initial_sequence_length), 
                                                  number_of_batches+1, base=math.e, 
                                                  dtype=torch.float64)[:-1]
    cost_per_batch               = tokens_per_batch * (sequence_length_per_batch**2)
    total_cost                   = cost_per_batch.sum()
    equivalent_seq_len           = (total_cost / 100_000_000_000)**0.5
    print("100B sequence length equivalent: ", end="\t")
    print(f"{equivalent_seq_len:,.2f} tokens")

Although the above code does not maintain $\text{equivalent_seq_len} > \text{final_sequence_length}$ for $\text{final_sequence_length} = 1 000 000$, it retains this characteristic if the number of tokens is scaled by $\frac{\log{(\frac{\text{new large context}}{8\ 000})}}{\log(\frac{128\ 000}{8\ 000})}$. This should be done for the proportional increase in context length at each step to be the same as it is when training for 128K context length. The intuition behind this is that it should take more steps to scale a model to 1M tokens than it takes to scale it to 128K tokens. Due to this, I used the 100B tokens trained in the final sequence length rule of thumb for other calculations.

Disclaimer

This post is not financial advice. I am not a financial advisor.

Although the main reason for this post is to publicize the technology, I do have a financial interests with this post.

Due to all the reasons above and that I am not an established researcher, I do strongly recommend to independently verify the results of the paper. I do believe that this paper will have significant impact in the stock market, but again, these results should be verified. The code in the GitHub repository is designed to be easy to run, and to reproduce the results of the paper should take less than a day using 8 instances of NVIDIA A100 40GB (The code should also be checked, anyone can print 100% accuracy)