Probability Distributions
Gen provides a library of built-in probability distributions, and four ways of constructing custom distributions, each of which are explained below:
The
HeterogeneousMixture
andHomogeneousMixture
constructors for distributions that are mixtures of other distributions.The
ProductDistribution
constructor for distributions that are products of other distributions.The
@dist
constructor, for a distribution that can be expressed as a simple deterministic transformation (technically, a pushforward) of an existing distribution.An API for defining arbitrary custom distributions in plain Julia code.
Built-In Distributions
Gen.bernoulli
— Constantbernoulli(prob_true::Real)
Samples a Bool
value which is true with given probability
Gen.beta
— Constantbeta(alpha::Real, beta::Real)
Sample a Float64
from a beta distribution.
Gen.beta_uniform
— Constantbeta_uniform(theta::Real, alpha::Real, beta::Real)
Samples a Float64
value from a mixture of a uniform distribution on [0, 1] with probability 1-theta
and a beta distribution with parameters alpha
and beta
with probability theta
.
Gen.binom
— Constantbinom(n::Integer, p::Real)
Sample an Int
from the Binomial distribution with parameters n
(number of trials) and p
(probability of success in each trial).
Gen.categorical
— Constantcategorical(probs::AbstractArray{U, 1}) where {U <: Real}
Given a vector of probabilities probs
where sum(probs) = 1
, sample an Int
i
from the set {1, 2, .., length(probs)
} with probability probs[i]
.
Gen.cauchy
— Constantcauchy(x0::Real, gamma::Real)
Sample a Float64
value from a Cauchy distribution with location x0 and scale gamma.
Gen.dirichlet
— Constantdirichlet(alpha::Vector{Float64})
Sample a simplex Vector{Float64} from a Dirichlet distribution.
Gen.exponential
— Constantexponential(rate::Real)
Sample a Float64
from the exponential distribution with rate parameter rate
.
Gen.gamma
— Constantgamma(shape::Real, scale::Real)
Sample a Float64
from a gamma distribution.
Gen.geometric
— Constantgeometric(p::Real)
Sample an Int
from the Geometric distribution with parameter p
.
Gen.inv_gamma
— Constantinv_gamma(shape::Real, scale::Real)
Sample a Float64
from a inverse gamma distribution.
Gen.laplace
— Constantlaplace(loc::Real, scale::Real)
Sample a Float64
from a laplace distribution.
Gen.mvnormal
— Constantmvnormal(mu::AbstractVector{T}, cov::AbstractMatrix{U}} where {T<:Real,U<:Real}
Samples a Vector{Float64}
value from a multivariate normal distribution.
Gen.neg_binom
— Constantneg_binom(r::Real, p::Real)
Sample an Int
from a Negative Binomial distribution. Returns the number of failures before the r
th success in a sequence of independent Bernoulli trials. r
is the number of successes (which may be fractional) and p
is the probability of success per trial.
Gen.normal
— Constantnormal(mu::Real, std::Real)
Samples a Float64
value from a normal distribution.
Gen.piecewise_uniform
— Constantpiecewise_uniform(bounds, probs)
Samples a Float64
value from a piecewise uniform continuous distribution.
There are n
bins where n = length(probs)
and n + 1 = length(bounds)
. Bounds must satisfy bounds[i] < bounds[i+1]
for all i
. The probability density at x
is zero if x <= bounds[1]
or x >= bounds[end]
and is otherwise probs[bin] / (bounds[bin] - bounds[bin+1])
where bounds[bin] < x <= bounds[bin+1]
.
Gen.poisson
— Constantpoisson(lambda::Real)
Sample an Int
from the Poisson distribution with rate lambda
.
Gen.uniform
— Constantuniform(low::Real, high::Real)
Sample a Float64
from the uniform distribution on the interval [low, high].
Gen.uniform_discrete
— Constantuniform_discrete(low::Integer, high::Integer)
Sample an Int
from the uniform distribution on the set {low, low + 1, ..., high-1, high}.
Gen.broadcasted_normal
— Constantbroadcasted_normal(mu::AbstractArray{<:Real, N1},
std::AbstractArray{<:Real, N2}) where {N1, N2}
Samples an Array{Float64, max(N1, N2)}
of shape Broadcast.broadcast_shapes(size(mu), size(std))
where each element is independently normally distributed. This is equivalent to (a reshape of) a multivariate normal with diagonal covariance matrix, but its implementation is more efficient than that of the more general mvnormal
for this case.
The shapes of mu
and std
must be broadcast-compatible.
If all args are 0-dimensional arrays, then sampling via broadcasted_normal(...)
returns a Float64
rather than properly returning an Array{Float64, 0}
. This is consistent with Julia's own inconsistency on the matter:
julia> typeof(ones())
Array{Float64,0}
julia> typeof(ones() .* ones())
Float64
Mixture Distribution Constructors
There are two built-in constructors for defining mixture distributions:
Gen.HomogeneousMixture
— TypeHomogeneousMixture(distribution::Distribution, dims::Vector{Int})
Define a new distribution that is a mixture of some number of instances of single base distributions.
The first argument defines the base distribution of each component in the mixture.
The second argument must have length equal to the number of arguments taken by the base distribution. A value of 0 at a position in the vector an indicates that the corresponding argument to the base distribution is a scalar, and integer values of i for i >= 1 indicate that the corresponding argument is an i-dimensional array.
Example:
mixture_of_normals = HomogeneousMixture(normal, [0, 0])
The resulting distribution (e.g. mixture_of_normals
above) can then be used like the built-in distribution values like normal
. The distribution takes n+1
arguments where n
is the number of arguments taken by the base distribution. The first argument to the distribution is a vector of non-negative mixture weights, which must sum to 1.0. The remaining arguments to the distribution correspond to the arguments of the base distribution, but have a different type: If an argument to the base distribution is a scalar of type T
, then the corresponding argument to the mixture distribution is a Vector{T}
, where each element of this vector is the argument to the corresponding mixture component. If an argument to the base distribution is an Array{T,N}
for some N
, then the corresponding argument to the mixture distribution is of the form arr::Array{T,N+1}
, where each slice of the array of the form arr[:,:,...,i]
is the argument for the i
th mixture component.
Example:
mixture_of_normals = HomogeneousMixture(normal, [0, 0])
mixture_of_mvnormals = HomogeneousMixture(mvnormal, [1, 2])
@gen function foo()
# mixture of two normal distributions
# with means -1.0 and 1.0
# and standard deviations 0.1 and 10.0
# the first normal distribution has weight 0.4; the second has weight 0.6
x ~ mixture_of_normals([0.4, 0.6], [-1.0, 1.0], [0.1, 10.0])
# mixture of two multivariate normal distributions
# with means: [0.0, 0.0] and [1.0, 1.0]
# and covariance matrices: [1.0 0.0; 0.0 1.0] and [10.0 0.0; 0.0 10.0]
# the first multivariate normal distribution has weight 0.4;
# the second has weight 0.6
means = [0.0 1.0; 0.0 1.0] # or, cat([0.0, 0.0], [1.0, 1.0], dims=2)
covs = cat([1.0 0.0; 0.0 1.0], [10.0 0.0; 0.0 10.0], dims=3)
y ~ mixture_of_mvnormals([0.4, 0.6], means, covs)
end
Gen.HeterogeneousMixture
— TypeHeterogeneousMixture(distributions::Vector{Distribution{T}}) where {T}
Define a new distribution that is a mixture of a given list of base distributions.
The argument is the vector of base distributions, one for each mixture component.
Note that the base distributions must have the same output type.
Example:
uniform_beta_mixture = HeterogeneousMixture([uniform, beta])
The resulting mixture distribution takes n+1
arguments, where n
is the sum of the number of arguments taken by each distribution in the list. The first argument to the mixture distribution is a vector of non-negative mixture weights, which must sum to 1.0. The remaining arguments are the arguments to each mixture component distribution, in order in which the distributions are passed into the constructor.
Example:
@gen function foo()
# mixure of a uniform distribution on the interval [`lower`, `upper`]
# and a beta distribution with alpha parameter `a` and beta parameter `b`
# the uniform as weight 0.4 and the beta has weight 0.6
x ~ uniform_beta_mixture([0.4, 0.6], lower, upper, a, b)
end
Product Distribution Constructors
There is a built-in constructor for defining product distributions:
Gen.ProductDistribution
— TypeProductDistribution(distributions::Vararg{<:Distribution})
Define new distribution that is the product of the given nonempty list of distributions having a common type.
The arguments comprise the list of base distributions.
Example:
normal_strip = ProductDistribution(uniform, normal)
The resulting product distribution takes n
arguments, where n
is the sum of the numbers of arguments taken by each distribution in the list. These arguments are the arguments to each component distribution, in the order in which the distributions are passed to the constructor.
Example:
@gen function unit_strip_and_near_seven()
x ~ flip_and_number(0.0, 0.1, 7.0, 0.01)
end
The @dist
DSL
The @dist
DSL allows the user to concisely define a distribution, as long as that distribution can be expressed as a certain type of deterministic transformation of an existing distribution. The syntax of the @dist
DSL, as well as the class of permitted deterministic transformations, are explained below.
@dist name(arg1, arg2, ..., argN) = body
or
@dist function name(arg1, arg2, ..., argN)
body
end
Here body
is ordinary Julia code, with the constraint that body
must contain exactly one random choice. The value of the @dist
expression is then a Gen.Distribution
object called name
, parameterized by arg1, ..., argN
, representing the distribution over return values of body
. Arguments are optionally typed.
This DSL is designed to address the issue that sometimes, values stored in the trace do not correspond to the most natural physical elements of the model state space, making inference programming and querying more taxing than necessary. For example, suppose we have a model of classes at a school, where the number of students is random, with mean 10, but always at least 3. Rather than writing the model as
@gen function class_model()
n_students = @trace(poisson(7), :n_students_minus_3) + 3
...
end
and thinking about the random variable :n_students_minus_3
, you can use the @dist
DSL to instead write
@dist student_distr(mean, min) = poisson(mean-min) + min
@gen function class_model()
n_students = @trace(student_distr(10, 3), :n_students)
...
end
and think about the more natural random variable :n_students
. This leads to more natural inference programs, which can constrain and propose directly to the :n_students
trace address.
Permitted constructs for the body of a @dist
It is not possible for @dist
to work on any arbitrary body
. We now describe which constructs are permitted inside the body
of a @dist
expression.
We can think of the body
of an @dist
function as containing ordinary Julia code, except that in addition to being described by their ordinary Julia types, each expression also belongs to one of three "type spaces." These are:
CONST
: Constants, whose value is known at the time this@dist
expression is evaluated.ARG
: Arguments and (deterministic, differentiable) functions of arguments. All expressions representing non-random values that depend on distribution arguments areARG
expressions.RND
: Random variables. All expressions whose runtime values may differ across multiple calls to this distribution (with the same arguments) areRND
expressions.
Importantly, Julia control flow constructs generally expect CONST
values: the condition of an if
or the range of a for
loop cannot be ARG
or RND
.
The body expression as a whole must be a RND
expression, representing a random variable. The behavior of the @dist
definition is then to define a new distribution (with name name
) that samples and evaluates the logpdf of the random variable represented by the body
expression.
Expressions are typed compositionally, with the following typing rules:
Literals and free variables are
CONST
s. Literals and symbols that appear free in the@dist
body are of typeCONST
.Arguments are
ARG
s. Symbols bound as arguments in the@dist
declaration have typeARG
in its body.Drawing from a distribution gives
RND
. Ifd
is a distribution, andx_i
are of typeARG
orCONST
,d(x_1, x_2, ...)
is of typeRND
.Functions of
CONST
s areCONST
s. Iff
is a deterministic function andx_i
are all of typeCONST
,f(x_1, x_2, ...)
is of typeCONST
.Functions of
CONST
s andARG
s areARG
s. Iff
is a differentiable function, and eachx_i
is either aCONST
or a scalarARG
(with at least onex_i
being anARG
), thenf(x_1, x_2, ...)
is of typeARG
.Functions of
CONST
s,ARG
s, andRND
s areRND
s. Iff
is one of a special set of deterministic functions we've defined (+
,-
,*
,/
,exp
,log
,getindex
), and exactly one of its argumentsx_i
is of typeRND
, thenf(x_1, x_2, ...)
is of typeRND
.
One way to think about this, without all the rules, is that CONST
values are "contaminated" by interaction with ARG
values (becoming ARG
s themselves), and both CONST
and ARG
are "contaminated" by interaction with RND
. Thinking of the body as an AST, the journey from leaf node to root node always involves transitions in the direction of CONST -> ARG -> RND
, never in reverse.
Restrictions
Users may not reassign to arguments (like x
in the above example), and may not apply functions with side effects. Names bound to expressions of type RND
must be used only once. e.g., let x = normal(0, 1) in x + x
is not allowed.
Examples
Let's walk through some examples.
@dist f(x) = exp(normal(x, 1))
We can annotate with types:
1 :: CONST (by rule 1)
x :: ARG (by rule 2)
normal(x, 1) :: RND (by rule 3)
exp(normal(x, 1)) :: RND (by rule 6)
Here's another:
@dist function labeled_cat(labels, probs)
index = categorical(probs)
labels[index]
end
And the types:
probs :: ARG (by rule 2)
categorical(probs) :: RND (by rule 3)
index :: RND (Julia assignment)
labels :: ARG (by rule 2)
labels[index] :: RND (by rule 6, f == getindex)
Note that getindex
is designed to work on anything indexible, not just vectors. So, for example, it also works with Dicts.
Another one (not as realistic, but it uses all the rules):
@dist function weird(x)
log(normal(exp(x), exp(x))) + (x * (2 + 3))
end
And the types:
2, 3 :: CONST (by rule 1)
2 + 3 :: CONST (by rule 4)
x :: ARG (by rule 2)
x * (2 + 3) :: ARG (by rule 5)
exp(x) :: ARG (by rule 5)
normal(exp(x), exp(x)) :: RND (by rule 3)
log(normal(exp(x), exp(x))) :: RND (by rule 6)
log(normal(exp(x), exp(x))) + (x * (2 + 3)) :: RND (by rule 6)
API
Gen.random
— Functionval::T = random(dist::Distribution{T}, args...)
Sample a random choice from the given distribution with the given arguments.
Gen.logpdf
— Functionlpdf = logpdf(dist::Distribution{T}, value::T, args...)
Evaluate the log probability (density) of the value.
Gen.logpdf_grad
— Functiongrads::Tuple = logpdf_grad(dist::Distribution{T}, value::T, args...)
Compute the gradient of the logpdf with respect to the value, and each of the arguments.
If has_output_grad
returns false, then the first element of the returned tuple is nothing
. Otherwise, the first element of the tuple is the gradient with respect to the value. If the return value of has_argument_grads
has a false value for at position i
, then the i+1
th element of the returned tuple has value nothing
. Otherwise, this element contains the gradient with respect to the i
th argument.
Gen.has_output_grad
— Functionhas::Bool = has_output_grad(dist::Distribution)
Return true if the distribution computes the gradient of the logpdf with respect to the value of the random choice.
Gen.is_discrete
— Functiondiscrete::Bool = is_discrete(::Distribution)
Return true if the distribution is discrete, false otherwise.
The has_argument_grads
function is also part of the distribution API.