Graphcast: How to Get Things Done

A guide on how to make predictions using Google’s latest tool, from fetching data to formatting and so much more.

Abhinav Kumar
Towards Data Science

--

Photo by NOAA on Unsplash

Weather prediction is a very complex problem to solve. Numerical Weather Predictions (NWP) models, Weather Research and Forecasting (WRF) models, have been used to solve the problem, however, the accuracy and precision sometimes are found to be lacking.

Being the complex problem it is, it has attracted interest and the pursuit of solutions from data scientists to data science enthusiasts to meteorological engineers. Solutions have been found, however consistency and uniformity has not. The solution varies from area to area, from mountain to plateau, from swamps to tundra. From my own personal experience and I am sure from others’ experiences too, weather prediction has been found to be a tough cookie to crack. Quoting a certain shrimp billionaire:

It is like a box of chocolates, you never know what you’re gonna get.

Recently, Deepmind released a new tool: Graphcast, an AI model for faster and more accurate global weather forecasting, taking a shot at making this particular bag of chocolates tastier and more efficient. On a Google TPU v4 machine, using Graphcast, one can fetch predictions at a 0.25 degree spatial resolution in less than a minute. It solves a lot of issues one might face when predicting using conventional methods:

  • predictions are generated for all coordinates all at once,
  • editing the logic depending on the coordinate is now redundant,
  • mind boggling efficiency and response time.

What isn’t so mind boggling is the data preparation required to fetch predictions using the aforementioned tool.

Photo by Ali Kokab on Unsplash

However, worry not, I shall be your knight in a dark and gloomy armor and explain, in this article, the steps required to prepare and format data and finally, fetch predictions using Graphcast.

Note: The usage of the word “AI” nowadays reminds me very much of how “quantum” is used in Marvel movies.

Getting the predictions is a process which can be divided into the below sections:

  1. Fetching the input data.
  2. Creating the targets.
  3. Creating the forcing data.
  4. Processing and formatting the data into a suitable format.
  5. Bringing them all together and making predictions.

Graphcast states that using the current weather data and the data from 6 hours ago, one can make predictions 6 hours into the future. Taking an example to put it simply:

  • if predictions are required for: 2024–01–01 18:00,
  • then input data to be put forth: 2024–01–01 12:00 & 2024–01–01 06:00.

It is important to note that 2024–01–01 18:00 will be the first prediction fetched. Graphcast can additionally fetch data for 10 days, with a 6 hour gap between each prediction. So, the other timestamps for which predictions can be fetched are:

  • 2024–01–02 00:00, 06:00, 12:00, 18:00,
  • 2024–01–03 00:00, 06:00 and similarly till,
  • 2024–01–10 06:00, 12:00.

To summarize, data for 40 timestamps can be predicted using the input of two timestamps.

Assumptions and important parameters

For the code I will present in this article, I have assigned the following values to certain parameters that dictate how fast you can get the predictions and the memory used.

  • Input timestamp: 2024–01–01 6:00, 12:00.
  • First prediction timestamp: 2024–01–01 18:00.
  • Number of predictions: 4.
  • Spatial resolution: 1 degree.
  • Pressure levels: 13.

Below is the code for importing the required packages, initializing arrays for fields required for input and prediction purposes and other variables that will come in handy.

import cdsapi
import datetime
import functools
from graphcast import autoregressive, casting, checkpoint, data_utils as du, graphcast, normalization, rollout
import haiku as hk
import isodate
import jax
import math
import numpy as np
import pandas as pd
from pysolar.radiation import get_radiation_direct
from pysolar.solar import get_altitude
import pytz
import scipy
from typing import Dict
import xarray

client = cdsapi.Client() # Making a connection to CDS, to fetch data.

# The fields to be fetched from the single-level source.
singlelevelfields = [
'10m_u_component_of_wind',
'10m_v_component_of_wind',
'2m_temperature',
'geopotential',
'land_sea_mask',
'mean_sea_level_pressure',
'toa_incident_solar_radiation',
'total_precipitation'
]

# The fields to be fetched from the pressure-level source.
pressurelevelfields = [
'u_component_of_wind',
'v_component_of_wind',
'geopotential',
'specific_humidity',
'temperature',
'vertical_velocity'
]

# The 13 pressure levels.
pressure_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]

# Initializing other required constants.
pi = math.pi
gap = 6 # There is a gap of 6 hours between each graphcast prediction.
predictions_steps = 4 # Predicting for 4 timestamps.
watts_to_joules = 3600
first_prediction = datetime.datetime(2024, 1, 1, 18, 0) # Timestamp of the first prediction.
lat_range = range(-180, 181, 1) # Latitude range.
lon_range = range(0, 360, 1) # Longitude range.

# A utility function used for ease of coding.
# Converting the variable to a datetime object.
def toDatetime(dt) -> datetime.datetime:
if isinstance(dt, datetime.date) and isinstance(dt, datetime.datetime):
return dt

elif isinstance(dt, datetime.date) and not isinstance(dt, datetime.datetime):
return datetime.datetime.combine(dt, datetime.datetime.min.time())

elif isinstance(dt, str):
if 'T' in dt:
return isodate.parse_datetime(dt)
else:
return datetime.datetime.combine(isodate.parse_date(dt), datetime.datetime.min.time())

Inputs

When it comes to machine learning, in order to get some predictions, you have to give the ML model some data using which it spits out a prediction. For example, when predicting whether a person is Batman, the input data might be:

  • How much sleep do they get?
  • Do they have a tan line on their face?
  • Do they sleep during early morning meetings?
  • How much is their net worth?

Similarly, Graphcast too takes certain inputs, which we obtain from CDS, using its python library: cdsapi. Currently, the data publisher uses the Creative Commons Attribution 4.0 License, which means that anyone can copy, distribute, transmit, and adapt the work as long as the original author is given credit.

However, authentication is required before making requests to fetch data using cdsapi, the instructions for which are provided by CDS and is pretty straightforward.

Assuming you are now CDS-approved, inputs can be created, which involves the following steps:

  1. Getting the single-level values: These are dependent on the coordinates and time. One of the input fields required is total_precipitation_6hr. As the name suggests, it is the cumulation of the previous 6 hours of rainfall from that particular timestamp. Hence, instead of getting the values for just the two input timestamps, we have to get values for timestamps ranging from, in our case: 2024–01–01 00:00 to 12:00.
  2. Getting the pressure-level values: In addition to being dependent on the coordinates, they also depend on the pressure-level. Hence, when requesting data, we mention the pressure levels we need the data for. In this case, we get values for the two input timestamps only.
  3. Merging the single and pressure values: An inner-merge operation is carried out on the aforementioned data on the basis of time, latitude and longitude.
  4. Integrating year and day progress: In addition to the single and pressure fields, four more fields need to be added to the input data: year_progress_sin, year_progress_cos, day_progress_sin and day_progress_cos. This can be done using functions provided by the graphcast package.

Other small steps include:

  • Renaming the columns after they are fetched from CDS because CDS outputs a shortened form of the weather variables.
  • Renaming geopotential variable to geopotential_at_surface for the single-level data, since pressure-level has the same field name.
  • Using math functions to calculate the sin and cos values after the progress value is obtained from graphcast.
  • Renaming latitude to lat, longitude to lon and introducing another index: batch, which is assigned the value 0.

The code for creating the input data is as follows.

# Getting the single and pressure level values.
def getSingleAndPressureValues():

client.retrieve(
'reanalysis-era5-single-levels',
{
'product_type': 'reanalysis',
'variable': singlelevelfields,
'grid': '1.0/1.0',
'year': [2024],
'month': [1],
'day': [1],
'time': ['00:00', '01:00', '02:00', '03:00', '04:00', '05:00', '06:00', '07:00', '08:00', '09:00', '10:00', '11:00', '12:00'],
'format': 'netcdf'
},
'single-level.nc'
)
singlelevel = xarray.open_dataset('single-level.nc', engine = scipy.__name__).to_dataframe()
singlelevel = singlelevel.rename(columns = {col:singlelevelfields[ind] for ind, col in enumerate(singlelevel.columns.values.tolist())})
singlelevel = singlelevel.rename(columns = {'geopotential': 'geopotential_at_surface'})

# Calculating the sum of the last 6 hours of rainfall.
singlelevel = singlelevel.sort_index()
singlelevel['total_precipitation_6hr'] = singlelevel.groupby(level=[0, 1])['total_precipitation'].rolling(window = 6, min_periods = 1).sum().reset_index(level=[0, 1], drop=True)
singlelevel.pop('total_precipitation')

client.retrieve(
'reanalysis-era5-pressure-levels',
{
'product_type': 'reanalysis',
'variable': pressurelevelfields,
'grid': '1.0/1.0',
'year': [2024],
'month': [1],
'day': [1],
'time': ['06:00', '12:00'],
'pressure_level': pressure_levels,
'format': 'netcdf'
},
'pressure-level.nc'
)
pressurelevel = xarray.open_dataset('pressure-level.nc', engine = scipy.__name__).to_dataframe()
pressurelevel = pressurelevel.rename(columns = {col:pressurelevelfields[ind] for ind, col in enumerate(pressurelevel.columns.values.tolist())})

return singlelevel, pressurelevel

# Adding sin and cos of the year progress.
def addYearProgress(secs, data):

progress = du.get_year_progress(secs)
data['year_progress_sin'] = math.sin(2 * pi * progress)
data['year_progress_cos'] = math.cos(2 * pi * progress)

return data

# Adding sin and cos of the day progress.
def addDayProgress(secs, lon:str, data:pd.DataFrame):

lons = data.index.get_level_values(lon).unique()
progress:np.ndarray = du.get_day_progress(secs, np.array(lons))
prxlon = {lon:prog for lon, prog in list(zip(list(lons), progress.tolist()))}
data['day_progress_sin'] = data.index.get_level_values(lon).map(lambda x: math.sin(2 * pi * prxlon[x]))
data['day_progress_cos'] = data.index.get_level_values(lon).map(lambda x: math.cos(2 * pi * prxlon[x]))

return data

# Adding day and year progress.
def integrateProgress(data:pd.DataFrame):

for dt in data.index.get_level_values('time').unique():
seconds_since_epoch = toDatetime(dt).timestamp()
data = addYearProgress(seconds_since_epoch, data)
data = addDayProgress(seconds_since_epoch, 'longitude' if 'longitude' in data.index.names else 'lon', data)

return data

# Adding batch field and renaming some others.
def formatData(data:pd.DataFrame) -> pd.DataFrame:

data = data.rename_axis(index = {'latitude': 'lat', 'longitude': 'lon'})
if 'batch' not in data.index.names:
data['batch'] = 0
data = data.set_index('batch', append = True)

return data

if __name__ == '__main__':

values:Dict[str, xarray.Dataset] = {}

single, pressure = getSingleAndPressureValues()
values['inputs'] = pd.merge(pressure, single, left_index = True, right_index = True, how = 'inner')
values['inputs'] = integrateProgress(values['inputs'])
values['inputs'] = formatData(values['inputs'])

Targets

There are 11 prediction fields:

  • u_component_of_wind,
  • v_component_of_wind,
  • geopotential,
  • specific_humidity,
  • temperature,
  • vertical_velocity,
  • 10m_u_component_of_wind,
  • 10m_v_component_of_wind,
  • 2m_temperature,
  • mean_sea_level_pressure,
  • total_precipitation.
Photo by Ricardo Arce on Unsplash

The targets one passes is essentially an empty xarray for all the prediction fields at:

  • every coordinate,
  • prediction timestamps and
  • pressure level.

The code to do so, is shared below.

# Includes the packages imported and constants assigned.
# The functions created for the inputs also go here.

predictionFields = [
'u_component_of_wind',
'v_component_of_wind',
'geopotential',
'specific_humidity',
'temperature',
'vertical_velocity',
'10m_u_component_of_wind',
'10m_v_component_of_wind',
'2m_temperature',
'mean_sea_level_pressure',
'total_precipitation_6hr'
]

# Creating an array full of nan values.
def nans(*args) -> list:
return np.full((args), np.nan)

# Adding or subtracting time.
def deltaTime(dt, **delta) -> datetime.datetime:
return dt + datetime.timedelta(**delta)

def getTargets(dt, data:pd.DataFrame):

# Creating an array consisting of unique values of each index.
lat, lon, levels, batch = sorted(data.index.get_level_values('lat').unique().tolist()), sorted(data.index.get_level_values('lon').unique().tolist()), sorted(data.index.get_level_values('level').unique().tolist()), data.index.get_level_values('batch').unique().tolist()
time = [deltaTime(dt, hours = days * gap) for days in range(4)]

# Creating an empty dataset using latitude, longitude, the pressure levels and each prediction timestamp.
target = xarray.Dataset({field: (['lat', 'lon', 'level', 'time'], nans(len(lat), len(lon), len(levels), len(time))) for field in predictionFields}, coords = {'lat': lat, 'lon': lon, 'level': levels, 'time': time, 'batch': batch})

return target.to_dataframe()

if __name__ == '__main__':

# The code for creating inputs will be here.

values['targets'] = getTargets(first_prediction, values['inputs'])

Forcings

As was the case with the targets, forcings too contains values for every coordinate and prediction timestamp but not the pressure level. The fields in forcings are:

  • total_incident_solar_radiation,
  • year_progress_sin,
  • year_progress_cos,
  • day_progress_sin,
  • day_progress_cos.

It is important to note that the above values are assigned wrt the prediction timestamp. As was the case when processing the inputs, year and day progress depends only on the timestamp and the solar radiation was fetched from the single-level source. However, since one is making predictions, i.e., getting values for the future, the solar values, in the case of forcings, will not be available in the CDS dataset. For this we simulate the solar radiation values using the pysolar library.

# Includes the packages imported and constants assigned.
# The functions created for the inputs and targets also go here.

# Adding a timezone to datetime.datetime variables.
def addTimezone(dt, tz = pytz.UTC) -> datetime.datetime:
dt = toDatetime(dt)
if dt.tzinfo == None:
return pytz.UTC.localize(dt).astimezone(tz)
else:
return dt.astimezone(tz)

# Getting the solar radiation value wrt longitude, latitude and timestamp.
def getSolarRadiation(longitude, latitude, dt):

altitude_degrees = get_altitude(latitude, longitude, addTimezone(dt))
solar_radiation = get_radiation_direct(dt, altitude_degrees) if altitude_degrees > 0 else 0

return solar_radiation * watts_to_joules

# Calculating the solar radiation values for timestamps to be predicted.
def integrateSolarRadiation(data:pd.DataFrame):

dates = list(data.index.get_level_values('time').unique())
coords = [[lat, lon] for lat in lat_range for lon in lon_range]
values = []

# For each data, getting the solar radiation value at a particular coordinate.
for dt in dates:
values.extend(list(map(lambda coord:{'time': dt, 'lon': coord[1], 'lat': coord[0], 'toa_incident_solar_radiation': getSolarRadiation(coord[1], coord[0], dt)}, coords)))

# Setting indices.
values = pd.DataFrame(values).set_index(keys = ['lat', 'lon', 'time'])

# The forcings dataset will now contain the solar radiation values.
return pd.merge(data, values, left_index = True, right_index = True, how = 'inner')

def getForcings(data:pd.DataFrame):

# Since forcings data does not contain batch as an index, it is dropped.
# So are all the columns, since forcings data only has 5, which will be created.
forcingdf = data.reset_index(level = 'level', drop = True).drop(labels = predictionFields, axis = 1)

# Keeping only the unique indices.
forcingdf = pd.DataFrame(index = forcingdf.index.drop_duplicates(keep = 'first'))

# Adding the sin and cos of day and year progress.
# Functions are included in the creation of inputs data section.
forcingdf = integrateProgress(forcingdf)

# Integrating the solar radiation values.
forcingdf = integrateSolarRadiation(forcingdf)

return forcingdf

if __name__ == '__main__':

# The code for creating inputs and targets will be here.

values['forcings'] = getForcings(values['targets'])

Post-processing the inputs, targets and forcings

Now that the three pillars of Graphcast is created, we enter the home stretch. Like in a NBA final, having won 3 games, we now proceed to the nittiest grittiest part, to get it done.

Like Kobe Bryant once said,

Job’s not over yet.

Photo by Mike Von on Unsplash

When it comes to an xarray, there are two main types of data:

  • Coordinates, the indices: lat, lon, time….. and
  • Data variables, the columns: land_sea_mask, geopotential et cetera.

Every value that a data variable contains, has certain coordinates assigned to it. The coordinates are those on which the value of the data variable depends on. Taking an example out of our own data,

  • land_sea_mask depends solely on the latitude and longitude, which are its coordinates.
  • geopotential’s coordinates are batch, latitude, longitude, time and pressure level.
  • In a stark contrast, but while making sense, the coordinates of geopotential_at_surface are latitude and longitude.

Hence, before we proceed to predicting the weather, we make sure each data variable is assigned to its right coordinates, the code for which is presented below.

# Includes the packages imported and constants assigned.
# The functions created for the inputs, targets and forcings also go here.

# A dictionary created, containing each coordinate a data variable requires.
class AssignCoordinates:

coordinates = {
'2m_temperature': ['batch', 'lon', 'lat', 'time'],
'mean_sea_level_pressure': ['batch', 'lon', 'lat', 'time'],
'10m_v_component_of_wind': ['batch', 'lon', 'lat', 'time'],
'10m_u_component_of_wind': ['batch', 'lon', 'lat', 'time'],
'total_precipitation_6hr': ['batch', 'lon', 'lat', 'time'],
'temperature': ['batch', 'lon', 'lat', 'level', 'time'],
'geopotential': ['batch', 'lon', 'lat', 'level', 'time'],
'u_component_of_wind': ['batch', 'lon', 'lat', 'level', 'time'],
'v_component_of_wind': ['batch', 'lon', 'lat', 'level', 'time'],
'vertical_velocity': ['batch', 'lon', 'lat', 'level', 'time'],
'specific_humidity': ['batch', 'lon', 'lat', 'level', 'time'],
'toa_incident_solar_radiation': ['batch', 'lon', 'lat', 'time'],
'year_progress_cos': ['batch', 'time'],
'year_progress_sin': ['batch', 'time'],
'day_progress_cos': ['batch', 'lon', 'time'],
'day_progress_sin': ['batch', 'lon', 'time'],
'geopotential_at_surface': ['lon', 'lat'],
'land_sea_mask': ['lon', 'lat'],
}

def modifyCoordinates(data:xarray.Dataset):

# Parsing through each data variable and removing unneeded indices.
for var in list(data.data_vars):
varArray:xarray.DataArray = data[var]
nonIndices = list(set(list(varArray.coords)).difference(set(AssignCoordinates.coordinates[var])))
data[var] = varArray.isel(**{coord: 0 for coord in nonIndices})
data = data.drop_vars('batch')

return data

def makeXarray(data:pd.DataFrame) -> xarray.Dataset:

# Converting to xarray.
data = data.to_xarray()
data = modifyCoordinates(data)

return data

if __name__ == '__main__':

# The code for creating inputs, targets and forcings will be here.

values = {value:makeXarray(values[value]) for value in values}

Predictions using Graphcast

Having calculated, processed and assembled the inputs, targets and forcings, it is now time to make predictions.

We now require the model weights and normalization statistics files, which are provided by Deepmind.

The files to be downloaded are:

  • stats/diffs_stddev_by_level.nc,
  • stats/stddev_by_level.nc,
  • stats/mean_by_level.nc and
  • params/GraphCast_small — ERA5 1979–2015 — resolution 1.0 — pressure levels 13 — mesh 2to5 — precipitation input and output.npz.

The relative paths of the aforementioned files wrt the prediction file is depicted below. It is important to maintain the structure so that the required files can be imported and read successfully.

.
├── prediction.py
├── model
├── params
├── GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz
├── stats
├── diffs_stddev_by_level.nc
├── mean_by_level.nc
├── stddev_by_level.nc

With the prediction code being provided by Deepmind, all the above functions culminate with the predictions being made using the snippet below.

# Includes the packages imported and constants assigned.
# The functions created for the inputs, targets and forcings also go here.

with open(r'model/params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz', 'rb') as model:
ckpt = checkpoint.load(model, graphcast.CheckPoint)
params = ckpt.params
state = {}
model_config = ckpt.model_config
task_config = ckpt.task_config

with open(r'model/stats/diffs_stddev_by_level.nc', 'rb') as f:
diffs_stddev_by_level = xarray.load_dataset(f).compute()

with open(r'model/stats/mean_by_level.nc', 'rb') as f:
mean_by_level = xarray.load_dataset(f).compute()

with open(r'model/stats/stddev_by_level.nc', 'rb') as f:
stddev_by_level = xarray.load_dataset(f).compute()

def construct_wrapped_graphcast(model_config:graphcast.ModelConfig, task_config:graphcast.TaskConfig):
predictor = graphcast.GraphCast(model_config, task_config)
predictor = casting.Bfloat16Cast(predictor)
predictor = normalization.InputsAndResiduals(predictor, diffs_stddev_by_level = diffs_stddev_by_level, mean_by_level = mean_by_level, stddev_by_level = stddev_by_level)
predictor = autoregressive.Predictor(predictor, gradient_checkpointing = True)
return predictor

@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
predictor = construct_wrapped_graphcast(model_config, task_config)
return predictor(inputs, targets_template = targets_template, forcings = forcings)

def with_configs(fn):
return functools.partial(fn, model_config = model_config, task_config = task_config)

def with_params(fn):
return functools.partial(fn, params = params, state = state)

def drop_state(fn):
return lambda **kw: fn(**kw)[0]

run_forward_jitted = drop_state(with_params(jax.jit(with_configs(run_forward.apply))))

class Predictor:

@classmethod
def predict(cls, inputs, targets, forcings) -> xarray.Dataset:
predictions = rollout.chunked_prediction(run_forward_jitted, rng = jax.random.PRNGKey(0), inputs = inputs, targets_template = targets, forcings = forcings)
return predictions

if __name__ == '__main__':

# The code for creating inputs, targets, forcings & processing will be here.

predictions = Predictor.predict(values['inputs'], values['targets'], values['forcings'])
predictions.to_dataframe().to_csv('predictions.csv', sep = ',')

Conclusion

Above, I have provided the code for each process that will be undertaken:

  • creating the inputs, targets and forcings,
  • processing the above data to a viable format and then finally
  • bringing them together and making predictions.

While executing, it is important to bring all the processes together for a seamless implementation.

For simplicity, I have uploaded the code along with the docker image and container files, which can be used to create an environment to execute the prediction program.

In the universe of weather prediction, we currently have contributors like Accuweather, IBM, multiple meteomatics models. Graphcast proves to be an interesting and in many cases, a more efficient addition to this collection. However it also has some attributes that are far from optimal. In a rare moment of thought, I came up with the following insights:

  • Graphcast is far more efficient and faster compared to other weather prediction services, fetching predictions for the whole world in a matter of minutes.
  • This makes making hundreds of calls for hundreds of geographies using APIs redundant.
  • However to do the above in minutes, one needs to have a very powerful machine, either a Google TPU v4 or better. That is something that isn’t readily available. Even if one chooses to make use of a VM from AWS or Google or Azure, the costs can rack up.
  • Currently, there are no provisions to use data for a small geography or a subset of coordinates and get predictions for the same. Data for all the coordinates is always required.
  • CDS provides data with a 5 day latency period, which means at ‘x’ date, CDS can provide data only till ‘x-5’ date. This makes future weather prediction a little complicated since one has to cover the latency period before predictions can be made for the future.

It is important to note that Graphcast is a fairly new addition to the weather prediction scene, changes and additions will definitely be made to improve the ease of access and usability. Given the lead they have wrt efficiency and performance, they are sure to capitalize on it.

Best of luck on your journey in data science and thank you for reading :)

--

--

Machine Learning Engineer at Fyllo | Data science enthusiast | I like to write and roll