Trace Translators

While generative functions define probability distributions on traces, Trace Translators convert from one space of traces to another space of traces. Trace translators are building blocks of inference programs that utilize multiple model representations, like Involutive MCMC.

Trace translators are significantly more general than Bijectors. Trace translators can (i) convert between spaces of traces that include mixed numeric discrete random choices, as well as stochastic control flow, and (ii) convert between spaces for which there is no one-to-one correspondence (e.g. between models of different dimensionality, or between discrete and continuous models). Bijectors are limited to deterministic transformations between real-valued vectors of constant dimension.

Deterministic Trace Translators

Inference programs manipulate traces, but they also keep track of probabilities and probability densities associated with these traces. Suppose we have two generative functions p1 and p2. Given a trace t2 of p2 we can easily compute the probability (or probability density) that the trace would have been generated by p2 using get_score(t2). But suppose we want to construct the trace of p2 first sampling a trace t1 of p1 and then applying a deterministic transformation to that trace to obtain t2. How can we compute the probability that a trace t2 would have been produced by this process? This probability is needed if, for example, p2 defines a probabilistic model and want to use p1 as a proposal distribution within importance sampling. If we produce t2 via an arbitrary deterministic transformation of the random choices in t1, then computing the necessary probability is difficult.

If we restrict ourselves to deterministic transformations that are bijections (one-to-one correspondences) from the set of traces of p1 to the set of traces of p2, then the problem is much simplified. If the transformation is a bijection this means that (i) each trace of p1 gets mapped to a different trace of p2, and (ii) for every trace of p2 there is some trace of p1 that maps to it. Bijective transformations between traces are useful components of inference programs because the probability that a given trace t2 of p2 would have been produced by first sampling from p1 and then applying the transform can be computed simply as the probability that p1 would produce the (unique) trace t1 that gets mapped to the given trace by the transform. Conceptually, bijective trace transforms preserve probability. When trace transforms operate on traces with continuous random choices, computing probability densities of the transformed traces requires computing a Jacobian associated with the continuous part of the transformation.

Gen provides a DSL for expressing bijections between spaces of traces, called the Trace Transform DSL. We introduce this DSL via an example. Below are two generative functions. The first samples polar coordinates and the second uses cartesian coordinates.

@gen function p1()
    r ~ inv_gamma(1, 1)
    theta ~ uniform(-pi/2, pi/2)
end
@gen function p2()
    x ~ normal(0, 1)
    y ~ normal(0, 1)
end

Defining a trace transform with the Trace Transform DSL

The following trace transform DSL program defines a transformation (called f) that transforms traces of p1 into traces of p2:

@transform f (t1) to (t2) begin
    r = @read(t1[:r], :continuous)
    theta = @read(t1[:theta], :continuous)
    @write(t2[:x], r * cos(theta), :continuous)
    @write(t2[:y], r * sin(theta), :continuous)
end

This transform reads values of random choices in the input trace (t1) at specific addresses (indicated by the syntax t1[addr]) using @read and writes values to the output trace (t2) using @write. Each read and write operation is labeled with whether the random choice is discrete or continuous. The section Trace Transform DSL defines the DSL in more detail.

It is usually a good idea to write the inverse of the bijection. The inverse can provide a dynamic check that the transform truly is a bijection. The inverse of the above transformation is:

@transform finv (t2) to (t1) begin
    x = @read(t2[:x], :continuous)
    y = @read(t2[:y], :continuous)
    r = sqrt(x^2 + y^2)
    @write(t1[:r], sqrt(x^2 + y^2), :continuous)
    @write(t1[:theta], atan(y, x), :continuous)
end

We can inform Gen that two transforms are inverses of one another using pair_bijections!:

pair_bijections!(f, finv)

Wrapping a trace transform in a trace translator

Note that the transform DSL code does not specify what the two generative functions are, or what the arguments to these generative functions are. This information will be required for computing probabilities and probability densities of traces. We provide this information by constructing a Trace Translator that wraps the transform along with this transformation:

translator = DeterministicTraceTranslator(p2, (), choicemap(), f)

We then can then apply the translator to a trace of p1 using function call syntax. The translator returns a trace of p2 and a log-weight that we can use to compute the probability (density) of the resulting trace:

t2, log_weight = translator(t1)

Specifically, the log probability (density) that the trace t2 was produced by first sampling t1 = simulate(p1, ()) and then applying the trace translator, is:

get_score(t1) + log_weight

Let's unpack the previous few code blocks in more detail. First, note that we did not pass in the source generative function (p1) or its arguments (()) when we constructed the trace translator. This is because this information will be attached to the input trace t1 itself. We did need to pass in the target generative function (p2) and its arguments (()) however, because this information is not included in t1.

In this case, because continuous random choices are involved, the probabilities are probability densities, and the trace translator used automatic differentiation of the code in the transform f to compute a change-of-variables Jacobian that is necessary to compute the correct probability density of the new trace t2.

Observations

Typically, there are observations associated with one or both of the generative functions involved, and we have values for these in a choice map, so we do not want the trace translator to be responsible for populating the values of these observed random choices. For example, suppose we want to condition p2 on an observed random choice z:

@gen function p2()
    x ~ normal(0, 1)
    y ~ normal(0, 1)
    z ~ normal(x + y, 0.1)
end
observations = choicemap()
observations[:z] = 2.3

When p2 has observations, these can be passed in as an additional argument to the trace translator constructor:

translator = DeterministicTraceTranslator(p2, (), observations, f)

Discrete random choices and stochastic control flow

Trace transforms and trace translators interoperate seamlessly with discrete random choices and stochastic control flow. Both the input and the output traces can contain a mix of discrete and continuous choices, and arbitrary stochastic control flow. Consider the following simple example, where we use two different discrete representations to represent probability distributions the integers 0-7:

@gen function p1()
    bit1 ~ bernoulli(0.5)
    bit2 ~ bernoulli(0.5)
    bit3 ~ bernoulli(0.5)
end
@gen function p2()
    n ~ categorical([0.1, 0.1, 0.1, 0.2, 0.2, 0.15, 0.15])
end

We define the forward and inverse transforms:

@transform f (t1) to (t2) begin
    n = (
        @read(t1[:bit1], :discrete) * 1 +
        @read(t1[:bit2], :discrete) * 2 +
        @read(t1[:bit3], :discrete) * 4)
    @write(t2[:n], n, :discrete)
end
@transform finv (t2) to (t1) begin
    bits = digits(@read(t2[:n], :discrete), base=2)
    @write(t1[:bit1], bits[1], :discrete)
    @write(t1[:bit2], bits[2], :discrete)
    @write(t1[:bit3], bits[3], :discrete)
end

Here is an example that includes discrete random choices, stochastic control flow, and continuous random choices.

@gen function p1()
    if ({:branch} ~ bernoulli(0.5))
        x ~ normal(0, 1)
    else
        other ~ categorical([0.3, 0.7])
    end
end
@gen function p2()
    k ~ uniform_discrete(1, 4)
    if k <= 2
        y ~ gamma(1, 1)
    end
end

Note that transformations between spaces of traces need not be intuitive (although they probably should)! Try to convince yourself that the functions below are indeed a pair of bijections between the traces of these two generative functions.

@transform f (t1) to (t2) begin
    if @read(t1[:branch], :discrete)
        x = @read(t1[:x], :continuous)
        if x > 0
            @write(t2[:k], 2, :discrete)
        else
            @write(t2[:k], 1, :discrete)
        end
        @write(t2[:y], abs(x), :continuous)
    else
        other = @read(t1[:other], :discrete)
        @write(t2[:k], (other == 1) ? 3 : 4, :discrete)
    end
end
@transform finv (t2) to (t1) begin
    k = @read(t2[:k], :discrete)
    if k <= 2
        y = @read(t2[:y], :continuous)
        @write(t2[:x], (k == 1) ? -y : y, :continuous)
    else
        @write(t1[:other], (k == 3) ? 1 : 2, :discrete)
    end
end

General Trace Translators

Note that for two arbitrary generative functions p1 and p2 there may not exist any one-to-one correspondence between their spaces of traces. For example, consider a generative function p1 that samples points within the unit square $[0, 1]^2$

@gen function p1()
    x ~ uniform(0, 1)
    y ~ uniform(0, 1)
end

and another generative function p2 that samples one of 100 possible discrete values, each value representing one cell of the unit square:

@gen function p2()
    i ~ uniform_discrete(1, 10) # interval [(i-1)/10, i/10]
    j ~ uniform_discrete(1, 10) # interval [(j-1)/10, j/10]
end

There is no one-to-one correspondence between the spaces of traces of these two generative functions: The first is an uncountably infinite set, and the other is a finite set with 100 elements in it.

However, there is an intuitive notion of correspondence that we would like to be able to encode. Each discrete cell $(i, j)$ corresponds to a subset of the unit square $[(i - 1)/10, i/10] \times [(j-1)/10, j/10]$.

We can express this correspondence (and any correspondence between two arbitrary generative functions) by introducing two auxiliary generative functions q1 and q2. The first function q1 will take a trace of p1 as input, and the second function q2 will take a trace of p2 as input. Then, instead of a transfomation between traces of p1 and traces of p2 our trace transform will transform between (i) the space of pairs of traces of p1 and q1 and (ii) the space of pairs of traces of p2 and q2. We construct q1 and q2 so that the two spaces have the same size, and a one-to-one correspondence is possible.

For our example above, we construct q2 to sample the coordinate ($[0, 0.1]^2$) relative to the cell. We construct q1 to be empty–there is already a mapping from each trace of p1 to each trace of p2 that simply identifies what cell $(i, j)$ a given point in $[0, 1]^2$ is in, so no extra random choices are needed.

@gen function q1(p1_trace)
end

@gen function q2(p2_trace)
    dx ~ uniform(0.0, 0.1)
    dy ~ uniform(0.0, 0.1)
end

Trace transforms between pairs of traces

To handle general trace translators that require auxiliary probability distributions, the trace trace DSL supports defining transformations between pairs of traces. For example, the following defines a trace transform that maps from pairs of traces of p1 and q1 to pairs of traces of p2 and q2:

@transform f (p1_trace, q1_trace) to (p2_trace, q2_trace) begin
    x = @read(p1_trace[:x], :continuous)
    y = @read(p1_trace[:y], :continuous)
    i = ceil(x * 10)
    j = ceil(y * 10)
    @write(p2_trace[:i], i, :discrete)
    @write(p2_trace[:j], j, :discrete)
    @write(q2_trace[:dx], x - (i-1)/10, :continuous)
    @write(q2_trace[:dy], y - (j-1)/10, :continuous)
end

and the inverse transform:

@transform f_inv (p2_trace, q2_trace) to (p1_trace, q1_trace) begin
    i = @read(p2_trace[:i], :discrete)
    j = @read(p2_trace[:j], :discrete)
    dx = @read(q2_trace[:dx], :continuous)
    dy = @read(q2_trace[:dy], :continuous)
    x = (i-1)/10 + dx
    y = (j-1)/10 + dy
    @write(p1_trace[:x], x, :continuous)
    @write(p1_trace[:y], y, :continuous)
end

which we associate as inverses:

pair_bijections!(f, f_inv)

Constructing a general trace translator

We now wrap the transform above into a general trace translator, by providing the three probabilistic programs p2, q1, q2 that it uses (a reference to p1 will included in the input trace), and the arguments to these functions.

translator = GeneralTraceTranslator(
    p_new=p2,
    p_new_args=(),
    new_observations=choicemap(),
    q_forward=q1,
    q_forward_args=(),
    q_backward=q2,
    q_backward_args=(),
    f=f)

Then, we can apply the trace translator to a trace (t1) of p1 and get a trace (t2) of p2 and a log-weight:

t2, log_weight = translator(t1)

Symmetric Trace Translators

When the previous and new generative functions (e.g. p1 and p2 in the previous example) are the same, and their arguments are the same, and q_forward and q_backward (and their arguments) are also identical, we call this the trace translator a Symmetric Trace Translator. Symmetric trace translators are important because they form the basis of Involutive MCMC. Instead of translating a trace of one generative function to the trace of another generative function, they translate a trace of a generative function to another trace of the same generative function.

Symmetric trace translators have the interesting property that the function f is an involution, or a function that is its own inverse. To indicate that a trace transform is an involution, use is_involution!.

Because symmetric trace translators translate within the same generative function, their implementation uses update to incrementally modify the trace from the previous to the new trace. This has two benefits when the previous and new traces have random choices that aren't modified between them: (i) the incremental modification may be more efficient than writing the new trace entirely from scratch, and (ii) the transform DSL program does not need to specify a value for addresses whose value is not changed from the previous trace.

Simple Extending Trace Translators

Simple extending trace translators extend an existing trace with new random choices sampled from a proposal distribution, as well as any new observations. The arguments of the trace may also be updated. This type of trace translation is the basic operation used in Particle Filtering. For example, we might have a model that sequentially samples new latent variables (:z, t) and observations (:x, t) for each timestep t:

@gen function model(T::Int)
    for t in 1:T
        z = {(:z, t)} ~ normal(0, 1)
        x = {(:x, t)} ~ normal(z, 1)
    end
end

Each time we observe a new (:x ,t), we might want to propose (:z, t) so that it is close in value:

@gen function proposal(trace::Trace, x::Real)
    t = get_args(trace)[1] + 1
    {(:z, t)} ~ normal(x, 1)
end

Suppose we initially generated a trace up to timestep t=1, e.g. by calling t1 = simulate(model, (1,)). Now we observe (:x, 2) to be 5.0. By constructing a simple extending trace translator, we can simultaneously update the trace t1 with new arguments, introduce the new observation at (:x, 2), and propose a likely value for (:z, 2):

translator = SimpleExtendingTraceTranslator(
    p_new_args=(2,), p_argdiffs=(UnknownChange(),),
    new_observations=choicemap((:x, 2) => 5.0),
    q_forward=proposal, q_forward_args=(5.0,))
t2, log_weight = translator(t1)

Similar functionality can be achieved through a combination of propose on the proposal and update on the original trace, but using a trace translator provides a nice layer of abstraction.

Trace Transform DSL

The Trace Transform DSL is a differentiable programming language for writing deterministic transformations of traces. Programs written in this DSL are called transforms. Transforms read the value of random choices from input trace(s) and write values of random choices to output trace(s). These programs are not typically executed directly by users, but are instead wrapped into one of the several forms of trace translators listed above (GeneralTraceTranslator, and SymmetricTraceTranslator).

A transform is identified with the @transform macro, and uses one of the following four syntactic forms for the signature (the name of the transform, and the names of the input and output traces are all user-defined varibles; the only keywords are @transform, to, begin, and end):

A transform from one trace to another, without extra parameters

@transform f t1 to t2 begin
    ...
end

A transform from one trace to another, with extra parameters

@transform f(x, y, ..) t1 to t2 begin
    ...
end

A transform from pairs of traces to pairs of traces, without extra parameters

@transform f (t1, t2) to (t3, t4) begin
    ...
end

A transform from one trace to another, with extra parameters

@transform f(x, y, ..) (t1, t2) to (t3, t4) begin
    ...
end

The extra parameters are optional, and can be used to pass arguments to a transform function that is invoked, from another transform function, using the @tcall macro. For example:

@transform g(x) t1 to t2 begin
    ...
end
@transform f t1 to t2 begin
    x = ..
    @tcall(g(x))
end

The callee transform function (g above) reads and writes to the same input and output traces as the caller transform function (f above). Note that the input and output traces can be assigned different names in the caller and the callee.

The body of a transform reads the values of random choices at addresses in the input trace(s), performs computation using regular Julia code (provided this code can be differentiated with ForwardDiff.jl) and writes valeus of random choices at addresses in the output trace(s). In the body @read expressions read a value from a particular address of an input trace:

val = @read(<source>, <type-label>)

where <source> is an expression of the form <trace>[<addr>] where <trace> must be the name of an input trace in the transform's signature. The <type-label> is either :continuous or :discrete, and indicates whether the random choice is discrete or continuous (in measure-theoretic terms, whether it uses the counting measure, or a Lebesgue-measure a Euclidean space of some dimension). Similarly, @write expressions write a value to a particular address in an output trace:

@write(<destination>, <value>, <type-label>)

Sometimes trace transforms need to directly copy the value from one address in an input trace to one address in an output trace. In these cases, it is recommended to use the explicit @copy expression:

@copy(<source>, <destination>)

where <source> and <destination> are of the form <trace>[<addr>] as above. Note you can also copy collections of multiple random choices under an address namespace in an input trace to an address namespace in an output trace. For example,

@copy(trace1[:foo], trace2[:bar])

will copy every random choice in trace1 with an address of the form :foo => <rest> to trace2 at address :bar => <rest>.

It is also possible to read the return value from an input trace using the following syntax, but this value must be discrete (in the local neighborhood of traces, the return value must be constant as a function of all continuous random choices in input traces):

val = @read(<trace>[], :discrete)

This feature is useful when the generative function precomputes a quantity as part of its return value, and we would like to reuse this value, instead of having to recompute it as part of the transform. The `discrete' requirement is needed because the transform DSL does not currently backpropagate through the return value (this feature could be added in the future).

Tips for defining valid transforms:

  • If you find yourself copying the same continuous source address to multiple locations, it probably means your transform is not valid (the Jacobian matrix will have rows that are identical, and so the Jacobian determinant will be zero).

  • You can gain some confidence that your transform is valid by enabling dynamic checks (check=true) in the trace translator that uses it.