Customizing Gradients

Determistic Functions with Custom Gradients

To add a custom gradient for a differentiable deterministic computation, define a concrete subtype of CustomGradientGF with the following methods:

For example, we can implement binary addition with a manually-defined gradient:

struct MyPlus <: CustomGradientGF{Float64} end

Gen.apply(::MyPlus, args) = args[1] + args[2]
Gen.gradient(::MyPlus, args, retval, retgrad) = (retgrad, retgrad)
Gen.has_argument_grads(::MyPlus) = (true, true)

Customizing Parameter Updates

To add support for a new type of gradient-based parameter update, create a new parameter update configuration with the following methods defined for the types of generative functions that are to be supported.

As an example, the built-in update configuration, FixedStepGradientDescent, is implemented as follows:

struct FixedStepGradientDescent
    step_size::Float64
end

mutable struct FixedStepGradientDescentBuiltinDSLState
    step_size::Float64
    gen_fn::Union{Gen.DynamicDSLFunction,Gen.StaticIRGenerativeFunction}
    param_list::Vector
end

function Gen.init_update_state(conf::FixedStepGradientDescent,
        gen_fn::Union{Gen.DynamicDSLFunction,Gen.StaticIRGenerativeFunction}, param_list::Vector)
    FixedStepGradientDescentBuiltinDSLState(conf.step_size, gen_fn, param_list)
end

function Gen.apply_update!(state::FixedStepGradientDescentBuiltinDSLState)
    for param_name in state.param_list
        value = Gen.get_param(state.gen_fn, param_name)
        grad = Gen.get_param_grad(state.gen_fn, param_name)
        Gen.set_param!(state.gen_fn, param_name, value + grad * state.step_size)
        Gen.zero_param_grad!(state.gen_fn, param_name)
    end
end