LSTM Hybrid Model with EasyHybrid.jl
This tutorial demonstrates how to use EasyHybrid to train a hybrid model with LSTM neural networks on synthetic data for respiration modeling with Q10 temperature sensitivity. The code for this tutorial can be found in docs/src/literate/tutorials => example_synthetic_lstm.jl.
1. Load Packages
Set project path and activate environment
using EasyHybrid
using AxisKeys
using DimensionalData2. Data Loading and Preprocessing
Load synthetic dataset from GitHub - it's tabular data
df = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc");Select a subset of data for faster execution
df = df[1:20000, :];
first(df, 5);3. Define Neural Network Architectures
Define a standard feedforward neural network
NN = Chain(Dense(15, 15, Lux.sigmoid), Dense(15, 15, Lux.sigmoid), Dense(15, 1))Chain(
layer_(1-2) = Dense(15 => 15, σ), # 480 (240 x 2) parameters
layer_3 = Dense(15 => 1), # 16 parameters
) # Total: 496 parameters,
# plus 0 states.Define LSTM-based neural network with memory
TIP
When the Chain ends with a Recurrence layer, EasyHybrid automatically adds a RecurrenceOutputDense layer to handle the sequence output format. The user only needs to define the Recurrence layer itself!
NN_Memory = Chain(
Recurrence(LSTMCell(15 => 15), return_sequence = true),
)Chain(
layer_1 = Recurrence(
cell = LSTMCell(15 => 15), # 1_920 parameters, plus 1 non-trainable
),
) # Total: 1_920 parameters,
# plus 1 states.4. We define the process-based model, a classical Q10 model for respiration
"""
RbQ10(; ta, Q10, rb, tref=15.0f0)
Respiration model with Q10 temperature sensitivity.
- `ta`: air temperature [°C]
- `Q10`: temperature sensitivity factor [-]
- `rb`: basal respiration rate [μmol/m²/s]
- `tref`: reference temperature [°C] (default: 15.0)
"""
function RbQ10(; ta, Q10, rb, tref = 15.0f0)
reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref))
return (; reco, Q10, rb)
end5. Define Model Parameters
Parameter specification: (default, lower_bound, upper_bound)
parameters = (
rb = (3.0f0, 0.0f0, 13.0f0), # Basal respiration [μmol/m²/s]
Q10 = (2.0f0, 1.0f0, 4.0f0), # Temperature sensitivity factor [-]
)(rb = (3.0f0, 0.0f0, 13.0f0), Q10 = (2.0f0, 1.0f0, 4.0f0))6. Configure Hybrid Model Components
Define input variables Forcing variables (temperature)
forcing = [:ta]1-element Vector{Symbol}:
:taPredictor variables (solar radiation, and its derivative)
predictors = [:sw_pot, :dsw_pot]2-element Vector{Symbol}:
:sw_pot
:dsw_potTarget variable (respiration)
target = [:reco]1-element Vector{Symbol}:
:recoParameter classification Global parameters (same for all samples)
global_param_names = [:Q10]1-element Vector{Symbol}:
:Q10Neural network predicted parameters
neural_param_names = [:rb]1-element Vector{Symbol}:
:rb7. Construct LSTM Hybrid Model
Create LSTM hybrid model using the unified constructor
hlstm = constructHybridModel(
predictors,
forcing,
target,
RbQ10,
parameters,
neural_param_names,
global_param_names,
hidden_layers = NN_Memory, # Neural network architecture
scale_nn_outputs = true, # Scale neural network outputs
input_batchnorm = false # Apply batch normalization to inputs
)Hybrid Model (Single NN)
Neural Network:
Chain(
layer_1 = WrappedFunction(identity),
layer_2 = Dense(2 => 15, tanh), # 45 parameters
layer_3 = Recurrence(
cell = LSTMCell(15 => 15), # 1_920 parameters, plus 1 non-trainable
),
layer_4 = RecurrenceOutputDense(
layer = Dense(15 => 15, tanh), # 240 parameters
),
layer_5 = Dense(15 => 1), # 16 parameters
) # Total: 2_221 parameters,
# plus 1 states.
Configuration:
predictors = [:sw_pot, :dsw_pot]
forcing = [:ta]
targets = [:reco]
mechanistic_model = RbQ10
neural_param_names = [:rb]
global_param_names = [:Q10]
fixed_param_names = Symbol[]
scale_nn_outputs = true
start_from_default = true
config = (; hidden_layers = Chain{@NamedTuple{layer_1::Recurrence{Static.True, LSTMCell{Static.False, Static.False, Int64, Int64, NTuple{4, Nothing}, NTuple{4, Nothing}, NTuple{4, Nothing}, typeof(WeightInitializers.zeros32), typeof(WeightInitializers.zeros32), Static.True}, Lux.BatchLastIndex}}, Nothing}((layer_1 = Recurrence{Static.True, LSTMCell{Static.False, Static.False, Int64, Int64, NTuple{4, Nothing}, NTuple{4, Nothing}, NTuple{4, Nothing}, typeof(WeightInitializers.zeros32), typeof(WeightInitializers.zeros32), Static.True}, Lux.BatchLastIndex}(LSTMCell(15 => 15), Lux.BatchLastIndex(), static(true), false, false),), nothing), activation = tanh, scale_nn_outputs = true, input_batchnorm = false, start_from_default = true,)
Parameters:
Hybrid Parameters
┌─────┬─────────┬───────┬───────┐
│ │ default │ lower │ upper │
├─────┼─────────┼───────┼───────┤
│ rb │ 3.0 │ 0.0 │ 13.0 │
│ Q10 │ 2.0 │ 1.0 │ 4.0 │
└─────┴─────────┴───────┴───────┘8. Data Preparation Steps (Demonstration)
The following steps demonstrate what happens under the hood during training. In practice, you can skip to Section 9 and use the train function directly.
:KeyedArray and :DimArray are supported
pref_array_type = :DimArray
x, y = prepare_data(hlstm, df, array_type = pref_array_type);Convert a (single) time series into many training samples by windowing.
Each sample consists of:
input_window: number of past steps given to the model (sequence length)output_window: number of steps to predictoutput_shift: stride between consecutive windows (controls overlap)lead_time: prediction lead (e.g. lead_time=1 predicts starting 1 step ahead)
This supports many-to-one / many-to-many forecasting depending on output_window. Creates an array of shape (variable, time, batch_size) with variable being feature, time the input window, andbatch_size1:n samples (= full batch)
output_shift = 1
output_window = 1
input_window = 10
(xs, forcings_s), ys = split_into_sequences(x, y; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0);First input_window/sample
xs[:, :, 1]┌ 2×10 DimArray{Float32, 2} ┐
├───────────────────────────┴─────────────────────────────────────── dims ┐
↓ variable Categorical{Symbol} [:sw_pot, :dsw_pot] Unordered,
→ time Categorical{Symbol} [:x9_to_x9, …, :x9_to_x0_y0] ReverseOrdered
└─────────────────────────────────────────────────────────────────────────┘
↓ → :x9_to_x9 :x9_to_x8 … :x9_to_x1 :x9_to_x0_y0
:sw_pot 109.817 109.817 109.817 109.817
:dsw_pot 115.595 115.595 115.595 115.595Second input_window/sample
xs[:, :, 2]┌ 2×10 DimArray{Float32, 2} ┐
├───────────────────────────┴─────────────────────────────────────── dims ┐
↓ variable Categorical{Symbol} [:sw_pot, :dsw_pot] Unordered,
→ time Categorical{Symbol} [:x9_to_x9, …, :x9_to_x0_y0] ReverseOrdered
└─────────────────────────────────────────────────────────────────────────┘
↓ → :x9_to_x9 :x9_to_x8 … :x9_to_x1 :x9_to_x0_y0
:sw_pot 109.817 109.817 109.817 109.817
:dsw_pot 115.595 115.595 115.595 115.595test of shift
xs[:, output_shift + 1, 1] == xs[:, 1, 2]trueFirst output_window/sample with time label like :x30_to_x5_y4 which indicates an accumulation of memory from x30 to x5 for the prediction of y4
ys.reco┌ 1×19991 DimArray{Float32, 2} ┐
├──────────────────────────────┴─────────────────────────────────────── dims ┐
↓ time Categorical{Symbol} [:x9_to_x0_y0] ForwardOrdered,
→ batch_size Sampled{Int64} [1, …, 19991] ForwardOrdered Irregular Points
└────────────────────────────────────────────────────────────────────────────┘
↓ → 1 2 3 … 19990 19991
:x9_to_x0_y0 0.806911 0.803646 0.794573 0.211483 0.211106Second output_window/sample
ys.reco[:, 2]
forcings_s.ta10×19991 Matrix{Float32}:
2.1 1.98 1.89 2.06 2.09 1.75 … 10.7 11.11 11.26 10.29 9.68
1.98 1.89 2.06 2.09 1.75 1.51 11.11 11.26 10.29 9.68 9.22
1.89 2.06 2.09 1.75 1.51 1.36 11.26 10.29 9.68 9.22 9.63
2.06 2.09 1.75 1.51 1.36 1.11 10.29 9.68 9.22 9.63 9.93
2.09 1.75 1.51 1.36 1.11 0.97 9.68 9.22 9.63 9.93 11.11
1.75 1.51 1.36 1.11 0.97 0.87 … 9.22 9.63 9.93 11.11 11.06
1.51 1.36 1.11 0.97 0.87 0.59 9.63 9.93 11.11 11.06 10.96
1.36 1.11 0.97 0.87 0.59 0.23 9.93 11.11 11.06 10.96 10.55
1.11 0.97 0.87 0.59 0.23 0.06 11.11 11.06 10.96 10.55 11.06
0.97 0.87 0.59 0.23 0.06 -0.23 11.06 10.96 10.55 11.06 10.91Any of the first output_window the same as the second output_window? ideally not big overlap
overlap = output_window - output_shift
overlap_length = sum(in(ys.reco[:, 1]), ys.reco[:, 2])0Split the (windowed) dataset into train/validation in the same way as train does.
sdf = split_data(df, hlstm, sequence_kwargs = (; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0), array_type = pref_array_type);
((x_train, f_train), y_train), ((x_val, f_val), y_val) = sdf;
x_train
y_train
f_train
y_train_nan = map(v -> .!isnan.(v), y_train)(reco = Bool[1 1 … 1 1],)Wrap the training windows/samples in a DataLoader to form batches.
WARNING
batchsize is the number of windows/samples used per gradient step to update the parameters. Processing 32 windows in one array is usually much faster than doing 32 separate forward/backward passes with batch_size=1.
train_dl = EasyHybrid.DataLoader(((x_train, f_train), y_train); batchsize = 32);┌ Warning: Number of observations less than batch-size, decreasing the batch-size to 8
└ @ MLUtils ~/.julia/packages/MLUtils/5jDrc/src/batchview.jl:104Run hybrid model forwards
(x_first, forcings_first), y_first = first(train_dl)
ps, st = Lux.setup(Random.default_rng(), hlstm);
xf = (x_first, forcings_first)
frun = hlstm(xf, ps, st);┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore ~/.julia/packages/LuxCore/kQC9S/src/LuxCore.jl:18
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore ~/.julia/packages/LuxCore/kQC9S/src/LuxCore.jl:18Extract predicted yhat
reco_mod = frun[1].reco10×8 Matrix{Float32}:
3.51847 3.48933 3.46763 3.50873 3.51603 3.43414 3.37748 3.34255
3.75827 3.7349 3.77917 3.78704 3.69883 3.63781 3.60018 3.53833
3.80997 3.85513 3.86316 3.77318 3.71093 3.67254 3.60945 3.57459
3.88699 3.89508 3.80436 3.7416 3.7029 3.63928 3.60414 3.57924
3.90866 3.81762 3.75464 3.7158 3.65197 3.6167 3.59172 3.52268
3.82216 3.7591 3.72022 3.65631 3.621 3.59599 3.52687 3.43995
3.75994 3.72105 3.65712 3.6218 3.59678 3.52765 3.44071 3.40041
3.72041 3.65649 3.62118 3.59617 3.52705 3.44012 3.39983 3.33217
3.65534 3.62004 3.59504 3.52594 3.43904 3.39876 3.33112 3.33112
3.61875 3.59376 3.52468 3.43782 3.39755 3.32993 3.32993 3.32302Bring observations in same shape
reco_obs = y_first.reco
reco_nan = .!isnan.(reco_obs);Compute loss
EasyHybrid.compute_loss(hlstm, ps, st, ((x_train, f_train), (y_train, y_train_nan)), logging = LoggingLoss(train_mode = true))(7.0889f0, (st_nn = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (rng = TaskLocalRNG(),), layer_4 = NamedTuple(), layer_5 = NamedTuple()), fixed = NamedTuple()), NamedTuple())9. Train LSTM Hybrid Model
out_lstm = train(
hlstm,
df,
();
nepochs = 100, # Number of training epochs
batchsize = 128, # Batch size of training windows/samples
opt = RMSProp(0.01), # Optimizer and learning rate
monitor_names = [:rb, :Q10], # Parameters to monitor during training
yscale = identity, # Scaling for outputs
shuffleobs = true,
training_loss = :nseLoss,
loss_types = [:nse],
sequence_kwargs = (; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0),
plotting = false,
show_progress = false,
input_batchnorm = false,
array_type = pref_array_type,
model_name = "RbQ10_synthetic_lstm"
);
out_lstm.val_obs_pred| Row | reco | reco_pred |
|---|---|---|
| Float32 | Float32 | |
| 1 | 0.803646 | 0.797879 |
| 2 | 0.806911 | 0.803522 |
10. Train Single NN Hybrid Model (Optional)
For comparison, we can also train a hybrid model with a standard feed-forward neural network
hm = constructHybridModel(
predictors,
forcing,
target,
RbQ10,
parameters,
neural_param_names,
global_param_names,
hidden_layers = NN, # Neural network architecture
scale_nn_outputs = true, # Scale neural network outputs
input_batchnorm = false, # Apply batch normalization to inputs
);Train the hybrid model
single_nn_out = train(
hm,
df,
();
nepochs = 100, # Number of training epochs
batchsize = 128, # Batch size for training
opt = RMSProp(0.01), # Optimizer and learning rate
monitor_names = [:rb, :Q10], # Parameters to monitor during training
yscale = identity, # Scaling for outputs
shuffleobs = true,
training_loss = :nseLoss,
loss_types = [:nse],
array_type = :DimArray,
plotting = false,
show_progress = false,
model_name = "RbQ10_synthetic_single_nn"
);┌ Warning: training_loss :nseLoss is not in loss_types [:nse], it won't appear in plots
└ @ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/config/TrainingConfig.jl:118
[ Info: Plotting disabled.
[ Info: Training outputs will be saved to: /Users/runner/work/EasyHybrid.jl/EasyHybrid.jl/docs/build
[ Info: Returning best model from epoch 100 with validation loss: 0.99714035Close enough
out_lstm.best_loss
single_nn_out.best_loss0.99714035f0