The world’s leading publication for data science, AI, and ML professionals.

MLX vs MPS vs CUDA: a Benchmark

A first benchmark of Apple's new ML framework MLX

Photo by Javier Allegue Barros on Unsplash
Photo by Javier Allegue Barros on Unsplash

If you’re a Mac user and a deep learning enthusiast, you’ve probably wished at some point that your Mac could handle those heavy models, right? Well, guess what? Apple just released MLX, a framework for running ML models efficiently on Apple Silicon.

The recent introduction of the MPS backend in PyTorch 1.12 was already a bold step, but with the announcement of MLX, it seems that Apple wants to make a significant leap into open source deep learning.

In this article, we’ll put these new approaches through their paces, benchmarking them against the traditional CPU backend on three different Apple Silicon chips, and two CUDA-enabled GPUs. By doing so, we aim to reveal just how much these novel Mac-compatible methods can be used in 2024 for deep learning experiments.

As a GNN-oriented researcher, I’ll focus the benchmark on a Graph Convolutional Network (GCN) model. But since this model mainly consists of linear layers, our findings could be insightful even for those not specifically in the GNN sphere.


Crafting an environment

To build an environment for MLX, we have to specify whether using the i386 or arm architecture. With conda, this can be done using:

CONDA_SUBDIR=osx-arm64 conda create -n mlx python=3.10 numpy pytorch scipy requests -c conda-forge
conda activate mlx

To check if your env is actually using arm, the output of the following command should be arm, not i386:

python -c "import platform; print(platform.processor())"

Now simply install MLX using pip, and you’re all set to start exploring:

pip install mlx

GCN implementation

The GCN model, a type of Graph Neural Network (GNN), works with an adjacency matrix (representing the graph structure) and node features. It calculates node embeddings by gathering info from neighboring nodes. Specifically, each node gets the average of its neighbors’ features. This averaging is done by multiplying the node features with the normalized adjacency matrix, adjusted by node degree. To learn this process, the features are first projected into an embedding space via a linear layer.

In our version, we normalize the adjacency matrix just like in the original paper: during the preprocessing step. While this article won’t go into the preprocessing code, you can find with the full code in this GitHub repo:

GitHub – TristanBilot/mlx-GCN

We’ll now walk through implementing a GCN layer and a GCN model using MLX:

import mlx.nn as nn

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_features, out_features, bias)

    def __call__(self, x, adj):
        x = self.linear(x)
        return adj @ x

class GCN(nn.Module):
    def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
        super(GCN, self).__init__()

        layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
        self.gcn_layers = [
            GCNLayer(in_dim, out_dim, bias)
            for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
        ]
        self.dropout = nn.Dropout(p=dropout)

    def __call__(self, x, adj):
        for layer in self.gcn_layers[:-1]:
            x = nn.relu(layer(x, adj))
            x = self.dropout(x)

        x = self.gcn_layers[-1](x, adj)
        return x

At a glance, MLX code closely resembles PyTorch code, with a notable difference: here we instantiate self.gcn_layers as a list of modules, whereas in PyTorch, you would typically use nn.Sequential for such a purpose.

The code starts to become quite different within the training loop:

gcn = GCN(
    x_dim=x.shape[-1],
    h_dim=args.hidden_dim,
    out_dim=args.nb_classes,
    nb_layers=args.nb_layers,
    dropout=args.dropout,
    bias=args.bias,
)
mx.eval(gcn.parameters())

optimizer = optim.Adam(learning_rate=args.lr)
loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)

# Training loop
for epoch in range(args.epochs):

    # Loss
    (loss, y_hat), grads = loss_and_grad_fn(
        gcn, x, adj, y, train_mask, args.weight_decay
    )
    optimizer.update(gcn, grads)
    mx.eval(gcn.parameters(), optimizer.state)

    # Validation
    val_loss = loss_fn(y_hat[val_mask], y[val_mask])
    val_acc = eval_fn(y_hat[val_mask], y[val_mask])

Immediately apparent is the use of mx.eval(). In MLX, computations are lazy, meaning eval() is often used to actually compute new model parameters post-update. Another key function, nn.value_and_grad(), generates a function that calculates loss with respect to parameters. The first argument is the model holding the current parameters, and the second is a callable function for the forward pass and loss computation. The function it returns takes the same arguments as the forward function (in this case, forward_fn). We can define this function as follows:

def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
    y_hat = gcn(x, adj)
    loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())
    return loss, y_hat

It simply consists in computing a forward pass and calculating the loss. loss_fn() and eval_fn() are defined as follows:

def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
    l = mx.mean(nn.losses.cross_entropy(y_hat, y))

    if weight_decay != 0.0:
        assert parameters != None, "Model parameters missing for L2 reg."

        l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
        return l + weight_decay * l2_reg

    return l

def eval_fn(x, y):
    return mx.mean(mx.argmax(x, axis=1) == y)

You might observe that the loss function appears quite extensive, but it essentially calculates the cross-entropy between predictions and labels, and includes L2 regularization. Since L2 regularization isn’t a built-in feature yet, I’ve implemented it manually.

One cool thing here is the elimination of the need to explicitly assign objects to a specific device, as we often do in PyTorch with .cuda() and .to(device). Thanks to the unified memory architecture of the Apple silicon chip, all variables coexist in the same space, eradicating slow data transfers between CPU and GPU and eliminating those pesky runtime errors related to device mismatches.

Benchmark

In our benchmark, we’ll be comparing MLX alongside MPS, CPU, and GPU devices, using a PyTorch implementation. Our testbed is a 2-layer GCN model, applied to the Cora dataset, which includes 2708 nodes and 5429 edges.

For MLX, MPS, and CPU tests, we benchmark the M1 Pro, M2 Ultra and M3 Max ships. Meanwhile, the GPU benchmarks are carried out on two NVIDIA Tesla models: the V100 PCIe and the V100 NVLINK.

Image by author: Benchmark of GCN running time on MLX and other backends (in ms)
Image by author: Benchmark of GCN running time on MLX and other backends (in ms)

MPS: more than 2x faster than CPU on M1 Pro, not bad. On the two other chips, we notice 30–50% improvement compared to CPU.

MLX: 2.34x faster than MPS on M1 Pro. On M2 Ultra we get a 24% improvement compared to MPS. No real improvement between MPS and MLX on M3 Pro though.

Cuda V100 PCIe & NVLINK: only 23% and 34% faster than M3 Max with MLX, this is some serious stuff!

MLX stands out as a game changer when compared to CPU and MPS, and it even comes close to the performance of a TESLA V100. This initial benchmark highlights MLX’s significant potential to emerge as a popular Mac-based deep learning framework. It’s also worth noting that MLX has only recently been released to the public, and we can expect further enhancements from the open-source community in the coming years. We can also expect even more powerful Apple Silicon chips in the near future, taking performance of MLX to a whole new level.

To recap

Cool things:

  • We can now run Deep Learning models locally by leveraging the full power of Apple Silicon.
  • The syntax is pretty much similar as torch, with some inspirations from Jax.
  • No more device, everything lives in unified memory!

What’s missing:

  • The framework is very young, many features are missing yet. Especially for Graph ML, all sparse operations and scattering APIs are not available at the moment, making it complicate to build Message Passing GNNs on top of MLX now.
  • As a new project, it’s worth noting that both the documentation and community discussions for MLX are somewhat limited at present.

In conclusion, MLX made a surprisingly impactful entrance upon its release and demonstrates serious potential. I believe this framework could become a staple for daily research experiments. We’re also eager to see additional experiments, as the GCN tests primarily showcase MLX’s performance on basic linear layers. More comprehensive testing could reveal its full capabilities.

Thx for reading!


Related Articles