LSTM Hybrid Model with EasyHybrid.jl
This tutorial demonstrates how to use EasyHybrid to train a hybrid model with LSTM neural networks on synthetic data for respiration modeling with Q10 temperature sensitivity. The code for this tutorial can be found in docs/src/literate/tutorials => example_synthetic_lstm.jl.
1. Load Packages
Set project path and activate environment
using EasyHybrid
using AxisKeys
using DimensionalData
using Lux2. Data Loading and Preprocessing
Load synthetic dataset from GitHub - it's tabular data
df = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc");Select a subset of data for faster execution
df = df[1:20000, :];
first(df, 5);3. Define Neural Network Architectures
Define a standard feedforward neural network
NN = Chain(Dense(15, 15, Lux.sigmoid), Dense(15, 15, Lux.sigmoid), Dense(15, 1))Chain(
layer_(1-2) = Dense(15 => 15, σ), # 480 (240 x 2) parameters
layer_3 = Dense(15 => 1), # 16 parameters
) # Total: 496 parameters,
# plus 0 states.Define LSTM-based neural network with memory
TIP
When the Chain ends with a Recurrence layer, EasyHybrid automatically adds a RecurrenceOutputDense layer to handle the sequence output format. The user only needs to define the Recurrence layer itself!
NN_Memory = Chain(
Recurrence(LSTMCell(15 => 15), return_sequence = true),
)Chain(
layer_1 = Recurrence(
cell = LSTMCell(15 => 15), # 1_920 parameters, plus 1 non-trainable
),
) # Total: 1_920 parameters,
# plus 1 states.4. We define the process-based model, a classical Q10 model for respiration
"""
RbQ10(; ta, Q10, rb, tref=15.0f0)
Respiration model with Q10 temperature sensitivity.
- `ta`: air temperature [°C]
- `Q10`: temperature sensitivity factor [-]
- `rb`: basal respiration rate [μmol/m²/s]
- `tref`: reference temperature [°C] (default: 15.0)
"""
function RbQ10(; ta, Q10, rb, tref = 15.0f0)
reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref))
return (; reco, Q10, rb)
end5. Define Model Parameters
Parameter specification: (default, lower_bound, upper_bound)
parameters = (
rb = (3.0f0, 0.0f0, 13.0f0), # Basal respiration [μmol/m²/s]
Q10 = (2.0f0, 1.0f0, 4.0f0), # Temperature sensitivity factor [-]
)(rb = (3.0f0, 0.0f0, 13.0f0), Q10 = (2.0f0, 1.0f0, 4.0f0))6. Configure Hybrid Model Components
Define input variables Forcing variables (temperature)
forcing = [:ta]1-element Vector{Symbol}:
:taPredictor variables (solar radiation, and its derivative)
predictors = [:sw_pot, :dsw_pot]2-element Vector{Symbol}:
:sw_pot
:dsw_potTarget variable (respiration)
target = [:reco]1-element Vector{Symbol}:
:recoParameter classification Global parameters (same for all samples)
global_param_names = [:Q10]1-element Vector{Symbol}:
:Q10Neural network predicted parameters
neural_param_names = [:rb]1-element Vector{Symbol}:
:rb7. Construct LSTM Hybrid Model
Create LSTM hybrid model using the unified constructor
hlstm = constructHybridModel(
predictors,
forcing,
target,
RbQ10,
parameters,
neural_param_names,
global_param_names,
hidden_layers = NN_Memory, # Neural network architecture
scale_nn_outputs = true, # Scale neural network outputs
input_batchnorm = false # Apply batch normalization to inputs
)Hybrid Model (Single NN)
Neural Network:
Chain(
layer_1 = WrappedFunction(identity),
layer_2 = Dense(2 => 15, tanh), # 45 parameters
layer_3 = Recurrence(
cell = LSTMCell(15 => 15), # 1_920 parameters, plus 1 non-trainable
),
layer_4 = RecurrenceOutputDense(
layer = Dense(15 => 15, tanh), # 240 parameters
),
layer_5 = Dense(15 => 1), # 16 parameters
) # Total: 2_221 parameters,
# plus 1 states.
Configuration:
predictors = [:sw_pot, :dsw_pot]
forcing = [:ta]
targets = [:reco]
mechanistic_model = RbQ10
neural_param_names = [:rb]
global_param_names = [:Q10]
fixed_param_names = Symbol[]
scale_nn_outputs = true
start_from_default = true
config = (; hidden_layers = Chain{@NamedTuple{layer_1::Lux.Recurrence{Static.True, Lux.LSTMCell{Static.False, Static.False, Int64, Int64, NTuple{4, Nothing}, NTuple{4, Nothing}, NTuple{4, Nothing}, typeof(WeightInitializers.zeros32), typeof(WeightInitializers.zeros32), Static.True}, Lux.BatchLastIndex}}, Nothing}((layer_1 = Lux.Recurrence{Static.True, Lux.LSTMCell{Static.False, Static.False, Int64, Int64, NTuple{4, Nothing}, NTuple{4, Nothing}, NTuple{4, Nothing}, typeof(WeightInitializers.zeros32), typeof(WeightInitializers.zeros32), Static.True}, Lux.BatchLastIndex}(LSTMCell(15 => 15), Lux.BatchLastIndex(), static(true), false, false),), nothing), activation = tanh, scale_nn_outputs = true, input_batchnorm = false, start_from_default = true,)
Parameters:
Hybrid Parameters
┌─────┬─────────┬───────┬───────┐
│ │ default │ lower │ upper │
├─────┼─────────┼───────┼───────┤
│ rb │ 3.0 │ 0.0 │ 13.0 │
│ Q10 │ 2.0 │ 1.0 │ 4.0 │
└─────┴─────────┴───────┴───────┘8. Data Preparation Steps (Demonstration)
The following steps demonstrate what happens under the hood during training. In practice, you can skip to Section 9 and use the train function directly.
:KeyedArray and :DimArray are supported
pref_array_type = :DimArray
x, y = prepare_data(hlstm, df, array_type = pref_array_type);Convert a (single) time series into many training samples by windowing.
Each sample consists of:
input_window: number of past steps given to the model (sequence length)output_window: number of steps to predictoutput_shift: stride between consecutive windows (controls overlap)lead_time: prediction lead (e.g. lead_time=1 predicts starting 1 step ahead)
This supports many-to-one / many-to-many forecasting depending on output_window. Creates an array of shape (variable, time, batch_size) with variable being feature, time the input window, andbatch_size1:n samples (= full batch)
output_shift = 1
output_window = 1
input_window = 10
xs, ys = split_into_sequences(x, y; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0);
ys_nan = .!isnan.(ys);First input_window/sample
xs[:, :, 1]┌ 3×10 DimArray{Float32, 2} ┐
├───────────────────────────┴─────────────────────────────────────── dims ┐
↓ variable Categorical{Symbol} [:sw_pot, …, :ta] Unordered,
→ time Categorical{Symbol} [:x9_to_x9, …, :x9_to_x0_y0] ReverseOrdered
└─────────────────────────────────────────────────────────────────────────┘
↓ → :x9_to_x9 :x9_to_x8 … :x9_to_x1 :x9_to_x0_y0
:sw_pot 109.817 109.817 109.817 109.817
:dsw_pot 115.595 115.595 115.595 115.595
:ta 2.1 1.98 1.11 0.97Second input_window/sample
xs[:, :, 2]┌ 3×10 DimArray{Float32, 2} ┐
├───────────────────────────┴─────────────────────────────────────── dims ┐
↓ variable Categorical{Symbol} [:sw_pot, …, :ta] Unordered,
→ time Categorical{Symbol} [:x9_to_x9, …, :x9_to_x0_y0] ReverseOrdered
└─────────────────────────────────────────────────────────────────────────┘
↓ → :x9_to_x9 :x9_to_x8 … :x9_to_x1 :x9_to_x0_y0
:sw_pot 109.817 109.817 109.817 109.817
:dsw_pot 115.595 115.595 115.595 115.595
:ta 1.98 1.89 0.97 0.87test of shift
xs[:, output_shift + 1, 1] == xs[:, 1, 2]trueFirst output_window/sample with time label like :x30_to_x5_y4 which indicates an accumulation of memory from x30 to x5 for the prediction of y4
ys[:, :, 1]┌ 1×1 DimArray{Float32, 2} ┐
├──────────────────────────┴────────────────────────── dims ┐
↓ variable Categorical{Symbol} [:reco] Unordered,
→ time Categorical{Symbol} [:x9_to_x0_y0] ForwardOrdered
└───────────────────────────────────────────────────────────┘
↓ → :x9_to_x0_y0
:reco 0.806911Second output_window/sample
ys[:, :, 2]┌ 1×1 DimArray{Float32, 2} ┐
├──────────────────────────┴────────────────────────── dims ┐
↓ variable Categorical{Symbol} [:reco] Unordered,
→ time Categorical{Symbol} [:x9_to_x0_y0] ForwardOrdered
└───────────────────────────────────────────────────────────┘
↓ → :x9_to_x0_y0
:reco 0.803646Any of the first output_window the same as the second output_window? ideally not big overlap
overlap = output_window - output_shift
overlap_length = sum(in(ys[:, :, 1]), ys[:, :, 2])0Split the (windowed) dataset into train/validation in the same way as train does.
sdf = split_data(df, hlstm, sequence_kwargs = (; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0), array_type = pref_array_type);
(x_train, y_train), (x_val, y_val) = sdf;
x_train
y_train
y_train_nan = .!isnan.(y_train)┌ 1×1×15993 DimArray{Bool, 3} ┐
├─────────────────────────────┴──────────────────────────────────────── dims ┐
↓ variable Categorical{Symbol} [:reco] Unordered,
→ time Categorical{Symbol} [:x9_to_x0_y0] ForwardOrdered,
↗ batch_size Sampled{Int64} [1, …, 15993] ForwardOrdered Irregular Points
└────────────────────────────────────────────────────────────────────────────┘
[:, :, 1]
↓ → :x9_to_x0_y0
:reco 1Wrap the training windows/samples in a DataLoader to form batches.
WARNING
batchsize is the number of windows/samples used per gradient step to update the parameters. Processing 32 windows in one array is usually much faster than doing 32 separate forward/backward passes with batch_size=1.
train_dl = EasyHybrid.DataLoader((x_train, y_train); batchsize = 32);Run hybrid model forwards
x_first = first(train_dl)[1]
y_first = first(train_dl)[2]
ps, st = Lux.setup(Random.default_rng(), hlstm);
frun = hlstm(x_first, ps, st);┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore ~/.julia/packages/LuxCore/kQC9S/src/LuxCore.jl:18
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore ~/.julia/packages/LuxCore/kQC9S/src/LuxCore.jl:18Extract predicted yhat
reco_mod = frun[1].reco┌ 10×32 DimArray{Float32, 2} ┐
├────────────────────────────┴─────────────────────────────────────── dims ┐
↓ time Categorical{Symbol} [:x9_to_x9, …, :x9_to_x0_y0] ReverseOrdered,
→ batch_size Sampled{Int64} [1, …, 32] ForwardOrdered Irregular Points
└──────────────────────────────────────────────────────────────────────────┘
↓ → 1 2 3 … 30 31 32
:x9_to_x9 3.51847 3.48933 3.46763 3.86629 3.79197 3.75796
:x9_to_x8 3.75827 3.7349 3.77917 4.08424 4.04761 3.80282
:x9_to_x7 3.80997 3.85513 3.86316 4.12896 3.87926 3.88195
:x9_to_x6 3.88699 3.89508 3.80436 3.91132 3.91403 3.84412
⋮ ⋱ ⋮
:x9_to_x3 3.75994 3.72105 3.65712 3.74693 3.66727 3.70047
:x9_to_x2 3.72041 3.65649 3.62118 3.66664 3.69983 3.63627
:x9_to_x1 3.65534 3.62004 3.59504 3.69867 3.63513 3.71666
:x9_to_x0_y0 3.61875 3.59376 3.52468 … 3.63383 3.71533 3.71276Bring observations in same shape
reco_obs = dropdims(y_first, dims = 1)
reco_nan = .!isnan.(reco_obs);Compute loss
EasyHybrid.compute_loss(hlstm, ps, st, (x_train, (y_train, y_train_nan)), logging = LoggingLoss(train_mode = true))(21.427757f0, (st_nn = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (rng = TaskLocalRNG(),), layer_4 = NamedTuple(), layer_5 = NamedTuple()), fixed = NamedTuple()), NamedTuple())9. Train LSTM Hybrid Model
out_lstm = train(
hlstm,
df,
();
nepochs = 100, # Number of training epochs
batchsize = 128, # Batch size of training windows/samples
opt = RMSProp(0.01), # Optimizer and learning rate
monitor_names = [:rb, :Q10], # Parameters to monitor during training
yscale = identity, # Scaling for outputs
shuffleobs = true,
training_loss = :nseLoss,
loss_types = [:nse],
sequence_kwargs = (; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0),
plotting = false,
show_progress = false,
input_batchnorm = false,
array_type = pref_array_type
);
out_lstm.val_obs_pred1×2 Matrix{Any}:
Dict{Any, DataFrame}(:x9_to_x0_y0=>3998×2 DataFrame
Row │ reco index
│ Float32 Int64
──────┼──────────────────
1 │ 2.5595 14511
2 │ 0.0883865 2358
3 │ 0.143945 1741
4 │ 6.17557 10882
5 │ 1.35721 16259
6 │ 0.474447 17659
7 │ 0.587843 4624
8 │ 1.01569 16939
⋮ │ ⋮ ⋮
3992 │ 5.44318 8129
3993 │ 2.60681 5641
3994 │ 4.5168 9359
3995 │ 2.79248 14367
3996 │ 5.35347 11034
3997 │ 0.514723 3758
3998 │ 1.75363 16482
3983 rows omitted) … 39980×1 DataFrame
Row │ reco_pred
│ Float32
───────┼───────────
1 │ 2.95494
2 │ 2.9369
3 │ 2.91584
4 │ 2.89814
5 │ 2.81676
6 │ 2.72723
7 │ 2.69439
8 │ 2.69111
⋮ │ ⋮
39974 │ 2.08417
39975 │ 2.12563
39976 │ 2.14105
39977 │ 2.16433
39978 │ 2.15962
39979 │ 2.17626
39980 │ 2.1728
39965 rows omitted10. Train Single NN Hybrid Model (Optional)
For comparison, we can also train a hybrid model with a standard feed-forward neural network
hm = constructHybridModel(
predictors,
forcing,
target,
RbQ10,
parameters,
neural_param_names,
global_param_names,
hidden_layers = NN, # Neural network architecture
scale_nn_outputs = true, # Scale neural network outputs
input_batchnorm = false, # Apply batch normalization to inputs
);Train the hybrid model
single_nn_out = train(
hm,
df,
();
nepochs = 100, # Number of training epochs
batchsize = 128, # Batch size for training
opt = RMSProp(0.01), # Optimizer and learning rate
monitor_names = [:rb, :Q10], # Parameters to monitor during training
yscale = identity, # Scaling for outputs
shuffleobs = true,
training_loss = :nseLoss,
loss_types = [:nse],
array_type = :DimArray,
plotting = false,
show_progress = false,
);[ Info: Plotting disabled.
[ Info: Training data type: DimensionalData.DimMatrix{Float32, Tuple{DimensionalData.Dimensions.Dim{:variable, DimensionalData.Dimensions.Lookups.Categorical{Symbol, SubArray{Symbol, 1, Vector{Symbol}, Tuple{Base.Slice{Base.OneTo{Int64}}}, true}, DimensionalData.Dimensions.Lookups.Unordered, DimensionalData.Dimensions.Lookups.NoMetadata}}, DimensionalData.Dimensions.Dim{:batch_size, DimensionalData.Dimensions.Lookups.Sampled{Int64, SubArray{Int64, 1, UnitRange{Int64}, Tuple{Vector{Int64}}, false}, DimensionalData.Dimensions.Lookups.ForwardOrdered, DimensionalData.Dimensions.Lookups.Irregular{Tuple{Nothing, Nothing}}, DimensionalData.Dimensions.Lookups.Points, DimensionalData.Dimensions.Lookups.NoMetadata}}}, Tuple{}, SubArray{Float32, 2, Matrix{Float32}, Tuple{Base.Slice{Base.OneTo{Int64}}, Vector{Int64}}, false}, DimensionalData.NoName, DimensionalData.Dimensions.Lookups.NoMetadata}
[ Info: Check the saved output (.png, .mp4, .jld2) from training at: /home/runner/work/EasyHybrid.jl/EasyHybrid.jl/docs/build
[ Info: Returning best model from epoch 79 of 100 epochs with best validation loss wrt nse: 0.9974605Close enough
out_lstm.best_loss
single_nn_out.best_loss0.9974605f0