Is Attention Actually Learnable?

Transformers work, but the optimization landscape of self-attention is notoriously non-convex. This paper provides the first polynomial-time guarantee for learning Linear Attention by revealing its hidden convex structure.

Before We Dive In: What You'll Need

  • Polynomial Time: An algorithm that finishes reasonably fast.
  • Guarantees: Mathematical proof that it will work.
  • Feature Map: Transforming data to higher dimensions to find patterns.

Context: If you've trained a neural net, you know it's often trial-and-error. This paper tries to replace luck with math—but at a cost.

1 The Optimization Paradox

Standard self-attention involves learning queries ($Q$) and keys ($K$) that interact multiplicatively. In optimization terms, this creates a non-convex landscape full of local minima. Gradient descent should get stuck. Yet, in practice, it works. Why?

The Landscape Problem

Click buttons to switch landscapes. Click canvas to drop a "parameter" ball.

2 The Cubic Feature Map

The authors observe that Multi-Head Linear Attention (MHLA) can be expanded into a polynomial. Specifically, the output for a token is a sum of cubic interactions between input tokens.

The Feature Map $\Psi(Z)$

Instead of learning $Q$ and $K$ directly, we can construct a fixed feature map $\Psi(Z)$ that captures all possible cubic interactions:

$$ \Psi(Z)_{j,(k,\ell)} = \langle z_{j:}, z_{k:} \rangle z_{\ell n} $$

This maps the input $Z$ to a higher-dimensional space $\mathbb{R}^{d \times d^2}$.

By lifting the problem to this space, the non-linear attention mechanism becomes linear with respect to a new parameter matrix $W$.

3 The Convex Relaxation

Here is the key trick. The original problem constrains the rank of the attention matrix (limited by the number of heads $H$). This rank constraint causes the non-convexity.

The authors relax this constraint. Instead of solving for $H$ heads directly, they solve for a full-rank matrix $W \in \mathbb{R}^{d^2 \times d^2}$.

Original Problem

$\min_{\Theta} L(\Theta)$

Constrained rank. Non-convex. Many local minima. Hard to learn.

Relaxed Problem

$\min_{W} \frac{1}{2} \sum \| y - W^\top \Psi(Z) \|^2$

Unconstrained rank. Convex (Ridge Regression). Global minimum guaranteed.

Math Deep Dive: Algorithm 1

The learning algorithm proceeds in three steps:

  1. Feature Construction: Compute $\Psi(Z_i)$ for all training examples.
  2. Regression: Solve the closed-form Ridge Regression solution to find $W^*$: $$ W^* = (\sum \Psi(Z_i)\Psi(Z_i)^\top + \lambda I)^{-1} \sum \Psi(Z_i)y_i^\top $$
  3. Decomposition: Perform SVD on $W^*$ to recover the attention parameters $V_h, Q_h$. $$ W^* \approx \sum_{h=1}^H \sigma_h u_h v_h^\top $$

    The recovered model has at most $d^2$ heads (the maximum rank of $W$). This is why "width" (number of heads) is key to the convex relaxation.

This guarantees finding the global optimum of the relaxed loss in polynomial time $O(n^3 + d^6)$.

💡 The Key Insight

By allowing up to $d^2$ heads (more than typically used), we remove the rank constraint that causes non-convexity. The relaxed problem becomes convex and solvable in polynomial time. We then recover attention parameters via SVD—no gradient descent needed.

4 Theoretical Guarantees

Historical Context: Prior work showed Transformers can express complex computations (Turing machines, formal languages), but not how to learn them efficiently. This paper provides the first efficient learning algorithm with provable guarantees.

Agnostic PAC Learning

The algorithm works even when the true function isn't an MHLA (agnostic setting)—it competes with the best MHLA approximation.

Sample Complexity

To achieve error $\epsilon$ with probability $1-\delta$, the algorithm requires $O(\frac{1}{\epsilon}(d^4 + \log(\delta^{-1})))$ samples.

Identifiability

If the minimum eigenvalue of $\Lambda_D = \mathbb{E}[H(Z)H(Z)^\top] > 0$, where $H(Z)$ is a specific polynomial feature map, all empirical risk minimizers compute the same function. This is a "truth serum" for the model.

Universal Turing Machines

Given input-output traces of a Turing machine's computation, the algorithm can learn to simulate any TM (up to polynomial bounds) on any input—a form of "algorithm learning from data."

🐘 The Elephant in the Room

Q: "If this is solvable in polynomial time, why don't we train all models this way?"

A: Because the 'polynomial' hides enormous constants. The runtime is $O(Nd^4)$. For GPT-3 scale ($d \approx 12,000$), $d^4$ is roughly $2 \times 10^{16}$.

Cost ≈ 300B tokens × (12,288)⁴ ≈ 10²⁹ operations
GPT-3 Training ≈ 3×10²³ FLOPs
Difference: 1,000,000x more expensive

Q: "So this is useless?"

A: Not quite! It's a theoretical foundation. Just like matrix multiplication was $O(n^3)$ for centuries until better algorithms were found, this proves a path exists.

Critical Analysis & Limitations

  • Dimensionality Explosion: The feature map $\Psi(Z)$ has dimension $d \times d^2$. For large $d$, this is computationally prohibitive.
  • Linear vs. Softmax: The result applies to Linear Attention. Most frontier models still use Softmax Attention, which remains non-convex.
  • Sample Complexity Gap: The theory requires $O(d^4)$ samples. For BERT ($d=768$), that's $\approx 10^{11}$ samples. BERT was trained on $3 \times 10^9$ words. The theory demands 100x more data than practice.

🔬 What Would Convince Me?

As a skeptical reader, here's what I'd want to see next:

  • Experiment: Learn a small MHLA on real language data (Penn Treebank) and compare to SGD.
  • Ablation: Can we approximate $\Psi(Z)$ with sparse features to make it tractable?
  • Analysis: Does real language data satisfy the identifiability condition?