How to move computations to GPU
This guide shows how to configure the setup and inversion of a HybridProblem so that computations of the ML model and maybe also the process-based model are executed on GPU.
Motivation
Machine learning is often accelerated by moving computations form CPU to GPU. So does HVI.
First load necessary packages.
using HybridVariationalInference
using ComponentArrays: ComponentArrays as CA
using Bijectors
using Lux
using SimpleChains # only loading save object
using StatsFuns
using StableRNGs
using MLUtils
using JLD2
using Random
using MLDataDevices
# using CairoMakie
# using PairPlots # scatterplot matrices
This tutorial reuses and modifies the fitted object saved at the end of the Basic workflow without GPU tutorial.
fname = "intermediate/basic_cpu_results.jld2"
print(abspath(fname))
prob = probo_chain = load(fname, "probo");
Updating the ML model of the problem to use LUX
Because the SimpleChains ML model used in the basic tutorial does not support GPU, we reconstruct the model using the LUX framework. Note that all the setup is almost the same, as in the basic workflow. The only difference is that a Lux.Chains
object is provided to construct_ChainsApplicator
.
n_out = length(prob.θM) # number of individuals to predict
n_covar = 5 #size(xM,1)
n_input = n_covar
g_lux = Lux.Chain(
Lux.Dense(n_covar => n_covar * 4, tanh),
Lux.Dense(n_covar * 4 => n_covar * 4, tanh),
Lux.Dense(n_covar * 4 => n_out, logistic, use_bias = false)
)
# get a template of the parameter vector, ϕg0
rng = StableRNG(111)
g_chain_app, ϕg0 = construct_ChainsApplicator(rng, g_lux)
#
priorsM = [prob.priors[k] for k in keys(prob.θM)]
lowers, uppers = get_quantile_transformed(priorsM, prob.transM)
FT = eltype(prob.θM)
g_chain_scaled = NormalScalingModelApplicator(g_chain_app, lowers, uppers, FT)
Update the HybridProblem
to use this ML model.
prob_lux = HybridProblem(probo_chain; g=g_chain_scaled, ϕg=ϕg0)
Specifying GPU devices during solve
The solve
method for the HybridPosteriorSolver accepts argument gdevs
, Its a NamedTuple
with entriesgdev_M
and gdev_P
, for the ML model on and the process-basee model (PBM) respectively. They specify functions that are applied to move callables and data to GPU.
They default to identity
, meaning that nothing is moved from CPU to GPU. Function gpu_device()
from package MLDataDevices
can be used instead for the standard GPU device.
Hence specify
gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device())
: to move both ML model and PBM to GPUgdevs = (; gdev_M=gpu_device(), gdev_P=identity)
: to move both ML model to GPU but execute the PBM (and parameter transformation) on CPU
Currently, putting the PBM on gpu is not efficient during inversion, because prior distribution needs to be evaluated for each sample. However, sampling and prediction using a fitted model is efficient.
In addition, the libraries of the GPU device need to be activated by importing respective Julia packages. Currently, only CUDA is tested with this HybridVariationalInference
package.
import CUDA, cuDNN # so that gpu_device() returns a CUDADevice
using OptimizationOptimisers
import Zygote
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
(; probo) = solve(prob_lux, solver;
callback = callback_loss(100),
epochs = 10,
gdevs = (; gdev_M=gpu_device(), gdev_P=identity)
); probo_lux = probo;
Moving results from GPU to CPU
The sampling and prediction methods, also take this gdevs
keyword argument.
n_sample_pred = 400
(y_dev, θsP_dev, θsMs_dev) = (; y, θsP, θsMs) = predict_hvi(
rng, probo_lux; n_sample_pred,
gdevs = (; gdev_M=gpu_device(), gdev_P=gpu_device()));
If gdev_P
is not an AbstractGPUDevice
then all the results are on CPU. If gdev_P
is an AbstractGPUDevice
then the results are GPUArrays and need to be transferred to CPU.
typeof(θsMs_dev)
ComponentArrays.ComponentArray{Float32, 3, CUDA.CuArray{Float32, 3, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(i = 1:800,)}, ComponentArrays.Axis{(r1 = 1, K1 = 2)}, ComponentArrays.Axis{(i = 1:400,)}}}
Handling of a ComponentArrays
backed by GPUArrays can result in errors of scalar indexing. Therefore, use a semicolon to suppress printing. Also for moving the ComponentArrays
to CPU, use function apply_preserve_axes
to circumvent this error.
cdev = cpu_device()
y = cdev(y_dev)
θsP = apply_preserve_axes(cdev, θsP_dev)
θsMs = apply_preserve_axes(cdev, θsMs_dev)