Skip to content

LSTM Hybrid Model with EasyHybrid.jl

This tutorial demonstrates how to use EasyHybrid to train a hybrid model with LSTM neural networks on synthetic data for respiration modeling with Q10 temperature sensitivity. The code for this tutorial can be found in docs/src/literate/tutorials => example_synthetic_lstm.jl.

1. Load Packages

Set project path and activate environment

julia
using EasyHybrid
using AxisKeys
using DimensionalData
using Lux

2. Data Loading and Preprocessing

Load synthetic dataset from GitHub - it's tabular data

julia
df = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc");

Select a subset of data for faster execution

julia
df = df[1:20000, :];
first(df, 5);

3. Define Neural Network Architectures

Define a standard feedforward neural network

julia
NN = Chain(Dense(15, 15, Lux.sigmoid), Dense(15, 15, Lux.sigmoid), Dense(15, 1))
Chain(
    layer_(1-2) = Dense(15 => 15, σ),             # 480 (240 x 2) parameters
    layer_3 = Dense(15 => 1),                     # 16 parameters
)         # Total: 496 parameters,
          #        plus 0 states.

Define LSTM-based neural network with memory

TIP

When the Chain ends with a Recurrence layer, EasyHybrid automatically adds a RecurrenceOutputDense layer to handle the sequence output format. The user only needs to define the Recurrence layer itself!

julia
NN_Memory = Chain(
    Recurrence(LSTMCell(15 => 15), return_sequence = true),
)
Chain(
    layer_1 = Recurrence(
        cell = LSTMCell(15 => 15),                # 1_920 parameters, plus 1 non-trainable
    ),
)         # Total: 1_920 parameters,
          #        plus 1 states.

4. We define the process-based model, a classical Q10 model for respiration

julia
"""
    RbQ10(; ta, Q10, rb, tref=15.0f0)

Respiration model with Q10 temperature sensitivity.

- `ta`: air temperature [°C]
- `Q10`: temperature sensitivity factor [-]
- `rb`: basal respiration rate [μmol/m²/s]
- `tref`: reference temperature [°C] (default: 15.0)
"""
function RbQ10(; ta, Q10, rb, tref = 15.0f0)
    reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref))
    return (; reco, Q10, rb)
end

5. Define Model Parameters

Parameter specification: (default, lower_bound, upper_bound)

julia
parameters = (
    rb = (3.0f0, 0.0f0, 13.0f0),  # Basal respiration [μmol/m²/s]
    Q10 = (2.0f0, 1.0f0, 4.0f0),   # Temperature sensitivity factor [-]
)
(rb = (3.0f0, 0.0f0, 13.0f0), Q10 = (2.0f0, 1.0f0, 4.0f0))

6. Configure Hybrid Model Components

Define input variables Forcing variables (temperature)

julia
forcing = [:ta]
1-element Vector{Symbol}:
 :ta

Predictor variables (solar radiation, and its derivative)

julia
predictors = [:sw_pot, :dsw_pot]
2-element Vector{Symbol}:
 :sw_pot
 :dsw_pot

Target variable (respiration)

julia
target = [:reco]
1-element Vector{Symbol}:
 :reco

Parameter classification Global parameters (same for all samples)

julia
global_param_names = [:Q10]
1-element Vector{Symbol}:
 :Q10

Neural network predicted parameters

julia
neural_param_names = [:rb]
1-element Vector{Symbol}:
 :rb

7. Construct LSTM Hybrid Model

Create LSTM hybrid model using the unified constructor

julia
hlstm = constructHybridModel(
    predictors,
    forcing,
    target,
    RbQ10,
    parameters,
    neural_param_names,
    global_param_names,
    hidden_layers = NN_Memory, # Neural network architecture
    scale_nn_outputs = true, # Scale neural network outputs
    input_batchnorm = false   # Apply batch normalization to inputs
)
Hybrid Model (Single NN)
Neural Network: 
  Chain(
      layer_1 = WrappedFunction(identity),
      layer_2 = Dense(2 => 15, tanh),               # 45 parameters
      layer_3 = Recurrence(
          cell = LSTMCell(15 => 15),                # 1_920 parameters, plus 1 non-trainable
      ),
      layer_4 = RecurrenceOutputDense(
          layer = Dense(15 => 15, tanh),            # 240 parameters
      ),
      layer_5 = Dense(15 => 1),                     # 16 parameters
  )         # Total: 2_221 parameters,
            #        plus 1 states.
Configuration:
  predictors = [:sw_pot, :dsw_pot]
  forcing = [:ta]
  targets = [:reco]
  mechanistic_model = RbQ10
  neural_param_names = [:rb]
  global_param_names = [:Q10]
  fixed_param_names = Symbol[]
  scale_nn_outputs = true
  start_from_default = true
  config = (; hidden_layers = Chain{@NamedTuple{layer_1::Lux.Recurrence{Static.True, Lux.LSTMCell{Static.False, Static.False, Int64, Int64, NTuple{4, Nothing}, NTuple{4, Nothing}, NTuple{4, Nothing}, typeof(WeightInitializers.zeros32), typeof(WeightInitializers.zeros32), Static.True}, Lux.BatchLastIndex}}, Nothing}((layer_1 = Lux.Recurrence{Static.True, Lux.LSTMCell{Static.False, Static.False, Int64, Int64, NTuple{4, Nothing}, NTuple{4, Nothing}, NTuple{4, Nothing}, typeof(WeightInitializers.zeros32), typeof(WeightInitializers.zeros32), Static.True}, Lux.BatchLastIndex}(LSTMCell(15 => 15), Lux.BatchLastIndex(), static(true), false, false),), nothing), activation = tanh, scale_nn_outputs = true, input_batchnorm = false, start_from_default = true,)

Parameters:
  Hybrid Parameters
    ┌─────┬─────────┬───────┬───────┐
    │     │ default │ lower │ upper │
    ├─────┼─────────┼───────┼───────┤
    │  rb │     3.0 │   0.0 │  13.0 │
    │ Q10 │     2.0 │   1.0 │   4.0 │
    └─────┴─────────┴───────┴───────┘

8. Data Preparation Steps (Demonstration)

The following steps demonstrate what happens under the hood during training. In practice, you can skip to Section 9 and use the train function directly.

:KeyedArray and :DimArray are supported

julia
pref_array_type = :DimArray
x, y = prepare_data(hlstm, df, array_type = pref_array_type);

Convert a (single) time series into many training samples by windowing.

Each sample consists of:

  • input_window: number of past steps given to the model (sequence length)

  • output_window: number of steps to predict

  • output_shift: stride between consecutive windows (controls overlap)

  • lead_time: prediction lead (e.g. lead_time=1 predicts starting 1 step ahead)

This supports many-to-one / many-to-many forecasting depending on output_window. Creates an array of shape (variable, time, batch_size) with variable being feature, time the input window, andbatch_size1:n samples (= full batch)

julia
output_shift = 1
output_window = 1
input_window = 10
xs, ys = split_into_sequences(x, y; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0);
ys_nan = .!isnan.(ys);

First input_window/sample

julia
xs[:, :, 1]
3×10 DimArray{Float32, 2}
├───────────────────────────┴─────────────────────────────────────── dims ┐
variable Categorical{Symbol} [:sw_pot, …, :ta] Unordered,
time Categorical{Symbol} [:x9_to_x9, …, :x9_to_x0_y0] ReverseOrdered
└─────────────────────────────────────────────────────────────────────────┘
           :x9_to_x9     :x9_to_x8:x9_to_x1     :x9_to_x0_y0
  :sw_pot   109.817       109.817          109.817       109.817
  :dsw_pot  115.595       115.595          115.595       115.595
  :ta         2.1           1.98             1.11          0.97

Second input_window/sample

julia
xs[:, :, 2]
3×10 DimArray{Float32, 2}
├───────────────────────────┴─────────────────────────────────────── dims ┐
variable Categorical{Symbol} [:sw_pot, …, :ta] Unordered,
time Categorical{Symbol} [:x9_to_x9, …, :x9_to_x0_y0] ReverseOrdered
└─────────────────────────────────────────────────────────────────────────┘
           :x9_to_x9     :x9_to_x8:x9_to_x1     :x9_to_x0_y0
  :sw_pot   109.817       109.817          109.817       109.817
  :dsw_pot  115.595       115.595          115.595       115.595
  :ta         1.98          1.89             0.97          0.87

test of shift

julia
xs[:, output_shift + 1, 1] == xs[:, 1, 2]
true

First output_window/sample with time label like :x30_to_x5_y4 which indicates an accumulation of memory from x30 to x5 for the prediction of y4

julia
ys[:, :, 1]
1×1 DimArray{Float32, 2}
├──────────────────────────┴────────────────────────── dims ┐
variable Categorical{Symbol} [:reco] Unordered,
time Categorical{Symbol} [:x9_to_x0_y0] ForwardOrdered
└───────────────────────────────────────────────────────────┘
      :x9_to_x0_y0
  :reco  0.806911

Second output_window/sample

julia
ys[:, :, 2]
1×1 DimArray{Float32, 2}
├──────────────────────────┴────────────────────────── dims ┐
variable Categorical{Symbol} [:reco] Unordered,
time Categorical{Symbol} [:x9_to_x0_y0] ForwardOrdered
└───────────────────────────────────────────────────────────┘
      :x9_to_x0_y0
  :reco  0.803646

Any of the first output_window the same as the second output_window? ideally not big overlap

julia
overlap = output_window - output_shift
overlap_length = sum(in(ys[:, :, 1]), ys[:, :, 2])
0

Split the (windowed) dataset into train/validation in the same way as train does.

julia
sdf = split_data(df, hlstm, sequence_kwargs = (; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0), array_type = pref_array_type);

(x_train, y_train), (x_val, y_val) = sdf;
x_train
y_train
y_train_nan = .!isnan.(y_train)
1×1×15993 DimArray{Bool, 3}
├─────────────────────────────┴──────────────────────────────────────── dims ┐
variable Categorical{Symbol} [:reco] Unordered,
time Categorical{Symbol} [:x9_to_x0_y0] ForwardOrdered,
batch_size Sampled{Int64} [1, …, 15993] ForwardOrdered Irregular Points
└────────────────────────────────────────────────────────────────────────────┘
[:, :, 1]
      :x9_to_x0_y0
  :reco  1

Wrap the training windows/samples in a DataLoader to form batches.

WARNING

batchsize is the number of windows/samples used per gradient step to update the parameters. Processing 32 windows in one array is usually much faster than doing 32 separate forward/backward passes with batch_size=1.

julia
train_dl = EasyHybrid.DataLoader((x_train, y_train); batchsize = 32);

Run hybrid model forwards

julia
x_first = first(train_dl)[1]
y_first = first(train_dl)[2]

ps, st = Lux.setup(Random.default_rng(), hlstm);
frun = hlstm(x_first, ps, st);
Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
@ LuxCore ~/.julia/packages/LuxCore/kQC9S/src/LuxCore.jl:18
Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
@ LuxCore ~/.julia/packages/LuxCore/kQC9S/src/LuxCore.jl:18

Extract predicted yhat

julia
reco_mod = frun[1].reco
10×32 DimArray{Float32, 2}
├────────────────────────────┴─────────────────────────────────────── dims ┐
time Categorical{Symbol} [:x9_to_x9, …, :x9_to_x0_y0] ReverseOrdered,
batch_size Sampled{Int64} [1, …, 32] ForwardOrdered Irregular Points
└──────────────────────────────────────────────────────────────────────────┘
            1        2        330        31        32
  :x9_to_x9     3.51847  3.48933  3.46763      3.86629   3.79197   3.75796
  :x9_to_x8     3.75827  3.7349   3.77917      4.08424   4.04761   3.80282
  :x9_to_x7     3.80997  3.85513  3.86316      4.12896   3.87926   3.88195
  :x9_to_x6     3.88699  3.89508  3.80436      3.91132   3.91403   3.84412
 ⋮                                         ⋱   ⋮                  
  :x9_to_x3     3.75994  3.72105  3.65712      3.74693   3.66727   3.70047
  :x9_to_x2     3.72041  3.65649  3.62118      3.66664   3.69983   3.63627
  :x9_to_x1     3.65534  3.62004  3.59504      3.69867   3.63513   3.71666
  :x9_to_x0_y0  3.61875  3.59376  3.52468  …   3.63383   3.71533   3.71276

Bring observations in same shape

julia
reco_obs = dropdims(y_first, dims = 1)
reco_nan = .!isnan.(reco_obs);

Compute loss

julia
EasyHybrid.compute_loss(hlstm, ps, st, (x_train, (y_train, y_train_nan)), logging = LoggingLoss(train_mode = true))
(21.427757f0, (st_nn = (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = (rng = TaskLocalRNG(),), layer_4 = NamedTuple(), layer_5 = NamedTuple()), fixed = NamedTuple()), NamedTuple())

9. Train LSTM Hybrid Model

julia
out_lstm = train(
    hlstm,
    df,
    ();
    nepochs = 100,           # Number of training epochs
    batchsize = 128,         # Batch size of training windows/samples
    opt = RMSProp(0.01),   # Optimizer and learning rate
    monitor_names = [:rb, :Q10], # Parameters to monitor during training
    yscale = identity,       # Scaling for outputs
    shuffleobs = true,
    training_loss = :nseLoss,
    loss_types = [:nse],
    sequence_kwargs = (; input_window = input_window, output_window = output_window, output_shift = output_shift, lead_time = 0),
    plotting = false,
    show_progress = false,
    input_batchnorm = false,
    array_type = pref_array_type
);

out_lstm.val_obs_pred
1×2 Matrix{Any}:
 Dict{Any, DataFrame}(:x9_to_x0_y0=>3998×2 DataFrame
  Row  reco       index 
 Float32    Int64 
──────┼──────────────────
    1 │ 2.5595     14511
    2 │ 0.0883865   2358
    3 │ 0.143945    1741
    4 │ 6.17557    10882
    5 │ 1.35721    16259
    6 │ 0.474447   17659
    7 │ 0.587843    4624
    8 │ 1.01569    16939
  ⋮   │     ⋮        ⋮
 3992 │ 5.44318     8129
 3993 │ 2.60681     5641
 3994 │ 4.5168      9359
 3995 │ 2.79248    14367
 3996 │ 5.35347    11034
 3997 │ 0.514723    3758
 3998 │ 1.75363    16482
        3983 rows omitted)  …  39980×1 DataFrame
   Row  reco_pred 
 Float32   
───────┼───────────
     1 │   2.95494
     2 │   2.9369
     3 │   2.91584
     4 │   2.89814
     5 │   2.81676
     6 │   2.72723
     7 │   2.69439
     8 │   2.69111
   ⋮   │     ⋮
 39974 │   2.08417
 39975 │   2.12563
 39976 │   2.14105
 39977 │   2.16433
 39978 │   2.15962
 39979 │   2.17626
 39980 │   2.1728
 39965 rows omitted

10. Train Single NN Hybrid Model (Optional)

For comparison, we can also train a hybrid model with a standard feed-forward neural network

julia
hm = constructHybridModel(
    predictors,
    forcing,
    target,
    RbQ10,
    parameters,
    neural_param_names,
    global_param_names,
    hidden_layers = NN, # Neural network architecture
    scale_nn_outputs = true, # Scale neural network outputs
    input_batchnorm = false,   # Apply batch normalization to inputs
);

Train the hybrid model

julia
single_nn_out = train(
    hm,
    df,
    ();
    nepochs = 100,           # Number of training epochs
    batchsize = 128,         # Batch size for training
    opt = RMSProp(0.01),   # Optimizer and learning rate
    monitor_names = [:rb, :Q10], # Parameters to monitor during training
    yscale = identity,       # Scaling for outputs
    shuffleobs = true,
    training_loss = :nseLoss,
    loss_types = [:nse],
    array_type = :DimArray,
    plotting = false,
    show_progress = false,
);
[ Info: Plotting disabled.
[ Info: Training data type: DimensionalData.DimMatrix{Float32, Tuple{DimensionalData.Dimensions.Dim{:variable, DimensionalData.Dimensions.Lookups.Categorical{Symbol, SubArray{Symbol, 1, Vector{Symbol}, Tuple{Base.Slice{Base.OneTo{Int64}}}, true}, DimensionalData.Dimensions.Lookups.Unordered, DimensionalData.Dimensions.Lookups.NoMetadata}}, DimensionalData.Dimensions.Dim{:batch_size, DimensionalData.Dimensions.Lookups.Sampled{Int64, SubArray{Int64, 1, UnitRange{Int64}, Tuple{Vector{Int64}}, false}, DimensionalData.Dimensions.Lookups.ForwardOrdered, DimensionalData.Dimensions.Lookups.Irregular{Tuple{Nothing, Nothing}}, DimensionalData.Dimensions.Lookups.Points, DimensionalData.Dimensions.Lookups.NoMetadata}}}, Tuple{}, SubArray{Float32, 2, Matrix{Float32}, Tuple{Base.Slice{Base.OneTo{Int64}}, Vector{Int64}}, false}, DimensionalData.NoName, DimensionalData.Dimensions.Lookups.NoMetadata}
[ Info: Check the saved output (.png, .mp4, .jld2) from training at: /home/runner/work/EasyHybrid.jl/EasyHybrid.jl/docs/build
[ Info: Returning best model from epoch 79 of 100 epochs with best validation loss wrt nse: 0.9974605

Close enough

julia
out_lstm.best_loss
single_nn_out.best_loss
0.9974605f0