How to specify custom Penalties
This guide shows how the user can specify a customized penalties to help the solver to converge to global minimum.
Motivation
The basic cost in HVI is the negative log of the joint probability, i.e. the likelihood of the observations given the parameters * prior probability of the parameters.
Sometimes there is additional knowledge not encoded in the prior, such as one parameter must be larger than another, or entropy-weights of the ML-parameters, and the solver accept a function to add additional loss terms. The loglikelihood function assigns a cost to the mismatch between predictions and observations. This often needs to be customized to the specific inversion.
This guide walks through the specification of such additional penalties.
First load necessary packages.
using HybridVariationalInference
using SimpleChains
using ComponentArrays: ComponentArrays as CA
using JLD2
import StableRNGsThis tutorial reuses and modifies the fitted object saved at the end of the Basic workflow without GPU tutorial, that used a log-Likelihood function assuming observation error to be distributed independently normal.
fname = "intermediate/basic_cpu_results.jld2"
print(abspath(fname))
prob = probo_normal = load(fname, "probo");Write function to compute the penalty loss
The function signature corresponds to the one described in compute_penalty.
In this example we want to avoid local minima when parameter, r1, is larger than 70% of the maximum observation.
# compute the maximum of observed rates at each site
y_obs = get_hybridproblem_train_dataloader(prob).data[3]
const y_obs_max = map(col -> maximum(x -> isfinite(x) ? x : zero(x), col), eachcol(y_obs))
function compute_penalty_r1(y_pred::AbstractMatrix, addq_pred::AbstractMatrix,
θMs_tr::AbstractMatrix, θP::AbstractVector, i_sites,
ϕg, ϕq::AbstractVector)
# get the maximum of current batch from closure of this function
y_obs_max_sites = y_obs_max[i_sites]
# add a penalty if r1 is larger than 0.95 times the maximum
penalty = max.(zero(eltype(θMs_tr)), θMs_tr[:,:r1] .- 0.95 .* y_obs_max_sites)
(; penalty)
endThe PenaltyComputer receives argument, i_sites, which can be used to index precomputed observation maxima.
Update the problem and redo the inversion
HybridProblem has keyword argument penalty_computer to specify the Callable that computes the penalty. It defaults to ZeroPenaltyComputer, which returns zero penalty cost.
We can pass the function directly or alternatively construct a CustomPenaltyComputer and update the problem.
#prob_pen = HybridProblem(prob; penalty_computer = compute_penalty_r1)
penalty_computer = CustomPenaltyComputer(compute_penalty_r1)
prob_pen = HybridProblem(prob; penalty_computer)
using OptimizationOptimisers
import Zygote
# silence warning of no GPU backend found (because we did not import CUDA here)
ENV["MLDATADEVICES_SILENCE_WARN_NO_GPU"] = 1
# first run a few iterators with updating only optimizing the mean
solver_point = HybridPointSolver(; alg=Adam(0.02))
(; probo) = solve(prob_pen, solver_point;
callback = callback_loss(100), # output during fitting
epochs = 5,
); probo_pen_point = probo;
# starting from this, also estimate the posterior uncertainty parameters
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
(; probo) = solve(probo_pen_point, solver;
callback = callback_loss(100), # output during fitting
epochs = 5,
);Inspect the computed maxima
Function predic_hvi also evaluates the penalties. Internally, the penalty function is called for each sample, but only the average is computed and returned.
rng = StableRNGs.StableRNG(112)
n_sample_pred = 200
(; y, θsP, θsMs_tr, ζsP, ζsMs_tr, penalties) = predict_hvi(rng, probo; n_sample_pred);
size(penalties)The penalties object is a ComponentMatrix, and we can look at a specific site and a named component returned by
i_site = 3
penalties[i_site, :penalty]Writing a customized PenaltyComputer
In the above example, the maximum of the observations in the batch are accesses by a global variable.
This can be improved. The precomputed maxima can be stored in a struct implementing type AbstractPenaltyComputer and function compute_penalty.
struct R1PenaltyComputer{T} <: AbstractPenaltyComputer where T
r_max::Vector{T}
end
function R1PenaltyComputer(ys::AbstractMatrix)
r_max = 0.95 .* vec(maximum(ys; dims = 1))
R1PenaltyComputer(r_max)
end
function HybridVariationalInference.compute_penalty(
pc::R1PenaltyComputer,
y_pred::AbstractMatrix, addq_pred::AbstractMatrix, θMs_tr::AbstractMatrix, θP::AbstractVector,
i_sites::AbstractVector,
ϕg, ϕq::AbstractVector
)
# @assert pc.r_max[i_sites] == 0.95 .* map(col -> maximum(x -> isfinite(x) ? x : zero(x), col), eachcol(y_obs))
# add a penalty if r1 is larger r_max
penalty = max.(zero(eltype(θMs_tr)), θMs_tr[:,:r1] .- pc.r_max[i_sites])
(;penalty)
end
penalty_computer = R1PenaltyComputer(y_obs)Rerunning the inversion using with the update PenaltyComputer:
prob_pen = HybridProblem(probo; penalty_computer)
(; probo) = solve(prob_pen, solver;
callback = callback_loss(100), # output during fitting
epochs = 5,
);