Skip to content

Instantly share code, notes, and snippets.

@celoyd
Last active July 31, 2025 06:27
Show Gist options
  • Select an option

  • Save celoyd/6bf10122c3f5f7e64b0c684704e4ffb2 to your computer and use it in GitHub Desktop.

Select an option

Save celoyd/6bf10122c3f5f7e64b0c684704e4ffb2 to your computer and use it in GitHub Desktop.
The usual implementaiton of attention transformers (SDPA) is kind of bad, actually

Introduction

I was writing a note to a friend that mentioned my tedious opinions on “AI” discourse. It veered off into my usual argument that big “AI” companies are shaping the industry ecosystem to their own ends by setting up a situation where expensive-to-run models are overvalued. I think they’re doing this because they have a competitive advantage in that tier of the market, having bought (time on) a lot of GPUs. It’s like how a company that owns diamond mines will probably promote the idea that large, mined diamonds are important and valuable, and that there’s something off about running a sub-industrial mine or lab-growing diamonds. You can do this without lying at all, but I still dislike it. Large mined diamonds here are $O(n^2)$ models.

To support this argument, I started making my case against the necessity of the standard transformer model. I admit that the case is scattershot and circumstantial. It’s not that SDPA (the normal transformer architecture) is a fraud, or that there is something much better ready to replace it everywhere and immediately. But maybe I can sow some doubts that SDPA is as good as the median ML practitioner assumes, and raise some hopes for better kinds of models in the pipeline.

That got out of hand in the e-mail I was writing, so I cut it out and put it here.

This note covers:

  • how some standard ML model families work, not in great depth but in order to have some context around…
  • how SDPA (the standard transformer) works;
  • some specific reasons I dislike SDPA; and
  • some things I hope might replace it.

This note does not make:

  • Normative judgments about any person or organization mentioned or not mentioned. I have very strong opinions about some of them, especially ones not mentioned, and my points here underlie some of those opinions. But it is not those opinions.
  • Any airtight case that SDPA is bad. If you love SDPA, you will probably still love SDPA after reading. That’s fine with me.
  • A nice, brief, well-organized argument. It was written in a sitting and when I came back to trim it down I accidentally added more. (And removed an embarrassing mistake where I said RWKV uses SSMs. I don’t know why I said that.)

Seven years ago, if you asked for the general architectures of the most studied and most widely applied ML models, you might get this list:

1. Fully connected networks (FCNs)

All inputs are fed in at once. A multi-layer perceptron (or a recognizable development of one) digests it, and you get some output.

Early on, these were studied for images, where each pixel is an input, so for example a 1e3 × 1e3 image is a vector of 1e6 inputs. It soon turned out (1) that this was wildly expensive to implement, because for example if the first hidden layer is also 1e6 wide then you have to do order of 1e12 operations right there, and also (2) that these networks are virtually impossible to train.

This trainability problem can be seen in many inter-related ways. One of them: Nearly all useful operations on images will rely on adjacency, for example two adjacent dark pixels being possibly part of a crack in a surface. But with a pure FCN we are asking the network to learn the entire topology of the image from scratch: separately for each pair of adjacent pixels. This is wildly inefficient; it is, roughly speaking, asking the network to fully explore the entire $\mathbb{R}^{1,000,000}$ space in order to find the manifold of images-of-cracks embedded in it. This is only realistic for toy problems.

2. Recurrent and convolutional networks

2.1 Recurrent networks (RNNs)

These are networks designed to process a stream of inputs, usually text, and as a sequence. They are head-based in the sense that they see only one item at a time, but they carry state. Something vaguely FCN-like in the head calculates (old state, input) → (new state, output). At a cartoon level, an RNN is a bit like a Turing machine that is only allowed to move forward. To do useful things, you have separately trained heads iterating on the stream; the early ones pass messages to the later ones in intermediate outputs.

We can apply an RNN to an image by mapping the image to 1D, but it’s easier to work with the 2D topology of an image and use…

2.2 Convolutional networks (CNNs)

Convolutions are a tidy answer to some of the problems with FCNs in the image domain. Many operations that are intractable to learn as image → image are very practical when factored out as neighborhood → neighborhood. An idealized CNN pushes data through many alternations of convolutions and pixelwise FCNs, so it never has to consider the whole image at once; it uses linear time over pixel count.

This leaves a problem. We now have finite receptive fields. If we have, say, 10 convolutional layers, each of 3×3 pixels, then information at a given place is only in the “light cone” (the receptive field) of other places within (3 - 1)/2 × 10 = 10 pixels. For some purposes, this might be fine. For others, it’s prohibitive. Say we’re trying to remove diffraction spikes from astronomical images. If a spike from an especially bright star can be 30 px long, the model will find itself in situations where it sees a bright line that could equally be part of a spike or part of a galaxy. It will have to make an under-informed decision about whether to try to remove it or not.

Resampling helps here. Resampling doesn’t do anything that convolutions can’t in principle, but it’s convenient because it’s (1) an unlearned and highly optimized function, and (2) reduces the data size, which makes everything faster, which we can use as budget for more convolutional layers. It also lets us at least loosely connect what we’re doing to several decades of research on image pyramids, scale spaces, multi-resolution analysis, and other more classical work.

The apotheosis of a CNN, at least for image → image networks, is a U-net, which is hourglass-shaped and has residual connections. In theory this collects and distributes information from across the frame. At the pinch of the hourglass, the image may be compressed down to a single pixel, though a very “tall” pixel, consisting of a high-dimensional vector of abstract information characterizing the image as a whole. This then informs the level-by-level rebuilding of the image.

Problems coordinating across space

For example, if we’re colorizing a monochrome photo including a mid-tone shirt, we want the network to choose a shirt color conditioned on the style of shirt. A plain tee could be almost any color; a plaid is likely in a more restricted range; and a sports jersey of a recognizable team has only one plausible color, at least to an informed human viewer.

We do not want the model to regress toward the mean by averaging the probability distribution of colors for that style of shirt. That will almost always be grayish. We want it to actually draw from the distribution. This is a problem that pops up far beyond CNNs, but it sets up the narrower idea that…

We also do not want what a CNN with a small receptive field might tend to do, which is to draw red in one neighborhood of a shirt and blue in another, and then have trouble rendering anything plausible where they adjoin.

If you look at low-quality ML-based colorization, undersaturation and patchiness are in fact common. It’s relatively rare that a typical model picks a color that is definitely incorrect to a human observer (green stop signs and such); the consistency is harder than the general plausibility. See for example this poorly “remastered” video – lots of interesting footage in there, but a wunderkammer of artifacts. I happened to skip to 2:20, where the libreria’s sign is an example of patchiness. (The undersaturation tends to hide behind the look of poorly stored color film, which I think is cheating.)

Anyway, broadly, CNNs tend to be bad at coordinating distant chunks of information. In theory they have the architecture necessary to do this, but the task of somehow connecting representations to where they’re spatially needed seems to be hard to learn in a robust way that overcomes the necessary information destruction of convolution and resizing. There are a lot of pixels in 2D, and it’s hard for two distant ones to reach each other and say “aha, we seem to be working on the same problem; we’re conditioned on each other” whether that’s deciding a shirt color or anything else.

We want to check for connections between each pixel and each other pixel, or something like that, because we don’t understand the problem very clearly. Whatever it is, it smells strongly of $O(n^2)$ time complexity, and a CNN is by design $O(n)$ in pixel count; we seem to be trying to pack a sleeping bag into a manilla folder here.

(One of the only people to show a principled and maybe even properly mathematical understanding of this problem was Hinton, who I feel is a remarkably good researcher and not a remarkably good public intellectual. (Like the elms, eminent academics are prone to a certain malady, and I’m afraid it’s found him.) You can track his lab working on what they called capsule networks with dynamic routing, taking object pose as their main interest, c. 2000-2020. They were doing research in the Apollonian style, thinking about the problem in great depth and setting careful experiments. But it was other people who discovered what they were looking for (or the first clear example of the kind of thing they were looking for), and it was by stumbling over it backwards. The normal course of science, I guess.)

Local and nonlocal means

Let’s step away from ML for a moment. If we want to remove spatially independent noise from an image, we can take local means: replace the color at each pixel with an average of the n×n neighborhood around it. Unfortunately, this is only a blur; when we say “I have high-frequency noise” we usually want an answer less snappy than “so low-pass the image”. However, we can do non-local means: instead of averaging each pixel with its neighborhood in the image, we average it with its neighborhood in color space. So for example if we are looking at a green pixel, we average it with every other pixel (in the image, or in some large radius) that is green or greenish. Outlier colors, presumably created by the noise, are pulled toward common colors. This is a little bit like k-means or an n-body gravity simulation. Various elaborations – for example, using patches instead of single pixels – develop a whole family of denoising methods, some good enough to use in practice (for example, BM3D).

We are still averaging, but in a feature space (specifically, color space) rather than in the image space (the 2D topology of the pixel grid).

I haven’t seen the connection made, but to me this is the intuitive way to get to…

3. SDPA, a.k.a. transformers

In 2017, some researchers noticed that a technique used to boost RNN performance could be gussied up a bit to actually work better than the RNN, and wrote it up as Attention Is All You Need, one of the most cited papers of the last 25 years, and, as a better index of its influence, instigator of a cohort of unamusing X Is All You Need titles for papers, much like X Considered Harmful back in the day. (The one cookie-cutter title I consider acceptable is Hopfield Networks is All You Need.)

The self-attention transformer technique, a.k.a. various other bad names, including scaled dot-product self-attention or SDPA, goes like so. We’ll use language processing, the originally intended domain, as an example.

We take our input string and break it into words and word-like fragments (“,”, “w”, “of the”, “ing”, end of string marker, …), called tokens. Each token is mapped in some previously learned way to an array of length 256 to 4096 or so, which we call an embedding, or a vector in the latent space of a language model.

To each of these items we add a position encoding, so it knows where it comes from in the input string.

From each item we produce, in a learned way (i.e., by a small FCN), three derived vectors: a query Q, a key K, and a value V. Here it’s useful to tell a bit of a just-so story to keep track of what we imagine each of these is for:

  • Q is a search vector for information useful to this item.

  • K is a description of the information that this item can provide.

  • V is the information payload.

So imagine we’re looking at the item for the word “orange”. In the first layer of a the network, the relevant FCN has learned that this is a word with a lot of semantic ambiguity and it needs to resolve that. So if we’re at some position X, the item might produce – in opaque, human-illegible terms – vectors like:

  • Q: Any colors near X? Any fruit near X? Any references to the Netherlands? Any nouns representing physical objects immediately after position X? Any references to eating or food around here?

  • K: Color. Fruit. Unknown part of speech at X. Possible noun at X. Possible adjective at X. Possible Netherlandish context. Possible autumnal context.

  • V: The color orange. Adjective. Citrus fruit. Food. Noun.

Now the network takes the dot-product similarity of every item’s Q against every item’s K. For each Q, it takes an average of every V as weighted by the similarity. The output is this mixture of Vs.

So what we’re looking at here is roughly another elaboration of nonlocal means. Instead of averaging in the color space, we have an item project itself into a new feature space learned just for this purpose, which lets it “ask” (Q), “offer” (K), and “tell” (V) functions of itself other than the identity. It’s also generally iterative, multi-headed, and so on; of course there are a lot of details we’re glossing over. But the underlying idea of letting items coordinate independently of their native topological space by simple averaging is intact.

This also reminds me of the improv game Convergence. Two people think of random words. On every turn, they count to 3 and say their words, talking over each other. They both win when they say the same word. The fun strategy is to think of words between the two words on each turn, so you get games like:

Round Alice Bob
1 horse gloves
2 horseshoes rawhide
3 jockey blacksmith
4 farrier farrier

For images, we tend to chop the input pixels up into 16×16 or similar chips, send them through some learned digestion, call those the items, and proceed in much the same way.

SDPA architectures are now the best-regarded way of doing almost everything in applied ML. The T in ChatGPT™ stands for this kind of transformer; image generators use them; grad students can (and I suspect actually do) throw a dart at a list of tasks, apply a transformer to whatever it hits, and post that on arXiv.

I dislike SDPA. – At least relative to the general love for it in the ML research world. It seems like a half-formed and inelegant version of something yet to be discovered. A few of my objections:

  1. It’s $O(n^2)$. Dotting n Q vectors with n K vectors has to be $O(n^2)$ operations. Every first-year CS student knows that $O(n^2)$ is cause for concern. To be fair, it turns out the bottleneck on real hardware is usually memory access, and, using an impressive optimization called flash attention, you can make SDPA look more like $O(n)$ for short-ish sequences. GPUs are, after all, massively parallel. But it’s still asymptotically $O(n^2)$, and that still matters!

  2. SDPA loses the inductive bias toward locality of CNNs. A reasonable complaint about CNNs is that they tend to all learn the same thing from scratch in their first few layers: something like Gábor filters or shearlets. A very similar set of basic kernels arises for almost all vision tasks, and deriving them out of pure noise in every training session is a bit silly. But SDPA does something even sillier: it learns the input domain’s topology from scratch every time. It’s a set → set function; until you add position information to the inputs, it’s completely permutation equivariant at the item scale! This is unlike our experience of the world. Permuting sensory experiences generally changes them. We don’t read a book as confetti, each fragment marked with its page and line number, in no particular order. Unlinearized video is a joke. We make non-local connections while reading, certainly, and we find value in indexes and concordances and in thematic analysis, but the topology of experience is important to us. Likewise, though some vital information in images is not local, most of it is. Almost always, almost all the meaning of a pixel is conditioned on its neighborhood. And every SDPA has to learn this completely from scratch on every training run. This gets mentioned as an interesting quirk, but I see it as a severe design flaw that merits great effort to contain and mitigate, and of course ideally to fix.

  3. With images, SDPA is too expensive to run with one item per pixel, so for non-tiny images we have to chunk out 16×16 or similar patches, and that means we lose translation equivariance. CNNs’ equivariance is a wonderful property, but apparently we’ve taken it for granted. Papers about SDPA like to brag about wide receptive fields, but you look at the attention maps and they’re visibly gridded by the patching process. We should be able to agree that, as a rule, if its attention map doesn’t look like a Gaussian, your vision network has not fully converged. It defies natural justice to claim that a pixel matters differently depending on where it falls in a lattice defined by arbitrary framing. In fact one of the trendsetting vision transformer models, Swin, is named after shifted windows: they were proud that they had a scheme for making offset patches in successive layers. This is just re-inventing convolutions but badly, because doing it right would take SDPA a chilling number of operations. It’s especially galling in combination with #2: SDPA starts training with complete ignorance of a well-established statistical fact about the world (locality – built into convolutions), but at the same time, very strong opinions based on something arbitrary (patch arrangement – sidestepped by convolutions). It has to learn something we know a priori, and to unlearn a meaningless choice built into it.

  4. It’s not always clear that SDPA, though definitely theoretically able to do things that CNNs can’t, is actually doing them. In 2022, a friendly but incisive paper called ConvNeXt claimed that (for one kind of image task) although modern SDPA models beat old-fashioned CNN models on benchmarks, the modern v. old-fashioned part was more important than the SDPA v. CNN part. They equipped a CNN with incremental micro-architectural refinements and training tricks that people had come up with over the last few years, and it beat a respected SDPA model in a fair race. A little later, when a particularly elegant and powerful self-supervised learning technique called MAE came out, apparently closely tied to the SDPA architecture, Team Convolution came back with ConvNeXt 2 showing that a CNN with one unusual implementation detail (sparse convolutions) could use the putatively SDPA-specific strategy to get state-of-the-art results. So, at least in this niche, problems that seem to show off special properties of SDPA can be solved equally well by CNNs for similar parameter counts, GPU-hours, researcher-hours, and so on. I do not think this proves anything big, but it does help color my reading of people saying things that sound a bit like “We did something with SDPA and got good results, so it must be because SDPA is good.”

  5. SDPA does not reliably start converging. This stinks. One of the actually potentially deeply interesting findings in ML is that if you overparameterize a sufficiently large model made of trivial parts (layers of matmul and relu) and hook it up to backpropagation, then, under substantial but not grotesque assumptions, it will almost always converge. In other words, in some still little-understood sense, you can [and here I’m waving my hands wildly and hoping you don’t ask for details] make a nonconvex problem act like a convex one. That, I think, is one of the only things that ML research has turned up that might merit a line in the Big Book of Interesting Ideas. There’s even a rule of thumb that there are no bad local minima. But not for SDPA! Sometimes it can’t seem to get there from here. I am told that even wildly successful SDPA architectures often have false starts in training. You initialize it sensibly and you start running backpropagation on a seemingly clear task … and it shrugs at you and makes what the French call a moue. If I recall correctly, you hear this issue mentioned in this presentation, which actually makes one of the best cases for vision transformers that I know. It draws some figures from How Do Vision Transformers Work, which for my money is one of the most useful and least annoying ML papers of the decade, although the tree video doesn’t interpret the paper entirely the way I do.

What can I say about SDPA that’s nice? Well, empirically, it can do impressive things. Whatever I say about efficiency and elegance, SDPA does things. And The Bitter Lesson – which I take as basically connectionism for the age of TSMC – is an account that I find insightful though not convincing. Plied with strong drink and prodded to be rude, I might say that SDPA converging poorly (getting trapped but also just being generally slow and kind of bad) acts as a regularization that makes optimizing a wastefully large network tractable, even though I think the best regularization for most models is to have fewer than A BILLION PARAMETERS. And I would belch, and wipe my face on my sleeve, and ask the waiter for a glass of ice water. But even then, I would admit that Sutton makes some okay points.

So what’s better than SDPA?

Well, nothing. Yet. Arguably.

State-space models

There is an interesting line of research in state-space models. Control theorists long ago worked out that a very large class of systems can be modeled by inputs, a state matrix with some transition-rule matrices, and outputs. (I’m bluffing; every time I read about control theory I feel like a deer who wandered into a Costco.) We can take a stream of language, for example, as kind of like inputs to a mental state, and this seems to work out more or less comparably to the SDPA-based LLM approach, but at lower quality … but $O(n)$. It’s unclear to me, and I think to the experts, whether there are sound ways to buy quality with the flops freed up by the subquadratic efficiency.

RWKV

There’s an open-source team called RWKV making models that synthesize ideas from RNNs and state-space models to do SDPA-like things. An index of the underlying/overarching/suffusing economic problem here – the thing I’m trying to illuminate in all of this! – is that RWKV’s “entire organization has less compute than a single researcher at Google”. RWKV is only about as good as a commercial SDPA model of similar parameter count (depending on training, benchmark choice, etc.), but it’s had far less effort put into it. This too, I think, has a chance of breaking out into something really interesting.

Holographic Reduced Representations

There are other subquadratic attention ideas out there. The Hopfield network paper I mentioned above, for example, is on my list. But I’ll close on a kind of scheme that’s gotten under my skin. These are the vector symbolic architectures (VSAs). Long ago, when the world was young and dinosaurs roamed the Earth, Gábor invented holography, which is probably a whole page in the Big Book of Interesting Ideas. (Large amounts of history skipped here.) In 1995, Plate comes out with Holographic Reduced Representations (HRR), pointing out that you can define some computationally simple operations such that you can build an approximate associative memory on the sole data type of a fixed-length, random, norm = 1 array of floats. That is, with all capital letters as arrays of some globally fixed length (say 1024), you can do things like this:

A = bind(B, C);
D = bind(E, F);
G = add(A, D);
assert unbind(G, E) ≈ F; // !

So we can merge and probabilistically recover vectors, keyed by other vectors, that we have assigned meaning, and all on vectors of the same dimension. Our accuracy in recovery is clearly bounded; we do not expect that we can merge, say, 100 arrays of length 10 onto each other and get them all back intact. How far we can push things depends of course on factors like how much data we have per vector (float32 v. float64 and how many of them) and the gory details of the bind/merge/unbind scheme used, e.g., naïvely pointwise v. Fourier-based with renormalizations. But if you choose your methods and constants right, you now have a key-value store. It’s approximate, but it is a key-value store. And it has no size limits, other than that it will degrade, and it runs in constant time. Hmm. Hmm!

Look at the SDPA operation again. You can replace the dot-product scheme with each item doing a bind(K, V) on its own K and V, all of those merging into a big KV, and then each item getting unbind(KV, Q) back. This value is some kind of messy mixture of Vs in proportion that the item’s Q matched the corresponding K. But FCNs are good at untangling messes, so information comes through.

Someone went and implemented HRR in place of SDPA (actually still doing the S in SDPA, which seems generous to me) and got remarkably good results on a benchmark that measures performance over very long sequences.

Does this work for images? Well, when I try, it doesn’t converge well. But I might have made a typo, or my old GPU may not start showing results before I get bored, or there may be some intrinsic reason why HRR – or HRR as I have interpreted it – doesn’t fit well with images. Maybe my homebrew implementation of 2D positional encoding, which I resent having to think about, is no good. Maybe HRR is as helpless as SDPA. All this means little either way, and I’m very interested in others trying. But no one cares! Everyone is still sending SDPA 10% of their income, singing it Happy Birthday, making K-pop–style fancams about it, and so on.

Thoughtful conclusion that ties the themes together and leaves the reader with a good but reflective feeling, highly detailed, trending on ArtStation

With HRR in particular, VSAs generally, and subquadratic attention systems even more generally, I think we might be seeing the beginning of something that revives the part of all this that’s actually interesting to me. With some luck, running a million GPUs for smarmy chatbots and leaking stuff meant to make competition seem hopeless will start to be a less appealing strategy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment