Diffusion models have emerged as effective distribution estimators in vision, language, and reinforcement learning, but their use as priors in downstream tasks poses an intractable posterior inference problem. This paper studies amortized sampling of the posterior over data, $\mathbf{x}\sim p^\text{post}(\mathbf{x})\propto p(\mathbf{x})r(\mathbf{x})$, in a model that consists of a diffusion generative model prior p($\mathbf{x}$) and a black-box constraint or likelihood function $r(\mathbf{x})$. We state and prove the asymptotic correctness of a data-free learning objective, relative trajectory balance, for training a diffusion model that samples from this posterior, a problem that existing methods solve only approximately or in restricted cases. Relative trajectory balance arises from the generative flow network perspective on diffusion models, which allows the use of deep reinforcement learning techniques to improve mode coverage. Experiments illustrate the broad potential of unbiased inference of arbitrary posteriors under diffusion priors: in vision (classifier guidance), language (infilling under a discrete diffusion LLM), and multimodal data (text-to-image generation). Beyond generative modeling, we apply relative trajectory balance to the problem of continuous control with a score-based behavior prior, achieving state-of-the-art results on benchmarks in offline reinforcement learning.
Given a diffusion model prior \( p(\mathbf{x}) \) and a black-box likelihood function \( r(\mathbf{x}) \), our goal is to sample from the posterior \( p^{\text{post}}(\mathbf{x}) \propto p(\mathbf{x}) r(\mathbf{x}) \). Conventional approaches often rely on heuristic guidance, leading to bias or restricted applicability. In contrast, we derive a principled, unbiased objective for posterior sampling, rooted in the Generative Flow Network (GFlowNet) perspective, which ensures improved mode coverage and asymptotic correctness without requiring data or approximations.
The Relative Trajectory Balance (RTB) objective ensures that the ratio of the forward trajectory probabilities under the posterior model \( p_\phi^{\text{post}} \) to the prior model \( p_\theta \) is proportional to the constraint function \( r(\mathbf{x}) \). This is achieved by minimizing the loss:
Here, \( Z_{\phi} \) is a learnable normalization constant. Satisfying the RTB constraint (minimizing loss to 0) for all diffusion trajectories facilitates unbiased sampling from the desired posterior distribution \( p^{\text{post}}(\mathbf{x}) \propto p_\theta(\mathbf{x}) r(\mathbf{x}) \).
We are free to choose off-policy diffusion trajectories to optimize the RTB objective, which facilitates improved exploration and mode coverage. In particular, useful strategies include the use of replay buffers and local search.
We fine-tune unconditional diffusion priors for class-conditional generation on MNIST and CIFAR-10 datasets. Starting with pretrained unconditional models \( p_\theta(x) \), we apply the RTB objective to adapt the priors to sample from posteriors conditioned on class labels \( c \). This is achieved by incorporating class-specific constraints \( r(x) = p(c | x) \) during fine-tuning. In the figure, we observe some of our results. RTB effectively balances reward maximization and sample diversity, finetuning both for single class conditions, or multimodal distributions (e.g. even numbers).
We apply RTB to the problem of KL regularized finetuning of text-to-image diffusion priors (stable-diffusion-1.5) with a reward function trained on human preferences (ImageReward). RTB optimized policies achieve high reward while maintaining diversity in the generated images. In the images shown below for different text prompts - the first row is the diffusion prior, second row consists of biased posteriors finetuned with KL regularized RL (DPOK), and third row consists of posteriors finetuned with RTB.
RTB is generally applicable to hierarchical generative models, including discrete diffusion. We apply RTB to infilling stories with a discrete diffusion model prior, outperforming finetuned autoregressive models for this task.
An important problem in offline RL is KL regularized policy extraction using the behavior policy as prior, and the trained Q function obtained using an off-the-shelf Q-learning algorithm. Diffusion policies are expressive and can model highly multimodal behavior policies. Given this diffusion prior \(\mu(a|s)\) and a Q function trained with IQL \(Q(s,a)\), we use RTB to obtain the KL regularized optimal policy of the form \(\pi^*(a|s) \propto \mu(a|s)e^{Q(s,a)}\). We match state of the art results in the D4RL benchmark.
@inproceedings{ venkatraman2024amortizing, title={Amortizing intractable inference in diffusion models for vision, language, and control}, author={Siddarth Venkatraman and Moksh Jain and Luca Scimeca and Minsu Kim and Marcin Sendera and Mohsin Hasan and Luke Rowe and Sarthak Mittal and Pablo Lemos and Emmanuel Bengio and Alexandre Adam and Jarrid Rector-Brooks and Yoshua Bengio and Glen Berseth and Nikolay Malkin}, booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, year={2024}, url={https://openreview.net/forum?id=gVTkMsaaGI} }