15.1. Introduction to Adjoint Sensitivity Analysis
This section presents the SUNAdjointStepper and
SUNAdjointCheckpointScheme classes. The SUNAdjointStepper
represents a generic adjoint sensitivity analysis (ASA) procedure to obtain the adjoint
sensitivities of an IVP of the form
where \(p\) is some set of \(N_s\) problem parameters.
Note
The API itself does not implement ASA, but it provides a common interface for ASA capabilities implemented in the SUNDIALS packages. Right now it supports the ASA capabilities in ARKODE, while the ASA capabilities in CVODES and IDAS must be used directly.
Suppose we have a functional \(g(t_f, y(t_f), p)\) for which we would like to compute the gradients \(dg(t_f, y(t_f), p)/dy(t_0)\) and/or \(dg(t_f, y(t_f), p)/dp\). This most often arises in the form of an optimization problem such as
Warning
The CVODES documentation uses \(\lambda\) to represent the adjoint variables needed to obtain the gradient \(dG/dp\) where \(G\) is an integral of \(g\). Our use of \(\lambda\) in the following is akin to the use of \(\mu\) in the CVODES docs.
The adjoint method is one approach to obtaining the gradients that is particularly efficient when there are relatively few functionals and a large number of parameters. While CVODES and IDAS continuous adjoint methods (differentiate-then-discretize), ARKODE provides discrete adjoint methods (discretize-then-differentiate). For the continuous approach, we derive and solve the adjoint IVP backwards in time
where \(\lambda(t) \in \mathbb{R}^{N_s}\), \(f_y \equiv \partial f/\partial y \in \mathbb{R}^{N \times N}\) and \(g_y \equiv \partial g/\partial y \in \mathbb{R}^{N \times N}\), are the Jacobians with respect to the dependent variable, \(*\) denotes the Hermitian (conjugate) transpose, \(N\) is the size of the original IVP, and \(N_s\) is the number of parameters. When solved with a numerical time integration scheme, the solution to the continuous adjoint IVP is a numerical approximation of the continuous adjoint sensitivities,
The gradients with respect to the parameters can then be obtained as
where \(y_p(t) \equiv \partial y(t)/\partial p \in \mathbb{R}^{N \times N_s}\), and \(g_p \equiv \partial g/\partial p \in \mathbb{R}^{N \times N_s}\) and \(f_p \equiv \partial f/\partial p \in \mathbb{R}^{N \times N_s}\) are the Jacobians with respect to the parameters.
For the discrete adjoint approach, we first numerically discretize the original IVP (15.1) using a time integration scheme, \(\varphi\), so that
For linear multistep methods \(k \geq 1\) and for one step methods \(k = 1\). Reformulating the optimization problem for the discrete case, we have
The gradients of (15.7) can be computed using the transposed chain rule backwards in time to obtain the discrete adjoint variables \(\lambda_n, \lambda_{n-1}, \cdots, \lambda_0\) and \(\mu_n, \mu_{n-1}, \cdots, \mu_0\). The discrete adjoint variables represent the gradients of the discrete cost function (15.7) with respect to changes in the discretized IVP (15.6),
15.1.1. Discrete vs. Continuous Adjoint Method
It is understood that the continuous adjoint method can be problematic in the context of optimization problems because the continuous adjoint method provides an approximation to the gradient of a continuous cost function while the optimizer is expecting the gradient of the discrete cost function. The discrepancy means that the optimizer can fail to due to inconsistent gradients [65, 66]. On the other hand, the discrete adjoint method provides the exact gradient of the discrete cost function allowing the optimizer to fully converge. Consequently, the discrete adjoint method is often preferable in optimization despite its own drawbacks – such as its (relatively) increased memory usage and the possible introduction of unphysical computational modes [135]. This is not to say that the discrete adjoint approach is always the better choice over the continuous adjoint approach in optimization. Computational efficiency and stability of one approach over the other can be both problem and method dependent. Section 8 in the paper [112] discusses the tradeoffs further and provides numerous references that may help inform users in choosing between the discrete and continuous adjoint approaches.
15.2. The SUNAdjointStepper Class
Added in version 7.3.0.
-
type SUNAdjointStepper
The
SUNAdjointStepperclass provides a package-agnostic interface to SUNDIALS ASA capabilities. It currently only supports the discrete ASA capabilities in the ARKODE package, but in the future this support may be expanded.
15.2.1. Class Methods
The SUNAdjointStepper class has the following methods:
-
SUNErrCode SUNAdjointStepper_Create(SUNStepper fwd_sunstepper, sunbooleantype own_fwd, SUNStepper adj_sunstepper, sunbooleantype own_adj, suncountertype final_step_idx, sunrealtype tf, N_Vector sf, SUNAdjointCheckpointScheme checkpoint_scheme, SUNContext sunctx, SUNAdjointStepper *adj_stepper)
Creates the
SUNAdjointStepperobject needed to solve the adjoint problem.- Parameters:
fwd_sunstepper – The
SUNStepperto be used for forward computations of the original ODE.own_fwd – Should fwd_sunstepper be owned (and destroyed) by the SUNAdjointStepper or not.
adj_sunstepper – The
SUNStepperto be used for the backward integration of the adjoint ODE.own_adj – Should adj_sunstepper be owned (and destroyed) by the SUNAdjointStepper or not.
final_step_idx – The index (step number) of the step corresponding to
t_ffor the forward ODE.tf – The terminal time for the forward ODE (the initial time for the adjoint ODE).
sf – The terminal condition for the adjoint ODE.
checkpoint_scheme – The
SUNAdjointCheckpointSchemeobject that determines the checkpointing strategy to use. This should be the same object provided to the forward integrator/stepper.sunctx – The
SUNContextfor the simulation.adj_stepper – The
SUNAdjointStepperto construct (will beNULLon failure).
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointStepper_ReInit(SUNAdjointStepper self, sunrealtype t0, N_Vector y0, sunrealtype tf, N_Vector sf)
Reinitializes the adjoint stepper to solve a new problem of the same size.
- Parameters:
adj_stepper – The adjoint solver object.
t0 – The new initial time.
y0 – The new initial condition.
tf – The time to start integrating the adjoint system from.
sf – The terminal condition vector of sensitivity solutions \(\partial g/\partial y_0\) and \(\partial g/\partial p\).
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointStepper_Evolve(SUNAdjointStepper adj_stepper, sunrealtype tout, N_Vector sens, sunrealtype *tret)
Integrates the adjoint system.
- Parameters:
adj_stepper – The adjoint solver object.
tout – The time at which the adjoint solution is desired.
sens – The vector of sensitivity solutions \(\partial g/\partial y_0\) and \(\partial g/\partial p\).
tret – On return, the time reached by the adjoint solver.
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointStepper_OneStep(SUNAdjointStepper adj_stepper, sunrealtype tout, N_Vector sens, sunrealtype *tret)
Evolves the adjoint system backwards one step.
- Parameters:
adj_stepper – The adjoint solver object.
tout – The time at which the adjoint solution is desired.
sens – The vector of sensitivity solutions \(\partial g/\partial y_0\) and \(\partial g/\partial p\).
tret – On return, the time reached by the adjoint solver.
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointStepper_RecomputeFwd(SUNAdjointStepper adj_stepper, suncountertype start_idx, sunrealtype t0, N_Vector y0, sunrealtype tf)
Evolves the forward system in time from (
start_idx,t0) to (stop_idx,tf) with dense checkpointing.- Parameters:
adj_stepper – The SUNAdjointStepper object.
start_idx – the index of the step, w.r.t. the original forward integration, to begin forward integration from.
t0 – the initial time, w.r.t. the original forward integration, to start forward integration from.
y0 – the initial state, w.r.t. the original forward integration, to start forward integration from.
tf – the final time, w.r.t. the original forward integration, to stop forward integration at.
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointStepper_SetUserData(SUNAdjointStepper adj_stepper, void *user_data)
Sets the user data pointer.
- Parameters:
adj_stepper – The SUNAdjointStepper object.
user_data – the user data pointer that will be passed back to user-supplied callback functions.
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointStepper_GetNumSteps(SUNAdjointStepper adj_stepper, suncountertype *num_steps)
Retrieves the number of steps taken by the adjoint stepper.
- Parameters:
adj_stepper – The SUNAdjointStepper object.
num_steps – Pointer to store the number of steps.
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointStepper_GetNumRecompute(SUNAdjointStepper adj_stepper, suncountertype *num_recompute)
Retrieves the number of recomputation steps (in the forward direction) performed by the adjoint stepper.
- Parameters:
adj_stepper – The SUNAdjointStepper object.
num_recompute – Pointer to store the number of recomputations.
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointStepper_PrintAllStats(SUNAdjointStepper adj_stepper, FILE *outfile, SUNOutputFormat fmt)
Prints the adjoint stepper statistics/counters in a human-readable table format or CSV format.
- Parameters:
adj_stepper – The SUNAdjointStepper object.
outfile – A file to write the output to.
fmt – the format to write in (
SUN_OUTPUTFORMAT_TABLEorSUN_OUTPUTFORMAT_CSV).
- Returns:
A
SUNErrCodeindicating failure or success.
15.2.2. User-Supplied Functions
-
typedef int (*SUNAdjRhsFn)(sunrealtype t, N_Vector y, N_Vector sens, N_Vector sens_dot, void *user_data)
These functions compute the adjoint ODE right-hand side.
For ARKODE, this is
\[\begin{split}\Lambda &= f_y^*(t, y, p) \lambda, \quad \text{and if the systems has parameters}, \\ \nu &= f_p^*(t, y, p) \lambda.\end{split}\]and corresponds to (2.74) for explicit Runge–Kutta methods.
Parameters:
t – the current value of the independent variable.
y – the current value of the forward solution vector.
sens – a NVECTOR_MANYVECTOR object with two subvectors, the first subvector holds \(\lambda\) and the second holds \(\mu\) and is unused in this function.
sens_dot – a NVECTOR_MANYVECTOR object with two subvectors, the first subvector holds \(\Lambda\) and the second holds \(\nu\).
user_data – the user_data pointer that was passed to
SUNAdjointStepper_SetUserData().
Returns:
A
SUNAdjRhsFnshould return 0 if successful, a positive value if a recoverable error occurred (in which case the integrator may attempt to correct), or a negative value if it failed unrecoverably (in which case the integration is halted and an error is raised).Note
Allocation of memory for
yis handled within the integrator.The vector
sens_dotmay be uninitialized on input; it is the user’s responsibility to fill this entire vector with meaningful values.
15.3. The SUNAdjointCheckpointScheme Class
Added in version 7.3.0.
As with other SUNDIALS classes, the SUNAdjointCheckpointScheme abstract base class is
implemented using a C structure containing a content pointer to the derived class member data
and a structure of function pointers to the derived class implementations of the virtual methods.
-
type SUNAdjointCheckpointScheme
A class that provides an interface for checkpointing states during forward integration and accessing them as needed during the backwards integration of the adjoint model.
-
enum SUNDataIOMode
-
enumerator SUNDATAIOMODE_INMEM
The IO mode for data that is stored in addressable random access memory. The location of the memory (e.g., CPU or GPU) is not specified by this mode.
-
enumerator SUNDATAIOMODE_INMEM
15.3.1. Base Class Methods
-
SUNErrCode SUNAdjointCheckpointScheme_NewEmpty(SUNContext sunctx, SUNAdjointCheckpointScheme *cs_ptr)
- Parameters:
sunctx – The SUNDIALS simulation context
cs_ptr – on output, a pointer to a new
SUNAdjointCheckpointSchemeobject
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointCheckpointScheme_NeedsSaving(SUNAdjointCheckpointScheme self, suncountertype step_num, suncountertype stage_num, sunrealtype t, sunbooleantype *yes_or_no)
Determines if the (step_num, stage_num) should be checkpointed or not.
- Parameters:
self – the
SUNAdjointCheckpointSchemeobjectstep_num – the step number of the checkpoint
stage_num – the stage number of the checkpoint
t – the time of the checkpoint
yes_or_no – boolean indicating if the checkpoint should be saved or not
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointCheckpointScheme_InsertVector(SUNAdjointCheckpointScheme self, suncountertype step_num, suncountertype stage_num, sunrealtype t, N_Vector y)
Inserts the vector as the checkpoint for (step_num, stage_num).
- Parameters:
self – the
SUNAdjointCheckpointSchemeobjectstep_num – the step number of the checkpoint
stage_num – the stage number of the checkpoint
t – the time of the checkpoint
y – the state vector to checkpoint
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointCheckpointScheme_LoadVector(SUNAdjointCheckpointScheme self, suncountertype step_num, suncountertype stage_num, sunrealtype t, sunbooleantype peek, N_Vector *yout, sunrealtype *tout)
Loads the checkpointed vector for (step_num, stage_num).
- Parameters:
self – the
SUNAdjointCheckpointSchemeobjectstep_num – the step number of the checkpoint
stage_num – the stage number of the checkpoint
t – the desired time of the checkpoint
peek – if true, then the checkpoint will be loaded but not deleted regardless of other implementation-specific settings. If false, then the checkpoint may be deleted depending on the implementation.
yout – the loaded state vector
tout – on output, the time of the checkpoint
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointCheckpointScheme_EnableDense(SUNAdjointCheckpointScheme self, sunbooleantype on_or_off)
Enables or disables dense checkpointing (checkpointing every step/stage). When dense checkpointing is disabled, the checkpointing interval that was set when the object was created is restored.
- Parameters:
self – the
SUNAdjointCheckpointSchemeobjecton_or_off – if true, dense checkpointing will be turned on, if false it will be turned off.
- Returns:
A
SUNErrCodeindicating failure or success.
-
SUNErrCode SUNAdjointCheckpointScheme_Destroy(SUNAdjointCheckpointScheme *cs_ptr)
Destroys (deallocates) the SUNAdjointCheckpointScheme object.
- Parameters:
cs_ptr – pointer to a
SUNAdjointCheckpointSchemeobject
- Returns:
A
SUNErrCodeindicating failure or success.
15.3.2. Implementation Specific Methods
This section describes the virtual methods defined by the SUNAdjointCheckpointScheme
abstract base class.
-
typedef SUNErrCode (*SUNAdjointCheckpointSchemeNeedsSavingFn)(SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, suncountertype stage_num, sunrealtype t, sunbooleantype *yes_or_no)
This type represents a function with the signature of
SUNAdjointCheckpointScheme_NeedsSaving().
-
typedef SUNErrCode (*SUNAdjointCheckpointSchemeInsertVectorFn)(SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, suncountertype stage_num, sunrealtype t, N_Vector y)
This type represents a function with the signature of
SUNAdjointCheckpointScheme_InsertVector().
-
typedef SUNErrCode (*SUNAdjointCheckpointSchemeLoadVectorFn)(SUNAdjointCheckpointScheme check_scheme, suncountertype step_num, suncountertype stage_num, sunrealtype t, sunbooleantype peek, N_Vector *yout, sunrealtype *tout)
This type represents a function with the signature of
SUNAdjointCheckpointScheme_LoadVector().
-
typedef SUNErrCode (*SUNAdjointCheckpointSchemeEnableDenseFn)(SUNAdjointCheckpointScheme check_scheme, sunbooleantype on_or_off)
This type represents a function with the signature of
SUNAdjointCheckpointScheme_EnableDense().
-
typedef SUNErrCode (*SUNAdjointCheckpointSchemeDestroyFn)(SUNAdjointCheckpointScheme *check_scheme_ptr)
This type represents a function with the signature of
SUNAdjointCheckpointScheme_Destroy().
15.3.3. Setting Content and Member Functions
These functions can be used to set the content pointer or virtual method pointers as needed when implementing the abstract base class.
-
SUNErrCode SUNAdjointCheckpointScheme_SetNeedsSavingFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeNeedsSavingFn fn)
This function attaches a
SUNAdjointCheckpointSchemeNeedsSavingFnfunction to aSUNAdjointCheckpointSchemeobject.- Parameters:
self – a checkpoint scheme object.
fn – the
SUNAdjointCheckpointSchemeNeedsSavingFnfunction to attach.
- Returns:
A
SUNErrCodeindicating success or failure.
-
SUNErrCode SUNAdjointCheckpointScheme_SetInsertVectorFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeInsertVectorFn fn)
This function attaches a
SUNAdjointCheckpointSchemeInsertVectorFnfunction to aSUNAdjointCheckpointSchemeobject.- Parameters:
self – a checkpoint scheme object.
fn – the
SUNAdjointCheckpointSchemeInsertVectorFnfunction to attach.
- Returns:
A
SUNErrCodeindicating success or failure.
-
SUNErrCode SUNAdjointCheckpointScheme_SetLoadVectorFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeLoadVectorFn fn)
This function attaches a
SUNAdjointCheckpointSchemeLoadVectorFnfunction to aSUNAdjointCheckpointSchemeobject.- Parameters:
self – a checkpoint scheme object.
fn – the
SUNAdjointCheckpointSchemeLoadVectorFnfunction to attach.
- Returns:
A
SUNErrCodeindicating success or failure.
-
SUNErrCode SUNAdjointCheckpointScheme_SetDestroyFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeDestroyFn fn)
This function attaches a
SUNAdjointCheckpointSchemeDestroyFnfunction to aSUNAdjointCheckpointSchemeobject.- Parameters:
self – a checkpoint scheme object.
fn – the
SUNAdjointCheckpointSchemeDestroyFnfunction to attach.
- Returns:
A
SUNErrCodeindicating success or failure.
-
SUNErrCode SUNAdjointCheckpointScheme_SetEnableDenseFn(SUNAdjointCheckpointScheme self, SUNAdjointCheckpointSchemeEnableDenseFn fn)
This function attaches a
SUNAdjointCheckpointSchemeEnableDenseFnfunction to aSUNAdjointCheckpointSchemeobject.- Parameters:
self – a checkpoint scheme object.
fn – the
SUNAdjointCheckpointSchemeEnableDenseFnfunction to attach.
- Returns:
A
SUNErrCodeindicating success or failure.
-
SUNErrCode SUNAdjointCheckpointScheme_SetContent(SUNAdjointCheckpointScheme self, void *content)
This function attaches a member data (content) pointer to a
SUNAdjointCheckpointSchemeobject.- Parameters:
self – a checkpoint scheme object.
content – a pointer to the checkpoint scheme member data.
- Returns:
A
SUNErrCodeindicating success or failure.
-
SUNErrCode SUNAdjointCheckpointScheme_GetContent(SUNAdjointCheckpointScheme self, void **content)
This function retrieves the member data (content) pointer from a
SUNAdjointCheckpointSchemeobject.- Parameters:
self – a checkpoint scheme object.
content – a pointer to set to the checkpoint scheme member data pointer.
- Returns:
A
SUNErrCodeindicating success or failure.
15.4. The SUNAdjointCheckpointScheme_Fixed Module
The SUNAdjointCheckpointScheme_Fixed module implements a scheme where a checkpoint is saved at some
fixed interval (in time steps). The module supports checkpointing of time step states only, or time step
states with intermediate stage states as well (for multistage methods). When used with a
fixed time step size then the number of checkpoints that will be saved is fixed. However, with
adaptive time steps the number of checkpoints stored with this scheme is unbounded.
The diagram below illustrates how checkpoints are stored with this scheme:
15.4.1. Base-class Method Overrides
The SUNAdjointCheckpointScheme_Fixed module implements the following SUNAdjointCheckpointScheme functions:
15.4.2. Implementation Specific Methods
The SUNAdjointCheckpointScheme_Fixed module also implements the following module-specific functions:
-
SUNErrCode SUNAdjointCheckpointScheme_Create_Fixed(SUNDataIOMode io_mode, SUNMemoryHelper mem_helper, suncountertype interval, suncountertype estimate, sunbooleantype keep, SUNContext sunctx, SUNAdjointCheckpointScheme *check_scheme_ptr)
Creates a new
SUNAdjointCheckpointSchemeobject that checkpoints at a fixed interval.- Parameters:
io_mode – The IO mode used for storing the checkpoints.
mem_helper – Memory helper for managing memory.
interval – The interval (in steps) between checkpoints.
estimate – An estimate of the total number of checkpoints needed.
keep – Keep data stored even after it is not needed anymore.
sunctx – The
SUNContextfor the simulation.check_scheme_ptr – Pointer to the newly constructed object.
- Returns:
A
SUNErrCodeindicating success or failure.