Extending Gen
Gen is designed for extensibility. To implement behaviors that are not directly supported by the existing modeling languages, users can implement `black-box' generative functions directly, without using built-in modeling language. These generative functions can then be invoked by generative functions defined using the built-in modeling language. This includes several special cases:
Extending Gen with custom gradient computations
Extending Gen with custom incremental computation of return values
Extending Gen with new modeling languages.
Custom gradients
To add a custom gradient for a differentiable deterministic computation, define a concrete subtype of CustomGradientGF
with the following methods:
For example:
struct MyPlus <: CustomGradientGF{Float64} end
Gen.apply(::MyPlus, args) = args[1] + args[2]
Gen.gradient(::MyPlus, args, retval, retgrad) = (retgrad, retgrad)
Gen.has_argument_grads(::MyPlus) = (true, true)
Gen.CustomGradientGF
— Type.CustomGradientGF{T}
Abstract type for a generative function with a custom gradient computation, and default behaviors for all other generative function interface methods.
T
is the type of the return value.
Gen.apply
— Function.retval = apply(gen_fn::CustomGradientGF, args)
Apply the function to the arguments.
Gen.gradient
— Function.arg_grads = gradient(gen_fn::CustomDetermGF, args, retval, retgrad)
Return the gradient tuple with respect to the arguments, where nothing
is for argument(s) whose gradient is not available.
Custom incremental computation
Iterative inference techniques like Markov chain Monte Carlo involve repeatedly updating the execution traces of generative models. In some cases, the output of a deterministic computation within the model can be incrementally computed during each of these updates, instead of being computed from scratch.
To add a custom incremental computation for a deterministic computation, define a concrete subtype of CustomUpdateGF
with the following methods:
The second type parameter of CustomUpdateGF
is the type of the state that may be used internally to facilitate incremental computation within update_with_state
.
For example, we can implement a function for computing the sum of a vector that efficiently computes the new sum when a small fraction of the vector elements change:
struct MyState
prev_arr::Vector{Float64}
sum::Float64
end
struct MySum <: CustomUpdateGF{Float64,MyState} end
function Gen.apply_with_state(::MySum, args)
arr = args[1]
s = sum(arr)
state = MyState(arr, s)
(s, state)
end
function Gen.update_with_state(::MySum, state, args, argdiffs::Tuple{VectorDiff})
arr = args[1]
prev_sum = state.sum
retval = prev_sum
for i in keys(argdiffs[1].updated)
retval += (arr[i] - state.prev_arr[i])
end
prev_length = length(state.prev_arr)
new_length = length(arr)
for i=prev_length+1:new_length
retval += arr[i]
end
for i=new_length+1:prev_length
retval -= arr[i]
end
state = MyState(arr, retval)
(state, retval, UnknownChange())
end
Gen.num_args(::MySum) = 1
Gen.CustomUpdateGF
— Type.CustomUpdateGF{T,S}
Abstract type for a generative function with a custom update computation, and default behaviors for all other generative function interface methods.
T
is the type of the return value and S
is the type of state used internally for incremental computation.
Gen.apply_with_state
— Function.retval, state = apply_with_state(gen_fn::CustomDetermGF, args)
Execute the generative function and return the return value and the state.
Gen.update_with_state
— Function.state, retval, retdiff = update_with_state(gen_fn::CustomDetermGF, state, args, argdiffs)
Update the arguments to the generative function and return new return value and state.
Custom distributions
Users can extend Gen with new probability distributions, which can then be used to make random choices within generative functions. Simple transformations of existing distributions can be created using the @dist
DSL. For arbitrary distributions, including distributions that cannot be expressed in the @dist
DSL, users can define a custom distribution by implementing Gen's Distribution interface directly, as defined below.
Probability distributions are singleton types whose supertype is Distribution{T}
, where T
indicates the data type of the random sample.
abstract type Distribution{T} end
A new Distribution type must implement the following methods:
By convention, distributions have a global constant lower-case name for the singleton value. For example:
struct Bernoulli <: Distribution{Bool} end
const bernoulli = Bernoulli()
Distribution values should also be callable, which is a syntactic sugar with the same behavior of calling random
:
bernoulli(0.5) # identical to random(bernoulli, 0.5) and random(Bernoulli(), 0.5)
Gen.random
— Function.val::T = random(dist::Distribution{T}, args...)
Sample a random choice from the given distribution with the given arguments.
Gen.logpdf
— Function.lpdf = logpdf(dist::Distribution{T}, value::T, args...)
Evaluate the log probability (density) of the value.
Gen.has_output_grad
— Function.has::Bool = has_output_grad(dist::Distribution)
Return true of the gradient if the distribution computes the gradient of the logpdf with respect to the value of the random choice.
Gen.logpdf_grad
— Function.grads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...)
Compute the gradient of the logpdf with respect to the value, and each of the arguments.
If has_output_grad
returns false, then the first element of the returned tuple is nothing
. Otherwise, the first element of the tuple is the gradient with respect to the value. If the return value of has_argument_grads
has a false value for at position i
, then the i+1
th element of the returned tuple has value nothing
. Otherwise, this element contains the gradient with respect to the i
th argument.
Custom generative functions
We recommend the following steps for implementing a new type of generative function, and also looking at the implementation for the DynamicDSLFunction
type as an example.
Define a trace data type
struct MyTraceType <: Trace
..
end
Decide the return type for the generative function
Suppose our return type is Vector{Float64}
.
Define a data type for your generative function
This should be a subtype of GenerativeFunction
, with the appropriate type parameters.
struct MyGenerativeFunction <: GenerativeFunction{Vector{Float64},MyTraceType}
..
end
Note that your generative function may not need to have any fields. You can create a constructor for it, e.g.:
function MyGenerativeFunction(...)
..
end
Decide what the arguments to a generative function should be
For example, our generative functions might take two arguments, a
(of type Int
) and b
(of type Float64
). Then, the argument tuple passed to e.g. generate
will have two elements.
NOTE: Be careful to distinguish between arguments to the generative function itself, and arguments to the constructor of the generative function. For example, if you have a generative function type that is parametrized by, for example, modeling DSL code, this DSL code would be a parameter of the generative function constructor.
Decide what the traced random choices (if any) will be
Remember that each random choice is assigned a unique address in (possibly) hierarchical address space. You are free to design this address space as you wish, although you should document it for users of your generative function type.
Implement methods of the Generative Function Interface
At minimum, you need to implement the following methods:
If you want to use the generative function within models, you should implement:
If you want to use MCMC on models that call your generative function, then implement:
If you want to use gradient-based inference techniques on models that call your generative function, then implement:
If your generative function has trainable parameters, then implement:
Custom modeling languages
Gen can be extended with new modeling languages by implementing new generative function types, and constructors for these types that take models as input. This typically requires implementing the entire generative function interface, and is advanced usage of Gen.