Parameter Optimization

Gen provides support for gradient-based optimization of the trainable parameters of generative functions, e.g. for maximixum likelihood learning, expectation-maximization, or empirical Bayes.

Trainable parameters of generative functions are initialized differently depending on the type of generative function. Trainable parameters of the built-in modeling language are initialized with init_param!.

Gradient-based optimization of the trainable parameters of generative functions is based on interleaving two steps:

  • Incrementing gradient accumulators for trainable parameters by calling accumulate_param_gradients! on one or more traces.

  • Updating the value of trainable parameters and resetting the gradient accumulators to zero, by calling apply! on a parameter update, as described below.

Parameter update

A parameter update reads from the gradient accumulators for certain trainable parameters, updates the values of those parameters, and resets the gradient accumulators to zero.

A paramter update is constructed by combining an update configuration with the set of trainable parameters to which the update should be applied:

Gen.ParamUpdateType
update = ParamUpdate(conf, param_lists...)

Return an update configured by conf that applies to set of parameters defined by param_lists.

Each element in param_lists value is is pair of a generative function and a vector of its parameter references.

Example. To construct an update that applies a gradient descent update to the parameters :a and :b of generative function foo and the parameter :theta of generative function :bar:

update = ParamUpdate(GradientDescent(0.001, 100), foo => [:a, :b], bar => [:theta])

Syntactic sugar for the constructor form above.

update = ParamUpdate(conf, gen_fn::GenerativeFunction)

Return an update configured by conf that applies to all trainable parameters owned by the given generative function.

Note that trainable parameters not owned by the given generative function will not be updated, even if they are used during execution of the function.

Example. If generative function foo has parameters :a and :b, to construct an update that applies a gradient descent update to the parameters :a and :b:

update = ParamUpdate(GradientDescent(0.001, 100), foo)
source

The set of possible update configurations is described in Update configurations.An update is applied with:

Gen.apply!Function
apply!(update::ParamUpdate)

Perform one step of the update.

source

Update configurations

Gen has built-in support for the following types of update configurations.

Gen.GradientDescentType
conf = GradientDescent(step_size_init, step_size_beta)

Configuration for stochastic gradient descent update with step size given by (t::Int) -> step_size_init * (step_size_beta + 1) / (step_size_beta + t) where t is the iteration number.

source
Gen.ADAMType
conf = ADAM(learning_rate, beta1, beta2, epsilon)

Configuration for ADAM update.

source

Custom update configurations should implement the following methods:

Gen.init_update_stateFunction
state = init_update_state(conf, gen_fn::GenerativeFunction, param_list::Vector)

Get the initial state for a parameter update to the given parameters of the given generative function.

param_list is a vector of references to parameters of gen_fn. conf configures the update.

source
Gen.apply_update!Function
apply_update!(state)

Apply one parameter update, mutating the values of the trainable parameters, and possibly also the given state.

source

Training generative functions

The train! method can be used to train the parameters of a generative function to maximize the likelihood of a dataset defined by a data_generator:

Gen.train!Function
train!(gen_fn::GenerativeFunction, data_generator::Function,
       update::ParamUpdate,
       num_epoch, epoch_size, num_minibatch, minibatch_size; verbose::Bool=false)

Train the given generative function to maximize the expected conditional log probability (density) that gen_fn generates the assignment constraints given inputs, where the expectation is taken under the output distribution of data_generator.

The function data_generator is a function of no arguments that returns a tuple (inputs, constraints) where inputs is a Tuple of inputs (arguments) to gen_fn, and constraints is an ChoiceMap.

conf configures the optimization algorithm used.

param_lists is a map from generative function to lists of its parameters. This is equivalent to minimizing the expected KL divergence from the conditional distribution constraints | inputs of the data generator to the distribution represented by the generative function, where the expectation is taken under the marginal distribution on inputs determined by the data generator.

source