Skip to content

Persistent Sampling #786

@reubenharry

Description

@reubenharry

Presentation of the new sampler

Persistent Sampling is an extension of SMC, somewhat related to waste free SMC, which uses the whole history of previous particles.

The authors of the paper linked above have used it with gradient-free kernels, but they (and I) are interested in putting it in blackjax, so as to benefit from HMC, MALA and similar, when tackling multimodal problems in higher-dimensional settings.

Key features and benefits of PS highlighted in the paper

  • Leveraging Historical Information: It uses multiple importance sampling (MIS) techniques, treating particles from all prior iterations as samples from a mixture of historical distributions. This creates a richer and more diverse particle pool compared to standard SMC, which typically only uses the most recent particle set.
  • Computational Efficiency: PS achieves its benefits without requiring additional likelihood evaluations. Weights for the persistent (historical) particles are calculated using cached likelihood values from the iterations when those particles were generated. While PS involves more operations per iteration for weight recalculation, the number of computationally expensive likelihood evaluations remains comparable to standard SMC.
  • Improved Accuracy and Stability: This framework leads to more accurate posterior approximations and significantly lower variance in marginal likelihood (evidence) estimates, which is crucial for model comparison.
  • Mitigating SMC Limitations: By using a growing pool of diverse particles, PS directly tackles common SMC issues like particle impoverishment (where resampling leads to many identical particles) and mode collapse in multimodal distributions. The resampling step in PS draws from a much larger pool ((t−1)×N) than standard SMC (N), leading to less correlated, more diverse particles after resampling.

How does it compare to other algorithms in blackjax?

SMC Style: PS is fundamentally an SMC-style algorithm. Like standard SMC in blackjax, it operates by propagating particles through a sequence of intermediate distributions, typically annealing from a prior to the posterior. It requires the prior and likelihood to be provided separately.

Key Differences from Standard SMC:

  • Particle Reuse: Unlike standard SMC which discards particles from previous iterations after resampling, PS retains and reweights all past particles at each step.
  • Resampling Pool: PS resamples from a pool of (t−1)×N persistent particles, whereas SMC resamples from N particles. This leads to less particle degeneracy in PS.
  • Kernel Adaptation: PS uses the larger persistent set for adapting MCMC kernels, potentially leading to better adaptation than SMC, especially with limited N or high dimensions.
  • Evidence Estimation: PS computes the marginal likelihood estimate differently, averaging over weights derived from the mixture distribution, and demonstrates significantly lower variance in these estimates compared to standard SMC.
    Posterior Estimation: PS uses all persistent particles, weighted appropriately, to estimate posterior expectations, whereas standard SMC typically uses only the final particle set.

Comparison to Waste-Free SMC (WFSMC)

Both PS and WFSMC resample from a larger pool of particles to improve diversity. However, WFSMC expands its pool by using states generated within the MCMC move step of the current iteration, while PS expands its pool by incorporating the final particle sets from all previous iterations. Experiments in the paper show PS consistently outperforming WFSMC in terms of Mean Squared Error (MSE) for both posterior moments and evidence estimates under matched computational cost. PS is also noted to be compatible with the waste-free idea, suggesting a potential hybrid approach.

Comparison to Recycled SMC (RSMC)

RSMC is a post-processing step applied after a standard SMC run is complete, where particles from intermediate steps are reweighted to target the final posterior. PS incorporates the historical particle information during the sampling process itself, influencing the resampling, moving, and annealing schedule steps. Experiments show PS significantly outperforms RSMC.

Where does it fit in blackjax?

This would likely enter Blackjax as a modification or extension of the existing SMC implementation. Key changes would involve:

  • Storing all particles {θ_t′^i​} and their cached likelihood values L(θ_t′^i​) from t′=1...t−1
  • Implementing the PS reweighting scheme using the mixture distribution denominator.
  • Implementing the PS marginal likelihood estimator.
  • Modifying the resampling step to draw N particles from the full persistent set of (t−1)×N weighted particles.
  • Updating the ESS calculation to use the persistent weights and potentially allowing the target ESS fraction α to exceed 1.
  • (Optionally) Modifying MCMC kernel adaptation routines to use the full persistent set for calculating empirical statistics (e.g., covariance).

Given the similarities in the overall structure (reweight, resample, move), modifying the current SMC framework seems feasible.

Are you willing to open a PR?

Yes, a few weeks down the line, most likely.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions