Trace Translators
While generative functions define probability distributions on traces, Trace Translators convert from one space of traces to another space of traces. Trace translators are building blocks of inference programs that utilize multiple model representations, like Involutive MCMC.
Trace translators are significantly more general than Bijectors. Trace translators can (i) convert between spaces of traces that include mixed numeric discrete random choices, as well as stochastic control flow, and (ii) convert between spaces for which there is no one-to-one correspondence (e.g. between models of different dimensionality, or between discrete and continuous models). Bijectors are limited to deterministic transformations between real-valued vectors of constant dimension.
Deterministic Trace Translators
Inference programs manipulate traces, but they also keep track of probabilities and probability densities associated with these traces. Suppose we have two generative functions p1
and p2
. Given a trace t2
of p2
we can easily compute the probability (or probability density) that the trace would have been generated by p2
using get_score(t2)
. But suppose we want to construct the trace of p2
first sampling a trace t1
of p1
and then applying a deterministic transformation to that trace to obtain t2
. How can we compute the probability that a trace t2
would have been produced by this process? This probability is needed if, for example, p2
defines a probabilistic model and want to use p1
as a proposal distribution within importance sampling. If we produce t2
via an arbitrary deterministic transformation of the random choices in t1
, then computing the necessary probability is difficult.
If we restrict ourselves to deterministic transformations that are bijections (one-to-one correspondences) from the set of traces of p1
to the set of traces of p2
, then the problem is much simplified. If the transformation is a bijection this means that (i) each trace of p1
gets mapped to a different trace of p2
, and (ii) for every trace of p2
there is some trace of p1
that maps to it. Bijective transformations between traces are useful components of inference programs because the probability that a given trace t2
of p2
would have been produced by first sampling from p1
and then applying the transform can be computed simply as the probability that p1
would produce the (unique) trace t1
that gets mapped to the given trace by the transform. Conceptually, bijective trace transforms preserve probability. When trace transforms operate on traces with continuous random choices, computing probability densities of the transformed traces requires computing a Jacobian associated with the continuous part of the transformation.
Gen provides a DSL for expressing bijections between spaces of traces, called the Trace Transform DSL. We introduce this DSL via an example. Below are two generative functions. The first samples polar coordinates and the second uses cartesian coordinates.
@gen function p1()
r ~ inv_gamma(1, 1)
theta ~ uniform(-pi/2, pi/2)
end
@gen function p2()
x ~ normal(0, 1)
y ~ normal(0, 1)
end
Defining a trace transform with the Trace Transform DSL
The following trace transform DSL program defines a transformation (called f
) that transforms traces of p1
into traces of p2
:
@transform f (t1) to (t2) begin
r = @read(t1[:r], :continuous)
theta = @read(t1[:theta], :continuous)
@write(t2[:x], r * cos(theta), :continuous)
@write(t2[:y], r * sin(theta), :continuous)
end
This transform reads values of random choices in the input trace (t1
) at specific addresses (indicated by the syntax t1[addr]
) using @read
and writes values to the output trace (t2
) using @write
. Each read and write operation is labeled with whether the random choice is discrete or continuous. The section Trace Transform DSL defines the DSL in more detail.
It is usually a good idea to write the inverse of the bijection. The inverse can provide a dynamic check that the transform truly is a bijection. The inverse of the above transformation is:
@transform finv (t2) to (t1) begin
x = @read(t2[:x], :continuous)
y = @read(t2[:y], :continuous)
r = sqrt(x^2 + y^2)
@write(t1[:r], sqrt(x^2 + y^2), :continuous)
@write(t1[:theta], atan(y, x), :continuous)
end
We can inform Gen that two transforms are inverses of one another using pair_bijections!
:
pair_bijections!(f, finv)
Wrapping a trace transform in a trace translator
Note that the transform DSL code does not specify what the two generative functions are, or what the arguments to these generative functions are. This information will be required for computing probabilities and probability densities of traces. We provide this information by constructing a Trace Translator that wraps the transform along with this transformation:
translator = DeterministicTraceTranslator(p2, (), choicemap(), f)
We then can then apply the translator to a trace of p1
using function call syntax. The translator returns a trace of p2
and a log-weight that we can use to compute the probability (density) of the resulting trace:
t2, log_weight = translator(t1)
Specifically, the log probability (density) that the trace t2
was produced by first sampling t1 = simulate(p1, ())
and then applying the trace translator, is:
get_score(t1) + log_weight
Let's unpack the previous few code blocks in more detail. First, note that we did not pass in the source generative function (p1
) or its arguments (()
) when we constructed the trace translator. This is because this information will be attached to the input trace t1
itself. We did need to pass in the target generative function (p2
) and its arguments (()
) however, because this information is not included in t1
.
In this case, because continuous random choices are involved, the probabilities are probability densities, and the trace translator used automatic differentiation of the code in the transform f
to compute a change-of-variables Jacobian that is necessary to compute the correct probability density of the new trace t2
.
Observations
Typically, there are observations associated with one or both of the generative functions involved, and we have values for these in a choice map, so we do not want the trace translator to be responsible for populating the values of these observed random choices. For example, suppose we want to condition p2
on an observed random choice z
:
@gen function p2()
x ~ normal(0, 1)
y ~ normal(0, 1)
z ~ normal(x + y, 0.1)
end
observations = choicemap()
observations[:z] = 2.3
When p2
has observations, these can be passed in as an additional argument to the trace translator constructor:
translator = DeterministicTraceTranslator(p2, (), observations, f)
Discrete random choices and stochastic control flow
Trace transforms and trace translators interoperate seamlessly with discrete random choices and stochastic control flow. Both the input and the output traces can contain a mix of discrete and continuous choices, and arbitrary stochastic control flow. Consider the following simple example, where we use two different discrete representations to represent probability distributions the integers 0-7:
@gen function p1()
bit1 ~ bernoulli(0.5)
bit2 ~ bernoulli(0.5)
bit3 ~ bernoulli(0.5)
end
@gen function p2()
n ~ categorical([0.1, 0.1, 0.1, 0.2, 0.2, 0.15, 0.15])
end
We define the forward and inverse transforms:
@transform f (t1) to (t2) begin
n = (
@read(t1[:bit1], :discrete) * 1 +
@read(t1[:bit2], :discrete) * 2 +
@read(t1[:bit3], :discrete) * 4)
@write(t2[:n], n, :discrete)
end
@transform finv (t2) to (t1) begin
bits = digits(@read(t2[:n], :discrete), base=2)
@write(t1[:bit1], bits[1], :discrete)
@write(t1[:bit2], bits[2], :discrete)
@write(t1[:bit3], bits[3], :discrete)
end
Here is an example that includes discrete random choices, stochastic control flow, and continuous random choices.
@gen function p1()
if ({:branch} ~ bernoulli(0.5))
x ~ normal(0, 1)
else
other ~ categorical([0.3, 0.7])
end
end
@gen function p2()
k ~ uniform_discrete(1, 4)
if k <= 2
y ~ gamma(1, 1)
end
end
Note that transformations between spaces of traces need not be intuitive (although they probably should)! Try to convince yourself that the functions below are indeed a pair of bijections between the traces of these two generative functions.
@transform f (t1) to (t2) begin
if @read(t1[:branch], :discrete)
x = @read(t1[:x], :continuous)
if x > 0
@write(t2[:k], 2, :discrete)
else
@write(t2[:k], 1, :discrete)
end
@write(t2[:y], abs(x), :continuous)
else
other = @read(t1[:other], :discrete)
@write(t2[:k], (other == 1) ? 3 : 4, :discrete)
end
end
@transform finv (t2) to (t1) begin
k = @read(t2[:k], :discrete)
if k <= 2
y = @read(t2[:y], :continuous)
@write(t2[:x], (k == 1) ? -y : y, :continuous)
else
@write(t1[:other], (k == 3) ? 1 : 2, :discrete)
end
end
General Trace Translators
Note that for two arbitrary generative functions p1
and p2
there may not exist any one-to-one correspondence between their spaces of traces. For example, consider a generative function p1
that samples points within the unit square $[0, 1]^2$
@gen function p1()
x ~ uniform(0, 1)
y ~ uniform(0, 1)
end
and another generative function p2
that samples one of 100 possible discrete values, each value representing one cell of the unit square:
@gen function p2()
i ~ uniform_discrete(1, 10) # interval [(i-1)/10, i/10]
j ~ uniform_discrete(1, 10) # interval [(j-1)/10, j/10]
end
There is no one-to-one correspondence between the spaces of traces of these two generative functions: The first is an uncountably infinite set, and the other is a finite set with 100 elements in it.
However, there is an intuitive notion of correspondence that we would like to be able to encode. Each discrete cell $(i, j)$ corresponds to a subset of the unit square $[(i - 1)/10, i/10] \times [(j-1)/10, j/10]$.
We can express this correspondence (and any correspondence between two arbitrary generative functions) by introducing two auxiliary generative functions q1
and q2
. The first function q1
will take a trace of p1
as input, and the second function q2
will take a trace of p2
as input. Then, instead of a transfomation between traces of p1
and traces of p2
our trace transform will transform between (i) the space of pairs of traces of p1
and q1
and (ii) the space of pairs of traces of p2
and q2
. We construct q1
and q2
so that the two spaces have the same size, and a one-to-one correspondence is possible.
For our example above, we construct q2
to sample the coordinate ($[0, 0.1]^2$) relative to the cell. We construct q1
to be empty–there is already a mapping from each trace of p1
to each trace of p2
that simply identifies what cell $(i, j)$ a given point in $[0, 1]^2$ is in, so no extra random choices are needed.
@gen function q1(p1_trace)
end
@gen function q2(p2_trace)
dx ~ uniform(0.0, 0.1)
dy ~ uniform(0.0, 0.1)
end
Trace transforms between pairs of traces
To handle general trace translators that require auxiliary probability distributions, the trace trace DSL supports defining transformations between pairs of traces. For example, the following defines a trace transform that maps from pairs of traces of p1
and q1
to pairs of traces of p2
and q2
:
@transform f (p1_trace, q1_trace) to (p2_trace, q2_trace) begin
x = @read(p1_trace[:x], :continuous)
y = @read(p1_trace[:y], :continuous)
i = ceil(x * 10)
j = ceil(y * 10)
@write(p2_trace[:i], i, :discrete)
@write(p2_trace[:j], j, :discrete)
@write(q2_trace[:dx], x - (i-1)/10, :continuous)
@write(q2_trace[:dy], y - (j-1)/10, :continuous)
end
and the inverse transform:
@transform f_inv (p2_trace, q2_trace) to (p1_trace, q1_trace) begin
i = @read(p2_trace[:i], :discrete)
j = @read(p2_trace[:j], :discrete)
dx = @read(q2_trace[:dx], :continuous)
dy = @read(q2_trace[:dy], :continuous)
x = (i-1)/10 + dx
y = (j-1)/10 + dy
@write(p1_trace[:x], x, :continuous)
@write(p1_trace[:y], y, :continuous)
end
which we associate as inverses:
pair_bijections!(f, f_inv)
Constructing a general trace translator
We now wrap the transform above into a general trace translator, by providing the three probabilistic programs p2
, q1
, q2
that it uses (a reference to p1
will included in the input trace), and the arguments to these functions.
translator = GeneralTraceTranslator(
p_new=p2,
p_new_args=(),
new_observations=choicemap(),
q_forward=q1,
q_forward_args=(),
q_backward=q2,
q_backward_args=(),
f=f)
Then, we can apply the trace translator to a trace (t1
) of p1
and get a trace (t2
) of p2
and a log-weight:
t2, log_weight = translator(t1)
Symmetric Trace Translators
When the previous and new generative functions (e.g. p1
and p2
in the previous example) are the same, and their arguments are the same, and q_forward
and q_backward
(and their arguments) are also identical, we call this the trace translator a Symmetric Trace Translator. Symmetric trace translators are important because they form the basis of Involutive MCMC. Instead of translating a trace of one generative function to the trace of another generative function, they translate a trace of a generative function to another trace of the same generative function.
Symmetric trace translators have the interesting property that the function f
is an involution, or a function that is its own inverse. To indicate that a trace transform is an involution, use is_involution!
.
Because symmetric trace translators translate within the same generative function, their implementation uses update
to incrementally modify the trace from the previous to the new trace. This has two benefits when the previous and new traces have random choices that aren't modified between them: (i) the incremental modification may be more efficient than writing the new trace entirely from scratch, and (ii) the transform DSL program does not need to specify a value for addresses whose value is not changed from the previous trace.
Simple Extending Trace Translators
Simple extending trace translators extend an existing trace with new random choices sampled from a proposal distribution, as well as any new observations. The arguments of the trace may also be updated. This type of trace translation is the basic operation used in Particle Filtering. For example, we might have a model that sequentially samples new latent variables (:z, t)
and observations (:x, t)
for each timestep t
:
@gen function model(T::Int)
for t in 1:T
z = {(:z, t)} ~ normal(0, 1)
x = {(:x, t)} ~ normal(z, 1)
end
end
Each time we observe a new (:x ,t)
, we might want to propose (:z, t)
so that it is close in value:
@gen function proposal(trace::Trace, x::Real)
t = get_args(trace)[1] + 1
{(:z, t)} ~ normal(x, 1)
end
Suppose we initially generated a trace up to timestep t=1
, e.g. by calling t1 = simulate(model, (1,))
. Now we observe (:x, 2)
to be 5.0
. By constructing a simple extending trace translator, we can simultaneously update the trace t1
with new arguments, introduce the new observation at (:x, 2)
, and propose a likely value for (:z, 2)
:
translator = SimpleExtendingTraceTranslator(
p_new_args=(2,), p_argdiffs=(UnknownChange(),),
new_observations=choicemap((:x, 2) => 5.0),
q_forward=proposal, q_forward_args=(5.0,))
t2, log_weight = translator(t1)
Similar functionality can be achieved through a combination of propose
on the proposal and update
on the original trace, but using a trace translator provides a nice layer of abstraction.
Trace Transform DSL
The Trace Transform DSL is a differentiable programming language for writing deterministic transformations of traces. Programs written in this DSL are called transforms. Transforms read the value of random choices from input trace(s) and write values of random choices to output trace(s). These programs are not typically executed directly by users, but are instead wrapped into one of the several forms of trace translators listed above (GeneralTraceTranslator
, and SymmetricTraceTranslator
).
A transform is identified with the @transform
macro, and uses one of the following four syntactic forms for the signature (the name of the transform, and the names of the input and output traces are all user-defined varibles; the only keywords are @transform
, to
, begin
, and end
):
A transform from one trace to another, without extra parameters
@transform f t1 to t2 begin
...
end
A transform from one trace to another, with extra parameters
@transform f(x, y, ..) t1 to t2 begin
...
end
A transform from pairs of traces to pairs of traces, without extra parameters
@transform f (t1, t2) to (t3, t4) begin
...
end
A transform from one trace to another, with extra parameters
@transform f(x, y, ..) (t1, t2) to (t3, t4) begin
...
end
The extra parameters are optional, and can be used to pass arguments to a transform function that is invoked, from another transform function, using the @tcall
macro. For example:
@transform g(x) t1 to t2 begin
...
end
@transform f t1 to t2 begin
x = ..
@tcall(g(x))
end
The callee transform function (g
above) reads and writes to the same input and output traces as the caller transform function (f
above). Note that the input and output traces can be assigned different names in the caller and the callee.
The body of a transform reads the values of random choices at addresses in the input trace(s), performs computation using regular Julia code (provided this code can be differentiated with ForwardDiff.jl) and writes valeus of random choices at addresses in the output trace(s). In the body @read
expressions read a value from a particular address of an input trace:
val = @read(<source>, <type-label>)
where <source>
is an expression of the form <trace>[<addr>]
where <trace>
must be the name of an input trace in the transform's signature. The <type-label>
is either :continuous
or :discrete
, and indicates whether the random choice is discrete or continuous (in measure-theoretic terms, whether it uses the counting measure, or a Lebesgue-measure a Euclidean space of some dimension). Similarly, @write
expressions write a value to a particular address in an output trace:
@write(<destination>, <value>, <type-label>)
Sometimes trace transforms need to directly copy the value from one address in an input trace to one address in an output trace. In these cases, it is recommended to use the explicit @copy
expression:
@copy(<source>, <destination>)
where <source>
and <destination>
are of the form <trace>[<addr>]
as above. Note you can also copy collections of multiple random choices under an address namespace in an input trace to an address namespace in an output trace. For example,
@copy(trace1[:foo], trace2[:bar])
will copy every random choice in trace1
with an address of the form :foo => <rest>
to trace2
at address :bar => <rest>
.
It is also possible to read the return value from an input trace using the following syntax, but this value must be discrete (in the local neighborhood of traces, the return value must be constant as a function of all continuous random choices in input traces):
val = @read(<trace>[], :discrete)
This feature is useful when the generative function precomputes a quantity as part of its return value, and we would like to reuse this value, instead of having to recompute it as part of the transform. The `discrete' requirement is needed because the transform DSL does not currently backpropagate through the return value (this feature could be added in the future).
Tips for defining valid transforms:
If you find yourself copying the same continuous source address to multiple locations, it probably means your transform is not valid (the Jacobian matrix will have rows that are identical, and so the Jacobian determinant will be zero).
You can gain some confidence that your transform is valid by enabling dynamic checks (
check=true
) in the trace translator that uses it.