Student Perspectives: Bayesian LLM Finetuning

A post by Sam Bowyer, PhD student on the Compass programme.

Training large AI models is tricky business. First you’ll want to raise money — and lots of it. (OpenAI’s GPT-4 reportedly cost over $100 million to train, roughly equivalent to 0.5% of Bristol’s GDP.) With that money you’ll need to buy hardware (25,000 NVIDIA A100 GPUs should do), hire a team of talented engineers, and purchase licensing to vast quantities of data (though you might consider foregoing that last one and just hope no one complains…). Once you’ve collected enough data (say, ~13 trillion tokens-worth1), settled on a model architecture with hundred of billions, if not trillions, of parameters (each taking up at least a byte of memory), you can sit back and wait around for 100 days whilst your engineers firefight software and hardware crashes to steer your model’s training to completion.2

But for those of us who can’t afford the $10^{25}$ FLOPs (floating point operations) needed to train such a model (or who might want to avoid the associated environmental costs), what can we do? The answer lies in finetuning: taking one of the available pretrained ‘foundation’ models (such as ChatGPT, or an open source model such as one from Meta’s Llama series) and tweaking them to suit your own purposes.

The basic idea is this: these foundation models are great multitaskers, they’ve been trained well enough to generate reasonable outputs to a wide variety of inputs, but if you’re only interested in using them on a particular set of data ($\mathcal{D}_\text{finetune}$), or for a particular task, then it might be a good idea to spend some extra time training on that data specifically, after the rest of (pre)training has taken place. Similarly, it’s worth noting that the foundation model you get straight out of pretraining will mimic its input dataset, $\mathcal{D}$. In the case that $\mathcal{D}$ is too large to be checked by humans (e.g. 13 trillion tokens — essentially including most of the public internet), your model will almost certainly have learnt undesirable behaviour and be capable of producing dangerous, offensive, and harmful output. Finetuning is critical to the pursuit of safe AI, putting guardrails in place and ensuring that a model’s behaviour is aligned with our desires, both in terms of utility and safety.3

In this blog post, I’ll give an overview of LLM finetuning, specifically parameter-efficient finetuning, which tackles the problem of finetuning models whilst avoiding the computational burden that was required for pretraining. Even if your finetuning dataset $\mathcal{D}_\text{finetune}$ is much smaller than your pretraining set $\mathcal{D}$, you’ve still got the computational problem of the model’s size: how do you efficiently4 do gradient-based optimisation on a model with potentially billions of parameters? I’ll also argue that taking a Bayesian approach can be beneficial, and that whilst the added computational cost of Bayes might not be feasible (or even all that helpful) in the pretraining setting, these costs are much less impactful when finetuning.

Parameter Efficient Finetuning

Perhaps the simplest way to finetune a model on $\mathcal{D}_\text{finetune}$ is to simply carry on training as before — with some gradient-based optimiser like Adam [1] — but on this new dataset (often repeatedly, i.e. for multiple ‘epochs’). This is known as full finetuning (FFT) and usually leads to the best results, however, it’s often infeasible due to the size of the model being finetuned.

Recall that the model we’re working with might have billions of parameters — in order to train these parameters we need to store not only their values, but also their gradients, as well as the activation values of each neuron in the network and, depending on your optimiser, potentially momentum and second order gradient information (e.g. Adam makes use of the exponential moving average of gradients and the EMA of squared gradients — all per parameter). On a model like Llama-7B, whose 7 billion parameters at 8-bit precision require 7GB of storage, these extra gradient costs can easily overwhelm the 16GB capacity of a typical high-end consumer GPU such as an NVIDIA RTX 4080. (Add to that the fact that we usually want to batch our input data — that is, pass multiple input examples through the model at a time — and you can see where things start to spiral out of control.)

This motivates the need for finetuning algorithms that have a smaller memory footprint. There’s an exciting field of literature in model compression and quantisation — using compression techniques to represent your model and its gradients by fewer and fewer bits5, but another approach is to simply reduce the number of parameters that you train during finetuning. However, choosing which parameters to train and which to freeze (thus freeing up space that would’ve gone to storing the gradient information of those parameters) is far from trivial.

Partial Finetuning

In order to discuss finetuning techniques, it’ll be useful to briefly touch on the basic architecture of neural networks. The simplest type of neural network is a multilayer perceptron, or MLP, which consists of $L$ layers in which the output of layer $l-1$, $x^{l-1} \in \mathbb{R}^{d_{l-1}}$, is multiplied by a learnable weight matrix $W^l \in \mathbb{R}^{d_{l} \times d_{l-1}}$ and added to a learnable bias vector $b^l \in \mathbb{R}^{d_{l}}$ before being transformed through a nonlinearity, such as a sigmoid $\sigma(x) = (1+e^{-x})^{-1}$:
$$x^l = \sigma(W^l x^{l-1} + b^l),$$
with $x^0 \in \mathbb{R}^{d_0}$ being input data.

A common strategy for finetuning is to freeze all weights in earlier layers, say, up until the final $\hat{L}$ layers, and only train the set of parameters $\{W^l, b^l | l \geq L-\hat{L}\}$. Assuming constant network width $d = d_0 = \ldots = d_L$ this reduces the number of trainable parameters from $L(d^2 + d)$ to $\hat{L}(d^2 + d)$.

Another simple finetuning strategy is BitFit [3], which works by only training the bias parameters, leading to a total of $Ld$ trainable parameters (though of course this does make the iterative finetuning updates significantly less expressive).

It’s important to note that the final-layers-only approach can also be applied more generally. Most LLMs architectures use transformers [4] as their backbone, which — very loosely speaking — consist of multi-headed attention layers (another, more complicated type of neural network) followed by an MLP (plus a whole bunch of other stuff containing yet more parameters), and with each transformer’s output typically going on to form the input of another transformer. So it’s common to see only the final transformer finetuned, or even only the final transformer’s MLP.

Since it would be ill-advised to take a long detour into the definition of multi-headed attention here (as that’d be fairly involved and might take the momentum out of our finetuning discussion), I won’t do that. (Instead, I’ll banish it to yet another (increasingly-obnoxious) footnote6.)

Adapter Tuning

Rather than retraining the weights already in your model, most modern finetuning approaches actually add new parameters to the model, termed ‘adapters’, and only train these instead. For example, [5], [6], and [7] all essentially propose techniques in which we insert two-layer MLPs at different places inside a transformer, with varying results.

Adapter methods have the benefit of being ‘plug-and-play’, in the sense that you can train multiple adapters on different finetuning tasks and then insert them into your model if you detect that it would be helpful for a user’s given request.

Low Rank Adaptation (LoRA)

By far the most common (and almost de facto standard as of 2024) finetuning method is Low Rank Adaptation (LoRA) [8]. The intuition behind LoRA is that the parameters inside your pretrained model are probably fairly close to their finetuned optimal values already, in the sense that those optimal values can probably be reached using only updates in a low-rank subspace. As such we can pose our finetuning problem in terms of finding the low-rank matrix $\Delta W \in \mathbb{R}^{d_\text{in} \times d_\text{out}}$ that optimises a given pretrained weight matrix $W_0$, leading to
$$W_\text{finetune} = W_0 + \Delta W,$$
where the low-rank of $\Delta W$ is enforced by parameterising it as $$\begin{aligned}\Delta W & = B A \\ B & \in \mathbb{R}^{d_\text{in} \times r} \\ A & \in \mathbb{R}^{r \times d_\text{out}} \end{aligned}$$so that $\text{rank}(\Delta W) \leq r \ll \text{rank}(W_0) \leq \min (d_\text{in}, d_\text{out})$. (Note that LoRA places the adapter in parallel to a pretrained weight matrix $W_0$, in contrast to the serial/in-between placement of the MLP adapters mentioned in the previous section.)

Figure reproduced from [8], showing a Gaussian-initialisation of $A$ and a zero-initialisation of $B$.
By only learning $A$ and $B$, we reduce the number of parameters from $d_\text{in} d_\text{out}$ to $r(d_\text{in} + d_\text{out})$. In practice, often $d_\text{in}$ and $d_\text{out}$ will be in the range ~$10^3$-$10^4$ whilst we’ll choose $r$ to be somewhere between 4 and 128. To cut down on the number of trainable parameters even further, we often only apply LoRA adapters to certain weight matrices in a model, for example only the query and value matrices ($W^Q$ and $W^V$) of attention layers.

LoRA’s success has led to a large number of variants, such as AdaLoRA [9] which adaptively decides which weight matrices to apply LoRA to based on their singular values. Other methods include PiSSA (Principal Singular Values and Singular Vectors Adaptation) [10] which performs LoRA updates only on the first few principle components of each weight matrix and freezes the ‘residuals’ which come from later principle components. One recent paper presents GaLore (Gradient Low Rank Projection) [11], which performs PCA on the weight matrix every few iterations and performs low-rank updates by specifically only optimising in the (low-rank) space spanned by the first few priniciple components.

Bayesian Finetuning

Although work has been done to introduce uncertainty estimation into pretraining, the results often aren’t worth the extra computational costs [12, 13]. Not only are the model sizes too large to make uncertainty quantification feasible, but the fact that your pretraining dataset, $\mathcal{D}$, is gigantic provides little uncertainty to reason about. However, in the context of finetuning we typically have a much smaller dataset, for which we’ll likely have much more uncertainty, and we also tend to work with far fewer parameters, allowing for extra computational budget to go towards the use of Bayesian methods.

Consider splitting up our finetuning set into prompt and target/response pairs $(X,y) \in \mathcal{D}_\text{finetune}$ where $X \in \mathcal{T}^{B \times n}$ is a matrix of $B$ sequences each of maximum length $n$ (potentially padded out with null-tokens) constructed with the token set $\mathcal{T}$, and $y \in \mathcal{Y}^B$ could be a corresponding batch of single tokens (in which case $\mathcal{Y} = \mathcal{T}$), or a batch of classification labels (e.g. in sentiment analysis, or multiple-choice Q&A, in which case $\mathcal{Y}$ might be different to $\mathcal{T}$).

What we fundamentally want to learn is a posterior distribution over all learnable parameters $$p(\theta | \mathcal{D}_\text{finetune}) = p(\theta | X, y),$$where, for example, in the case of LoRA finetuning, $\theta$ is the collection of all adapter weights $A$ and $B$. This not only gives us information about the uncertainty in the model’s parameters, which can be useful in itself, but can also be used to give us the posterior predictive distribution for a test input $x^* \in \mathcal{T}^n$, $$p(y^* | x^*, \mathcal{D}_\text{finetune}) = \int p(y^* | x^*, \theta)p(\theta|\mathcal{D}_\text{finetune})d\theta.$$
This is often more desirable than a predictive distribution that only uses a point estimate of $\theta$ and which would then ignore the uncertainty present in the model’s parameters.

Bayesian LoRA (via Laplace Approximation and KFAC)

Yang et al. [14] suggest a method for finding the posterior $$p(\theta | X, y) \propto p(y | X, \theta)p(\theta)$$post-hoc, i.e. after regular finetuning (with LoRA) using a Laplace approximation — which assumes the posterior is a Gaussian centered at the maximum a-posteriori (MAP) solution, $\theta_\text{MAP}$.

First, we note that the MAP solution can be written as the maximum of the log-joint $\mathcal{L}(y, X; \theta)$, $$\begin{align} \mathcal{L}(y, X; \theta) &= \log p(y | X, \theta) +\log p(\theta) = \log p(\theta | X, y) + \text{const} \\ \theta_\text{MAP} &= \arg\max{}_\theta \mathcal{L}(y, X; \theta). \end{align}$$
Then assuming that the finetuning successfully optimised $\theta$, i.e. reached parameter values $\theta_\text{MAP}$, the Laplace approximation involves taking the second-order Taylor expansion of the log-joint around $\theta_\text{MAP}$, $$\mathcal{L}(y, X; \theta) \approx \mathcal{L}(y, X; \theta_\text{MAP}) – \frac{1}{2}(\theta – \theta_\text{MAP})^T(\nabla_\theta^2 \mathcal{L}(X, y; \theta)|_{\theta_\text{MAP}})(\theta – \theta_\text{MAP}).$$(The expansion’s first-order term disappears because the gradient of the MAP objective at $\theta_\text{MAP}$ is zero.) This quadratic form can then be written as a Gaussian density, with mean $\theta_\text{MAP}$ and covariance given by the inverse of the log-joint Hessian: $$\begin{align}p(\theta | X, y) &\approx \mathcal{N}(\theta ; \theta_\text{MAP}, \Sigma), \\
\Sigma &= -(\nabla_\theta^2 \mathcal{L}(X, y; \theta))^{-1}.\end{align}$$
The authors makes use of various tricks to render computing this Hessian inverse feasible, most notably Kronecker-Factored Approximate Curvature (KFAC) [15]. (A nice explanation of which can be found at this blog post.)

Using the Laplace approximation comes with added benefits. Specifically, we can make use of the Gaussian form of the (approximate) posterior to easily compute two values of interest: samples from the posterior predictive distribution, and estimates of the marginal likelihood.

For the first of these, we can linearise our model, with output $f_\theta(x^*)$ approximated by a first-order Taylor expansion around $\theta_\text{MAP}$, $$f_\theta(x^*) \approx f_{\theta_\text{MAP}}(x^*) + \nabla_\theta f_\theta(x^*)|^T_{\theta_\text{MAP}}(\theta – \theta_\text{MAP}).$$
We can write this as a Gaussian density $$f_\theta(x^*) \sim \mathcal{N}(y^*; f_{\theta_\text{MAP}}(x^*), \Lambda)$$ where $$\Lambda = (\nabla_\theta f_\theta(x^*)|^T_{\theta_\text{MAP}})\Sigma(\nabla_\theta f_\theta(x^*)|_{\theta_\text{MAP}}).$$
With this, we can easily obtain samples from our predictive posterior through reparameterised sampling of some Gaussian noise $\mathbf{\xi} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$ and a Cholesky decomposition $\Lambda = LL^T$: $$\hat{y} = f_\theta(x^*) = f_{\theta_\text{MAP}}(x^*) + L\mathbf{\xi}.$$

The second value of interest is the marginal likelihood (also known as the model evidence), which is useful for hyperparameter optimisation and can be computed simply as follows $$\begin{align}p(y|X) &= \int p(y|X,\theta)p(\theta)d\theta \\ &\approx \exp (\mathcal{L}(y, X; \theta_\text{MAP}))(2\pi)^{D/2}\det(\Sigma)^{1/2}.\end{align}$$

Using Stein Variational Gradient Descent (SVGD)

A reasonable question to ask is whether it might be feasible to learn the posterior distribution during finetuning, rather than afterwards. One such method for achieving this is Stein variational gradient descent (SVGD) [16], in which a collection of $n$ parameter particles $\{\theta_i^{(0)}\}_{i=1}^n$ are iteratively updated to fit the true posterior using some similarity function (i.e. a kernel) $k: \Theta \times \Theta \to \mathbb{R}$, $$\begin{align}\theta_i^{t+1} &= \theta_i^{(t)} – \epsilon_i \phi(\theta_i^{(t)}) \\ \phi(\theta_i) &= \frac{1}{n} \sum_{j=1}^n \left[\frac{1}{T}k(\theta_j,\theta_i)\nabla_{\theta_j}\log p(\theta_j | \mathcal{D}_\text{finetune}) + \nabla_{\theta_j} k(\theta_j, \theta_i) \right],\end{align}$$where $\epsilon_i$ is a learning rate and $T$ is a temperature hyperparameter. The basic interpretation of the update is that the first term inside the summation drives particles towards areas of high posterior probability, whilst the second term penalises particles that are too similar to one another, acting as a repulsive force that encourages exploration of the parameter-space.

Once the particles have converged, we can simply approximate the posterior predictive as the average output of the network across each parameter particle $\theta_i$, $$p(y^* | x^*, \mathcal{D}_\text{finetune}) \approx \frac{1}{n} \sum_{i=1}^n f_{\theta_i}(x^*).$$

My current research lies in applying SVGD to LoRA adapters. The hopes are that we can learn a richer, multi-modal posterior distribution using SVGD’s particles without making the Gaussian posterior assumption of the Laplace approximation. Recent concurrent work [17] applies a very similar technique to computer-vision tasks and achieves promising results.

Conclusion

I hope this blog has been a useful introduction to the finetuning of LLMs. Feel free to get in touch if you’re interested! My email is sam.bowyer@bristol.ac.uk.

Footnotes

1: LLMs split input text up into a sequence of tokens. Roughly speaking, most words are split into one or two tokens depending on how common and how long they are. Using GPT-4’s tokenizer, this sentence is made from 17 tokens. (back to top)
2: Spare a moment, if you will, for the Meta engineers behind the OPT-175B (175 billion-parameters) model. The training logbook of which reads at times like that of a doomed ship at sea. (back to top)

3: Note that in the case of LLMs specifically, the straight-out-of-pretraining model will also likely be a poor virtual assistant, in the way we tend to desire of chatbots like ChatGPT. A model which can complete sentences to match the general patterns found in $\mathcal{D}$ won’t necessarily be much good at the user-agent back-and-forth conversation style we’d like, and as such might not have properly ‘learnt’ how to, for example, follow instructions and answer questions. It’s because of this that most public-facing LLMs go through what’s known as instruction fine-tuning, in which the model is finetuned on a large dataset of instruction-following chat logs before being deployed. (back to top)
4: That is, without using 25,000 GPUs… (back to top)
5: Consider this paper [2] by Huang et al. which boasts 1.08-bit quantisationa of 16-bit models, all whilst retaining impressive levels of performance. (back to top)

a : i.e. representing parameters with an average precision of 1.08 bits.

6:

An (Ill-Advised) Aside: Attention

Attention layers work by taking three matrices as input, $Q_\text{input}, K_\text{input}, V_\text{input} \in \mathbb{R}^{n \times d_\text{model}}$, typically representing $d_\text{model}$-dimensional embeddings of a sequence of $n$ tokens. First we project these matrices using learnable weight matrices $W^Q, W^K \in \mathbb{R}^{d_\text{model} \times d_k}$, and $W^V \in \mathbb{R}^{d_\text{model} \times d_v}$ to obtain our queries, keys and values: $$\begin{align}
Q &= Q_\text{input} W^Q \in \mathbb{R}^{n \times d_k} \\
K &= K_\text{input} W^K \in \mathbb{R}^{n \times d_k} \\
V &= V_\text{input} W^V \in \mathbb{R}^{n \times d_v}.
\end{align}$$With these, we then compute attention as $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$ where $\text{softmax}$ is applied over each row such that, denoting the $i$th row of the matrix $A = QK^T$ as $A^{(i)}$ and that row’s $j$th element as $a^{(i)}_j$, we define: $$\text{softmax}(A)^{(i)} = \frac{\exp A^{(i)}}{\sum_{j=1}^n \exp a^{(i)}_j}.$$
The intuition behind this is that our $n \times n$ attention matrix $\text{softmax}(QK^T/\sqrt{d_k})$ has entries representing how much token $i$ relates (or ‘attends’) to token $j$. The $\text{softmax}$ normalises each row so that the entries all add up to one, allowing us to think of each row as a distribution over tokens. The final multiplication with $V$ might then be thought of as selecting (or weighting) tokens in $V$ according to those distributions.

One important limitation of the attention mechanism we’ve just described is that it only allows us to consider how each token attends to each other token in some universal way, whereas in reality there are multiple ways that words in a sentence (for example) can relate to each other. Because of this, most of the time we actually use multi-headed attention, in which we compute attention between the token sequences $H \in \mathbb{N}$ times, each time with different learnable weight matrices $W^Q_h, W^K_h, W^V_h$ for $h \in \{1,\ldots,H\}$. Then we combine these separate attention heads, using yet another learnable weight matrix $W^O \in \mathbb{R}^{H d_v \times d_\text{model}}$, $$\text{MultiHead}(Q_\text{input}, K_\text{input}, V_\text{input}) = \text{Concat}(\text{head}_1,\ldots,\text{head}_H)W^O \in \mathbb{R}^{n \times d_\text{model}},$$ where $\text{head}_h = \text{Attention}(Q_\text{input}Q_h, K_\text{input}K_h, V_\text{input}V_h)$. Allowing the model to learn different types of attention on different heads makes MHA an incredibly powerful and expressive part of a neural network.

To summarise and return to the discussion of finetuning: MHA layers contain a ton of learnable parameters (specifically, $2H d_\text{model} (d_k + d_v)$ of them). (back to top)

 

References

[1] Kingma, D.P., 2014. Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980.
[2] Huang, W., Liu, Y., Qin, H., Li, Y., Zhang, S., Liu, X., Magno, M. and Qi, X., 2024. Billm: Pushing the limit of post-training quantization for llms. arXiv preprint arXiv:2402.04291.
[3] Zaken, E.B., Ravfogel, S. and Goldberg, Y., 2021. Bitfit: Simple parameter-efficient fine-tuning for transformer-based masked language-models. arXiv preprint arXiv:2106.10199.
[4] Vaswani, A., 2017. Attention is all you need. arXiv preprint arXiv:1706.03762.
[5] Houlsby, N., Giurgiu, A., Jastrzebski, S., Morrone, B., De Laroussilhe, Q., Gesmundo, A., Attariyan, M. and Gelly, S., 2019, May. Parameter-efficient transfer learning for NLP. In International conference on machine learning (pp. 2790-2799). PMLR.
[6] Lin, Z., Madotto, A. and Fung, P., 2020. Exploring versatile generative language model via parameter-efficient transfer learning. arXiv preprint arXiv:2004.03829.
[7] Pfeiffer, J., Kamath, A., Rücklé, A., Cho, K. and Gurevych, I., 2021. AdapterFusion: Non-Destructive Task Composition for Transfer Learning. EACL 2021.
[8] Hu, E.J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., Wang, L. and Chen, W., 2021. Lora: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685.
[9] Zhang, Q., Chen, M., Bukharin, A., Karampatziakis, N., He, P., Cheng, Y., Chen, W. and Zhao, T., 2023. AdaLoRA: Adaptive budget allocation for parameter-efficient fine-tuning. arXiv preprint arXiv:2303.10512.
[10] Meng, F., Wang, Z. and Zhang, M., 2024. Pissa: Principal singular values and singular vectors adaptation of large language models. arXiv preprint arXiv:2404.02948.
[11] Zhao, J., Zhang, Z., Chen, B., Wang, Z., Anandkumar, A. and Tian, Y., 2024. Galore: Memory-efficient llm training by gradient low-rank projection. arXiv preprint arXiv:2403.03507.
[12] Cinquin, T., Immer, A., Horn, M. and Fortuin, V., 2021. Pathologies in priors and inference for Bayesian transformers. arXiv preprint arXiv:2110.04020.
[13] Chen, W. and Li, Y., 2023. Calibrating transformers via sparse gaussian processes. arXiv preprint arXiv:2303.02444.
[14] Yang, A.X., Robeyns, M., Wang, X. and Aitchison, L., 2023. Bayesian low-rank adaptation for large language models. arXiv preprint arXiv:2308.13111.
[15] Martens, J. and Grosse, R., 2015, June. Optimizing neural networks with kronecker-factored approximate curvature. In International conference on machine learning (pp. 2408-2417). PMLR.
[16] Liu, Q. and Wang, D., 2016. Stein variational gradient descent: A general purpose bayesian inference algorithm. Advances in neural information processing systems, 29.
[17] Doan, B.G., Shamsi, A., Guo, X.Y., Mohammadi, A., Alinejad-Rokny, H., Sejdinovic, D., Ranasinghe, D.C. and Abbasnejad, E., 2024. Bayesian Low-Rank LeArning (Bella): A Practical Approach to Bayesian Neural Networks. arXiv preprint arXiv:2407.20891.

Skip to toolbar