Cavendish Blueprints

QAPR 5: grokking is maybe not *that* big a deal?


By Quintin Pope
First published July 23, 2023 • Last updated July 31, 2023

Crossposted from https://www.lesswrong.com/posts/GpSzShaaf8po4rcmA/qapr-5-grokking-is-maybe-not-that-big-a-deal

Introduction

Grokking refers to an observation by Power et al. (below) that models trained on simple modular arithmetic tasks would first overfit to their training data and achieve nearly perfect training loss, but that training well past the point of overfitting would eventually cause the models to generalize to unseen test data. The rest of this post discusses a number of recent papers on grokking.

Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets

In this paper we propose to study generalization of neural networks on small algorithmically generated datasets. In this setting, questions about data efficiency, memorization, generalization, and speed of learning can be studied in great detail. In some situations we show that neural networks learn through a process of "grokking" a pattern in the data, improving generalization performance from random chance level to perfect generalization, and that this improvement in generalization can happen well past the point of overfitting. We also study generalization as a function of dataset size and find that smaller datasets require increasing amounts of optimization for generalization. We argue that these datasets provide a fertile ground for studying a poorly understood aspect of deep learning: generalization of overparametrized neural networks beyond memorization of the finite training dataset.

My opinion:

When I first read this paper, I was very excited. It seemed like a pared-down / "minimal" example that could let us study the underlying mechanism behind neural network generalization. You can read more of my initial opinion on grokking in the post Hypothesis: gradient descent prefers general circuits.

I now think I was way too excited about this paper, that grokking is probably a not-particularly-important optimization artifact, and that grokking is no more connected to the "core" of deep learning generalization than, say, the fact that it's possible for deep learning to generalize from an MNIST training set to the testing set.

I also think that using the word "grokking" was anthropomorphizing and potentially misleading (like calling the adaptive information routing component of a transformer model its "attention"). Evocative names risk letting the connotations of the name filter into the analysis of the object being named. E.g.,

I've heard several people say things like:

A more factual and descriptive phrase for "grokking" would be something like "eventual recovery from overfitting". I've personally found that using more neutral mental labels for "grokking" helps me think about it more clearly. It lets me more easily think about just the empirical results of grokking experiments and their implications, without priming myself with potentially unwarranted connotations.

An aside on the suddenness of grokking:

Example of grokking from Power et al. Figure 1.

People often talk as though grokking is a sudden process. This isn't necessarily true. For example, the grokking shown in this plot above is not sudden. Rather, the log base-10 scale of the x-axis makes it look sudden. If you actually measure the graph, you'll see that the grokking phase takes up the majority of the training steps (between ~80% and 95%, depending on when you place the start / end of the grokking period).

To be clear, grokking can be sudden. Rapid grokking most often happens when training with an explicit regularizer such as weight decay (e.g., in the below paper). However, relying on weaker implicit regularizers can lead to much more gradual grokking. The plot above shows a training run that used the slingshot mechanism as its source of implicit regularization, which occurs when numerical underflow errors in calculating training losses create anomalous gradients which adaptive gradient optimizers like Adam propagate. This can act as a 'poor man's gradient noise' and thus a source of regularization.

From personal experiments, I've found that avoiding explicit weight decay regularization, combined with minimizing implicit regularization by using a low learning rate alongside full batch gradient descent, and using 64 bit precision for loss calculations, that I can fully avoid grokking on modular arithmetic.

A Mechanistic Interpretability Analysis of Grokking

This is a write-up of an independent research project I did into understanding grokking through the lens of mechanistic interpretability. My most important claim is that grokking has a deep relationship to phase changes. Phase changes, ie a sudden change in the model's performance for some capability during training, are a general phenomena that occur when training models, that have also been observed in large models trained on non-toy tasks. For example, the sudden change in a transformer's capacity to do in-context learning when it forms induction heads. In this work examine several toy settings where a model trained to solve them exhibits a phase change in test loss, regardless of how much data it is trained on. I show that if a model is trained on these limited data with high regularisation, then that the model shows grokking.

My opinion:

This post provides an awesome example of mechanistic interpretability analysis to understand how models use Fourier transforms and trig identities to build general solutions to modular arithmetic problems, and tracks how that solution develops over time as the model groks. The post also connects grokking to the much more general and widespread phenomena of phase changes in ML training.

However, I don't think that grokking is a the best testbed for studying phase changes more generally. We have much more realistic deep learning systems that undergo phase transitions, such as during double descent, the formation of induction heads and emergent outliers in language models, or (possibly) OpenFold's series of sequential transitions as its outputs move from being zero dimensional, to one dimensional, to two dimensional, and finally to three dimensional.

Towards Understanding Grokking: An Effective Theory of Representation Learning

We aim to understand grokking, a phenomenon where models generalize long after overfitting their training set. We present both a microscopic analysis anchored by an effective theory and a macroscopic analysis of phase diagrams describing learning performance across hyperparameters. We find that generalization originates from structured representations whose training dynamics and dependence on training set size can be predicted by our effective theory in a toy setting. We observe empirically the presence of four learning phases: comprehension, grokking, memorization, and confusion. We find representation learning to occur only in a "Goldilocks zone" (including comprehension and grokking) between memorization and confusion. We find on transformers the grokking phase stays closer to the memorization phase (compared to the comprehension phase), leading to delayed generalization. The Goldilocks phase is reminiscent of "intelligence from starvation" in Darwinian evolution, where resource limitations drive discovery of more efficient solutions. This study not only provides intuitive explanations of the origin of grokking, but also highlights the usefulness of physics-inspired tools, e.g., effective theories and phase diagrams, for understanding deep learning.

My opinion:

I really liked the illustrations of how the representation spaces differ before and after grokking:

Generalization seems to correspond to simpler and smoother geometries of the representation spaces. This meshes with another perspective that points to geometric simplicity / smoothness as one of the key inductive biases driving generalization in deep learning, which also seems in line with Power et al.'s finding (in section A.5) that post-grokking solutions correspond to flatter local minima.

However, I think the key result of this paper is that it's possible to avoid grokking completely by choosing different training hyperparameters (see below). From a capabilities perspective, grokking is a mistake. The ideal network doesn't grok. Rather, it starts generalizing immediately.

Omnigrok: Grokking Beyond Algorithmic Data

Grokking, the unusual phenomenon for algorithmic datasets where generalization happens long after overfitting the training data, has remained elusive. We aim to understand grokking by analyzing the loss landscapes of neural networks, identifying the mismatch between training and test losses as the cause for grokking. We refer to this as the "LU mechanism" because training and test losses (against model weight norm) typically resemble "L" and "U", respectively. This simple mechanism can nicely explain many aspects of grokking: data size dependence, weight decay dependence, the emergence of representations, etc. Guided by the intuitive picture, we are able to induce grokking on tasks involving images, language and molecules. In the reverse direction, we are able to eliminate grokking for algorithmic datasets. We attribute the dramatic nature of grokking for algorithmic datasets to representation learning.

My opinion:

The above two papers suggest grokking is a consequence of moderately bad training setups. I.e., training setups that are bad enough that the model starts out by just memorizing the data, but which also contain some sort of weak regularization that eventually corrects this initial mistake.

If that story is true, then I think it casts doubt on the relevance of studying grokking to AGI safety. Presumably, an AGI's training process is going to have a pretty good setup. Why should we expect results from studying grokking to transfer?

E.g., Omnigrok indicates that the reason we don't see grokking in MNIST is because we've extensively tuned the training setups for MNIST models (including their initialization processes), and conversely, the reason we do see grokking in algorithmic tasks is because we haven't extensively tuned the training setups for such models. Given this, how useful should we expect algorithmic grokking results to be for improving / tuning / controlling MNIST models?

A Tale of Two Circuits: Grokking as Competition of Sparse and Dense Subnetworks

Grokking is a phenomenon where a model trained on an algorithmic task first overfits but, then, after a large amount of additional training, undergoes a phase transition to generalize perfectly. We empirically study the internal structure of networks undergoing grokking on the sparse parity task, and find that the grokking phase transition corresponds to the emergence of a sparse subnetwork that dominates model predictions. On an optimization level, we find that this subnetwork arises when a small subset of neurons undergoes rapid norm growth, whereas the other neurons in the network decay slowly in norm. Thus, we suggest that the grokking phase transition can be understood to emerge from competition of two largely distinct subnetworks: a dense one that dominates before the transition and generalizes poorly, and a sparse one that dominates afterwards.

My opinion:

This paper doesn't seem too novel in its implications, but it does seem to confirm some of the findings in A Mechanistic Interpretability Analysis of Grokking and Omnigrok: Grokking Beyond Algorithmic Data.

Similar to A Mechanistic Interpretability Analysis of Grokking, this paper finds two competing solutions inside the network, and that grokking occurs as a phase transition where the general solution takes over from the memorizing solution.

In Omnigrok, decreasing the initialization norms leads to immediate generalization. This paper finds that grokking corresponds to increasing weight norms of the generalizing subnetwork and decreasing weight norms for the rest of the network:

Unifying Grokking and Double Descent

A principled understanding of generalization in deep learning may require unifying disparate observations under a single conceptual framework. Previous work has studied grokking, a training dynamic in which a sustained period of near-perfect training performance and near-chance test performance is eventually followed by generalization, as well as the superficially similar double descent. These topics have so far been studied in isolation. We hypothesize that grokking and double descent can be understood as instances of the same learning dynamics within a framework of pattern learning speeds. We propose that this framework also applies when varying model capacity instead of optimization steps, and provide the first demonstration of model-wise grokking.

My opinion:

I find this paper somewhat dubious. Their key novel result, that grokking can happen with increasing model size, is illustrated in their figure 4:

However, these results seem explainable by the widely-observed tendency of larger models to learn faster and generalize better, given equal optimization steps.

I also don't see how saying 'different patterns are learned at different speeds' is supposed to have any explanatory power. It doesn't explain why some types of patterns are faster to learn than others, or what determines the relative learnability of memorizing versus generalizing patterns across domains. It feels like saying 'bricks fall because it's in a brick's nature to move towards the ground': both are repackaging an observation as an explanation.

Grokking of Hierarchical Structure in Vanilla Transformers

For humans, language production and comprehension is sensitive to the hierarchical structure of sentences. In natural language processing, past work has questioned how effectively neural sequence models like transformers capture this hierarchical structure when generalizing to structurally novel inputs. We show that transformer language models can learn to generalize hierarchically after training for extremely long periods – far beyond the point when in-domain accuracy has saturated. We call this phenomenon structural grokking. On multiple datasets, structural grokking exhibits inverted U-shaped scaling in model depth: intermediate-depth models generalize better than both very deep and very shallow transformers. When analyzing the relationship between model-internal properties and grokking, we find that optimal depth for grokking can be identified using the tree-structuredness metric of Murty et al. (2023). Overall, our work provides strong evidence that, with extended training, vanilla transformers discover and use hierarchical structure.

My opinion:

I liked this paper a lot.

Firstly, its use of language as a domain (even if limited to synthetic data) makes it more relevant to the current paradigm for making progress on AGI.

Secondly, many grokking experiments compare "generalization versus memorization". I.e., they compare the most complicated possible solution[1] to the training data to the least. Realistically, we're more interested in which of many possible generalizations a deep learning model develops, where the relative simplicities of the generalizations may not be clear.

[1] If the goal is to better understand how neural networks pick between out of distribution generalizations, then a solution with zero generalization capacity at all (memorization) feels like a degenerate case.

Finally, this paper finds phenomena that don't fit with prior grokking results:

A full account of generalization for realistic problems probably involves interactions between optimizer, architecture, and dataset properties. The U-shaped loss results suggest this paper is starting to probe at how the inductive biases of a given architecture do or do not match the structure of a dataset, and how that interaction ties into the resulting generalization patterns (2/3 isn't a bad start!).

Conclusion

I don't currently think that grokking is particularly core to the underlying mechanisms of deep learning generalization. That's not to say grokking has nothing to do with generalization, or that we couldn't possibly learn more about generalization by studying grokking. Rather, I don't think the current evidence implies that studying grokking would be particularly more fruitful than, say, studying generalization on CIFAR10, TinyStories, or full-on language modeling.

I also worry that the extremely simplified domains in which grokking is often studied will lead to biased results that don't generalize to more realistic setups. This makes me more excited about attempts to analyze grokking in less-simplified domains.

Future

I intend to restart this series. I won't be aiming for a weekly update schedule, but I will aim to release at least two more before the end of the summer.

My next topic will be on runtime interventions in neural net cognition, similar to Steering GPT-2-XL by adding an activation vector. However, feel free to suggest other topics for future roundups.


If you'd like to cite this article, you can use this:

@misc{Pope2023qapr5,
  author = "Quintin Pope",
  title = "QAPR 5: grokking is maybe not *that* big a deal?",
  year = 2023,
  howpublished = "Blog post",
  url = "https://cavendishlabs.org/blog/qapr5/"
}