Source code for abcmodel.coupling

from dataclasses import dataclass, field
from typing import Generic

import jax.numpy as jnp
from flax import nnx
from jax import Array

from .abstracts import (
    AbstractAtmosphereModel,
    AbstractCoupledState,
    AbstractLandModel,
    AbstractRadiationModel,
    AtmosT,
    LandT,
    RadT,
)


[docs] @dataclass class CoupledState( AbstractCoupledState[RadT, LandT, AtmosT], Generic[RadT, LandT, AtmosT] ): """Hierarchical coupled state, generic over component types.""" rad: RadT land: LandT atmos: AtmosT t: Array = field(default_factory=lambda: jnp.array(-1)) total_water_mass: Array = field(default_factory=lambda: jnp.array(0.0))
[docs] class ABCoupler(nnx.Module): """Coupling class to bound all the components.""" def __init__( self, rad: AbstractRadiationModel, land: AbstractLandModel, atmos: AbstractAtmosphereModel, ): self.rad = rad self.land = land self.atmos = atmos
[docs] def init_state( self, rad_state: RadT | None, land_state: LandT | None, atmos_state: AtmosT | None, ) -> CoupledState[RadT, LandT, AtmosT]: return CoupledState( rad=rad_state if rad_state is not None else self.rad.init_state(), land=land_state if land_state is not None else self.land.init_state(), atmos=atmos_state if atmos_state is not None else self.atmos.init_state(), )
[docs] def compute_diagnostics(self, state: AbstractCoupledState) -> AbstractCoupledState: """Compute diagnostic variables for total water budget.""" # limamau: this needs to be re-implemented total_water_mass = jnp.array(0.0) return state.replace(total_water_mass=total_water_mass)