Source code for abcmodel.atmos.surface_layer.simple
from dataclasses import dataclass, field, replace
import jax.numpy as jnp
from jax import Array
from ...abstracts import AbstractCoupledState, LandT, RadT
from ..abstracts import (
AbstractSurfaceLayerModel,
AbstractSurfaceLayerState,
CloudT,
MixedT,
)
from ..dayonly import DayOnlyAtmosphereState
[docs]
@dataclass
class SimpleState(AbstractSurfaceLayerState):
"""Minimal surface layer model initial state."""
ustar: Array
"""Surface friction velocity [m/s]."""
uw: Array = field(default_factory=lambda: jnp.array(0.0))
"""Zonal surface momentum flux [m2 s-2]."""
vw: Array = field(default_factory=lambda: jnp.array(0.0))
"""Meridional surface momentum flux [m2 s-2]."""
ra: Array = field(default_factory=lambda: jnp.array(0.0))
"""Aerodynamic resistance [s/m]."""
# limamau: maybe these type variables could be abstracts...
StateAlias = AbstractCoupledState[
RadT,
LandT,
DayOnlyAtmosphereState[
SimpleState,
MixedT,
CloudT,
],
]
[docs]
class SimpleModel(AbstractSurfaceLayerModel[SimpleState]):
"""Simple surface layer model with constant friction velocity."""
def __init__(self):
pass
[docs]
def init_state(self, ustar: float) -> SimpleState:
"""Initialize the model state.
Args:
ustar: Friction velocity [m/s].
Returns:
The initial surface layer state.
"""
return SimpleState(
ustar=jnp.array(ustar),
)
[docs]
def run(self, state: StateAlias):
"""Run the model.
Args:
state:
Returns:
The updated surface layer state.
"""
atmos = state.atmos
sl_state = atmos.surface
uw = compute_uw(atmos.u, atmos.v, sl_state.ustar)
vw = compute_vw(atmos.u, atmos.v, sl_state.ustar)
ra = compute_ra(atmos.u, atmos.v, atmos.wstar, sl_state.ustar)
return replace(sl_state, uw=uw, vw=vw, ra=ra)
[docs]
def compute_uw(u: Array, v: Array, ustar: Array) -> Array:
"""Calculate the zonal momentum flux from wind components and friction velocity."""
return jnp.where(
u == 0.0,
0.0,
-jnp.sign(u) * (ustar**4.0 / (v**2.0 / u**2.0 + 1.0)) ** (0.5),
)
[docs]
def compute_vw(u: Array, v: Array, ustar: Array) -> Array:
"""Calculate the meridional momentum flux from wind components and friction velocity."""
return jnp.where(
v == 0.0,
0.0,
-jnp.sign(v) * (ustar**4.0 / (u**2.0 / v**2.0 + 1.0)) ** (0.5),
)
[docs]
def compute_ra(u: Array, v: Array, wstar: Array, ustar: Array) -> Array:
"""Calculate aerodynamic resistance from wind speed and friction velocity."""
ueff = jnp.sqrt(u**2.0 + v**2.0 + wstar**2.0)
return ueff / jnp.maximum(1.0e-3, ustar) ** 2.0