Parameter Optimization
Gen provides support for gradient-based optimization of the trainable parameters of generative functions, e.g. for maximixum likelihood learning, expectation-maximization, or empirical Bayes.
Trainable parameters of generative functions are initialized differently depending on the type of generative function. Trainable parameters of the built-in modeling language are initialized with init_param!
.
Gradient-based optimization of the trainable parameters of generative functions is based on interleaving two steps:
Incrementing gradient accumulators for trainable parameters by calling
accumulate_param_gradients!
on one or more traces.Updating the value of trainable parameters and resetting the gradient accumulators to zero, by calling
apply!
on a parameter update, as described below.
Parameter update
A parameter update reads from the gradient accumulators for certain trainable parameters, updates the values of those parameters, and resets the gradient accumulators to zero.
A paramter update is constructed by combining an update configuration with the set of trainable parameters to which the update should be applied:
Gen.ParamUpdate
— Typeupdate = ParamUpdate(conf, param_lists...)
Return an update configured by conf
that applies to set of parameters defined by param_lists
.
Each element in param_lists
value is is pair of a generative function and a vector of its parameter references.
Example. To construct an update that applies a gradient descent update to the parameters :a
and :b
of generative function foo
and the parameter :theta
of generative function :bar
:
update = ParamUpdate(GradientDescent(0.001, 100), foo => [:a, :b], bar => [:theta])
Syntactic sugar for the constructor form above.
update = ParamUpdate(conf, gen_fn::GenerativeFunction)
Return an update configured by conf
that applies to all trainable parameters owned by the given generative function.
Note that trainable parameters not owned by the given generative function will not be updated, even if they are used during execution of the function.
Example. If generative function foo
has parameters :a
and :b
, to construct an update that applies a gradient descent update to the parameters :a
and :b
:
update = ParamUpdate(GradientDescent(0.001, 100), foo)
The set of possible update configurations is described in Update configurations.An update is applied with:
Gen.apply!
— Functionapply!(update::ParamUpdate)
Perform one step of the update.
Update configurations
Gen has built-in support for the following types of update configurations.
Gen.FixedStepGradientDescent
— Typeconf = FixedStepGradientDescent(step_size)
Configuration for stochastic gradient descent update with fixed step size.
Gen.GradientDescent
— Typeconf = GradientDescent(step_size_init, step_size_beta)
Configuration for stochastic gradient descent update with step size given by (t::Int) -> step_size_init * (step_size_beta + 1) / (step_size_beta + t)
where t
is the iteration number.
Gen.ADAM
— Typeconf = ADAM(learning_rate, beta1, beta2, epsilon)
Configuration for ADAM update.
Custom update configurations should implement the following methods:
Gen.init_update_state
— Functionstate = init_update_state(conf, gen_fn::GenerativeFunction, param_list::Vector)
Get the initial state for a parameter update to the given parameters of the given generative function.
param_list
is a vector of references to parameters of gen_fn
. conf
configures the update.
Gen.apply_update!
— Functionapply_update!(state)
Apply one parameter update, mutating the values of the trainable parameters, and possibly also the given state.
Training generative functions
The train!
method can be used to train the parameters of a generative function to maximize the likelihood of a dataset defined by a data_generator
:
Gen.train!
— Functiontrain!(gen_fn::GenerativeFunction, data_generator::Function,
update::ParamUpdate,
num_epoch, epoch_size, num_minibatch, minibatch_size; verbose::Bool=false)
Train the given generative function to maximize the expected conditional log probability (density) that gen_fn
generates the assignment constraints
given inputs, where the expectation is taken under the output distribution of data_generator
.
The function data_generator
is a function of no arguments that returns a tuple (inputs, constraints)
where inputs
is a Tuple
of inputs (arguments) to gen_fn
, and constraints
is an ChoiceMap
.
conf
configures the optimization algorithm used.
param_lists
is a map from generative function to lists of its parameters. This is equivalent to minimizing the expected KL divergence from the conditional distribution constraints | inputs
of the data generator to the distribution represented by the generative function, where the expectation is taken under the marginal distribution on inputs
determined by the data generator.