Generative Function Interface
One of the core abstractions in Gen is the generative function. The interface for interacting with generative functions is called the generative function interface (GFI) . Generative functions are used to represent a variety of different types of probabilistic computations including generative models, inference models, custom proposal distributions, and variational approximations.
Introduction
There are various kinds of generative functions, which are represented by concrete subtypes of GenerativeFunction
.
For example, the Built-in Modeling Language allows generative functions to be constructed using Julia function definition syntax:
@gen function foo(a, b=0)
if @trace(bernoulli(0.5), :z)
return a + b + 1
else
return a + b
end
end;
DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any], true, Union{Nothing, Some{Any}}[nothing, Some(0)], Main.var"##foo#275", Bool[0, 0], false)
Users can also extend Gen by implementing their own custom generative functions, which can be new modeling languages, or just specialized optimized implementations of a fragment of a specific model.
Generative functions behave like Julia functions in some respects. For example, we can call a generative function foo
on arguments and get an output value using regular Julia call syntax:
foo(2, 4)
6
However, generative functions are distinct from Julia functions because they support additional behaviors, described in the remainder of this section.
Mathematical concepts
Generative functions represent computations that accept some arguments, may use randomness internally, return an output, and cannot mutate externally observable state. We represent the randomness used during an execution of a generative function as a choice map from unique addresses to values of random choices, denoted $t : A \to V$ where $A$ is a finite (but not a priori bounded) address set and $V$ is a set of possible values that random choices can take. In this section, we assume that random choices are discrete to simplify notation. We say that two choice maps $t$ and $s$ agree if they assign the same value for any address that is in both of their domains.
Generative functions may also use non-addressable randomness, which is not included in the map $t$. We denote non-addressable randomness by $r$. Untraced randomness is useful for example, when calling black box Julia code that implements a randomized algorithm.
The observable behavior of every generative function is defined by the following mathematical objects:
Input type
The set of valid argument tuples to the function, denoted $X$.
Probability distribution family
A family of probability distributions $p(t, r; x)$ on maps $t$ from random choice addresses to their values, and non-addressable randomness $r$, indexed by arguments $x$, for all $x \in X$. Note that the distribution must be normalized:
\[\sum_{t, r} p(t, r; x) = 1 \;\; \text{for all} \;\; x \in X\]
This corresponds to a requirement that the function terminate with probabability 1 for all valid arguments. We use $p(t; x)$ to denote the marginal distribution on the map $t$:
\[p(t; x) := \sum_{r} p(t, r; x)\]
And we denote the conditional distribution on non-addressable randomness $r$, given the map $t$, as:
\[p(r | t; x) := p(t, r; x) / p(t; x)\]
Return value function
A (deterministic) function $f$ that maps the tuple $(x, t)$ of the arguments and the choice map to the return value of the function (which we denote by $y$). Note that the return value cannot depend on the non-addressable randomness.
Auxiliary state
Generative functions may expose additional auxiliary state associated with an execution, besides the choice map and the return value. This auxiliary state is a function $z = h(x, t, r)$ of the arguments, choice map, and non-addressable randomness. Like the choice map, the auxiliary state is indexed by addresses. We require that the addresses of auxiliary state are disjoint from the addresses in the choice map. Note that when a generative function is called within a model, the auxiliary state is not available to the caller. It is typically used by inference programs, for logging and for caching the results of deterministic computations that would otherwise need to be reconstructed.
Internal proposal distribution family
A family of probability distributions $q(t; x, u)$ on maps $t$ from random choice addresses to their values, indexed by tuples $(x, u)$ where $u$ is a map from random choice addresses to values, and where $x$ are the arguments to the function. It must satisfy the following conditions:
\[\sum_{t} q(t; x, u) = 1 \;\; \text{for all} \;\; x \in X, u\]
\[p(t; x) > 0 \text{ if and only if } q(t; x, u) > 0 \text{ for all } u \text{ where } u \text{ and } t \text{ agree }\]
\[q(t; x, u) > 0 \text{ implies that } u \text{ and } t \text{ agree }.\]
There is also a family of probability distributions $q(r; x, t)$ on non-addressable randomness, that satisfies:
\[q(r; x, t) > 0 \text{ if and only if } p(r | t, x) > 0\]
Traces
An execution trace (or just trace) is a record of an execution of a generative function. Traces are the primary data structures manipulated by Gen inference programs. There are various methods for producing, updating, and inspecting traces. Traces contain:
the arguments to the generative function
the choice map
the return value
auxiliary state
other implementation-specific state that is not exposed to the caller or user of the generative function, but is used internally to facilitate e.g. incremental updates to executions and automatic differentiation
any necessary record of the non-addressable randomness
Different concrete types of generative functions use different data structures and different Julia types for their traces, but traces are subtypes of Trace
.
The concrete trace type that a generative function uses is the second type parameter of the GenerativeFunction
abstract type. For example, the trace type of DynamicDSLFunction
is DynamicDSLTrace
.
A generative function can be executed to produce a trace of the execution using simulate
:
trace = simulate(foo, (a, b))
A traced execution that satisfies constraints on the choice map can be generated using generate
:
trace, weight = generate(foo, (a, b), choicemap((:z, false)))
There are various methods for inspecting traces, including:
get_args
(returns the arguments to the function)get_retval
(returns the return value of the function)get_choices
(returns the choice map)get_score
(returns the log probability that the random choices took the values they did)get_gen_fn
(returns a reference to the generative function)
You can also access the values in the choice map and the auxiliary state of the trace by passing the address to Base.getindex
. For example, to retrieve the value of random choice at address :z
:
z = trace[:z]
When a generative function has default values specified for trailing arguments, those arguments can be left out when calling simulate
, generate
, and other functions provided by the generative function interface. The default values will automatically be filled in:
trace = simulate(foo, (2,));
get_args(trace)
(2, 0)
Updating traces
It is often important to incrementally modify the trace of a generative function (e.g. within MCMC, numerical optimization, sequential Monte Carlo, etc.). In Gen, traces are functional data structures, meaning they can be treated as immutable values. There are several methods that take a trace of a generative function as input and return a new trace of the generative function based on adjustments to the execution history of the function. We will illustrate these methods using the following generative function:
@gen function bar()
val = @trace(bernoulli(0.3), :a)
if @trace(bernoulli(0.4), :b)
val = @trace(bernoulli(0.6), :c) && val
else
val = @trace(bernoulli(0.1), :d) && val
end
val = @trace(bernoulli(0.7), :e) && val
return val
end
DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], Main.var"##bar#276", Bool[], false)
Suppose we have a trace (trace
) of bar
with initial choices:
│
├── :a : false
│
├── :b : true
│
├── :e : true
│
└── :c : false
Note that address :d
is not present because the branch in which :d
is sampled was not taken because random choice :b
had value true
.
Update
The update
method takes a trace and generates an adjusted trace that is consistent with given changes to the arguments to the function, and changes to the values of random choices made.
Example. Suppose we run update
on the example trace
, with the following constraints:
│
├── :b : false
│
└── :d : true
constraints = choicemap((:b, false), (:d, true))
(new_trace, w, _, discard) = update(trace, (), (), constraints)
(Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], Main.var"##bar#276", Bool[], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:a => Gen.ChoiceOrCallRecord{Bool}(false, -0.35667494393873245, NaN, true), :b => Gen.ChoiceOrCallRecord{Bool}(false, -0.5108256237659907, NaN, true), :d => Gen.ChoiceOrCallRecord{Bool}(true, -2.3025850929940455, NaN, true), :e => Gen.ChoiceOrCallRecord{Bool}(true, -0.35667494393873245, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -3.5267606046375013, 0.0, (), false), -0.9808292530117263, UnknownChange(), DynamicChoiceMap(Dict{Any, Any}(:b => true, :c => false), Dict{Any, Any}()))
Then get_choices(new_trace)
will be:
get_choices(new_trace)
│
├── :a : false
│
├── :b : false
│
├── :d : true
│
└── :e : true
and discard
will be:
discard
│
├── :b : true
│
└── :c : false
Note that the discard contains both the previous values of addresses that were overwritten, and the values for addresses that were in the previous trace but are no longer in the new trace. The weight (w
) is computed as:
\[p(t; x) = 0.7 × 0.4 × 0.4 × 0.7 = 0.0784\\ p(t'; x') = 0.7 × 0.6 × 0.1 × 0.7 = 0.0294\\ w = \log p(t'; x')/p(t; x) = \log 0.0294/0.0784 = \log 0.375\]
Example. Suppose we run update
on the example trace
, with the following constraints, which do not contain a value for :d
:
│
└── :b : false
constraints = choicemap((:b, false))
(new_trace, w, _, discard) = update(trace, (), (), constraints)
(Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], Main.var"##bar#276", Bool[], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:a => Gen.ChoiceOrCallRecord{Bool}(false, -0.35667494393873245, NaN, true), :b => Gen.ChoiceOrCallRecord{Bool}(false, -0.5108256237659907, NaN, true), :d => Gen.ChoiceOrCallRecord{Bool}(false, -0.10536051565782628, NaN, true), :e => Gen.ChoiceOrCallRecord{Bool}(true, -0.35667494393873245, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -1.329536027301282, 0.0, (), false), 1.3217558399823193, UnknownChange(), DynamicChoiceMap(Dict{Any, Any}(:b => true, :c => false), Dict{Any, Any}()))
Since b
is constrained to false
, the updated trace must now sample at address d
(note address e
remains fixed). There two possibilities for get_choices(new_trace)
:
The first choicemap:
│
├── :a : false
│
├── :b : false
│
├── :d : true
│
└── :e : true
occurs with probability 0.1. The second:
│
├── :a : false
│
├── :b : false
│
├── :d : false
│
└── :e : true
occurs with probability 0.9. Also, discard
will be:
│
├── :b : true
│
└── :c : false
If the former case occurs and :d
is assigned to true
, then the weight (w
) is computed as:
\[p(t; x) = 0.7 × 0.4 × 0.4 × 0.7 = 0.0784\\ p(t'; x') = 0.7 × 0.6 × 0.1 × 0.7 = 0.0294\\ q(t'; x', t + u) = 0.1\\ w = \log p(t'; x')/(p(t; x) q(t'; x', t + u)) = \log 0.0294/(0.0784 \cdot 0.1) = \log (3.75)\]
Regenerate
The regenerate
method takes a trace and generates an adjusted trace that is consistent with a change to the arguments to the function, and also generates new values for selected random choices.
Example. Suppose we run regenerate
on the example trace
, with selection :a
and :b
:
(new_trace, w, _) = regenerate(trace, (), (), select(:a, :b))
(Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], Main.var"##bar#276", Bool[], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:a => Gen.ChoiceOrCallRecord{Bool}(false, -0.35667494393873245, NaN, true), :b => Gen.ChoiceOrCallRecord{Bool}(false, -0.5108256237659907, NaN, true), :d => Gen.ChoiceOrCallRecord{Bool}(false, -0.10536051565782628, NaN, true), :e => Gen.ChoiceOrCallRecord{Bool}(true, -0.35667494393873245, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -1.329536027301282, 0.0, (), false), 0.0, UnknownChange())
Then, a new value for :a
will be sampled from bernoulli(0.3)
, and a new value for :b
will be sampled from bernoulli(0.4)
. If the new value for :b
is true
, then the previous value for :c
(false
) will be retained. If the new value for :b
is false
, then a new value for :d
will be sampled from bernoulli(0.7)
. The previous value for :c
will always be retained. Suppose the new value for :a
is true
, and the new value for :b
is true
. Then get_choices(new_trace)
will be:
│
├── :a : true
│
├── :b : true
│
├── :c : false
│
└── :e : true
The weight (w
) is $\log 1 = 0$.
Argdiffs
In addition to the input trace, and other arguments that indicate how to adjust the trace, each of these methods also accepts an args argument and an argdiffs argument, both of which are tuples. The args argument contains the new arguments to the generative function, which may differ from the previous arguments to the generative function (which can be retrieved by applying get_args
to the previous trace). In many cases, the adjustment to the execution specified by the other arguments to these methods is 'small' and only affects certain parts of the computation. Therefore, it is often possible to generate the new trace and the appropriate log probability ratios required for these methods without revisiting every state of the computation of the generative function.
To enable this, the argdiffs argument provides additional information about the difference between each of the previous arguments to the generative function, and its new argument value. This argdiff information permits the implementation of the update method to avoid inspecting the entire argument data structure to identify which parts were updated. Note that the correctness of the argdiff is in general not verified by Gen–-passing incorrect argdiff information may result in incorrect behavior.
The trace update methods for all generative functions above should accept at least the following types of argdiffs:
Generative functions may also be able to process more specialized diff data types for each of their arguments, that allow more precise information about the different to be supplied.
Retdiffs
To enable generative functions that invoke other functions to efficiently make use of incremental computation, the trace update methods of generative functions also return a retdiff value, which provides information about the difference in the return value of the previous trace an the return value of the new trace.
Differentiable Programming
The trace of a generative function may support computation of gradients of its log probability with respect to some subset of (i) its arguments, (ii) values of random choice, and (iii) any of its trainable parameters (see below).
To compute gradients with respect to the arguments as well as certain selected random choices, use:
To compute gradients with respect to the arguments, and to increment a stateful gradient accumulator for the trainable parameters of the generative function, use:
A generative function statically reports whether or not it is able to compute gradients with respect to each of its arguments, through the function has_argument_grads
.
Trainable parameters
The trainable parameters of a generative function are (unlike arguments and random choices) state of the generative function itself, and are not contained in the trace. Generative functions that have trainable parameters maintain gradient accumulators for these parameters, which get incremented by the gradient induced by the given trace by a call to accumulate_param_gradients!
. Users then use these accumulated gradients to update to the values of the trainable parameters.
Return value gradient
The set of elements (either arguments, random choices, or trainable parameters) for which gradients are available is called the gradient source set. If the return value of the function is conditionally dependent on any element in the gradient source set given the arguments and values of all other random choices, for all possible traces of the function, then the generative function requires a return value gradient to compute gradients with respect to elements of the gradient source set. This static property of the generative function is reported by accepts_output_grad
.
API
The following GFI methods should be implemented for a Trace
:
Gen.Trace
— TypeTrace
Abstract type for a trace of a generative function.
Gen.get_args
— Functionget_args(trace)
Return the argument tuple for a given execution.
Example:
args::Tuple = get_args(trace)
Gen.get_retval
— Functionget_retval(trace)
Return the return value of the given execution.
Example for generative function with return type T
:
retval::T = get_retval(trace)
Gen.get_choices
— Functionget_choices(trace)
Return a value implementing the assignment interface
Note that the value of any non-addressed randomness is not externally accessible.
Example:
choices::ChoiceMap = get_choices(trace)
z_val = choices[:z]
Gen.get_score
— Functionget_score(trace)
Return:
\[\log \frac{p(r, t; x)}{q(r; x, t)}\]
When there is no non-addressed randomness, this simplifies to the log probability $\log p(t; x)$.
Gen.get_gen_fn
— Functiongen_fn::GenerativeFunction = get_gen_fn(trace)
Return the generative function that produced the given trace.
Base.getindex
— Functionvalue = getindex(trace::Trace, addr)
Get the value of the random choice, or auxiliary state (e.g. return value of inner function call), at address addr
.
retval = getindex(trace::Trace)
retval = trace[]
Synonym for get_retval
.
The following GFI methods should be implemented for a GenerativeFunction
and its associated Trace
datatype:
Gen.GenerativeFunction
— TypeGenerativeFunction{T,U <: Trace}
Abstract type for a generative function with return value type T and trace type U.
Gen.simulate
— Functiontrace = simulate(gen_fn, args)
Execute the generative function and return the trace.
Given arguments (args
), sample $(r, t) \sim p(\cdot; x)$ and return a trace with choice map $t$.
If gen_fn
has optional trailing arguments (i.e., default values are provided), the optional arguments can be omitted from the args
tuple. The generated trace will have default values filled in.
Gen.generate
— Function(trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple)
Return a trace of a generative function.
(trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple,
constraints::ChoiceMap)
Return a trace of a generative function that is consistent with the given constraints on the random choices.
Given arguments $x$ (args
) and assignment $u$ (constraints
) (which is empty for the first form), sample $t \sim q(\cdot; u, x)$ and $r \sim q(\cdot; x, t)$, and return the trace $(x, r, t)$ (trace
). Also return the weight (weight
):
\[\log \frac{p(r, t; x)}{q(t; u, x) q(r; x, t)}\]
If gen_fn
has optional trailing arguments (i.e., default values are provided), the optional arguments can be omitted from the args
tuple. The generated trace will have default values filled in.
Example without constraints:
(trace, weight) = generate(foo, (2, 4))
Example with constraint that address :z
takes value true
.
(trace, weight) = generate(foo, (2, 4), choicemap((:z, true))
Gen.update
— Function(new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple,
constraints::ChoiceMap)
Update a trace by changing the arguments and/or providing new values for some existing random choice(s) and values for some newly introduced random choice(s).
Given a previous trace $(x, r, t)$ (trace
), new arguments $x'$ (args
), and a map $u$ (constraints
), return a new trace $(x', r', t')$ (new_trace
) that is consistent with $u$. The values of choices in $t'$ are either copied from $t$ or from $u$ (with $u$ taking precedence) or are sampled from the internal proposal distribution. All choices in $u$ must appear in $t'$. Also return an assignment $v$ (discard
) containing the choices in $t$ that were overwritten by values from $u$, and any choices in $t$ whose address does not appear in $t'$. Sample $t' \sim q(\cdot; x', t + u)$, and $r' \sim q(\cdot; x', t')$, where $t + u$ is the choice map obtained by merging $t$ and $u$ with $u$ taking precedence for overlapping addresses. Also return a weight (weight
):
\[\log \frac{p(r', t'; x')}{q(r'; x', t') q(t'; x', t + u)} - \log \frac{p(r, t; x)}{q(r; x, t)}\]
Note that argdiffs
is expected to be the same length as args
. If the function that generated trace
supports default values for trailing arguments, then these arguments can be omitted from args
and argdiffs
. Note that if the original trace
was generated using non-default argument values, then for each optional argument that is omitted, the old value will be over-written by the default argument value in the updated trace.
(new_trace, weight, retdiff, discard) = update(trace, constraints::ChoiceMap)
Shorthand variant of update
which assumes the arguments are unchanged.
Gen.regenerate
— Function(new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple,
selection::Selection)
Update a trace by changing the arguments and/or randomly sampling new values for selected random choices using the internal proposal distribution family.
Given a previous trace $(x, r, t)$ (trace
), new arguments $x'$ (args
), and a set of addresses $A$ (selection
), return a new trace $(x', t')$ (new_trace
) such that $t'$ agrees with $t$ on all addresses not in $A$ ($t$ and $t'$ may have different sets of addresses). Let $u$ denote the restriction of $t$ to the complement of $A$. Sample $t' \sim Q(\cdot; u, x')$ and sample $r' \sim Q(\cdot; x', t')$. Return the new trace $(x', r', t')$ (new_trace
) and the weight (weight
):
\[\log \frac{p(r', t'; x')}{q(t'; u, x') q(r'; x', t')} - \log \frac{p(r, t; x)}{q(t; u', x) q(r; x, t)}\]
where $u'$ is the restriction of $t'$ to the complement of $A$.
Note that argdiffs
is expected to be the same length as args
. If the function that generated trace
supports default values for trailing arguments, then these arguments can be omitted from args
and argdiffs
. Note that if the original trace
was generated using non-default argument values, then for each optional argument that is omitted, the old value will be over-written by the default argument value in the regenerated trace.
(new_trace, weight, retdiff) = regenerate(trace, selection::Selection)
Shorthand variant of regenerate
which assumes the arguments are unchanged.
Gen.project
— Functionweight = project(trace::U, selection::Selection)
Estimate the probability that the selected choices take the values they do in a trace.
Given a trace $(x, r, t)$ (trace
) and a set of addresses $A$ (selection
), let $u$ denote the restriction of $t$ to $A$. Return the weight (weight
):
\[\log \frac{p(r, t; x)}{q(t; u, x) q(r; x, t)}\]
Gen.propose
— Function(choices, weight, retval) = propose(gen_fn::GenerativeFunction, args::Tuple)
Sample an assignment and compute the probability of proposing that assignment.
Given arguments (args
), sample $t \sim p(\cdot; x)$ and $r \sim p(\cdot; x, t)$, and return $t$ (choices
) and the weight (weight
):
\[\log \frac{p(r, t; x)}{q(r; x, t)}\]
Gen.assess
— Function(weight, retval) = assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap)
Return the probability of proposing an assignment
Given arguments $x$ (args
) and an assignment $t$ (choices
) such that $p(t; x) > 0$, sample $r \sim q(\cdot; x, t)$ and return the weight (weight
):
\[\log \frac{p(r, t; x)}{q(r; x, t)}\]
It is an error if $p(t; x) = 0$.
Generative functions that support gradient computation with respect to arguments or trainable parameters should implement the following static properties:
Gen.has_argument_grads
— Functionbools::Tuple = has_argument_grads(gen_fn::Union{GenerativeFunction,Distribution})
Return a tuple of booleans indicating whether a gradient is available for each of its arguments.
Gen.accepts_output_grad
— Functionreq::Bool = accepts_output_grad(gen_fn::GenerativeFunction)
Return a boolean indicating whether the return value is dependent on any of the gradient source elements for any trace.
The gradient source elements are:
Any argument whose position is true in
has_argument_grads
Any trainable parameter
Random choices made at a set of addresses that are selectable by
choice_gradients
.
Gen.choice_gradients
— Function(arg_grads, choice_values, choice_grads) = choice_gradients(
trace, selection=EmptySelection(), retgrad=nothing)
Given a previous trace $(x, t)$ (trace
) and a gradient with respect to the return value $∇_y J$ (retgrad
), return the following gradient (arg_grads
) with respect to the arguments $x$:
\[∇_x \left( \log P(t; x) + J \right)\]
The length of arg_grads
will be equal to the number of arguments to the function that generated trace
(including any optional trailing arguments). If an argument is not annotated with (grad)
, the corresponding value in arg_grads
will be nothing
.
Also given a set of addresses $A$ (selection
) that are continuous-valued random choices, return the folowing gradient (choice_grads
) with respect to the values of these choices:
\[∇_A \left( \log P(t; x) + J \right)\]
The gradient is represented as a choicemap whose value at (hierarchical) address addr
is $∂J/∂t[\texttt{addr}]$.
Also return the choicemap (choice_values
) that is the restriction of $t$ to $A$.
Gen.accumulate_param_gradients!
— Functionarg_grads = accumulate_param_gradients!(trace, retgrad=nothing, scale_factor=1.)
Increment gradient accumulators for parameters by the gradient of the log-probability of the trace, optionally scaled, and return the gradient with respect to the arguments (not scaled).
Given a previous trace $(x, t)$ (trace
) and a gradient with respect to the return value $∇_y J$ (retgrad
), return the following gradient (arg_grads
) with respect to the arguments $x$:
\[∇_x \left( \log P(t; x) + J \right)\]
The length of arg_grads
will be equal to the number of arguments to the function that generated trace
(including any optional trailing arguments). If an argument is not annotated with (grad)
, the corresponding value in arg_grads
will be nothing
.
Also increment the gradient accumulators for the trainable parameters $Θ$ of the function by:
\[s * ∇_Θ \left( \log P(t; x) + J \right)\]
where $s$ is scale_factor
.
Gen.get_params
— Functionget_params(gen_fn::GenerativeFunction)
Return an iterable over the trainable parameters of the generative function.