Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach
We study a novel language model architecture that is capable of scaling test-time computation by implicitly reasoning in latent space. Our model works by iterating a recurrent block, thereby unrolling to arbitrary depth at test-time. This stands in contrast to mainstream reasoning models that scale up compute by producing more tokens. Unlike approaches based on chain-of-thought, our approach does not require any specialized training data, can work with small context windows, and can capture types of reasoning that are not easily represented in words. We scale a proof-of-concept model to 3.5 billion parameters and 800 billion tokens. We show that the resulting model can improve its performance on reasoning benchmarks, sometimes dramatically, up to a computation load equivalent to 50 billion parameters.
Discussion
Host: Hey everyone, welcome back to the podcast! Today we're diving into something really interesting: scaling language models, but with a twist. We're not just talking about making them bigger in terms of parameters; we're talking about how they think and how we can make them think harder at test time. It's about enabling more complex reasoning without necessarily adding more weights.
Guest: Now, we need efficiency. We need them to reason better even within resource constraints.
Host: Seems like a game changer on how LLMs can implicitly reason in latent space. It’s called 'Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach', the guys at UMD are onto something really interesting here.
Guest: Yes. It's a different philosophy, instead of scaling the size, they're scaling computation at test time. It brings a lot to the table, it does not require specialized training data. The training can be variable, models can also work with small context windows, and it can capture reasoning not verbalized.
Host: Absolutely. What stood out to me was the idea of 'thinking in continuous space.' It's analogous to how humans tackle problems, we don't just blurt out answers; there's a lot of internal processing, a 'recurrent firing pattern' happening in our brains before we verbalize anything.
Guest: That's a solid connection. Most reasoning models externalize thoughts, as chain-of-thought. This model, on the other hand, 'thinks' in its continuous latent space.
Host: So, let's unpack this 'Recurrent Depth' concept a bit. The paper trains a language model where a recurrent block iterates, creating arbitrary depth at test-time. It's about enabling more computation without a corresponding massive increase in parameters. This is so key to the scalability issue everyone is facing now.
Guest: Precisely, Leo. Think about the Chain-of-Thought (CoT) approach. You need to train your model to verbalize intermediate calculations, which requires special training data and long context windows. The inefficiency is the expensive internal reasoning has to be compressed into a single token output. These models, it appears, avoid that by natively 'thinking' in their continuous latent space, and that opens up a world of potential.
Host: Okay, so how do they actually do this? The architecture, as I understand it, introduces a recurrent unit that runs in a loop, constantly processing and updating its hidden state. That enables the computations to continue indefinitely, right?
Guest: That's right. They build a transformer architecture on a latent depth-recurrent block. During training, the block runs for randomly sampled iterations, but during test time, the model's performance is improved through iterative reasoning.
Host: And this allows it to compete with larger models, even those with more parameters and training data? Because it's effectively 'thinking harder' with the same resources?
Guest: That’s the key claim, yes. They showed this paradigm can scale and compete with other models that benefit from more params and training data. Moreover, recurrent depth models naturally support interesting features that require substantial tuning and research efforts in non-recurrent models.
Host: Such as?
Guest: The model is able to implement per-token adaptive compute, which is awesome. Then self-speculative decoding, and also KV-cache sharing.
Host: What about visualizing the token paths and trajectories?
Guest: That's the final study. They track the token trajectories in latent space and show that a number of interesting computation behaviors simply emerge with scale, for example, the model rotating shapes in latent space for numerical computations.
Host: That’s mind-blowing. Let's dig deeper into why recurrent depth is a good training strategy. The paper highlights that, compared to standard long-context reasoning, this latent recurrent approach has some serious advantages.
Guest: There's quite a few. First, you don't need to construct bespoke training data. Long context chain of thought requires the model to be trained on demonstrations that are constructed in the domain of interest. That's expensive and time-consuming.
Host: Right, and the model can just train with a variable compute budget, and use standard training data. What else?
Guest: Latent reasoning models require less memory for training and inference than chain-of-thought models, because the latter requires long context windows. This implies the recurrent-depth networks perform more FLOPs per parameter than standard transformers, reducing communication costs between accelerators.
Host: So it uses available hardware much more efficiently. It's leaning into computation rather than solely relying on parameter count. That's a good way to think about it, actually. So, is it also about encouraging a different kind of problem-solving? Like moving away from memorization towards actual 'thinking'?
Guest: That's the hope. The model learns meta-strategies, logic, and abstraction, instead of memorizing. This recurrent depth approach makes it small in parameter count, but heavy with compute.
Host: Okay, that's compelling. It aligns with what we intuitively understand about intelligence. It's not just about storing information, it's about processing it efficiently and creatively. And philosophically...?
Guest: There are facets of human reasoning that defy verbalization. Spatial thinking, physical intuition, motor planning... Over many iterations of the recurrent process, reasoning in a high-dimensional vector space enables deep exploration of multiple directions, as opposed to linear thinking.
Host: So, it's potentially unlocking a more holistic, less constrained form of reasoning, the kind of stuff that's hard to put into words but that we still rely on all the time. Now, let's talk about the architecture itself. I understand it's built on decoder-only transformer blocks, but with some key distinctions, right?
Guest: Yes. The model is structured on decoder-only transformer blocks, but these blocks are structured into three functional groups, the prelude, the core recurrent block, and the coda.
Host: Can you explain each of those?
Guest: The prelude embeds the input data into a latent space. The core recurrent block is the central unit of recurrent computation, modifying states. Finally, the coda block un-embeds from latent space, and also contains the prediction head of the model.
Host: So, the core block sits between the prelude and coda, and by looping that core block, you can essentially add 'verses' to your 'song,' as the paper puts it.
Guest: Yes. Given a number of recurrent iterations and a sequence of input tokens, these groups are used to produce output probabilities. First the inputs are embedded by the prelude, then an initial random state is sampled, then the core block repeatedly applies to the latent state until it finishes all iterations. At the end, the coda block processes the last state and produces the probabilities of the next token.
Host: Okay, and why this specific design? What's the motivation?
Guest: This recurrent design is the minimal setup required to learn stable iterative operators. They give the example of gradient descent of a function where 'x' may be the variable of interest and 'y' is the data. Note that we need to use y in every step to optimize our function. Similarly, they repeatedly inject the data in every step of the recurrence.
Host: Ah, so it's about stability and ensuring the process doesn't just drift off course. If the input was only provided at the start, the iterative process wouldn't be stable, it would depend only on its boundary conditions.
Guest: Exactly. The structure of using several layers to embed input tokens into a hidden latent space is based on empirical results analyzing standard fixed-depth transformers. This body of research shows that the initial and end layers of LLMs are noticeably different, whereas middle layers are interchangeable and permutable.
Host: Interesting. So it's leveraging the distinct roles that different layers play in a standard transformer. Now, the paper also touches on a connection to diffusion models. How does that fit in?
Guest: The iterative architecture will look familiar to the other modern iterative modeling paradigm, diffusion models, especially latent diffusion models. Iterative schemes even more similar to diffusion models were tested, such as adding noise. However, the injection of noise did not help in the preliminary experiments, which is possibly connected to the training objective.
Host: Okay, so they explored those avenues but ultimately found the current approach more effective, at least for their specific goals. What about the more granular design choices within each of these blocks? Like the attention mechanisms and normalization layers?
Guest: The block contains multiple layers, and each layer contains a standard causal self-attention block using RoPE with a base of 50000, and a gated SiLU MLP. RMSNorm is used as the normalization function. The model has learnable biases on queries and keys. To stabilize the recurrence, we order all layers in the 'sandwich' format.
Host: A 'sandwich' format? What does that mean?
Guest: It's where the norm layers are around the attention and MLP blocks. It was observed small scales normalization strategies work, but this normalization was required to train the recurrence at scale.
Host: Okay, so specific choices to stabilize training, especially as they scale up the recurrence. And how does the training objective itself encourage this iterative reasoning?
Guest: That's key. To ensure the model functions when scaling up recurrent iterations at test-time, the iteration counts are randomly sampled during training. The loss function is optimized over random samples from a distribution and random iteration counts from another distribution.
Host: So, it's explicitly trained to handle varying amounts of computation. What distribution do they use for those iteration counts?
Guest: They chose it to be a log-normal Poisson distribution. This distribution samples values less than the mean most often, but contains a heavy tail of occasional events where significantly more iterations are taken.
Host: Okay, so it's biased towards fewer iterations, but with occasional bursts of more intense computation. How do they handle the computational cost of training with this variable iteration count?
Guest: They use truncated backpropagation, backpropagating through only the last 'k' iterations of the recurrent unit. This enables training with the heavy-tailed Poisson distribution, as maximum activation memory and backward compute is independent of 'r'. The setup resembles truncated backpropagation through time, as commonly done with RNNs, although this setup is recurrent in depth rather than time.
Host: This is interesting! So, they randomly select the number of iterations during training, using something akin to BPTT, but applied in depth. Seems like they've thought of everything! Now that we have a good understanding of the architecture and how it's trained, let's move on and talk more about the actual training runs themselves.