Skip to content

LSTM Hybrid Model with EasyHybrid.jl

This tutorial demonstrates how to use EasyHybrid to train a hybrid model with LSTM neural networks on synthetic data for respiration modeling with Q10 temperature sensitivity. The code for this tutorial can be found in docs/src/literate/tutorials => example_synthetic_lstm.jl.

1. Load Packages

Set project path and activate environment

julia
using EasyHybrid
using AxisKeys
using DimensionalData

2. Data Loading and Preprocessing

Load synthetic dataset from GitHub - it's tabular data

julia
df = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc");

Select a subset of data for faster execution

julia
df = df[1:20000, :];
first(df, 5);

3. Define Neural Network Architectures

Define a standard feedforward neural network

julia
NN = Chain(Dense(15, 15, Lux.sigmoid), Dense(15, 15, Lux.sigmoid), Dense(15, 1))
Chain(
    layer_(1-2) = Dense(15 => 15, σ),             # 480 (240 x 2) parameters
    layer_3 = Dense(15 => 1),                     # 16 parameters
)         # Total: 496 parameters,
          #        plus 0 states.

Define LSTM-based neural network with memory

TIP

When the Chain ends with a Recurrence layer, EasyHybrid automatically adds a RecurrenceOutputDense layer to handle the sequence output format. The user only needs to define the Recurrence layer itself!

julia
NN_Memory = Chain(
    Recurrence(LSTMCell(15 => 15), return_sequence = true),
)
Chain(
    layer_1 = Recurrence(
        cell = LSTMCell(15 => 15),                # 1_920 parameters, plus 1 non-trainable
    ),
)         # Total: 1_920 parameters,
          #        plus 1 states.

4. We define the process-based model, a classical Q10 model for respiration

julia
"""
    RbQ10(; ta, Q10, rb, tref=15.0f0)

Respiration model with Q10 temperature sensitivity.

- `ta`: air temperature [°C]
- `Q10`: temperature sensitivity factor [-]
- `rb`: basal respiration rate [μmol/m²/s]
- `tref`: reference temperature [°C] (default: 15.0)
"""
function RbQ10(; ta, Q10, rb, tref = 15.0f0)
    reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref))
    return (; reco, Q10, rb)
end

5. Define Model Parameters

Parameter specification: (default, lower_bound, upper_bound)

julia
parameters = (
    rb = (3.0f0, 0.0f0, 13.0f0),  # Basal respiration [μmol/m²/s]
    Q10 = (2.0f0, 1.0f0, 4.0f0),   # Temperature sensitivity factor [-]
)
(rb = (3.0f0, 0.0f0, 13.0f0), Q10 = (2.0f0, 1.0f0, 4.0f0))

6. Configure Hybrid Model Components

Define input variables Forcing variables (temperature)

julia
forcing = [:ta]
1-element Vector{Symbol}:
 :ta

Predictor variables (solar radiation, and its derivative)

julia
predictors = [:sw_pot, :dsw_pot]
2-element Vector{Symbol}:
 :sw_pot
 :dsw_pot

Target variable (respiration)

julia
target = [:reco]
1-element Vector{Symbol}:
 :reco

Parameter classification Global parameters (same for all samples)

julia
global_param_names = [:Q10]
1-element Vector{Symbol}:
 :Q10

Neural network predicted parameters

julia
neural_param_names = [:rb]
1-element Vector{Symbol}:
 :rb

7. Construct LSTM Hybrid Model

Create LSTM hybrid model using the unified constructor

julia
hlstm = constructHybridModel(
    predictors,
    forcing,
    target,
    RbQ10,
    parameters,
    neural_param_names,
    global_param_names,
    hidden_layers = NN_Memory, # Neural network architecture
    scale_nn_outputs = true, # Scale neural network outputs
    input_batchnorm = false   # Apply batch normalization to inputs
)
Hybrid Model (Single NN)
Neural Network: 
  Chain(
      layer_1 = WrappedFunction(identity),
      layer_2 = Dense(2 => 15, tanh),               # 45 parameters
      layer_3 = Recurrence(
          cell = LSTMCell(15 => 15),                # 1_920 parameters, plus 1 non-trainable
      ),
      layer_4 = RecurrenceOutputDense(
          layer = Dense(15 => 15, tanh),            # 240 parameters
      ),
      layer_5 = Dense(15 => 1),                     # 16 parameters
  )         # Total: 2_221 parameters,
            #        plus 1 states.
Configuration:
  predictors = [:sw_pot, :dsw_pot]
  forcing = [:ta]
  targets = [:reco]
  mechanistic_model = RbQ10
  neural_param_names = [:rb]
  global_param_names = [:Q10]
  fixed_param_names = Symbol[]
  scale_nn_outputs = true
  start_from_default = true
  config = (; hidden_layers = Chain{@NamedTuple{layer_1::Recurrence{Static.True, LSTMCell{Static.False, Static.False, Int64, Int64, NTuple{4, Nothing}, NTuple{4, Nothing}, NTuple{4, Nothing}, typeof(WeightInitializers.zeros32), typeof(WeightInitializers.zeros32), Static.True}, Lux.BatchLastIndex}}, Nothing}((layer_1 = Recurrence{Static.True, LSTMCell{Static.False, Static.False, Int64, Int64, NTuple{4, Nothing}, NTuple{4, Nothing}, NTuple{4, Nothing}, typeof(WeightInitializers.zeros32), typeof(WeightInitializers.zeros32), Static.True}, Lux.BatchLastIndex}(LSTMCell(15 => 15), Lux.BatchLastIndex(), static(true), false, false),), nothing), activation = tanh, scale_nn_outputs = true, input_batchnorm = false, start_from_default = true,)

Parameters:
  Hybrid Parameters
    ┌─────┬─────────┬───────┬───────┐
    │     │ default │ lower │ upper │
    ├─────┼─────────┼───────┼───────┤
    │  rb │     3.0 │   0.0 │  13.0 │
    │ Q10 │     2.0 │   1.0 │   4.0 │
    └─────┴─────────┴───────┴───────┘

8. Data Preparation Steps (Demonstration)

The following steps demonstrate what happens under the hood during training. In practice, you can skip to Section 9 and use the train function directly.

:KeyedArray and :DimArray are supported

julia
pref_array_type = :DimArray
x, y = prepare_data(hlstm, df, array_type = pref_array_type);

Convert a (single) time series into many training samples by windowing.

Each sample consists of:

  • input_window: number of past steps given to the model (sequence length)

  • output_window: number of steps to predict

  • output_shift: stride between consecutive windows (controls overlap)

  • lead_time: prediction lead (e.g. lead_time=1 predicts starting 1 step ahead)

This supports many-to-one / many-to-many forecasting depending on output_window. Creates an array of shape (variable, time, batch_size) with variable being feature, time the input window, andbatch_size1:n samples (= full batch)

julia
output_shift = 1
output_window = 1
input_window = 10
(xs, forcings_s), ys = split_into_sequences(x, y; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0);

First input_window/sample

julia
xs[:, :, 1]
2×10 DimArray{Float32, 2}
├───────────────────────────┴─────────────────────────────────────── dims ┐
variable Categorical{Symbol} [:sw_pot, :dsw_pot] Unordered,
time Categorical{Symbol} [:x9_to_x9, …, :x9_to_x0_y0] ReverseOrdered
└─────────────────────────────────────────────────────────────────────────┘
           :x9_to_x9     :x9_to_x8:x9_to_x1     :x9_to_x0_y0
  :sw_pot   109.817       109.817          109.817       109.817
  :dsw_pot  115.595       115.595          115.595       115.595

Second input_window/sample

julia
xs[:, :, 2]
2×10 DimArray{Float32, 2}
├───────────────────────────┴─────────────────────────────────────── dims ┐
variable Categorical{Symbol} [:sw_pot, :dsw_pot] Unordered,
time Categorical{Symbol} [:x9_to_x9, …, :x9_to_x0_y0] ReverseOrdered
└─────────────────────────────────────────────────────────────────────────┘
           :x9_to_x9     :x9_to_x8:x9_to_x1     :x9_to_x0_y0
  :sw_pot   109.817       109.817          109.817       109.817
  :dsw_pot  115.595       115.595          115.595       115.595

test of shift

julia
xs[:, output_shift + 1, 1] == xs[:, 1, 2]
true

First output_window/sample with time label like :x30_to_x5_y4 which indicates an accumulation of memory from x30 to x5 for the prediction of y4

julia
ys.reco
1×19991 DimArray{Float32, 2}
├──────────────────────────────┴─────────────────────────────────────── dims ┐
time Categorical{Symbol} [:x9_to_x0_y0] ForwardOrdered,
batch_size Sampled{Int64} [1, …, 19991] ForwardOrdered Irregular Points
└────────────────────────────────────────────────────────────────────────────┘
            1         2         319990         19991
  :x9_to_x0_y0  0.806911  0.803646  0.794573         0.211483      0.211106

Second output_window/sample

julia
ys.reco[:, 2]

forcings_s.ta
10×19991 Matrix{Float32}:
 2.1   1.98  1.89  2.06  2.09   1.75  …  10.7   11.11  11.26  10.29   9.68
 1.98  1.89  2.06  2.09  1.75   1.51     11.11  11.26  10.29   9.68   9.22
 1.89  2.06  2.09  1.75  1.51   1.36     11.26  10.29   9.68   9.22   9.63
 2.06  2.09  1.75  1.51  1.36   1.11     10.29   9.68   9.22   9.63   9.93
 2.09  1.75  1.51  1.36  1.11   0.97      9.68   9.22   9.63   9.93  11.11
 1.75  1.51  1.36  1.11  0.97   0.87  …   9.22   9.63   9.93  11.11  11.06
 1.51  1.36  1.11  0.97  0.87   0.59      9.63   9.93  11.11  11.06  10.96
 1.36  1.11  0.97  0.87  0.59   0.23      9.93  11.11  11.06  10.96  10.55
 1.11  0.97  0.87  0.59  0.23   0.06     11.11  11.06  10.96  10.55  11.06
 0.97  0.87  0.59  0.23  0.06  -0.23     11.06  10.96  10.55  11.06  10.91

Any of the first output_window the same as the second output_window? ideally not big overlap

julia
overlap = output_window - output_shift
overlap_length = sum(in(ys.reco[:, 1]), ys.reco[:, 2])
0

Split the (windowed) dataset into train/validation in the same way as train does.

julia
sdf = split_data(df, hlstm, sequence_kwargs = (; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0), array_type = pref_array_type);

((x_train, f_train), y_train), ((x_val, f_val), y_val) = sdf;
x_train
y_train
f_train
y_train_nan = map(v -> .!isnan.(v), y_train)
(reco = Bool[1 1 … 1 1],)

Wrap the training windows/samples in a DataLoader to form batches.

WARNING

batchsize is the number of windows/samples used per gradient step to update the parameters. Processing 32 windows in one array is usually much faster than doing 32 separate forward/backward passes with batch_size=1.

julia
train_dl = EasyHybrid.DataLoader(((x_train, f_train), y_train); batchsize = 32);
Warning: Number of observations less than batch-size, decreasing the batch-size to 8
@ MLUtils ~/.julia/packages/MLUtils/5jDrc/src/batchview.jl:104

Run hybrid model forwards

julia
(x_first, forcings_first), y_first = first(train_dl)

ps, st = Lux.setup(Random.default_rng(), hlstm);
xf = (x_first, forcings_first)
frun = hlstm(xf, ps, st);
Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
@ LuxCore ~/.julia/packages/LuxCore/kQC9S/src/LuxCore.jl:18
Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
@ LuxCore ~/.julia/packages/LuxCore/kQC9S/src/LuxCore.jl:18

Extract predicted yhat

julia
reco_mod = frun[1].reco
10×8 Matrix{Float32}:
 3.51847  3.48933  3.46763  3.50873  3.51603  3.43414  3.37748  3.34255
 3.75827  3.7349   3.77917  3.78704  3.69883  3.63781  3.60018  3.53833
 3.80997  3.85513  3.86316  3.77318  3.71093  3.67254  3.60945  3.57459
 3.88699  3.89508  3.80436  3.7416   3.7029   3.63928  3.60414  3.57924
 3.90866  3.81762  3.75464  3.7158   3.65197  3.6167   3.59172  3.52268
 3.82216  3.7591   3.72022  3.65631  3.621    3.59599  3.52687  3.43995
 3.75994  3.72105  3.65712  3.6218   3.59678  3.52765  3.44071  3.40041
 3.72041  3.65649  3.62118  3.59617  3.52705  3.44012  3.39983  3.33217
 3.65534  3.62004  3.59504  3.52594  3.43904  3.39876  3.33112  3.33112
 3.61875  3.59376  3.52468  3.43782  3.39755  3.32993  3.32993  3.32302

Bring observations in same shape

julia
reco_obs = y_first.reco
reco_nan = .!isnan.(reco_obs);

Compute loss

julia
EasyHybrid.compute_loss(hlstm, ps, st, ((x_train, f_train), (y_train, y_train_nan)), logging = LoggingLoss(train_mode = true))
(7.0889f0, (st_nn = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (rng = TaskLocalRNG(),), layer_4 = NamedTuple(), layer_5 = NamedTuple()), fixed = NamedTuple()), NamedTuple())

9. Train LSTM Hybrid Model

julia
out_lstm = train(
    hlstm,
    df,
    ();
    nepochs = 100,           # Number of training epochs
    batchsize = 128,         # Batch size of training windows/samples
    opt = RMSProp(0.01),   # Optimizer and learning rate
    monitor_names = [:rb, :Q10], # Parameters to monitor during training
    yscale = identity,       # Scaling for outputs
    shuffleobs = true,
    training_loss = :nseLoss,
    loss_types = [:nse],
    sequence_kwargs = (; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0),
    plotting = false,
    show_progress = false,
    input_batchnorm = false,
    array_type = pref_array_type,
    model_name = "RbQ10_synthetic_lstm"
);

out_lstm.val_obs_pred
2×2 DataFrame
Rowrecoreco_pred
Float32Float32
10.8036460.797879
20.8069110.803522

10. Train Single NN Hybrid Model (Optional)

For comparison, we can also train a hybrid model with a standard feed-forward neural network

julia
hm = constructHybridModel(
    predictors,
    forcing,
    target,
    RbQ10,
    parameters,
    neural_param_names,
    global_param_names,
    hidden_layers = NN, # Neural network architecture
    scale_nn_outputs = true, # Scale neural network outputs
    input_batchnorm = false,   # Apply batch normalization to inputs
);

Train the hybrid model

julia
single_nn_out = train(
    hm,
    df,
    ();
    nepochs = 100,           # Number of training epochs
    batchsize = 128,         # Batch size for training
    opt = RMSProp(0.01),   # Optimizer and learning rate
    monitor_names = [:rb, :Q10], # Parameters to monitor during training
    yscale = identity,       # Scaling for outputs
    shuffleobs = true,
    training_loss = :nseLoss,
    loss_types = [:nse],
    array_type = :DimArray,
    plotting = false,
    show_progress = false,
    model_name = "RbQ10_synthetic_single_nn"
);
Warning: training_loss :nseLoss is not in loss_types [:nse], it won't appear in plots
@ EasyHybrid ~/work/EasyHybrid.jl/EasyHybrid.jl/src/config/TrainingConfig.jl:118
[ Info: Plotting disabled.
[ Info: Training outputs will be saved to: /Users/runner/work/EasyHybrid.jl/EasyHybrid.jl/docs/build
[ Info: Returning best model from epoch 100 with validation loss: 0.99714035

Close enough

julia
out_lstm.best_loss
single_nn_out.best_loss
0.99714035f0