from functools import partial
import jax
import jax.numpy as jnp
from jax import Array
from .abstracts import AbstractCoupledState, AtmosT, LandT, RadT
from .coupling import ABCoupler
StateT = AbstractCoupledState[RadT, LandT, AtmosT]
[docs]
def warmup(
state: AbstractCoupledState[RadT, LandT, AtmosT],
t: Array,
coupler: ABCoupler,
dt: float,
tstart: float,
) -> AbstractCoupledState[RadT, LandT, AtmosT]:
"""Warmup the model by running it for a few timesteps."""
state = coupler.atmos.warmup(coupler.rad, coupler.land, state, t, dt, tstart)
return state
[docs]
def inner_step(
state: AbstractCoupledState[RadT, LandT, AtmosT],
_: None, # this is here because the function signature requires it for a scan
coupler: ABCoupler,
dt: float,
tstart: float,
) -> tuple[
AbstractCoupledState[RadT, LandT, AtmosT],
AbstractCoupledState[RadT, LandT, AtmosT],
]:
"""Run a single timestep of the model."""
t = state.t
atmos = coupler.atmos.statistics(state, t)
state = state.replace(atmos=atmos)
rad = coupler.rad.run(state, t, dt, tstart)
state = state.replace(rad=rad)
land = coupler.land.run(state)
state = state.replace(land=land)
atmos = coupler.atmos.run(state)
state = state.replace(atmos=atmos)
land = coupler.land.integrate(state.land, dt)
state = state.replace(land=land)
atmos = coupler.atmos.integrate(state.atmos, dt)
state = state.replace(atmos=atmos)
state = coupler.compute_diagnostics(state)
state = state.replace(t=t + 1)
return state, state
[docs]
def outter_step(
state: AbstractCoupledState[RadT, LandT, AtmosT],
_: None, # this is here because the function signature requires it for a scan
coupler: ABCoupler,
inner_dt: float,
inner_tsteps: int,
tstart: float,
) -> tuple[
AbstractCoupledState[RadT, LandT, AtmosT],
AbstractCoupledState[RadT, LandT, AtmosT],
]:
"""A block of inner steps averaging the result."""
initial_t = state.t
step_fn = partial(
inner_step,
coupler=coupler,
dt=inner_dt,
tstart=tstart,
)
state, inner_traj = jax.lax.scan(step_fn, state, None, length=inner_tsteps)
avg_traj = jax.tree.map(lambda x: jnp.mean(x, axis=0), inner_traj)
# the average block is tagged with the initial time
avg_traj = avg_traj.replace(t=initial_t)
return state, avg_traj
[docs]
def integrate(
state: AbstractCoupledState[RadT, LandT, AtmosT],
coupler: ABCoupler,
inner_dt: float,
outter_dt: float,
runtime: float,
tstart: float,
) -> tuple[Array, AbstractCoupledState[RadT, LandT, AtmosT]]:
"""Integrate the coupler forward in time."""
if outter_dt % inner_dt != 0:
outter_dt = inner_dt * int(outter_dt / inner_dt)
print(
"The outter_dt should be a multiple of the inner_dt. Taking the closest multiple."
)
inner_tsteps = int(outter_dt / inner_dt)
outter_tsteps = int(runtime / outter_dt)
# warmup and initial diagnostics (t=0)
state = warmup(state, jnp.asarray(0), coupler, inner_dt, tstart)
state = coupler.compute_diagnostics(state)
# configure outter step function
step_fn = partial(
outter_step,
coupler=coupler,
inner_dt=inner_dt,
inner_tsteps=inner_tsteps,
tstart=tstart,
)
# this is effectively the integration
state, trajectory = jax.lax.scan(step_fn, state, length=outter_tsteps)
# real time as separate output
times = jnp.arange(outter_tsteps) * outter_dt / 3600.0 + tstart
return times, trajectory