How to Customize Incremental Computation
Iterative inference techniques like Markov chain Monte Carlo involve repeatedly updating the execution traces of generative models. In some cases, the output of a deterministic computation within the model can be incrementally computed during each of these updates, instead of being computed from scratch.
To add a custom incremental computation for a deterministic computation, define a concrete subtype of CustomUpdateGF
with the following methods:
The second type parameter of CustomUpdateGF
is the type of the state that may be used internally to facilitate incremental computation within update_with_state
.
For example, we can implement a function for computing the sum of a vector that efficiently computes the new sum when a small fraction of the vector elements change:
struct MyState
prev_arr::Vector{Float64}
sum::Float64
end
struct MySum <: CustomUpdateGF{Float64,MyState} end
function Gen.apply_with_state(::MySum, args)
arr = args[1]
s = sum(arr)
state = MyState(arr, s)
(s, state)
end
function Gen.update_with_state(::MySum, state, args, argdiffs::Tuple{VectorDiff})
arr = args[1]
prev_sum = state.sum
retval = prev_sum
for i in keys(argdiffs[1].updated)
retval += (arr[i] - state.prev_arr[i])
end
prev_length = length(state.prev_arr)
new_length = length(arr)
for i=prev_length+1:new_length
retval += arr[i]
end
for i=new_length+1:prev_length
retval -= arr[i]
end
state = MyState(arr, retval)
(state, retval, UnknownChange())
end
Gen.num_args(::MySum) = 1