Cross-Validation in EasyHybrid.jl
This tutorial demonstrates one option for cross-validation in EasyHybrid. The code for this tutorial can be found in docs/src/literate/tutorials => folds.jl.
1. Load Packages
julia
using EasyHybrid
using OhMyThreads
using CairoMakie
2. Data Loading and Preprocessing
Load synthetic dataset from GitHub
julia
df = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc");
Select a subset of data for faster execution
julia
df = df[1:20000, :];
first(df, 5)
5×6 DataFrame
Row | time | sw_pot | dsw_pot | ta | reco | rb |
---|---|---|---|---|---|---|
DateTime | Float64? | Float64? | Float64? | Float64? | Float64? | |
1 | 2003-01-01T00:15:00 | 109.817 | 115.595 | 2.1 | 0.844741 | 1.42522 |
2 | 2003-01-01T00:45:00 | 109.817 | 115.595 | 1.98 | 0.840641 | 1.42522 |
3 | 2003-01-01T01:15:00 | 109.817 | 115.595 | 1.89 | 0.837579 | 1.42522 |
4 | 2003-01-01T01:45:00 | 109.817 | 115.595 | 2.06 | 0.843372 | 1.42522 |
5 | 2003-01-01T02:15:00 | 109.817 | 115.595 | 2.09 | 0.844399 | 1.42522 |
3. Define the Physical Model
julia
"""
RbQ10(; ta, Q10, rb, tref=15.0f0)
Respiration model with Q10 temperature sensitivity.
- `ta`: air temperature [°C]
- `Q10`: temperature sensitivity factor [-]
- `rb`: basal respiration rate [μmol/m²/s]
- `tref`: reference temperature [°C] (default: 15.0)
"""
function RbQ10(; ta, Q10, rb, tref = 15.0f0)
reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref))
return (; reco, Q10, rb)
end
Main.RbQ10
4. Define Model Parameters
Parameter specification: (default, lower_bound, upper_bound)
julia
parameters = (
rb = (3.0f0, 0.0f0, 13.0f0),
Q10 = (2.0f0, 1.0f0, 4.0f0)
)
(rb = (3.0f0, 0.0f0, 13.0f0),
Q10 = (2.0f0, 1.0f0, 4.0f0),)
5. Configure Hybrid Model Components
Define input variables Forcing variables (temperature)
julia
forcing = [:ta]
1-element Vector{Symbol}:
:ta
Predictor variables (solar radiation, and its derivative)
julia
predictors = [:sw_pot, :dsw_pot]
2-element Vector{Symbol}:
:sw_pot
:dsw_pot
Target variable (respiration)
julia
target = [:reco]
1-element Vector{Symbol}:
:reco
Parameter classification Global parameters (same for all samples)
julia
global_param_names = [:Q10]
1-element Vector{Symbol}:
:Q10
Neural network predicted parameters
julia
neural_param_names = [:rb]
1-element Vector{Symbol}:
:rb
6. Construct the Hybrid Model
julia
hybrid_model = constructHybridModel(
predictors,
forcing,
target,
RbQ10,
parameters,
neural_param_names,
global_param_names,
hidden_layers = [16, 16],
activation = sigmoid,
scale_nn_outputs = true,
input_batchnorm = true
)
Neural Network:
Chain(
layer_1 = BatchNorm(2, affine=false, track_stats=true),
layer_2 = Dense(2 => 16, σ), # 48 parameters
layer_3 = Dense(16 => 16, σ), # 272 parameters
layer_4 = Dense(16 => 1), # 17 parameters
) # Total: 337 parameters,
# plus 5 states.
Predictors: [:sw_pot, :dsw_pot]
Forcing: [:ta]
Neural parameters: [:rb]
Global parameters: [:Q10]
Fixed parameters: Symbol[]
Scale NN outputs: true
Parameter defaults and bounds:
HybridParams{typeof(Main.RbQ10)}(
┌─────┬─────────┬───────┬───────┐
│ │ default │ lower │ upper │
├─────┼─────────┼───────┼───────┤
│ rb │ 3.0 │ 0.0 │ 13.0 │
│ Q10 │ 2.0 │ 1.0 │ 4.0 │
└─────┴─────────┴───────┴───────┘
)
7. Model Training: k-Fold Cross-Validation
julia
k = 3
folds = make_folds(df, k = k, shuffle = true)
results = Vector{Any}(undef, k)
@time @tasks for val_fold in 1:k
@info "Split data outside of train function. Training fold $val_fold of $k"
sdata = split_data(df, hybrid_model; val_fold = val_fold, folds = folds)
out = train(
hybrid_model,
sdata,
();
nepochs = 10,
patience = 10,
batchsize = 512, # Batch size for training
opt = RMSProp(0.001), # Optimizer and learning rate
monitor_names = [:rb, :Q10],
hybrid_name = "folds_$(val_fold)",
folder_to_save = "CV_results",
file_name = "trained_model_folds_$(val_fold).jld2",
show_progress = false,
plotting = false
)
results[val_fold] = out
end
[ Info: Split data outside of train function. Training fold 1 of 3
┌ Warning: shuffleobs is not supported when using folds and val_fold, this will be ignored and should be done during fold constructions
└ @ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/train.jl:425
[ Info: K-fold via external assignments: val_fold=1 → train=13333 val=6667
[ Info: Plotting disabled.
┌ Warning: data was prepared already, none of the keyword arguments for split_data will be used
└ @ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/train.jl:384
[ Info: Check the saved output (.png, .mp4, .jld2) from training at: /home/runner/work/EasyHybrid.jl/EasyHybrid.jl/docs/build
[ Info: Returning best model from epoch 10 of 10 epochs with best validation loss wrt mse: 0.48590943
[ Info: Split data outside of train function. Training fold 2 of 3
┌ Warning: shuffleobs is not supported when using folds and val_fold, this will be ignored and should be done during fold constructions
└ @ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/train.jl:425
[ Info: K-fold via external assignments: val_fold=2 → train=13333 val=6667
[ Info: Plotting disabled.
┌ Warning: data was prepared already, none of the keyword arguments for split_data will be used
└ @ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/train.jl:384
[ Info: Check the saved output (.png, .mp4, .jld2) from training at: /home/runner/work/EasyHybrid.jl/EasyHybrid.jl/docs/build
[ Info: Returning best model from epoch 10 of 10 epochs with best validation loss wrt mse: 0.5063128
[ Info: Split data outside of train function. Training fold 3 of 3
┌ Warning: shuffleobs is not supported when using folds and val_fold, this will be ignored and should be done during fold constructions
└ @ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/train.jl:425
[ Info: K-fold via external assignments: val_fold=3 → train=13334 val=6666
[ Info: Plotting disabled.
┌ Warning: data was prepared already, none of the keyword arguments for split_data will be used
└ @ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/train.jl:384
[ Info: Check the saved output (.png, .mp4, .jld2) from training at: /home/runner/work/EasyHybrid.jl/EasyHybrid.jl/docs/build
[ Info: Returning best model from epoch 10 of 10 epochs with best validation loss wrt mse: 0.51476365
11.821922 seconds (27.71 M allocations: 1.882 GiB, 2.63% gc time, 76.18% compilation time: 29% of which was recompilation)