Basic workflow without GPU

First load necessary packages.

using HybridVariationalInference
using HybridVariationalInference: HybridVariationalInference as HVI
using ComponentArrays: ComponentArrays as CA
using Bijectors
using StableRNGs
using SimpleChains
using StatsFuns
using MLUtils
using DistributionFits

Next, specify many moving parts of the Hybrid variational inference (HVI)

The process-based model

The example process based model (PBM) predicts a double-monod constrained rate for different substrate concentrations, S1, and S2.

\[ y = r_0+ r_1 \frac{S_1}{K_1 + S_1} \frac{S_2}{K_2 + S_2}\]

function f_doubleMM(θc::CA.ComponentVector{ET}, x) where ET
    # extract parameters not depending on order, i.e whether they are in θP or θM
    (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
        CA.getdata(θc[par])::ET
    end
    r0 .+ r1 .* x.S1 ./ (K1 .+ x.S1) .* x.S2 ./ (K2 .+ x.S2)
end

Its formulation is independent of which parameters are global, site-specific, or fixed during the model inversion. However, it cannot assume an ordering in the parameters, but needs to access the components by its symbolic names in the provided ComponentArray.

Likelihood function

HVI requires the evaluation of the likelihood of the predictions. It corresponds to the cost of predictions given some observations.

The user specifies a function of the negative log-Likehood neg_logden(obs, pred, uncertainty_parameters), where all of the parameters are arrays with columns for sites.

Here, we use the neg_logden_indep_normal function that assumed observations to be distributed independently normal around a true value. The provided y_unc uncertainty parameters, here, corresponds to logσ2, denoting the log of the variance parameter of the normal distribution.

py = neg_logden_indep_normal

Global-Site, transformations, and priors

Global and site-specific parameters

In this example, we will assign a fixed value to r0 parameter, treat the K2 parameter as unknown but the same across sites, and predict r1 and K1 for each site separately, based on covariates known at the sites.

Here we provide initial values for them by using ComponentVector.

FT = Float32
θM0 = θM = CA.ComponentVector{FT}(r1=0.5, K1=0.2) # separately for each individual
θP0 = θP = CA.ComponentVector{FT}(K2=2.0)  # population: same across individuals, 
θFix = CA.ComponentVector{FT}(r0=0.3) # r0, i.e. not estimated

Parameter Transformations

HVI allows for transformations of parameters in an unconstrained space, where the probability density is not strictly zero anywhere to the original constrained space.

Here, our model parameters are strictly positive, and we use the exponential function to transform unconstrained estimates to the original constrained domain.

transP = Stacked(HVI.Exp())  
transM = Stacked(HVI.Exp(), HVI.Exp())

Parameter transformations are specified using the Bijectors package. Because, Bijectors.elementwise(exp), has problems with automatic differentiation (AD) on GPU, we use the public but non-exported Exp wrapper inside Bijectors.Stacked.

Prior information on parameters at constrained scale

HVI is an approximate bayesian analysis and combines prior information on the parameters with the model and observed data.

Here, we provide a wide prior by fitting a Lognormal distributions to

  • the mode corresponding to the initial value provided above
  • the 0.95-quantile 3 times the mode

using the DistributionFits.jl package.

θall = vcat(θP, θM)
priors_dict = Dict{Symbol, Distribution}(
    keys(θall) .=> fit.(LogNormal, θall, QuantilePoint.(θall .* 3, 0.95), Val(:mode)))

Observations, model drivers and covariates

The model parameters are inverted using information on the

  • observed data, y_o
  • its uncertainty, y_unc
  • known covariates across sites, xM
  • model drivers, xP

Here, we use synthetic data generated by the package.

rng = StableRNG(111)
(; xM, xP, y_o, y_unc) = gen_hybridproblem_synthetic(
    rng, DoubleMM.DoubleMMCase(); scenario=Val((:omit_r0,)))

Lets look at them.

size(xM), size(xP), size(y_o), size(y_unc)
((5, 800), (16, 800), (8, 800), (8, 800))

All of them have 800 columns, corresponding to 800 sites. There are 5 site-covaritas, 16 values of model drivers, and 8 observations per site.

xP[:,1]
ComponentVector{Float32}(S1 = Float32[0.5, 0.5, 0.5, 0.5, 0.4, 0.3, 0.2, 0.1], S2 = Float32[1.0, 3.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0])

In each column of the model drivers there is a ComponentVector with components S1 and S2 corresponding to the concentrations, for which outputs were observed. This allows notation x.S1 in the PBM above.

The y_unc becomes its meaning by the Likelihood-function to be specified with the problem below.

Providing data in batches

HVI uses MLUtils.DataLoader to provide baches of the data during each iteration of the solver. In addition to the data, it provides an index to the sites inside a tuple.

n_site = size(y_o,2)
n_batch = 20
train_dataloader = MLUtils.DataLoader(
    (xM, xP, y_o, y_unc, 1:n_site), batchsize=n_batch, partial=false)

The Machine-Learning model

The machine-learning (ML) part predicts parameters of the posterior of site-specific PBM parameters, given the covariates. Here, we specify a 3-layer feed-forward neural network using the SimpleChains framework which works efficiently on CPU.

n_out = length(θM) # number of individuals to predict 
n_input = n_covar = size(xM,1)

g_chain = SimpleChain(
    static(n_input), # input dimension (optional)
    TurboDense{true}(tanh, n_input * 4),
    TurboDense{true}(tanh, n_input * 4),
    # dense layer without bias that maps to n outputs to (0..1)
    TurboDense{false}(logistic, n_out)
)
# get a template of the parameter vector, ϕg0
g_chain_app, ϕg0 = construct_ChainsApplicator(rng, g_chain)

The g_chain_app ChainsApplicator predicts the parameters of the posterior, approximation given a vector of ML weights,ϕg. During construction, an initial template of this vector is created. This abstraction layer allows to use different ML frameworks and replace the SimpleChains model by Flux or Lux.

Using priors to scale ML-parameter estimates

In order to balance gradients, the g_chain_app ModelApplicator defined above predicts on a scale (0..1). Now the priors are used to translate this to the parameter range by using the cumulative density distribution.

Priors were specified at constrained scale, but the ML model predicts parameters on unconstrained scale. This transformation of the distribution can be mathematically worked out for specific prior distribution forms. However, for simplicity, a NormalScalingModelApplicator is fitted to the transformed 5% and 95% quantiles of the original prior.

priorsM = [priors_dict[k] for k in keys(θM)]
lowers, uppers = get_quantile_transformed(priorsM, transM)
g_chain_scaled = NormalScalingModelApplicator(g_chain_app, lowers, uppers, FT)

The g_chain_scaled ModelApplicator now predicts in unconstrained scale, transforms logistic predctions around 0.5 to the range of high prior probability of the parameters, and transforms ML predictions near 0 or 1 towards the outer lower probability ranges.

Assembling the information

All the specifications above are stored in a HybridProblem structure.

Before, a PBMSiteApplicator is constructed that translates an invocation given a vector of global parameters, and a matrix of site parameters to invocation of the process based model (PBM), defined at the beginning.

f_batch = f_allsites = PBMSiteApplicator(f_doubleMM; θP, θM, θFix, xPvec=xP[:,1])

prob = HybridProblem(θP, θM, g_chain_scaled, ϕg0, 
    f_batch, f_allsites, priors_dict, py,
    transM, transP, train_dataloader, n_covar, n_site, n_batch)

Perform the inversion

Eventually, having assembled all the moving parts of the HVI, we can perform the inversion.

using OptimizationOptimisers
import Zygote

solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)

(; probo, interpreters) = solve(prob, solver; rng,
    callback = callback_loss(100), # output during fitting
    epochs = 2,
);

The solver object is constructed given the specific stochastic optimization algorithm and the number of Monte-Carlo samples that are drawn in each iteration from the predicted parameter posterior.

Then the solver is applied to the problem using solve for a given number of iterations or epochs. For this tutorial, we additionally specify that the function to transfer structures to the GPU is the identity function, so that all stays on the CPU, and this tutorial hence does not require ad GPU or GPU livraries.

Among the return values are

  • probo: A copy of the HybridProblem, with updated optimized parameters
  • interpreters: A NamedTuple with several ComponentArrayInterpreters that

will help analyzing the results.

Using a population-level process-based model

So far, the process-based model ram for each single site. For this simple model, some performance grains result from matrix-computations when running the model for all sites within one batch simultaneously.

In the following, the PBM specification accepts matrices as arguments for parameters and drivers and returns a matrix of precitions. For the parameters, one row corresponds to one site. For the drivers and predictions, one column corresponds to one site.

function f_doubleMM_sites(θc::CA.ComponentMatrix, xPc::CA.ComponentMatrix)
    # extract several covariates from xP
    ST = typeof(CA.getdata(xPc)[1:1,:])  # workaround for non-type-stable Symbol-indexing
    S1 = (CA.getdata(xPc[:S1,:])::ST)   
    S2 = (CA.getdata(xPc[:S2,:])::ST)
    #
    # extract the parameters as row-repeated vectors
    n_obs = size(S1, 1)
    VT = typeof(CA.getdata(θc)[:,1])   # workaround for non-type-stable Symbol-indexing
    (r0, r1, K1, K2) = map((:r0, :r1, :K1, :K2)) do par
        p1 = CA.getdata(θc[:, par]) ::VT
        repeat(p1', n_obs)  # matrix: same for each concentration row in S1
    end
    #
    # each variable is a matrix (n_obs x n_site)
    r0 .+ r1 .* S1 ./ (K1 .+ S1) .* S2 ./ (K2 .+ S2)
end

Again, the function should not rely on the order of parameters but use symbolic indexing to extract the parameter vectors. For type stability of this symbolic indexing, it uses a workaround to get the type of a single row. Similarly, it uses type hints to index into the drivers, xPc, to extract sub-matrices by symbols. Alternatively, here it could rely on the structure and ordering of the columns in xPc.

A corresponding PBMPopulationApplicator transforms calls with partitioned global and site parameters to calls of this matrix version of the PBM. The HVI Problem needs to be updated with this new applicatior.

f_batch = PBMPopulationApplicator(f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1])
f_allsites = PBMPopulationApplicator(f_doubleMM_sites, n_site; θP, θM, θFix, xPvec=xP[:,1])
probo_sites = HybridProblem(probo; f_batch, f_allsites)

For numerical efficiency, the number of sites within one batch is part of the PBMPopulationApplicator. Hence, we have two different functions, one applied to a batch of site, and another applied to all sites.

As a test of the new applicator, the results are refined by running a few more epochs of the optimization.

(; probo) = solve(probo_sites, solver; rng,
    callback = callback_loss(100), # output during fitting
    epochs = 20,
    #is_inferred = Val(true), # activate type-checks 
);

Saving the results

Extracting useful information from the optimized HybridProblem is covered in the following Inspect results of fitted problem tutorial. In order to use the results from this tutorial in other tutorials, the updated probo HybridProblem and the interpreters are saved to a JLD2 file.

Before the problem is updated to use the redefinition DoubleMM.f_doubleMM_sites of the PBM in module DoubleMM rather than module Main to allow for easier reloading with JLD2.

f_batch = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_batch; θP, θM, θFix, xPvec=xP[:,1])
f_allsites = PBMPopulationApplicator(DoubleMM.f_doubleMM_sites, n_site; θP, θM, θFix, xPvec=xP[:,1])
probo2 = HybridProblem(probo; f_batch, f_allsites)
using JLD2
fname = "intermediate/basic_cpu_results.jld2"
mkpath("intermediate")
if probo2 isa AbstractHybridProblem # do not save on failure above
    jldsave(fname, false, IOStream; probo=probo2, interpreters)
end