Basic workflow without GPU
First load necessary packages.
using HybridVariationalInference
using HybridVariationalInference: HybridVariationalInference as HVI
using ComponentArrays: ComponentArrays as CA
using Bijectors
using StableRNGs
using SimpleChains
using StatsFuns
using MLUtils
using DistributionFits
Next, specify many moving parts of the Hybrid variational inference (HVI)
The process-based model
The example process based model (PBM) predicts a double-monod constrained rate for different substrate concentrations, S1
, and S2
.
\[ y = r_0+ r_1 \frac{S_1}{K_1 + S_1} \frac{S_2}{K_2 + S_2}\]
function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET
# extract parameters not depending on order, i.e whether they are in θP or θM
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
CA.getdata(θc[par])::ET
end
r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
end
Its formulation is independent of which parameters are global, site-specific, or fixed during the model inversion. However, it cannot assume an ordering in the parameters, but needs to access the components by its symbolic names in the provided ComponentArray
.
Likelihood function
HVI requires the evaluation of the likelihood of the predictions. It corresponds to the cost of predictions given some observations.
The user specifies a function of the negative log-Likehood neg_logden(obs, pred, uncertainty_parameters)
, where all of the parameters are arrays with columns for sites.
Here, we use the neg_logden_indep_normal
function that assumed observations to be distributed independently normal around a true value. The provided y_unc
uncertainty parameters, here, corresponds to logσ2
, denoting the log of the variance parameter of the normal distribution.
py = neg_logden_indep_normal
Global-Site, transformations, and priors
Global and site-specific parameters
In this example, we will assign a fixed value to r0 parameter, treat the K2 parameter as unknown but the same across sites, and predict r1 and K1 for each site separately, based on covariates known at the sites.
Here we provide initial values for them by using ComponentVector
.
FT = Float32
θM0 = θM = CA.ComponentVector{FT}(r1=0.5, K1=0.2) # separately for each individual
θP0 = θP = CA.ComponentVector{FT}(K2=2.0) # population: same across individuals,
θFix = CA.ComponentVector{FT}(r0=0.3) # r0, i.e. not estimated
Parameter Transformations
HVI allows for transformations of parameters in an unconstrained space, where the probability density is not strictly zero anywhere to the original constrained space.
Here, our model parameters are strictly positive, and we use the exponential function to transform unconstrained estimates to the original constrained domain.
transP = Stacked(HVI.Exp())
transM = Stacked(HVI.Exp(), HVI.Exp())
Parameter transformations are specified using the Bijectors
package. Because, Bijectors.elementwise(exp)
, has problems with automatic differentiation (AD) on GPU, we use the public but non-exported Exp
wrapper inside Bijectors.Stacked
.
Prior information on parameters at constrained scale
HVI is an approximate bayesian analysis and combines prior information on the parameters with the model and observed data.
Here, we provide a wide prior by fitting a Lognormal distributions to
- the mode corresponding to the initial value provided above
- the 0.95-quantile 3 times the mode
using the DistributionFits.jl
package.
θall = vcat(θP, θM)
priors_dict = Dict{Symbol, Distribution}(
keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95), Val(:mode)))
Observations, model drivers and covariates
The model parameters are inverted using information on the
- observed data,
y_o
- its uncertainty,
y_unc
- known covariates across sites,
xM
- model drivers,
xP
Here, we use synthetic data generated by the package.
rng = StableRNG(111)
(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(
rng, DoubleMM.DoubleMMCase(); scenario=Val((:omit_r0,)))
Lets look at them.
size(xM), size(xP), size(y_o), size(y_unc)
((5, 800), (16, 800), (8, 800), (8, 800))
All of them have 800 columns, corresponding to 800 sites. There are 5 site-covaritas, 16 values of model drivers, and 8 observations per site.
xP[:,1]
ComponentVector{Float32}(S1 = Float32[0.5, 0.5, 0.5, 0.5, 0.4, 0.3, 0.2, 0.1], S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0])
In each column of the model drivers there is a ComponentVector with components S1 and S2 corresponding to the concentrations, for which outputs were observed. This allows notation x.S1
in the PBM above.
The y_unc
becomes its meaning by the Likelihood-function to be specified with the problem below.
Providing data in batches
HVI uses MLUtils.DataLoader
to provide baches of the data during each iteration of the solver. In addition to the data, it provides an index to the sites inside a tuple.
n_site = size(y_o,2)
n_batch = 20
train_dataloader = MLUtils.DataLoader(
(xM, xP, y_o, y_unc, 1:n_site), batchsize=n_batch, partial=false)
The Machine-Learning model
The machine-learning (ML) part predicts parameters of the posterior of site-specific PBM parameters, given the covariates. Here, we specify a 3-layer feed-forward neural network using the SimpleChains
framework which works efficiently on CPU.
n_out = length(θM) # number of individuals to predict
n_input = n_covar = size(xM,1)
g_chain = SimpleChain(
static(n_input), # input dimension (optional)
TurboDense{true}(tanh, n_input * 4),
TurboDense{true}(tanh, n_input * 4),
# dense layer without bias that maps to n outputs to (0..1)
TurboDense{false}(logistic, n_out)
)
# get a template of the parameter vector, ϕg0
g_chain_app, ϕg0 = construct_ChainsApplicator(rng, g_chain)
The g_chain_app
ChainsApplicator
predicts the parameters of the posterior, approximation given a vector of ML weights,ϕg
. During construction, an initial template of this vector is created. This abstraction layer allows to use different ML frameworks and replace the SimpleChains
model by Flux
or Lux
.
Using priors to scale ML-parameter estimates
In order to balance gradients, the g_chain_app
ModelApplicator defined above predicts on a scale (0..1). Now the priors are used to translate this to the parameter range by using the cumulative density distribution.
Priors were specified at constrained scale, but the ML model predicts parameters on unconstrained scale. This transformation of the distribution can be mathematically worked out for specific prior distribution forms. However, for simplicity, a NormalScalingModelApplicator
is fitted to the transformed 5% and 95% quantiles of the original prior.
priorsM = [priors_dict[k] for k in keys(θM)]
lowers, uppers = get_quantile_transformed(priorsM, transM)
g_chain_scaled = NormalScalingModelApplicator(g_chain_app, lowers, uppers, FT)
The g_chain_scaled
ModelApplicator
now predicts in unconstrained scale, transforms logistic predctions around 0.5 to the range of high prior probability of the parameters, and transforms ML predictions near 0 or 1 towards the outer lower probability ranges.
Assembling the information
All the specifications above are stored in a HybridProblem
structure.
Before, a PBMSiteApplicator
is constructed that translates an invocation given a vector of global parameters, and a matrix of site parameters to invocation of the process based model (PBM), defined at the beginning.
f_batch = f_allsites = PBMSiteApplicator(f_doubleMM; θP, θM, θFix, xPvec=xP[:,1])
prob = HybridProblem(θP, θM, g_chain_scaled, ϕg0,
f_batch, f_allsites, priors_dict, py,
transM, transP, train_dataloader, n_covar, n_site, n_batch)
Perform the inversion
Eventually, having assembled all the moving parts of the HVI, we can perform the inversion.
using OptimizationOptimisers
import Zygote
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
(; probo, interpreters) = solve(prob, solver; rng,
callback = callback_loss(100), # output during fitting
epochs = 2,
);
The solver object is constructed given the specific stochastic optimization algorithm and the number of Monte-Carlo samples that are drawn in each iteration from the predicted parameter posterior.
Then the solver is applied to the problem using solve
for a given number of iterations or epochs. For this tutorial, we additionally specify that the function to transfer structures to the GPU is the identity function, so that all stays on the CPU, and this tutorial hence does not require ad GPU or GPU livraries.
Among the return values are
probo
: A copy of the HybridProblem, with updated optimized parametersinterpreters
: ANamedTuple
with severalComponentArrayInterpreter
s that
will help analyzing the results.
Using a population-level process-based model
So far, the process-based model ram for each single site. For this simple model, some performance grains result from matrix-computations when running the model for all sites within one batch simultaneously.
In the following, the PBM specification accepts matrices as arguments for parameters and drivers and returns a matrix of precitions. For the parameters, one row corresponds to one site. For the drivers and predictions, one column corresponds to one site.
function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
# extract several covariates from xP
ST = typeof(CA.getdata(xPc)[1:1,:]) # workaround for non-type-stable Symbol-indexing
S1 = (CA.getdata(xPc[:S1,:])::ST)
S2 = (CA.getdata(xPc[:S2,:])::ST)
#
# extract the parameters as row-repeated vectors
n_obs = size(S1, 1)
VT = typeof(CA.getdata(θc)[:,1]) # workaround for non-type-stable Symbol-indexing
(r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
p1 = CA.getdata(θc[:, par]) ::VT
repeat(p1', n_obs) # matrix: same for each concentration row in S1
end
#
# each variable is a matrix (n_obs x n_site)
r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
end
Again, the function should not rely on the order of parameters but use symbolic indexing to extract the parameter vectors. For type stability of this symbolic indexing, it uses a workaround to get the type of a single row. Similarly, it uses type hints to index into the drivers, xPc
, to extract sub-matrices by symbols. Alternatively, here it could rely on the structure and ordering of the columns in xPc
.
A corresponding PBMPopulationApplicator
transforms calls with partitioned global and site parameters to calls of this matrix version of the PBM. The HVI Problem needs to be updated with this new applicatior.
f_batch = PBMPopulationApplicator(f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1])
f_allsites = PBMPopulationApplicator(f_doubleMM_sites, n_site; θP, θM, θFix, xPvec=xP[:,1])
probo_sites = HybridProblem(probo; f_batch, f_allsites)
For numerical efficiency, the number of sites within one batch is part of the PBMPopulationApplicator
. Hence, we have two different functions, one applied to a batch of site, and another applied to all sites.
As a test of the new applicator, the results are refined by running a few more epochs of the optimization.
(; probo) = solve(probo_sites, solver; rng,
callback = callback_loss(100), # output during fitting
epochs = 20,
#is_inferred = Val(true), # activate type-checks
);
Saving the results
Extracting useful information from the optimized HybridProblem is covered in the following Inspect results of fitted problem tutorial. In order to use the results from this tutorial in other tutorials, the updated probo
HybridProblem
and the interpreters are saved to a JLD2 file.
Before the problem is updated to use the redefinition DoubleMM.f_doubleMM_sites
of the PBM in module DoubleMM
rather than module Main
to allow for easier reloading with JLD2.
f_batch = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1])
f_allsites = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_site; θP, θM, θFix, xPvec=xP[:,1])
probo2 = HybridProblem(probo; f_batch, f_allsites)
using JLD2
fname = "intermediate/basic_cpu_results.jld2"
mkpath("intermediate")
if probo2 isa AbstractHybridProblem # do not save on failure above
jldsave(fname, false, IOStream; probo=probo2, interpreters)
end