Generative Functions
— TypeGenerativeFunction{T,U <: Trace}
Abstract type for a generative function with return value type T and trace type U.
— TypeTrace
Abstract type for a trace of a generative function.
The complete set of methods in the generative function interface (GFI) is:
— Functiontrace = simulate(gen_fn, args)
Execute the generative function and return the trace.
Given arguments (args
), sample $(r, t) \sim p(\cdot; x)$ and return a trace with choice map $t$.
If gen_fn
has optional trailing arguments (i.e., default values are provided), the optional arguments can be omitted from the args
tuple. The generated trace will have default values filled in.
— Function(trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple)
Return a trace of a generative function.
(trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple,
Return a trace of a generative function that is consistent with the given constraints on the random choices.
Given arguments $x$ (args
) and assignment $u$ (constraints
) (which is empty for the first form), sample $t \sim q(\cdot; u, x)$ and $r \sim q(\cdot; x, t)$, and return the trace $(x, r, t)$ (trace
). Also return the weight (weight
\[\log \frac{p(r, t; x)}{q(t; u, x) q(r; x, t)}\]
If gen_fn
has optional trailing arguments (i.e., default values are provided), the optional arguments can be omitted from the args
tuple. The generated trace will have default values filled in.
Example without constraints:
(trace, weight) = generate(foo, (2, 4))
Example with constraint that address :z
takes value true
(trace, weight) = generate(foo, (2, 4), choicemap((:z, true))
— Function(new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple,
Update a trace by changing the arguments and/or providing new values for some existing random choice(s) and values for some newly introduced random choice(s).
Given a previous trace $(x, r, t)$ (trace
), new arguments $x'$ (args
), and a map $u$ (constraints
), return a new trace $(x', r', t')$ (new_trace
) that is consistent with $u$. The values of choices in $t'$ are either copied from $t$ or from $u$ (with $u$ taking precedence) or are sampled from the internal proposal distribution. All choices in $u$ must appear in $t'$. Also return an assignment $v$ (discard
) containing the choices in $t$ that were overwritten by values from $u$, and any choices in $t$ whose address does not appear in $t'$. Sample $t' \sim q(\cdot; x', t + u)$, and $r' \sim q(\cdot; x', t')$, where $t + u$ is the choice map obtained by merging $t$ and $u$ with $u$ taking precedence for overlapping addresses. Also return a weight (weight
\[\log \frac{p(r', t'; x')}{q(r'; x', t') q(t'; x', t + u)} - \log \frac{p(r, t; x)}{q(r; x, t)}\]
Note that argdiffs
is expected to be the same length as args
. If the function that generated trace
supports default values for trailing arguments, then these arguments can be omitted from args
and argdiffs
. Note that if the original trace
was generated using non-default argument values, then for each optional argument that is omitted, the old value will be over-written by the default argument value in the updated trace.
(new_trace, weight, retdiff, discard) = update(trace, constraints::ChoiceMap)
Shorthand variant of update
which assumes the arguments are unchanged.
— Function(new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple,
Update a trace by changing the arguments and/or randomly sampling new values for selected random choices using the internal proposal distribution family.
Given a previous trace $(x, r, t)$ (trace
), new arguments $x'$ (args
), and a set of addresses $A$ (selection
), return a new trace $(x', t')$ (new_trace
) such that $t'$ agrees with $t$ on all addresses not in $A$ ($t$ and $t'$ may have different sets of addresses). Let $u$ denote the restriction of $t$ to the complement of $A$. Sample $t' \sim Q(\cdot; u, x')$ and sample $r' \sim Q(\cdot; x', t')$. Return the new trace $(x', r', t')$ (new_trace
) and the weight (weight
\[\log \frac{p(r', t'; x')}{q(t'; u, x') q(r'; x', t')} - \log \frac{p(r, t; x)}{q(t; u', x) q(r; x, t)}\]
where $u'$ is the restriction of $t'$ to the complement of $A$.
Note that argdiffs
is expected to be the same length as args
. If the function that generated trace
supports default values for trailing arguments, then these arguments can be omitted from args
and argdiffs
. Note that if the original trace
was generated using non-default argument values, then for each optional argument that is omitted, the old value will be over-written by the default argument value in the regenerated trace.
(new_trace, weight, retdiff) = regenerate(trace, selection::Selection)
Shorthand variant of regenerate
which assumes the arguments are unchanged.
— Functionget_args(trace)
Return the argument tuple for a given execution.
args::Tuple = get_args(trace)
— Functionget_retval(trace)
Return the return value of the given execution.
Example for generative function with return type T
retval::T = get_retval(trace)
— Functionget_choices(trace)
Return a value implementing the assignment interface
Note that the value of any non-addressed randomness is not externally accessible.
choices::ChoiceMap = get_choices(trace)
z_val = choices[:z]
— Functionget_score(trace)
\[\log \frac{p(r, t; x)}{q(r; x, t)}\]
When there is no non-addressed randomness, this simplifies to the log probability $\log p(t; x)$.
— Functiongen_fn::GenerativeFunction = get_gen_fn(trace)
Return the generative function that produced the given trace.
— Functionvalue = getindex(trace::Trace, addr)
Get the value of the random choice, or auxiliary state (e.g. return value of inner function call), at address addr
retval = getindex(trace::Trace)
retval = trace[]
Synonym for get_retval
— Functionweight = project(trace::U, selection::Selection)
Estimate the probability that the selected choices take the values they do in a trace.
Given a trace $(x, r, t)$ (trace
) and a set of addresses $A$ (selection
), let $u$ denote the restriction of $t$ to $A$. Return the weight (weight
\[\log \frac{p(r, t; x)}{q(t; u, x) q(r; x, t)}\]
— Function(choices, weight, retval) = propose(gen_fn::GenerativeFunction, args::Tuple)
Sample an assignment and compute the probability of proposing that assignment.
Given arguments (args
), sample $t \sim p(\cdot; x)$ and $r \sim p(\cdot; x, t)$, and return $t$ (choices
) and the weight (weight
\[\log \frac{p(r, t; x)}{q(r; x, t)}\]
— Function(weight, retval) = assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap)
Return the probability of proposing an assignment
Given arguments $x$ (args
) and an assignment $t$ (choices
) such that $p(t; x) > 0$, sample $r \sim q(\cdot; x, t)$ and return the weight (weight
\[\log \frac{p(r, t; x)}{q(r; x, t)}\]
It is an error if $p(t; x) = 0$.
— Functionbools::Tuple = has_argument_grads(gen_fn::Union{GenerativeFunction,Distribution})
Return a tuple of booleans indicating whether a gradient is available for each of its arguments.
— Functionhas_submap(choices::ChoiceMap, addr)
Return true if there is a non-empty sub-assignment at the given address.
— Functionreq::Bool = accepts_output_grad(gen_fn::GenerativeFunction)
Return a boolean indicating whether the return value is dependent on any of the gradient source elements for any trace.
The gradient source elements are:
Any argument whose position is true in
Any trainable parameter
Random choices made at a set of addresses that are selectable by
— Functionarg_grads = accumulate_param_gradients!(trace, retgrad=nothing, scale_factor=1.)
Increment gradient accumulators for parameters by the gradient of the log-probability of the trace, optionally scaled, and return the gradient with respect to the arguments (not scaled).
Given a previous trace $(x, t)$ (trace
) and a gradient with respect to the return value $∇_y J$ (retgrad
), return the following gradient (arg_grads
) with respect to the arguments $x$:
\[∇_x \left( \log P(t; x) + J \right)\]
The length of arg_grads
will be equal to the number of arguments to the function that generated trace
(including any optional trailing arguments). If an argument is not annotated with (grad)
, the corresponding value in arg_grads
will be nothing
Also increment the gradient accumulators for the trainable parameters $Θ$ of the function by:
\[∇_Θ \left( \log P(t; x) + J \right)\]
— Function(arg_grads, choice_values, choice_grads) = choice_gradients(
trace, selection=EmptySelection(), retgrad=nothing)
Given a previous trace $(x, t)$ (trace
) and a gradient with respect to the return value $∇_y J$ (retgrad
), return the following gradient (arg_grads
) with respect to the arguments $x$:
\[∇_x \left( \log P(t; x) + J \right)\]
The length of arg_grads
will be equal to the number of arguments to the function that generated trace
(including any optional trailing arguments). If an argument is not annotated with (grad)
, the corresponding value in arg_grads
will be nothing
Also given a set of addresses $A$ (selection
) that are continuous-valued random choices, return the folowing gradient (choice_grads
) with respect to the values of these choices:
\[∇_A \left( \log P(t; x) + J \right)\]
The gradient is represented as a choicemap whose value at (hierarchical) address addr
is $∂J/∂t[\texttt{addr}]$.
Also return the choicemap (choice_values
) that is the restriction of $t$ to $A$.
— Functionget_params(gen_fn::GenerativeFunction)
Return an iterable over the trainable parameters of the generative function.
— TypeDiff
Abstract supertype for information about a change to a value.
— TypeNoChange
Singleton to indicate the value did not change.
— TypeUnknownChange
Singleton to indicate the change to the value is unknown or unprovided.
— TypeSetDiff <: Diff
— TypeDiffed{V,DV <: Diff}
Container for a value and information about a change to its value.
— TypeCustomUpdateGF{T,S}
Abstract type for a generative function with a custom update computation, and default behaviors for all other generative function interface methods.
is the type of the return value and S
is the type of state used internally for incremental computation.
— Functionretval, state = apply_with_state(gen_fn::CustomDetermGF, args)
Execute the generative function and return the return value and the state.
— Functionstate, retval, retdiff = update_with_state(gen_fn::CustomDetermGF, state, args, argdiffs)
Update the arguments to the generative function and return new return value and state.
— TypeCustomGradientGF{T}
Abstract type for a generative function with a custom gradient computation, and default behaviors for all other generative function interface methods.
is the type of the return value.
— Functionretval = apply(gen_fn::CustomGradientGF, args)
Apply the function to the arguments.
— Functionarg_grads = gradient(gen_fn::CustomDetermGF, args, retval, retgrad)
Return the gradient tuple with respect to the arguments, where nothing
is for argument(s) whose gradient is not available.
— Functionstate = init_update_state(conf, gen_fn::GenerativeFunction, param_list::Vector)
Get the initial state for a parameter update to the given parameters of the given generative function.
is a vector of references to parameters of gen_fn
. conf
configures the update.
— Functionapply_update!(state)
Apply one parameter update, mutating the values of the trainable parameters, and possibly also the given state.