-
Notifications
You must be signed in to change notification settings - Fork 126
Description
Presentation of the new sampler
I am surprised BlackJax does not a rejection sampler -- this is usually the first thing I try before jumping to MCMC-based ones. The barebones implementation can look something like this. One benefit is that it's O(1) if your GPU can fit the arrays in question, so that's nice :).
def rejection_sample_parallel(key, density_fn, proposal_sample, proposal_density, max_ratio, num_samples=1):
key, key_propose, key_accept = jrandom.split(key, 3)
# sample twice the needed-in-expectation amount
num_candidates = int(num_samples * max_ratio * 2)
candidates = proposal_sample(key_propose, num_candidates)
proposal_values = proposal_density(candidates)
target_values = density_fn(candidates)
# Accept with probability target/proposal/max_ratio
accepted = jrandom.uniform(key_accept, num_candidates) * max_ratio * proposal_values <= target_values
samples = candidates[accepted]
return samples[:num_samples]
Can also auto-detect the max_ratio if not given by the user, and even auto-select the proposal (e.g. always uniform or more sophisticated -- find the best one from some family), and handle the improbable event of getting over 2E(# rejections) rejections.
How does it compare to other algorithms in blackjax?
Performance -- great in 1-d, horrible in many dimensions
Complexity -- dead simple
Where does it fit in blackjax
A sampling library should have simple sampling algos, in addition to sophisticated ones. Speaking of which -- what about inverse sampling?
Are you willing to open a PR?
Yes