Skip to content

Cross-Validation in EasyHybrid.jl

This tutorial demonstrates one option for cross-validation in EasyHybrid. The code for this tutorial can be found in docs/src/literate/tutorials => folds.jl.

1. Load Packages

julia
using EasyHybrid
using OhMyThreads
using CairoMakie

2. Data Loading and Preprocessing

Load synthetic dataset from GitHub

julia
df = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc");

Select a subset of data for faster execution

julia
df = df[1:20000, :];
first(df, 5)
5×6 DataFrame
Rowtimesw_potdsw_pottarecorb
DateTimeFloat64?Float64?Float64?Float64?Float64?
12003-01-01T00:15:00109.817115.5952.10.8447411.42522
22003-01-01T00:45:00109.817115.5951.980.8406411.42522
32003-01-01T01:15:00109.817115.5951.890.8375791.42522
42003-01-01T01:45:00109.817115.5952.060.8433721.42522
52003-01-01T02:15:00109.817115.5952.090.8443991.42522

3. Define the Physical Model

julia
"""
    RbQ10(; ta, Q10, rb, tref=15.0f0)

Respiration model with Q10 temperature sensitivity.

- `ta`: air temperature [°C]
- `Q10`: temperature sensitivity factor [-]
- `rb`: basal respiration rate [μmol/m²/s]
- `tref`: reference temperature [°C] (default: 15.0)
"""
function RbQ10(; ta, Q10, rb, tref = 15.0f0)
    reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref))
    return (; reco, Q10, rb)
end
Main.RbQ10

4. Define Model Parameters

Parameter specification: (default, lower_bound, upper_bound)

julia
parameters = (
    rb  = (3.0f0, 0.0f0, 13.0f0),
    Q10 = (2.0f0, 1.0f0, 4.0f0)
)
(rb = (3.0f0, 0.0f0, 13.0f0),
 Q10 = (2.0f0, 1.0f0, 4.0f0),)

5. Configure Hybrid Model Components

Define input variables Forcing variables (temperature)

julia
forcing     = [:ta]
1-element Vector{Symbol}:
 :ta

Predictor variables (solar radiation, and its derivative)

julia
predictors  = [:sw_pot, :dsw_pot]
2-element Vector{Symbol}:
 :sw_pot
 :dsw_pot

Target variable (respiration)

julia
target      = [:reco]
1-element Vector{Symbol}:
 :reco

Parameter classification Global parameters (same for all samples)

julia
global_param_names = [:Q10]
1-element Vector{Symbol}:
 :Q10

Neural network predicted parameters

julia
neural_param_names = [:rb]
1-element Vector{Symbol}:
 :rb

6. Construct the Hybrid Model

julia
hybrid_model = constructHybridModel(
    predictors,
    forcing,
    target,
    RbQ10,
    parameters,
    neural_param_names,
    global_param_names,
    hidden_layers = [16, 16],
    activation = sigmoid,
    scale_nn_outputs = true,
    input_batchnorm = true
)
Neural Network:
    Chain(
        layer_1 = BatchNorm(2, affine=false, track_stats=true),
        layer_2 = Dense(2 => 16, σ),                  # 48 parameters
        layer_3 = Dense(16 => 16, σ),                 # 272 parameters
        layer_4 = Dense(16 => 1),                     # 17 parameters
    )         # Total: 337 parameters,
              #        plus 5 states.
Predictors: [:sw_pot, :dsw_pot]
Forcing: [:ta]
Neural parameters: [:rb]
Global parameters: [:Q10]
Fixed parameters: Symbol[]
Scale NN outputs: true
Parameter defaults and bounds:
    HybridParams{typeof(Main.RbQ10)}(
    ┌─────┬─────────┬───────┬───────┐
    │     │ default │ lower │ upper │
    ├─────┼─────────┼───────┼───────┤
    │  rb │     3.0 │   0.0 │  13.0 │
    │ Q10 │     2.0 │   1.0 │   4.0 │
    └─────┴─────────┴───────┴───────┘
    )

7. Model Training: k-Fold Cross-Validation

julia
k = 3
folds = make_folds(df, k = k, shuffle = true)

results = Vector{Any}(undef, k)

@time @tasks for val_fold in 1:k
    @info "Split data outside of train function. Training fold $val_fold of $k"
    sdata = split_data(df, hybrid_model; val_fold = val_fold, folds = folds)
    out = train(
        hybrid_model,
        sdata,
        ();
        nepochs = 10,
        patience = 10,
        batchsize = 512,         # Batch size for training
        opt = RMSProp(0.001),    # Optimizer and learning rate
        monitor_names = [:rb, :Q10],
        hybrid_name = "folds_$(val_fold)",
        folder_to_save = "CV_results",
        file_name = "trained_model_folds_$(val_fold).jld2",
        show_progress = false,
        plotting = false
    )
    results[val_fold] = out
end
[ Info: Split data outside of train function. Training fold 1 of 3
Warning: shuffleobs is not supported when using folds and val_fold, this will be ignored and should be done during fold constructions
@ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/train.jl:425
[ Info: K-fold via external assignments: val_fold=1 → train=13333 val=6667
[ Info: Plotting disabled.
Warning: data was prepared already, none of the keyword arguments for split_data will be used
@ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/train.jl:384
[ Info: Check the saved output (.png, .mp4, .jld2) from training at: /home/runner/work/EasyHybrid.jl/EasyHybrid.jl/docs/build
[ Info: Returning best model from epoch 10 of 10 epochs with best validation loss wrt mse: 0.48590943
[ Info: Split data outside of train function. Training fold 2 of 3
Warning: shuffleobs is not supported when using folds and val_fold, this will be ignored and should be done during fold constructions
@ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/train.jl:425
[ Info: K-fold via external assignments: val_fold=2 → train=13333 val=6667
[ Info: Plotting disabled.
Warning: data was prepared already, none of the keyword arguments for split_data will be used
@ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/train.jl:384
[ Info: Check the saved output (.png, .mp4, .jld2) from training at: /home/runner/work/EasyHybrid.jl/EasyHybrid.jl/docs/build
[ Info: Returning best model from epoch 10 of 10 epochs with best validation loss wrt mse: 0.5063128
[ Info: Split data outside of train function. Training fold 3 of 3
Warning: shuffleobs is not supported when using folds and val_fold, this will be ignored and should be done during fold constructions
@ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/train.jl:425
[ Info: K-fold via external assignments: val_fold=3 → train=13334 val=6666
[ Info: Plotting disabled.
Warning: data was prepared already, none of the keyword arguments for split_data will be used
@ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/train.jl:384
[ Info: Check the saved output (.png, .mp4, .jld2) from training at: /home/runner/work/EasyHybrid.jl/EasyHybrid.jl/docs/build
[ Info: Returning best model from epoch 10 of 10 epochs with best validation loss wrt mse: 0.51476365
 11.821922 seconds (27.71 M allocations: 1.882 GiB, 2.63% gc time, 76.18% compilation time: 29% of which was recompilation)