data:image/s3,"s3://crabby-images/38233/382330277e7485e81977150147486ccc5c7665f6" alt="Source: https://unsplash.com/"
1. Quick Start: Why Data Assimilation
Weather forecasting models are chaotic dynamical systems, where forecasts become unstable due to small perturbations in model states, making blind trust on the forecasts risky. While current forecasting services, such as the European Centre for Medium-Range Weather Forecasts (ECMWF), achieve high accuracy in predicting mid-range (15 days) to seasonal weather. The hack behind the good forecasts lies in the 4-dimensional variational data Assimilation (4D-Var), used since 1997 in ECMWF. This algorithm incorporates real-time observations to improve forecasts. As the main technique to minimize the butterfly effect – the high sensitivity to initial conditions – 4D-Var is also widely used in operational time-series forecasting systems across other fields.
data:image/s3,"s3://crabby-images/662d9/662d9b158d79e1110d2bf2d2313453ef3685887d" alt="A schematic of the data assimilation in ECMWF. Source: https://www.ecmwf.int/sites/default/files/2023-08/Fact sheet - Reanalysis -v3_1.pdf. Copyright © ECMWF, licensed under CC BY-SA 4.0."
This blog post will introduce the mathematics behind 4D-Var and provide its implementation with a simple example using Pytorch, a modern deep learning framework that provides powerful features to accelerate traditional data assimilation. By the end of this post, you will fully understand how 4D-Var works and be ready to apply it to your own problems. The post is structured as follows:
- Quick Start: Why Data Assimilation
- Prerequisites
- Standard 4D-Var Algorithm
- Incremental 4D-Var Algorithm
- 4D-Var Implementation
- Case Study
- Summary
2. Prerequisites
The derivation of 4D-Var requires basic knowledge of vector calculus, Bayes’ theorem, and multivariate statistics. It is totally fine to play with the code without diving into the math. Experience with Python and PyTorch will help you understand the code. The code repository is available at pytorch-data-assimilation-tutorial.
3. Standard 4D-Var Algorithm
3.1. Problem Formulation from Bayes’s Perspective
Unlike sequential data assimilation methods, such as the Kalman Filter, which update model states at each time step, 4D-Var performs batch optimization. It updates all model states across several consecutive steps simultaneously. Suppose a simulation model consists of two functions: a state transition model m and an observation operator h:
data:image/s3,"s3://crabby-images/6bd9c/6bd9c6c369a2cff7833166a799a00e7821eb9747" alt=""
Here, x is the model states and y is the model outputs. Due to the sequential dependence of model states, optimizing a batch of states is equivalent to optimizing only the initial state, x0. Recall Bayes’ s theorem:
data:image/s3,"s3://crabby-images/ef9ea/ef9ea20121b2ea63fba2b20a3903ec5104e5de0d" alt=""
When using a Gaussian prior distribution p(x) and likelihood p(y|x), we have:
data:image/s3,"s3://crabby-images/38a7b/38a7bc021a35827b2c8abbf8c172a5ce14646060" alt=""
The posterior probability can be log-transformed into a cost function J to be minimized:
data:image/s3,"s3://crabby-images/38473/384739f4dfaf018a17bd9d5fcbb604ca457958e4" alt=""
3.2. Gradient Descent Optimization
The gradient of each term in the cost function is:
data:image/s3,"s3://crabby-images/6ca8e/6ca8ec3631963a11d1788407817a5d8a2f0ee64e" alt=""
All derivatives can be expressed in the form of Jacobian matrices:
data:image/s3,"s3://crabby-images/c0e07/c0e07814c7f0fc36557c3b1f9a518f8ebe7ad3ff" alt=""
3.3. Simplification Technique: Adjoint
It is much easy to calculate the cost term above in a recursive manner, which leverages adjoint operator from linear algebra:
data:image/s3,"s3://crabby-images/a6392/a6392f218c4fae732183b58b4cec83e2ada73083" alt=""
This way propagates information backward in time, allowing the efficient computation of the gradient of the cost function:
data:image/s3,"s3://crabby-images/7a192/7a19228c681b9fd9b021f09eefc21a5fd0faa17f" alt=""
3.4. Wraping Up: Algorithm
data:image/s3,"s3://crabby-images/4ad19/4ad19a9d7e68d55bd339bc512f8f06e86518c406" alt=""
4. Incremental 4D-Var Algorithm
4.1. Problem Formulation
If we define the increments of model states and observations:
data:image/s3,"s3://crabby-images/9a954/9a954008f8db9b9507e40a4e27756332a1a11797" alt=""
Then, the cost function in standard 4D-Var can be rewritten as:
data:image/s3,"s3://crabby-images/2412f/2412f669fbdc39b7bc2d2bd84acea547770cbe09" alt=""
4.2. Further Simplification: Tangent Linear Model
This brings us to what ECMWF refers to as "the heart of 4D-Var": linearized physics. This concept is based on the mathematical rule that, when the increment is small, a nonlinear function can be approximated by a tangent linear model, i.e., the Jacobian of the function:
data:image/s3,"s3://crabby-images/c0ddb/c0ddb7b80365f2f1c04418926149d76b9e0a593d" alt=""
In this way, the nonlinear optimization problem for x0 is transformed into a quadratic optimization problem for the incremental x0:
data:image/s3,"s3://crabby-images/d0bae/d0baecb32813544f5f2f69071c86133f760871cd" alt=""
A quadratic optimization problem is much easier to solve using methods like conjugate gradient or Quasi-Newton optimizers, compared to gradient descent.
4.3. Wraping Up: Algorithm
data:image/s3,"s3://crabby-images/d8678/d8678ce100f56a4a7db661c10a774c07da238a5f" alt=""
4.4. VS. Standard 4D-Var
Incremental 4D-Var breaks the nonlinear optimization problem into multiple linear ones. It does this by dividing the process into two separate loops: an outer loop that updates the state at the initial time step (x0), and an inner loop that derives the optimal increment of x0. Since the quadratic optimization in the inner loop converges to global optimal values faster than a gradient descent, the outer loop converges more quickly in incremental 4D-Var compared to standard 4D-Var.
5. 4D-Var Implementation
First, import the necessary libraries. The entire pipeline is implemented using PyTorch, which simplifies the derivation of the Jacobian matrix and tangent linear model through its automatic gradient calculation.
import torch
import numpy as np
from typing import Dict, Tuple
import pandas as pd
from scipy.optimize import minimize
import seaborn as sns
from sklearn.metrics import root_mean_squared_error
5.1. Standard 4D-Var
In a standard 4D-Var update step, observations within a time window are used to derive the optimal model state at the start of the window (x0). The state x0 is then updated using gradient descent.
def standard_4dvar_update(x_background: torch.Tensor,
window_observations: torch.Tensor,
model: callable,
obs_operator: callable,
Q_inv: torch.Tensor,
R_inv: torch.Tensor,
window_model_input: Dict[str, torch.Tensor],
n_iter: int = 10,
lr: float = 0.01) -> torch.Tensor:
'''
:param x_background: (state_dim,), initial guess of state
:param window_observations: (assimilation_window, y_dim), observations over the assimilation window
:param model: state transition
:param obs_operator: observation operator
:param Q_inv: (state_dim, state_dim), inverse of the state transition noise covariance matrix
:param R_inv: (y_dim, y_dim), inverse of the observation noise covariance matrix
:param window_model_input: dict of (assimilation_window, input_dim), external input for the state transition model
:param n_iter: number of iterations for the gradient descent
:param lr: learning rate
:return: (state_dim, ), updated state
'''
num_steps = len(window_observations)
# Initialize the state
x0 = x_background.clone()
for _ in range(n_iter):
# Forward integration (state propagation over the assimilation window)
x_forward = torch.zeros((num_steps, x0.shape[0]), dtype=torch.float32)
x_forward[0] = x0
innovations = torch.zeros((num_steps, y_dim), dtype=torch.float32)
for t in range(0, num_steps):
if t > 0:
step_model_input = {k: v[t] for k, v in window_model_input.items()}
x_forward[t] = model(x_forward[t - 1], **step_model_input)
innovations[t] = window_observations[t] - obs_operator(x_forward[t])
# Backward integration to compute gradients of the cost w.r.t. initial state using adjoint
x0_grad = nonlinear_delta(x_forward, innovations, model, obs_operator, Q_inv, R_inv, window_model_input)
# Apply gradient descent to update x0
x0 = x0 - lr * x0_grad
return x0
In each gradient descent step, the total gradient of the cost function is calculated by propagating gradients, represented by Jacobian matrices backward in time. This process is implemented in the function nonlinear_delta
. You can refer to the algorithm section in this post to understand the following code.
def nonlinear_delta(x_forward: torch.Tensor,
innovations: torch.Tensor,
model: callable,
obs_operator: callable,
Q_inv: torch.Tensor,
R_inv: torch.Tensor,
window_model_input: Dict[str, torch.Tensor]) -> torch.Tensor:
'''
:param x_forward: (assimilation_window, state_dim), state trajectory over the assimilation window
:param innovations: (assimilation_window, y_dim), innovations over the assimilation window
:param model: state transition
:param obs_operator: observation operator
:param Q_inv: (state_dim, state_dim), inverse of the state transition noise covariance matrix
:param R_inv: (y_dim, y_dim), inverse of the observation noise covariance matrix
:param window_model_input: dict of (assimilation_window, input_dim), external input for the state transition model
:return: (state_dim, ), gradient of the cost function w.r.t. the initial state
'''
num_steps = len(innovations)
# Initialize adjoint tensor with zeros
adjoint = torch.zeros_like(x_forward)
# Backward pass to compute gradients with respect to the initial state (x0)
for t in range(num_steps - 1, -1, -1): # Going backward in time
# Gradients of the observation term (J_y)
operator_tlm = torch.autograd.functional.jacobian(lambda x: obs_operator(x), x_forward[t])
grad_obs = -operator_tlm.t() @ R_inv @ innovations[t]
# Compute the adjoint at time t (this is the gradient w.r.t. the state at time t)
adjoint[t] = grad_obs # Adjoint at time t comes from the innovation term
if t < num_steps - 1:
# Backpropagate through the model: compute the adjoint of the model at time t
step_model_input = {k: v[t + 1] for k, v in window_model_input.items()}
model_tlm = torch.autograd.functional.jacobian(lambda x: model(x, **step_model_input), x_forward[t])
# Chain rule to propagate the adjoint backward
adjoint[t] += model_tlm.t() @ adjoint[t + 1]
J_b_grad = Q_inv @ (x_forward[0] - x_background) # Gradient of the background term
x0_grad = adjoint[0] + J_b_grad
return x0_grad
5.2. Incremental 4D-Var
Similar to standard 4D-Var, incremental 4D-Var optimizes the starting state x0 using the observations within a time window. However, in each optimization step, incremental 4D-Var updates x0 by directly finding the optimal increment dx0 through quadratic optimization, which is a simpler problem compared to the nonlinear optimization in standard 4D-Var.
def incremental_4dvar_update(x_background: torch.Tensor,
window_observations: torch.Tensor,
model: callable,
obs_operator: callable,
Q_inv: torch.Tensor,
R_inv: torch.Tensor,
window_model_input: Dict[str, torch.Tensor],
n_iter: int = 10) -> torch.Tensor:
'''
:param x_background: (state_dim,), initial guess of state
:param window_observations: (assimilation_window, y_dim), observations over the assimilation window
:param model: state transition
:param obs_operator: observation operator
:param Q_inv: (state_dim, state_dim), inverse of the state transition noise covariance matrix
:param R_inv: (y_dim, y_dim), inverse of the observation noise covariance matrix
:param window_model_input: dict of (assimilation_window, input_dim), external input for the state transition model
:param n_iter: number of iterations for the increment calculation
:return: (state_dim, ), updated state
'''
num_steps = len(window_observations)
# Initialize the state
x0 = x_background.clone()
for _ in range(n_iter):
# Outer loop: forward integration (state propagation over the assimilation window)
x_forward = torch.zeros((num_steps, x0.shape[0]), dtype=torch.float32)
x_forward[0] = x0
innovations = torch.zeros((num_steps, y_dim), dtype=torch.float32)
for t in range(0, num_steps):
if t > 0:
step_model_input = {k: v[t] for k, v in window_model_input.items()}
x_forward[t] = model(x_forward[t - 1], **step_model_input)
innovations[t] = window_observations[t] - obs_operator(x_forward[t])
# Compute the optimized increment
dx0 = linear_delta(x_forward, innovations, model, obs_operator, Q_inv, R_inv, window_model_input)
# Update the initial state (x0)
x0 = x0 + dx0
return x0
In the inner loop function linear_delta
, the tangent linear models—represented by Jacobian matrices along the trajectory starting from the reference state—make the cost function quadratic. This allows the cost function to be minimized by the conjugate gradient method from scipy.optimize.minimize
.
def linear_delta(x_forward: torch.Tensor,
innovations: torch.Tensor,
model: callable,
obs_operator: callable,
Q_inv: torch.Tensor,
R_inv: torch.Tensor,
window_model_input: Dict[str, torch.Tensor]) -> torch.Tensor:
'''
:param x_forward: (assimilation_window, state_dim), state trajectory over the assimilation window
:param innovations: (assimilation_window, y_dim), innovations over the assimilation window
:param model: state transition
:param obs_operator: observation operator
:param Q_inv: (state_dim, state_dim), inverse of the state transition noise covariance matrix
:param R_inv: (y_dim, y_dim), inverse of the observation noise covariance matrix
:param window_model_input: dict of (assimilation_window, input_dim), external input for the state transition model
:return: (state_dim, ), optimized increment of the initial state
'''
num_steps = len(innovations)
state_dim = x_forward.shape[1]
# Calculate the Tangent Linear Model (TLM) matrices
model_tlm_stacks = torch.zeros((num_steps, state_dim, state_dim), dtype=torch.float32)
operator_tlm_stacks = torch.zeros((num_steps, y_dim, state_dim), dtype=torch.float32)
for t in range(num_steps):
step_model_input = {k: v[t] for k, v in window_model_input.items()}
model_tlm = torch.autograd.functional.jacobian(lambda x: model(x, **step_model_input), x_forward[t])
operator_tlm = torch.autograd.functional.jacobian(lambda x: obs_operator(x), x_forward[t])
model_tlm_stacks[t] = model_tlm
operator_tlm_stacks[t] = operator_tlm
def cost_function(dx0, num_steps, model_tlm_stacks_np, operator_tlm_stacks_np, innovations_np, Q_inv_np, R_inv_np):
# Propagete perturbation using the TLM
J_b = 0.5 * dx0.T @ Q_inv_np @ dx0
J_y = 0.
dxi = dx0.copy()
for t in range(num_steps):
if t > 0:
dxi = model_tlm_stacks_np[t - 1] @ dxi
dyi = operator_tlm_stacks_np[t] @ dxi
J_y += 0.5 * (innovations_np[t] - dyi).T @ R_inv_np @ (innovations_np[t] - dyi)
return J_b + J_y
dx0_init = np.zeros_like(x_forward[0]) + .5
model_tlm_stacks_np = model_tlm_stacks.numpy()
operator_tlm_stacks_np = operator_tlm_stacks.numpy()
innovations_np = innovations.numpy()
Q_inv_np = Q_inv.numpy()
R_inv_np = R_inv.numpy()
res = minimize(
cost_function, dx0_init,
args=(num_steps, model_tlm_stacks_np, operator_tlm_stacks_np, innovations_np, Q_inv_np, R_inv_np),
method='CG',
options={'gtol': 1e-6}
)
dx0 = torch.tensor(res.x, dtype=torch.float32)
return dx0
5.3. Assimilation Loop
In real-world problems, data assimilation is performed in consecutive windows to produce seamless temporal model corrections. The following function is a simplified operational data assimilation module that runs 4D-Var (either standard or incremental) over a period and updates the states for all time steps.
def vad_4dvar_loop(x_background: torch.Tensor,
observations_y: torch.Tensor,
assimilation_window: int,
model: callable,
obs_operator: callable,
Q: torch.Tensor,
R: torch.Tensor,
model_input: Dict[str, torch.Tensor],
method: str = "standard",
n_iter: int = 10,
lr: float = 0.01,
) -> torch.Tensor:
'''
:param x_background: (state_dim,), initial guess of state
:param observations_y: (time_steps, y_dim), observations
:param assimilation_window: number of time steps to assimilate in a batch
:param model: state transition
:param obs_operator: observation operator
:param Q: (state_dim, state_dim), state transition noise covariance matrix
:param R: (y_dim, y_dim), observation noise covariance matrix
:param model_input: dict of (time_steps, input_dim), external input for the state transition model
:param method: "standard" or "incremental"
:param n_iter: number of iterations for the gradient descent (standard) or increment calculation (incremental)
:param lr: learning rate for the gradient descent
:return: (time_steps, state_dim), updated state trajectory
'''
x_assimilated = []
Q_inv, R_inv = torch.inverse(Q), torch.inverse(R)
for start in range(0, len(observations_y), assimilation_window):
window_observations = observations_y[start:(start + assimilation_window)]
window_model_input = {k: v[start:(start + assimilation_window)] for k, v in model_input.items()}
# Initialize the background state for the current window
if start > 0:
step_model_input = {k: v[0] for k, v in window_model_input.items()}
x_background = model(x_assimilated[-1], **step_model_input)
if method == "standard":
x0 = standard_4dvar_update(x_background, window_observations, model, obs_operator, Q_inv, R_inv, window_model_input, n_iter, lr)
elif method == "incremental":
x0 = incremental_4dvar_update(x_background, window_observations, model, obs_operator, Q_inv, R_inv, window_model_input, n_iter)
else:
raise ValueError(f"Unknown method: {method}")
# Update simulation states
x_assimilated.append(x0)
for t in range(1, assimilation_window): # end with the start of the next window
step_model_input = {k: v[t] for k, v in window_model_input.items()}
x_assimilated.append(model(x_assimilated[-1], **step_model_input))
return torch.stack(x_assimilated)
6. Case Study
This post uses a simplified rainfall-runoff model as a case study. A description of the model can be found in the Problem Overview section of Addressing the Butterfly Effect: Data Assimilation Using Ensemble Kalman Filter, and the code for the model and data generation is available in pytorch-data-assimilation-tutorial. We assess the impact of assimilation on simulation accuracy by comparing the following:
- Simulated runoff (model output)
- Assimilated runoff (model output after assimilation)
- Observed runoff (ground truth)
The next two figures illustrate these comparisons. Both the standard and incremental 4D-Var algorithms produce model states and output variables that align much more closely with the ground truth than those from pure simulation.
data:image/s3,"s3://crabby-images/90920/90920167e41c0ec258f616ad7aa9f5def4f64980" alt="Trajectories of model states and outputs (true vs. simulated vs. assimilated) in the case study. The assimilation algorithm is standard 4D-Var. Source: by author."
data:image/s3,"s3://crabby-images/16dbb/16dbbe91333c2c56ba646b2317d0aa11add4de66" alt="Trajectories of model states and outputs (true vs. simulated vs. assimilated) in the case study. The assimilation algorithm is incremental 4D-Var. Source: by author."
7. Summary
In this post, you’ve learned about standard 4D-Var and its more efficient version, incremental 4D-Var – key techniques for stabilizing Weather Forecasts, which are highly sensitive to weather system states. The post introduces the mathematical foundation of 4D-Var, rooted in Bayesian inference and nonlinear optimization. It also shows how modern computational tools with automatic differentiation, such as PyTorch, simplify the algorithm’s implementation. In real-world operational weather forecasting systems, where data dimensions and volumes are large, integrating 4D-Var is far more complex, and many techniques remain to learn.
8. Further Reading
- https://www.ecmwf.int/en/newsletter/175/earth-system-science/linearised-physics-heart-ecmwfs-4d-var
- https://www.ecmwf.int/sites/default/files/elibrary/2003/76079-variational-data-assimiltion-theory-and-overview_0.pdf
- https://www.ecmwf.int/en/about/media-centre/news/2022/25-years-4d-var-how-machine-learning-can-improve-use-observations
- https://www.ecmwf.int/sites/default/files/elibrary/2018/18542-tangent-linear-and-adjoint-models-data-assimilation.pdf
Follow me on Medium for more updates on how modern data science tools are revolutionizing traditional models in science and industry.