Generative Combinators
Generative function combinators are Julia functions that take one or more generative functions as input and return a new generative function. Generative function combinators are used to express patterns of repeated computation that appear frequently in generative models. Some generative function combinators are similar to higher order functions from functional programming languages.
Map combinator
In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$.
For example, consider the following generative function, which makes one random choice at address r^2
:
using Gen
@gen function foo(x, y, z)
r ~ normal(x^2 + y^2 + z^2, 1.0)
return r
end
DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing], Main.var"##foo#276", Bool[0, 0, 0], false)
We apply the map combinator to produce a new generative function bar
:
bar = Map(foo)
Map{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing], Main.var"##foo#276", Bool[0, 0, 0], false))
We can then obtain a trace of bar
:
trace, _ = generate(bar, ([0.0, 0.5], [0.5, 1.0], [1.0, -1.0]))
trace
Gen.VectorTrace{Gen.MapType, Any, Gen.DynamicDSLTrace}(Map{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing], Main.var"##foo#276", Bool[0, 0, 0], false)), Gen.DynamicDSLTrace[Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing], Main.var"##foo#276", Bool[0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:r => Gen.ChoiceOrCallRecord{Float64}(0.5102699241218409, -1.1925388257840261, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -1.1925388257840261, 0.0, (0.0, 0.5, 1.0), 0.5102699241218409), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any, Any, Any], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing], Main.var"##foo#276", Bool[0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:r => Gen.ChoiceOrCallRecord{Float64}(3.795322514504554, -2.1129493701220117, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -2.1129493701220117, 0.0, (0.5, 1.0, -1.0), 3.795322514504554)], Any[0.5102699241218409, 3.795322514504554], ([0.0, 0.5], [0.5, 1.0], [1.0, -1.0]), 2, 2, -3.3054881959060376, 0.0)
This causes foo
to be invoked twice, once with arguments (0.0, 0.5, 1.0)
in address namespace 1
and once with arguments (0.5, 1.0, -1.0)
in address namespace 2
.
get_choices(trace)
│
├── 1
│ │
│ └── :r : 0.5102699241218409
│
└── 2
│
└── :r : 3.795322514504554
If the resulting trace has random choices: then the return value is:
get_retval(trace)
Persistent{Any}[0.5102699241218409, 3.795322514504554]
Unfold combinator
In the schematic below, the kernel is denoted $\mathcal{G}_{\mathrm{k}}$. The initial state is denoted $y_0$, the number of applications is $n$, and the remaining arguments to the kernel not including the state, are $z$.
For example, consider the following kernel, with state type Bool
, which makes one random choice at address :z
:
using Gen
@gen function foo(t::Int, y_prev::Bool, z1::Float64, z2::Float64)
y = @trace(bernoulli(y_prev ? z1 : z2), :y)
return y
end
DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false)
We apply the map combinator to produce a new generative function bar
:
bar = Unfold(foo)
Unfold{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false))
We can then obtain a trace of bar
:
trace, _ = generate(bar, (5, false, 0.05, 0.95))
trace
Gen.VectorTrace{Gen.UnfoldType, Any, Gen.DynamicDSLTrace}(Unfold{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false)), Gen.DynamicDSLTrace[Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Bool}(true, -0.05129329438755058, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.05129329438755058, 0.0, (1, false, 0.05, 0.95), true), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Bool}(false, -0.05129329438755058, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.05129329438755058, 0.0, (2, true, 0.05, 0.95), false), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Bool}(true, -0.05129329438755058, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.05129329438755058, 0.0, (3, false, 0.05, 0.95), true), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Bool}(false, -0.05129329438755058, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.05129329438755058, 0.0, (4, true, 0.05, 0.95), false), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], Main.var"##foo#277", Bool[0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Bool}(true, -0.05129329438755058, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.05129329438755058, 0.0, (5, false, 0.05, 0.95), true)], Any[true, false, true, false, true], (5, false, 0.05, 0.95), 5, 5, -0.2564664719377529, 0.0)
This causes foo
to be invoked five times. The resulting trace may contain the following random choices:
get_choices(trace)
│
├── 1
│ │
│ └── :y : true
│
├── 2
│ │
│ └── :y : false
│
├── 3
│ │
│ └── :y : true
│
├── 4
│ │
│ └── :y : false
│
└── 5
│
└── :y : true
then the return value is:
get_retval(trace)
Persistent{Any}[true, false, true, false, true]
Switch combinator
Consider the following constructions:
@gen function line(x)
z ~ normal(3*x+1,1.0)
return z
end
@gen function outlier(x)
z ~ normal(3*x+1, 10.0)
return z
end
switch_model = Switch(line, outlier)
Switch{Int64, 2, Tuple{DynamicDSLFunction{Any}, DynamicDSLFunction{Any}}, Any}((DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], Main.var"##line#278", Bool[0], false), DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], Main.var"##outlier#279", Bool[0], false)), Dict{Int64, Int64}())
This creates a new generative function switch_model
whose arguments take the form (branch, args...)
. By default, branch is an integer indicating which generative function to execute. For example, branch 2
corresponds to outlier
:
trace = simulate(switch_model, (2, 5.0))
get_choices(trace)
│
└── :z : 15.146216831214357