Skip to content

Parallelized rejection sampler #790

@Vilin97

Description

@Vilin97

Presentation of the new sampler

I am surprised BlackJax does not a rejection sampler -- this is usually the first thing I try before jumping to MCMC-based ones. The barebones implementation can look something like this. One benefit is that it's O(1) if your GPU can fit the arrays in question, so that's nice :).

def rejection_sample_parallel(key, density_fn, proposal_sample, proposal_density, max_ratio, num_samples=1):
    key, key_propose, key_accept = jrandom.split(key, 3)
    # sample twice the needed-in-expectation amount
    num_candidates = int(num_samples * max_ratio * 2)
    candidates = proposal_sample(key_propose, num_candidates)
    proposal_values = proposal_density(candidates)
    target_values = density_fn(candidates)
    
    # Accept with probability target/proposal/max_ratio
    accepted = jrandom.uniform(key_accept, num_candidates) * max_ratio * proposal_values <= target_values
    samples = candidates[accepted]
    
    return samples[:num_samples]

Can also auto-detect the max_ratio if not given by the user, and even auto-select the proposal (e.g. always uniform or more sophisticated -- find the best one from some family), and handle the improbable event of getting over 2E(# rejections) rejections.

How does it compare to other algorithms in blackjax?

Performance -- great in 1-d, horrible in many dimensions
Complexity -- dead simple

Where does it fit in blackjax

A sampling library should have simple sampling algos, in addition to sophisticated ones. Speaking of which -- what about inverse sampling?

Are you willing to open a PR?

Yes

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions