Source code for abcmodel.integration

from functools import partial

import jax
import jax.numpy as jnp
from jax import Array

from .abstracts import AbstractCoupledState, AtmosT, LandT, RadT
from .coupling import ABCoupler

StateT = AbstractCoupledState[RadT, LandT, AtmosT]


[docs] def warmup( state: AbstractCoupledState[RadT, LandT, AtmosT], t: Array, coupler: ABCoupler, dt: float, tstart: float, ) -> AbstractCoupledState[RadT, LandT, AtmosT]: """Warmup the model by running it for a few timesteps.""" state = coupler.atmos.warmup(coupler.rad, coupler.land, state, t, dt, tstart) return state
[docs] def inner_step( state: AbstractCoupledState[RadT, LandT, AtmosT], _: None, # this is here because the function signature requires it for a scan coupler: ABCoupler, dt: float, tstart: float, ) -> tuple[ AbstractCoupledState[RadT, LandT, AtmosT], AbstractCoupledState[RadT, LandT, AtmosT], ]: """Run a single timestep of the model.""" t = state.t atmos = coupler.atmos.statistics(state, t) state = state.replace(atmos=atmos) rad = coupler.rad.run(state, t, dt, tstart) state = state.replace(rad=rad) land = coupler.land.run(state) state = state.replace(land=land) atmos = coupler.atmos.run(state) state = state.replace(atmos=atmos) land = coupler.land.integrate(state.land, dt) state = state.replace(land=land) atmos = coupler.atmos.integrate(state.atmos, dt) state = state.replace(atmos=atmos) state = coupler.compute_diagnostics(state) state = state.replace(t=t + 1) return state, state
[docs] def outter_step( state: AbstractCoupledState[RadT, LandT, AtmosT], _: None, # this is here because the function signature requires it for a scan coupler: ABCoupler, inner_dt: float, inner_tsteps: int, tstart: float, ) -> tuple[ AbstractCoupledState[RadT, LandT, AtmosT], AbstractCoupledState[RadT, LandT, AtmosT], ]: """A block of inner steps averaging the result.""" initial_t = state.t step_fn = partial( inner_step, coupler=coupler, dt=inner_dt, tstart=tstart, ) state, inner_traj = jax.lax.scan(step_fn, state, None, length=inner_tsteps) avg_traj = jax.tree.map(lambda x: jnp.mean(x, axis=0), inner_traj) # the average block is tagged with the initial time avg_traj = avg_traj.replace(t=initial_t) return state, avg_traj
[docs] def integrate( state: AbstractCoupledState[RadT, LandT, AtmosT], coupler: ABCoupler, inner_dt: float, outter_dt: float, runtime: float, tstart: float, ) -> tuple[Array, AbstractCoupledState[RadT, LandT, AtmosT]]: """Integrate the coupler forward in time.""" if outter_dt % inner_dt != 0: outter_dt = inner_dt * int(outter_dt / inner_dt) print( "The outter_dt should be a multiple of the inner_dt. Taking the closest multiple." ) inner_tsteps = int(outter_dt / inner_dt) outter_tsteps = int(runtime / outter_dt) # warmup and initial diagnostics (t=0) state = warmup(state, jnp.asarray(0), coupler, inner_dt, tstart) state = coupler.compute_diagnostics(state) # configure outter step function step_fn = partial( outter_step, coupler=coupler, inner_dt=inner_dt, inner_tsteps=inner_tsteps, tstart=tstart, ) # this is effectively the integration state, trajectory = jax.lax.scan(step_fn, state, length=outter_tsteps) # real time as separate output times = jnp.arange(outter_tsteps) * outter_dt / 3600.0 + tstart return times, trajectory