The world’s leading publication for data science, AI, and ML professionals.

Decoding the Hack behind Accurate Weather Forecasting: Variational Data Assimilation

Learn how to implement the variational data assimilation, with mathematical details and PyTorch for efficient implementation.

Source: https://unsplash.com/
Source: https://unsplash.com/

1. Quick Start: Why Data Assimilation

Weather forecasting models are chaotic dynamical systems, where forecasts become unstable due to small perturbations in model states, making blind trust on the forecasts risky. While current forecasting services, such as the European Centre for Medium-Range Weather Forecasts (ECMWF), achieve high accuracy in predicting mid-range (15 days) to seasonal weather. The hack behind the good forecasts lies in the 4-dimensional variational data Assimilation (4D-Var), used since 1997 in ECMWF. This algorithm incorporates real-time observations to improve forecasts. As the main technique to minimize the butterfly effect – the high sensitivity to initial conditions – 4D-Var is also widely used in operational time-series forecasting systems across other fields.

A schematic of the data assimilation in ECMWF. Source: https://www.ecmwf.int/sites/default/files/2023-08/Fact sheet - Reanalysis -v3_1.pdf. Copyright © ECMWF, licensed under CC BY-SA 4.0.
A schematic of the data assimilation in ECMWF. Source: https://www.ecmwf.int/sites/default/files/2023-08/Fact sheet – Reanalysis -v3_1.pdf. Copyright © ECMWF, licensed under CC BY-SA 4.0.

This blog post will introduce the mathematics behind 4D-Var and provide its implementation with a simple example using Pytorch, a modern deep learning framework that provides powerful features to accelerate traditional data assimilation. By the end of this post, you will fully understand how 4D-Var works and be ready to apply it to your own problems. The post is structured as follows:

  • Quick Start: Why Data Assimilation
  • Prerequisites
  • Standard 4D-Var Algorithm
  • Incremental 4D-Var Algorithm
  • 4D-Var Implementation
  • Case Study
  • Summary

2. Prerequisites

The derivation of 4D-Var requires basic knowledge of vector calculus, Bayes’ theorem, and multivariate statistics. It is totally fine to play with the code without diving into the math. Experience with Python and PyTorch will help you understand the code. The code repository is available at pytorch-data-assimilation-tutorial.

3. Standard 4D-Var Algorithm

3.1. Problem Formulation from Bayes’s Perspective

Unlike sequential data assimilation methods, such as the Kalman Filter, which update model states at each time step, 4D-Var performs batch optimization. It updates all model states across several consecutive steps simultaneously. Suppose a simulation model consists of two functions: a state transition model m and an observation operator h:

Here, x is the model states and y is the model outputs. Due to the sequential dependence of model states, optimizing a batch of states is equivalent to optimizing only the initial state, x0. Recall Bayes’ s theorem:

When using a Gaussian prior distribution p(x) and likelihood p(y|x), we have:

The posterior probability can be log-transformed into a cost function J to be minimized:

3.2. Gradient Descent Optimization

The gradient of each term in the cost function is:

All derivatives can be expressed in the form of Jacobian matrices:

3.3. Simplification Technique: Adjoint

It is much easy to calculate the cost term above in a recursive manner, which leverages adjoint operator from linear algebra:

This way propagates information backward in time, allowing the efficient computation of the gradient of the cost function:

3.4. Wraping Up: Algorithm

4. Incremental 4D-Var Algorithm

4.1. Problem Formulation

If we define the increments of model states and observations:

Then, the cost function in standard 4D-Var can be rewritten as:

4.2. Further Simplification: Tangent Linear Model

This brings us to what ECMWF refers to as "the heart of 4D-Var": linearized physics. This concept is based on the mathematical rule that, when the increment is small, a nonlinear function can be approximated by a tangent linear model, i.e., the Jacobian of the function:

In this way, the nonlinear optimization problem for x0 is transformed into a quadratic optimization problem for the incremental x0:

A quadratic optimization problem is much easier to solve using methods like conjugate gradient or Quasi-Newton optimizers, compared to gradient descent.

4.3. Wraping Up: Algorithm

4.4. VS. Standard 4D-Var

Incremental 4D-Var breaks the nonlinear optimization problem into multiple linear ones. It does this by dividing the process into two separate loops: an outer loop that updates the state at the initial time step (x0), and an inner loop that derives the optimal increment of x0. Since the quadratic optimization in the inner loop converges to global optimal values faster than a gradient descent, the outer loop converges more quickly in incremental 4D-Var compared to standard 4D-Var.

5. 4D-Var Implementation

First, import the necessary libraries. The entire pipeline is implemented using PyTorch, which simplifies the derivation of the Jacobian matrix and tangent linear model through its automatic gradient calculation.

import torch
import numpy as np
from typing import Dict, Tuple
import pandas as pd
from scipy.optimize import minimize
import seaborn as sns
from sklearn.metrics import root_mean_squared_error

5.1. Standard 4D-Var

In a standard 4D-Var update step, observations within a time window are used to derive the optimal model state at the start of the window (x0). The state x0 is then updated using gradient descent.

def standard_4dvar_update(x_background: torch.Tensor,
                          window_observations: torch.Tensor,
                          model: callable,
                          obs_operator: callable,
                          Q_inv: torch.Tensor,
                          R_inv: torch.Tensor,
                          window_model_input: Dict[str, torch.Tensor],
                          n_iter: int = 10,
                          lr: float = 0.01) -> torch.Tensor:
    '''
    :param x_background: (state_dim,), initial guess of state
    :param window_observations: (assimilation_window, y_dim), observations over the assimilation window
    :param model: state transition
    :param obs_operator: observation operator
    :param Q_inv: (state_dim, state_dim), inverse of the state transition noise covariance matrix
    :param R_inv: (y_dim, y_dim), inverse of the observation noise covariance matrix
    :param window_model_input: dict of (assimilation_window, input_dim), external input for the state transition model
    :param n_iter: number of iterations for the gradient descent
    :param lr: learning rate
    :return: (state_dim, ), updated state
    '''
    num_steps = len(window_observations)

    # Initialize the state
    x0 = x_background.clone()

    for _ in range(n_iter):
        # Forward integration (state propagation over the assimilation window)
        x_forward = torch.zeros((num_steps, x0.shape[0]), dtype=torch.float32)
        x_forward[0] = x0
        innovations = torch.zeros((num_steps, y_dim), dtype=torch.float32)
        for t in range(0, num_steps):
            if t > 0:
                step_model_input = {k: v[t] for k, v in window_model_input.items()}
                x_forward[t] = model(x_forward[t - 1], **step_model_input)
            innovations[t] = window_observations[t] - obs_operator(x_forward[t])

        # Backward integration to compute gradients of the cost w.r.t. initial state using adjoint
        x0_grad = nonlinear_delta(x_forward, innovations, model, obs_operator, Q_inv, R_inv, window_model_input)

        # Apply gradient descent to update x0
        x0 = x0 - lr * x0_grad

    return x0

In each gradient descent step, the total gradient of the cost function is calculated by propagating gradients, represented by Jacobian matrices backward in time. This process is implemented in the function nonlinear_delta. You can refer to the algorithm section in this post to understand the following code.

def nonlinear_delta(x_forward: torch.Tensor,
                    innovations: torch.Tensor,
                    model: callable,
                    obs_operator: callable,
                    Q_inv: torch.Tensor,
                    R_inv: torch.Tensor,
                    window_model_input: Dict[str, torch.Tensor]) -> torch.Tensor:
    '''
    :param x_forward: (assimilation_window, state_dim), state trajectory over the assimilation window
    :param innovations: (assimilation_window, y_dim), innovations over the assimilation window
    :param model: state transition
    :param obs_operator: observation operator
    :param Q_inv: (state_dim, state_dim), inverse of the state transition noise covariance matrix
    :param R_inv: (y_dim, y_dim), inverse of the observation noise covariance matrix
    :param window_model_input: dict of (assimilation_window, input_dim), external input for the state transition model
    :return: (state_dim, ), gradient of the cost function w.r.t. the initial state
    '''
    num_steps = len(innovations)

    # Initialize adjoint tensor with zeros
    adjoint = torch.zeros_like(x_forward)

    # Backward pass to compute gradients with respect to the initial state (x0)
    for t in range(num_steps - 1, -1, -1):  # Going backward in time

        # Gradients of the observation term (J_y)
        operator_tlm = torch.autograd.functional.jacobian(lambda x: obs_operator(x), x_forward[t])
        grad_obs = -operator_tlm.t() @ R_inv @ innovations[t]

        # Compute the adjoint at time t (this is the gradient w.r.t. the state at time t)
        adjoint[t] = grad_obs  # Adjoint at time t comes from the innovation term

        if t < num_steps - 1:
            # Backpropagate through the model: compute the adjoint of the model at time t
            step_model_input = {k: v[t + 1] for k, v in window_model_input.items()}
            model_tlm = torch.autograd.functional.jacobian(lambda x: model(x, **step_model_input), x_forward[t])

            # Chain rule to propagate the adjoint backward
            adjoint[t] += model_tlm.t() @ adjoint[t + 1]

    J_b_grad = Q_inv @ (x_forward[0] - x_background)  # Gradient of the background term
    x0_grad = adjoint[0] + J_b_grad
    return x0_grad

5.2. Incremental 4D-Var

Similar to standard 4D-Var, incremental 4D-Var optimizes the starting state x0 using the observations within a time window. However, in each optimization step, incremental 4D-Var updates x0 by directly finding the optimal increment dx0 through quadratic optimization, which is a simpler problem compared to the nonlinear optimization in standard 4D-Var.

def incremental_4dvar_update(x_background: torch.Tensor,
                             window_observations: torch.Tensor,
                             model: callable,
                             obs_operator: callable,
                             Q_inv: torch.Tensor,
                             R_inv: torch.Tensor,
                             window_model_input: Dict[str, torch.Tensor],
                             n_iter: int = 10) -> torch.Tensor:
    '''
    :param x_background: (state_dim,), initial guess of state
    :param window_observations: (assimilation_window, y_dim), observations over the assimilation window
    :param model: state transition
    :param obs_operator: observation operator
    :param Q_inv: (state_dim, state_dim), inverse of the state transition noise covariance matrix
    :param R_inv: (y_dim, y_dim), inverse of the observation noise covariance matrix
    :param window_model_input: dict of (assimilation_window, input_dim), external input for the state transition model
    :param n_iter: number of iterations for the increment calculation
    :return: (state_dim, ), updated state
    '''
    num_steps = len(window_observations)

    # Initialize the state
    x0 = x_background.clone()

    for _ in range(n_iter):
        # Outer loop: forward integration (state propagation over the assimilation window)
        x_forward = torch.zeros((num_steps, x0.shape[0]), dtype=torch.float32)
        x_forward[0] = x0
        innovations = torch.zeros((num_steps, y_dim), dtype=torch.float32)
        for t in range(0, num_steps):
            if t > 0:
                step_model_input = {k: v[t] for k, v in window_model_input.items()}
                x_forward[t] = model(x_forward[t - 1], **step_model_input)
            innovations[t] = window_observations[t] - obs_operator(x_forward[t])

        # Compute the optimized increment
        dx0 = linear_delta(x_forward, innovations, model, obs_operator, Q_inv, R_inv, window_model_input)

        # Update the initial state (x0)
        x0 = x0 + dx0

    return x0

In the inner loop function linear_delta, the tangent linear models—represented by Jacobian matrices along the trajectory starting from the reference state—make the cost function quadratic. This allows the cost function to be minimized by the conjugate gradient method from scipy.optimize.minimize.

def linear_delta(x_forward: torch.Tensor,
                 innovations: torch.Tensor,
                 model: callable,
                 obs_operator: callable,
                 Q_inv: torch.Tensor,
                 R_inv: torch.Tensor,
                 window_model_input: Dict[str, torch.Tensor]) -> torch.Tensor:
    '''
    :param x_forward: (assimilation_window, state_dim), state trajectory over the assimilation window
    :param innovations: (assimilation_window, y_dim), innovations over the assimilation window
    :param model: state transition
    :param obs_operator: observation operator
    :param Q_inv: (state_dim, state_dim), inverse of the state transition noise covariance matrix
    :param R_inv: (y_dim, y_dim), inverse of the observation noise covariance matrix
    :param window_model_input: dict of (assimilation_window, input_dim), external input for the state transition model
    :return: (state_dim, ), optimized increment of the initial state
    '''
    num_steps = len(innovations)
    state_dim = x_forward.shape[1]

    # Calculate the Tangent Linear Model (TLM) matrices
    model_tlm_stacks = torch.zeros((num_steps, state_dim, state_dim), dtype=torch.float32)
    operator_tlm_stacks = torch.zeros((num_steps, y_dim, state_dim), dtype=torch.float32)
    for t in range(num_steps):
        step_model_input = {k: v[t] for k, v in window_model_input.items()}
        model_tlm = torch.autograd.functional.jacobian(lambda x: model(x, **step_model_input), x_forward[t])
        operator_tlm = torch.autograd.functional.jacobian(lambda x: obs_operator(x), x_forward[t])
        model_tlm_stacks[t] = model_tlm
        operator_tlm_stacks[t] = operator_tlm

    def cost_function(dx0, num_steps, model_tlm_stacks_np, operator_tlm_stacks_np, innovations_np, Q_inv_np, R_inv_np):

        # Propagete perturbation using the TLM
        J_b = 0.5 * dx0.T @ Q_inv_np @ dx0
        J_y = 0.
        dxi = dx0.copy()
        for t in range(num_steps):
            if t > 0:
                dxi = model_tlm_stacks_np[t - 1] @ dxi
            dyi = operator_tlm_stacks_np[t] @ dxi
            J_y += 0.5 * (innovations_np[t] - dyi).T @ R_inv_np @ (innovations_np[t] - dyi)
        return J_b + J_y

    dx0_init = np.zeros_like(x_forward[0]) + .5
    model_tlm_stacks_np = model_tlm_stacks.numpy()
    operator_tlm_stacks_np = operator_tlm_stacks.numpy()
    innovations_np = innovations.numpy()
    Q_inv_np = Q_inv.numpy()
    R_inv_np = R_inv.numpy()

    res = minimize(
        cost_function, dx0_init,
        args=(num_steps, model_tlm_stacks_np, operator_tlm_stacks_np, innovations_np, Q_inv_np, R_inv_np),
        method='CG',
        options={'gtol': 1e-6}
    )
    dx0 = torch.tensor(res.x, dtype=torch.float32)

    return dx0

5.3. Assimilation Loop

In real-world problems, data assimilation is performed in consecutive windows to produce seamless temporal model corrections. The following function is a simplified operational data assimilation module that runs 4D-Var (either standard or incremental) over a period and updates the states for all time steps.

def vad_4dvar_loop(x_background: torch.Tensor,
                   observations_y: torch.Tensor,
                   assimilation_window: int,
                   model: callable,
                   obs_operator: callable,
                   Q: torch.Tensor,
                   R: torch.Tensor,
                   model_input: Dict[str, torch.Tensor],
                   method: str = "standard",
                   n_iter: int = 10,
                   lr: float = 0.01,
                   ) -> torch.Tensor:
    '''
    :param x_background: (state_dim,), initial guess of state
    :param observations_y: (time_steps, y_dim), observations
    :param assimilation_window: number of time steps to assimilate in a batch
    :param model: state transition
    :param obs_operator: observation operator
    :param Q: (state_dim, state_dim), state transition noise covariance matrix
    :param R: (y_dim, y_dim), observation noise covariance matrix
    :param model_input: dict of (time_steps, input_dim), external input for the state transition model
    :param method: "standard" or "incremental"
    :param n_iter: number of iterations for the gradient descent (standard) or increment calculation (incremental)
    :param lr: learning rate for the gradient descent
    :return: (time_steps, state_dim), updated state trajectory
    '''
    x_assimilated = []
    Q_inv, R_inv = torch.inverse(Q), torch.inverse(R)

    for start in range(0, len(observations_y), assimilation_window):
        window_observations = observations_y[start:(start + assimilation_window)]
        window_model_input = {k: v[start:(start + assimilation_window)] for k, v in model_input.items()}

        # Initialize the background state for the current window
        if start > 0:
            step_model_input = {k: v[0] for k, v in window_model_input.items()}
            x_background = model(x_assimilated[-1], **step_model_input)

        if method == "standard":
            x0 = standard_4dvar_update(x_background, window_observations, model, obs_operator, Q_inv, R_inv, window_model_input, n_iter, lr)
        elif method == "incremental":
            x0 = incremental_4dvar_update(x_background, window_observations, model, obs_operator, Q_inv, R_inv, window_model_input, n_iter)
        else:
            raise ValueError(f"Unknown method: {method}")

        # Update simulation states
        x_assimilated.append(x0)
        for t in range(1, assimilation_window):  # end with the start of the next window
            step_model_input = {k: v[t] for k, v in window_model_input.items()}
            x_assimilated.append(model(x_assimilated[-1], **step_model_input))

    return torch.stack(x_assimilated)

6. Case Study

This post uses a simplified rainfall-runoff model as a case study. A description of the model can be found in the Problem Overview section of Addressing the Butterfly Effect: Data Assimilation Using Ensemble Kalman Filter, and the code for the model and data generation is available in pytorch-data-assimilation-tutorial. We assess the impact of assimilation on simulation accuracy by comparing the following:

  • Simulated runoff (model output)
  • Assimilated runoff (model output after assimilation)
  • Observed runoff (ground truth)

The next two figures illustrate these comparisons. Both the standard and incremental 4D-Var algorithms produce model states and output variables that align much more closely with the ground truth than those from pure simulation.

Trajectories of model states and outputs (true vs. simulated vs. assimilated) in the case study. The assimilation algorithm is standard 4D-Var. Source: by author.
Trajectories of model states and outputs (true vs. simulated vs. assimilated) in the case study. The assimilation algorithm is standard 4D-Var. Source: by author.
Trajectories of model states and outputs (true vs. simulated vs. assimilated) in the case study. The assimilation algorithm is incremental 4D-Var. Source: by author.
Trajectories of model states and outputs (true vs. simulated vs. assimilated) in the case study. The assimilation algorithm is incremental 4D-Var. Source: by author.

7. Summary

In this post, you’ve learned about standard 4D-Var and its more efficient version, incremental 4D-Var – key techniques for stabilizing Weather Forecasts, which are highly sensitive to weather system states. The post introduces the mathematical foundation of 4D-Var, rooted in Bayesian inference and nonlinear optimization. It also shows how modern computational tools with automatic differentiation, such as PyTorch, simplify the algorithm’s implementation. In real-world operational weather forecasting systems, where data dimensions and volumes are large, integrating 4D-Var is far more complex, and many techniques remain to learn.

8. Further Reading


Follow me on Medium for more updates on how modern data science tools are revolutionizing traditional models in science and industry.


Related Articles

Some areas of this page may shift around if you resize the browser window. Be sure to check heading and document order.