Generative Function Combinators
Generative function combinators are Julia functions that take one or more generative functions as input and return a new generative function. Generative function combinators are used to express patterns of repeated computation that appear frequently in generative models. Some generative function combinators are similar to higher order functions from functional programming languages. However, generative function combinators are not 'higher order generative functions', because they are not themselves generative functions (they are regular Julia functions).
Map combinator
Gen.Map
— Typegen_fn = Map(kernel::GenerativeFunction)
Return a new generative function that applies the kernel independently for a vector of inputs.
The returned generative function has one argument with type Vector{X}
for each argument of the input generative function with type X
. The length of each argument, which must be the same for each argument, determines the number of times the input generative function is called (N). Each call to the input function is made under address namespace i for i=1..N. The return value of the returned function has type FunctionalCollections.PersistentVector{Y}
where Y
is the type of the return value of the input function. The map combinator is similar to the 'map' higher order function in functional programming, except that the map combinator returns a new generative function that must then be separately applied.
If kernel
has optional trailing arguments, the corresponding Vector
arguments can be omitted from calls to Map(kernel)
.
In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$.
For example, consider the following generative function, which makes one random choice at address :z
:
@gen function foo(x1::Float64, x2::Float64)
y = @trace(normal(x1 + x2, 1.0), :z)
return y
end
We apply the map combinator to produce a new generative function bar
:
bar = Map(foo)
We can then obtain a trace of bar
:
(trace, _) = generate(bar, ([0.0, 0.5], [0.5, 1.0]))
This causes foo
to be invoked twice, once with arguments (0.0, 0.5)
in address namespace 1
and once with arguments (0.5, 1.0)
in address namespace 2
. If the resulting trace has random choices:
│
├── 1
│ │
│ └── :z : -0.5757913836706721
│
└── 2
│
└── :z : 0.7357177113395333
then the return value is:
FunctionalCollections.PersistentVector{Any}[-0.575791, 0.735718]
Unfold combinator
Gen.Unfold
— Typegen_fn = Unfold(kernel::GenerativeFunction)
Return a new generative function that applies the kernel in sequence, passing the return value of one application as an input to the next.
The kernel accepts the following arguments:
The first argument is the
Int
index indicating the position in the sequence (starting from 1).The second argument is the state.
The kernel may have additional arguments after the state.
The return type of the kernel must be the same type as the state.
The returned generative function accepts the following arguments:
The number of times (N) to apply the kernel.
The initial state.
The rest of the arguments (not including the state) that will be passed to each kernel application.
The return type of the returned generative function is FunctionalCollections.PersistentVector{T}
where T
is the return type of the kernel.
If kernel
has optional trailing arguments, the corresponding arguments can be omitted from calls to Unfold(kernel)
.
In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$. The initial state is denoted $y_0$, the number of applications is $n$, and the remaining arguments to the kernel not including the state, are $z$.
For example, consider the following kernel, with state type Bool
, which makes one random choice at address :z
:
@gen function foo(t::Int, y_prev::Bool, z1::Float64, z2::Float64)
y = @trace(bernoulli(y_prev ? z1 : z2), :y)
return y
end
We apply the map combinator to produce a new generative function bar
:
bar = Unfold(foo)
We can then obtain a trace of bar
:
(trace, _) = generate(bar, (5, false, 0.05, 0.95))
This causes foo
to be invoked five times. The resulting trace may contain the following random choices:
│
├── 1
│ │
│ └── :y : true
│
├── 2
│ │
│ └── :y : false
│
├── 3
│ │
│ └── :y : true
│
├── 4
│ │
│ └── :y : false
│
└── 5
│
└── :y : true
then the return value is:
FunctionalCollections.PersistentVector{Any}[true, false, true, false, true]
Recurse combinator
Gen.Recurse
— TypeRecurse(production_kernel, aggregation_kernel, max_branch,
::Type{U}, ::Type{V}, ::Type{W})
Constructor for recurse production and aggregation function.
Switch combinator
Gen.Switch
— Typegen_fn = Switch(gen_fns::GenerativeFunction...)
Returns a new generative function that accepts an argument tuple of type Tuple{Int, ...}
where the first index indicates which branch to call.
gen_fn = Switch(d::Dict{T, Int}, gen_fns::GenerativeFunction...) where T
Returns a new generative function that accepts an argument tuple of type Tuple{Int, ...}
or an argument tuple of type Tuple{T, ...}
where the first index either indicates which branch to call, or indicates an index into d
which maps to the selected branch. This form is meant for convenience - it allows the programmer to use d
like if-else or case statements.
Switch
is designed to allow for the expression of patterns of if-else control flow. gen_fns
must satisfy a few requirements:
- Each
gen_fn
ingen_fns
must accept the same argument types. - Each
gen_fn
ingen_fns
must return the same return type.
Otherwise, each gen_fn
can come from different modeling languages, possess different traces, etc.
Consider the following constructions:
@gen function bang((grad)(x::Float64), (grad)(y::Float64))
std::Float64 = 3.0
z = @trace(normal(x + y, std), :z)
return z
end
@gen function fuzz((grad)(x::Float64), (grad)(y::Float64))
std::Float64 = 3.0
z = @trace(normal(x + 2 * y, std), :z)
return z
end
sc = Switch(bang, fuzz)
This creates a new generative function sc
. We can then obtain the trace of sc
:
(trace, _) = simulate(sc, (2, 5.0, 3.0))
The resulting trace contains the subtrace from the branch with index 2
- in this case, a call to fuzz
:
│
└── :z : 13.552870875213735
Design and Implementation
Internally, the Combinators use custom trace types such as Gen.VectorTrace
, and are implemented using the following methods:
Gen.VectorTrace
— TypeVectorTrace <: Trace
U is the type of the subtrace, R is the return value type for the kernel
Gen.process_all_new!
— FunctionProcess all new applications.
Process all new applications.
Gen.update_recurse_merge
— Functionupdate_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap)
Returns choices that are in constraints, merged with all choices in the previous trace that do not have the same address as some choice in the constraints."
Gen.update_discard
— Functionupdate_discard(prev_choices::ChoiceMap, choices::ChoiceMap, new_choices::ChoiceMap)
Returns choices from previous trace that:
- have an address which does not appear in the new trace.
- have an address which does appear in the constraints.