Particle Filtering and Sequential Monte Carlo
Gen.jl provides support for Sequential Monte Carlo (SMC) inference in the form of particle filtering. The state of a particle filter is a represented as a ParticleFilterState
object.
Gen.ParticleFilterState
— TypeParticleFilterState{U}
Represents the state of a particle filter as a collection of weighted traces, where the type of each trace is U
.
Fields
traces
: A vector of traces, one for each particle.new_traces
: A preallocated vector for storing new traces.log_weights
: A vector of log importance weights for each trace.log_ml_est
: Estimate of the log marginal likelihood before the last resampling step.parents
: The parent indices of each trace intraces
.
The fields above are an implementation detail that are subject to future changes.
Particle Filtering Steps
The basic steps of particle filtering are initialization (via initialize_particle_filter
), updating (via particle_filter_step!
), and resampling (via maybe_resample!
). The latter two operations are applied to a ParticleFilterState
, modifying it in place.
Gen.initialize_particle_filter
— Functionstate = initialize_particle_filter(model::GenerativeFunction, model_args::Tuple,
observations::ChoiceMap, proposal::GenerativeFunction, proposal_args::Tuple,
num_particles::Int)
Initialize the state of a particle filter using a custom proposal for the initial latent state.
state = initialize_particle_filter(model::GenerativeFunction, model_args::Tuple,
observations::ChoiceMap, num_particles::Int)
Initialize the state of a particle filter, using the default proposal for the initial latent state.
Gen.particle_filter_step!
— Function(log_incremental_weights,) = particle_filter_step!(
state::ParticleFilterState, new_args::Tuple, argdiffs,
observations::ChoiceMap, proposal::GenerativeFunction, proposal_args::Tuple)
Perform a particle filter update, where the model arguments are adjusted, new observations are added, and some combination of a custom proposal and the model's internal proposal is used for proposing new latent state. That is, for each particle,
- The proposal function
proposal
is evaluated with argumentsTuple(t_old, proposal_args...)
(wheret_old
is the old model trace), and produces its own trace (call itproposal_trace
); and - The old model trace is replaced by a new model trace (call it
t_new
).
The choicemap of t_new
satisfies the following conditions:
get_choices(t_old)
is a subset ofget_choices(t_new)
;observations
is a subset ofget_choices(t_new)
;get_choices(proposal_trace)
is a subset ofget_choices(t_new)
.
Here, when we say one choicemap a
is a "subset" of another choicemap b
, we mean that all keys that occur in a
also occur in b
, and the values at those addresses are equal.
It is an error if no trace t_new
satisfying the above conditions exists in the support of the model (with the new arguments). If such a trace exists, then the random choices not determined by the above requirements are sampled using the internal proposal.
(log_incremental_weights,) = particle_filter_step!(
state::ParticleFilterState, new_args::Tuple, argdiffs,
observations::ChoiceMap)
Perform a particle filter update, where the model arguments are adjusted, new observations are added, and the default proposal is used for new latent state.
Gen.maybe_resample!
— Functiondid_resample::Bool = maybe_resample!(state::ParticleFilterState;
ess_threshold::Float64=length(state.traces)/2, verbose=false)
Do a resampling step if the effective sample size is below the given threshold. Return true
if a resample thus occurred, false
otherwise.
Accessors
The following accessor functions can be used to return information about a ParticleFilterState
, or to sample traces from the distribution that the particle filter approximates.
Gen.log_ml_estimate
— Functionestimate = log_ml_estimate(state::ParticleFilterState)
Return the particle filter's current estimate of the log marginal likelihood.
Gen.get_traces
— Functiontraces = get_traces(state::ParticleFilterState)
Return the vector of traces in the current state, one for each particle.
Gen.get_log_weights
— Functionlog_weights = get_log_weights(state::ParticleFilterState)
Return the vector of log weights for the current state, one for each particle.
The weights are not normalized, and are in log-space.
Gen.sample_unweighted_traces
— Functiontraces::Vector = sample_unweighted_traces(state::ParticleFilterState, num_samples::Int)
Sample a vector of num_samples
traces from the weighted collection of traces in the given particle filter state.
Advanced Particle Filtering
For a richer set of particle filtering techniques, including support for stratified sampling, multiple resampling methods, MCMC rejuvenation moves, particle filter resizing, users are recommended to use the GenParticleFilters.jl extension library.
To use the generalization of standard SMC known as Sequential Monte Carlo with Probabilistic Program Proposals (SMCP³), use the API provided by GenSMCP3.jl, or implement an UpdatingTraceTranslator
in GenParticleFilters.jl.
Even more advanced SMC techniques (such as divide-and-conquer SMC) are not currently supported by Gen.