The world’s leading publication for data science, AI, and ML professionals.

Run LLM inference using Apple Hardware

Unlock Apple GPU Power for LLM Inference with MLX



We are in a position to run inference and fine-tune our own LLMs using Apple’s native hardware. This article will cover the setup for creating your own experiments and running inference. In the future I will be making an article on how to fine-tune these LLMs (again using Apple hardware).

If you haven’t checked out my previous articles, I suggest doing so as I make a case for why you should consider hosting (and fine-tuning) your own open-source LLM. I also cover strategies as to how you can optimise the process to reduce inference and training times. I will brush over topics such as quantisation as these are covered in depth in the aforementioned articles.

I will be using the mlx framework in combination with Meta’s Llama2 model. In-depth information on how to access the models can be found in my previous article. However, I’ll briefly explain how to do so in this article as well.

Let’s get started.


  1. A machine with an M-series chip (M1/M2/M3)
  2. OS >= 13.0
  3. Python between 3.8–3.11

For my personal hardware set-up, I am using a MacBook Pro with an M1 Max chip – 64GB RAM // 10-Core CPU // 32-Core GPU.

My OS is Sonoma 14.3 // Python is 3.11.6

As long as you meet the 3 requirements listed above, you should be able to follow along. If you have around 16GB RAM, I suggest sticking with the 7B models. Inference times etc. will of course vary depending on your hardware specs.


Feel free to follow along and set-up a directory where you will store all files relating to this article. It will make the process a lot easier if they’re all in one place. I’m calling mine mlx.

First we need to ensure you are running a native arm version of Python. Otherwise we will be unable to install mlx. You can do so by running the following command in your terminal:

python -c "import platform; print(platform.processor())"

The output of the above command should be arm:

If it is, skip to the install requirements section. If not, read below:

Mine was initially returning i386. This essentially means you are using a non-native Python version created for Intel x86 architecture which was used in older macs. Depending on what you tools you use to handle your Python environments, you will want to set-up a new environment with a native arm version of Python. I’m using pyenv to handle my Python environments so I simply type the command:

pyenv virtualenv 3.11.6 mlx-env

This creates a new python environment named mlx-env with my chosen version of Python. Then activate this environment using:

pyenv activate mlx-env

If you use anaconda/miniconda for managing your environments, here are instructions as to how you can update it.

Install requirements:

Below is a list of all packages you will require to run inference on the Llama-2 model:


Add these to a requirements.txt file within your mlx directory and run:

pip install -r requirements.txt

Model Download:

Next, we need to download the required Llama-2 model. Here you can choose whichever you wish to use. Again, make a decision based upon your hardware specifications. The larger the model, the more resources you will require. I have been working with the 13B parameter version. We will quantise the model regardless to decrease inference times. I’ll quickly go over the download process. If anything is unclear, please see my previous article:

  1. Request access
  2. Download the script and move it to your mlx directory
  3. Access email sent to you by Meta and copy link in point 2 of the "How to Download Models" section
  4. Run the script by entering sh ./ from your terminal
  5. Paste the link you copied from step 3 into the terminal prompt. It should look something like this:
  1. The prompt will ask you which models you wish to download. Select which you want to download. For example, if you want to download the 7B and 13B models, simply type: 7B,13B
  2. Your models with the download into your mlx folder.


Now we move to the cool bit. MLX provide great documentation as to how you can get up and running with different models. As stated, I will concentrate just on running inference. Fine-tuning will come in a later article.

Download the and files provided by MLX into your mlx folder. At this stage, I highly recommend you to check through both files so you can get an insight into what is happening under the hood. I’ll include them below with a brief summary.

I wish to state here that full attribution for the code goes to the MLX researchers. – this script converts the original Pytorch-based Llama-2 model into MxNet format. MxNet tensors share many similarities to those found in their Pytorch and Tensorflow counterparts with some slight nuances. Check out this article if you’re interested in reading more about MxNet. The script also handles the input of command line arguments where we have the option of adding quantisation:

import argparse
import collections
import copy
import glob
import json
import shutil
from pathlib import Path

import mlx.core as mx
import mlx.nn as nn
import torch
from llama import Llama, ModelArgs, sanitize_config
from mlx.utils import tree_flatten, tree_map, tree_unflatten

def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
    # bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
    a = if dtype == "bfloat16" else, dtype))
    return mx.array(a.numpy(), getattr(mx, dtype))

def llama(model_path, *, dtype: str):
    SHARD_FIRST = ["wv", "wq", "wk", "w1", "w3", "output"]
    SHARD_SECOND = ["tok_embeddings", "wo", "w2"]

    def shard_key(k):
        keys = k.split(".")
        if len(keys) < 2:
            return None
        return keys[-2]

    def unshard(k, v):
        wn = shard_key(k)
        if wn not in SHARD_WEIGHTS:
            return v
        elif wn in SHARD_FIRST:
            axis = 0
        elif wn in SHARD_SECOND:
            axis = 1
            raise ValueError("Invalid weight name")
        return mx.concatenate(v, axis=axis)

    torch_files = glob.glob(str(model_path / "consolidated.*.pth"))
    weights = collections.defaultdict(list)
    for wf in torch_files:
        state = torch.load(wf, map_location=torch.device("cpu"))
        for k, v in state.items():
            v = torch_to_mx(v, dtype=dtype)
            state[k] = None  # free memory
            if shard_key(k) in SHARD_WEIGHTS:
                weights[k] = v

    for k, v in weights.items():
        weights[k] = unshard(k, v)
    with open(model_path / "params.json", "r") as f:
        params = json.loads(
    return weights, params

def tiny_llama(model_path, *, dtype: str):
        import transformers
    except ImportError:
        print("The transformers package must be installed for this model conversion:")
        print("pip install transformers")

    model = transformers.AutoModelForCausalLM.from_pretrained(
    config = transformers.AutoConfig.from_pretrained(model_path)

    # things to change
    # 1. there's no "model." in the weight names
    model = {k.replace("model.", ""): v for k, v in model.items()}

    # 2. mlp is called feed_forward
    model = {k.replace("mlp", "feed_forward"): v for k, v in model.items()}

    # 3. up_proj, down_proj, gate_proj
    model = {k.replace("down_proj", "w2"): v for k, v in model.items()}
    model = {k.replace("up_proj", "w3"): v for k, v in model.items()}
    model = {k.replace("gate_proj", "w1"): v for k, v in model.items()}

    # 4. layernorms
    model = {
        k.replace("input_layernorm", "attention_norm"): v for k, v in model.items()
    model = {
        k.replace("post_attention_layernorm", "ffn_norm"): v for k, v in model.items()

    # 5. lm head
    model = {k.replace("lm_head", "output"): v for k, v in model.items()}

    # 6. token emb
    model = {k.replace("embed_tokens", "tok_embeddings"): v for k, v in model.items()}

    # 7. attention
    model = {k.replace("self_attn", "attention"): v for k, v in model.items()}
    model = {k.replace("q_proj", "wq"): v for k, v in model.items()}
    model = {k.replace("k_proj", "wk"): v for k, v in model.items()}
    model = {k.replace("v_proj", "wv"): v for k, v in model.items()}
    model = {k.replace("o_proj", "wo"): v for k, v in model.items()}

    params = {}
    params["dim"] = config.hidden_size
    params["hidden_dim"] = config.intermediate_size
    params["n_heads"] = config.num_attention_heads
    if hasattr(config, "num_key_value_heads"):
        params["n_kv_heads"] = config.num_key_value_heads
    params["n_layers"] = config.num_hidden_layers
    params["vocab_size"] = config.vocab_size
    params["norm_eps"] = config.rms_norm_eps
    params["rope_traditional"] = False
    weights = {k: torch_to_mx(v, dtype=dtype) for k, v in model.items()}

    return weights, params

def quantize(weights, config, args):
    quantized_config = copy.deepcopy(config)

    # Load the model:
    config = sanitize_config(config, weights)
    model = Llama(ModelArgs(**config))
    weights = tree_map(mx.array, weights)

    # Quantize the model:
    nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)

    # Update the config:
    quantized_config["quantization"] = {
        "group_size": args.q_group_size,
        "bits": args.q_bits,
    quantized_weights = dict(tree_flatten(model.parameters()))

    return quantized_weights, quantized_config

def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
    max_file_size_bytes = max_file_size_gibibyte << 30
    shards = []
    shard, shard_size = {}, 0
    for k, v in weights.items():
        if shard_size + v.nbytes > max_file_size_bytes:
            shard, shard_size = {}, 0
        shard[k] = v
        shard_size += v.nbytes
    return shards

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
        help="Path to the PyTorch model.",
        help="Path to save the MLX model.",
            "Name of the model to convert. Use 'llama' for models in the "
            "Llama family distributed by Meta including Llama 1, Llama 2, "
            "Code Llama, and Llama chat."
        choices=["tiny_llama", "llama"],
        help="Generate a quantized model.",
        help="Group size for quantization.",
        help="Bits per weight for quantization.",
        help="dtype for loading the torch model and input for quantization or saving the converted model. "
        "The original weights are stored in bfloat16.",

    args = parser.parse_args()

    torch_path = Path(args.torch_path)
    mlx_path = Path(args.mlx_path)
    mlx_path.mkdir(parents=True, exist_ok=True)

    print("[INFO] Loading")
    weights, params = globals()[args.model_name](torch_path, dtype=args.dtype)
    params["model_type"] = "llama"
    if args.quantize:
        print("[INFO] Quantizing")
        weights, params = quantize(weights, params, args)

    print("[INFO] Saving")
        str(torch_path / "tokenizer.model"),
        str(mlx_path / "tokenizer.model"),
    shards = make_shards(weights)
    if len(shards) == 1:
        mx.savez(str(mlx_path / f"weights.npz"), **shards[0])
        for i, shard in enumerate(shards):
            mx.savez(str(mlx_path / f"weights.{i:02d}.npz"), **shard)
    with open(mlx_path / "config.json", "w") as fid:
        json.dump(params, fid, indent=4) – this script handles the heavy lifting of running inference on our downloaded Llama2 model as well as handling tokenisation of our input. Whilst the model does not actually edit our model, it contains custom classes that RMSNorm, Attention, FeedForward and TransformerBlock that implement the transformer architecture which allow for the responses to our prompts to be coherent:

import argparse
import glob
import json
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple

import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from sentencepiece import SentencePieceProcessor

class ModelArgs:
    dim: int
    n_layers: int
    head_dim: int
    hidden_dim: int
    n_heads: int
    n_kv_heads: int
    norm_eps: float
    vocab_size: int
    rope_theta: float
    rope_traditional: bool = True

class RMSNorm(nn.Module):
    def __init__(self, dims: int, eps: float = 1e-5):
        self.weight = mx.ones((dims,))
        self.eps = eps

    def _norm(self, x):
        return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)

    def __call__(self, x):
        output = self._norm(x.astype(mx.float32)).astype(x.dtype)
        return self.weight * output

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        self.args = args

        self.n_heads: int = args.n_heads
        self.n_kv_heads: int = args.n_kv_heads

        self.repeats = self.n_heads // self.n_kv_heads

        self.scale = self.args.head_dim**-0.5

        self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
        self.rope = nn.RoPE(
            args.head_dim, traditional=args.rope_traditional, base=args.rope_theta

    def __call__(
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Tuple[mx.array, mx.array]] = None,
    ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
        B, L, D = x.shape

        queries, keys, values = self.wq(x), self.wk(x), self.wv(x)

        # Prepare the queries, keys and values for the attention computation
        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

        def repeat(a):
            a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
            return a.reshape([B, self.n_heads, L, -1])

        keys, values = map(repeat, (keys, values))

        if cache is not None:
            key_cache, value_cache = cache
            queries = self.rope(queries, offset=key_cache.shape[2])
            keys = self.rope(keys, offset=key_cache.shape[2])
            keys = mx.concatenate([key_cache, keys], axis=2)
            values = mx.concatenate([value_cache, values], axis=2)
            queries = self.rope(queries)
            keys = self.rope(keys)

        scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
        if mask is not None:
            scores += mask
        scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
        output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.wo(output), (keys, values)

class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs):

        self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
        self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
        self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)

    def __call__(self, x) -> mx.array:
        return self.w2(nn.silu(self.w1(x)) * self.w3(x))

class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.attention = Attention(args)
        self.feed_forward = FeedForward(args=args)
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.args = args

    def __call__(
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Tuple[mx.array, mx.array]] = None,
    ) -> mx.array:
        r, cache = self.attention(self.attention_norm(x), mask, cache)
        h = x + r
        r = self.feed_forward(self.ffn_norm(h))
        out = h + r
        return out, cache

class Llama(nn.Module):
    def __init__(self, args: ModelArgs):
        self.args = args
        self.vocab_size = args.vocab_size
        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
        self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

    def __call__(self, x):
        mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
        mask = mask.astype(self.tok_embeddings.weight.dtype)

        x = self.tok_embeddings(x)
        for l in self.layers:
            x, _ = l(x, mask)
        x = self.norm(x)
        return self.output(x)

    def generate(self, x, temp=1.0):
        def sample(logits):
            if temp == 0:
                return mx.argmax(logits, axis=-1)
                return mx.random.categorical(logits * (1 / temp))

        cache = []

        # Make an additive causal mask. We will need that to process the prompt.
        mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
        mask = mask.astype(self.tok_embeddings.weight.dtype)

        # First we process the prompt x the same was as in __call__ but
        # save the caches in cache
        x = self.tok_embeddings(x)
        for l in self.layers:
            x, c = l(x, mask=mask)
            # We store the per layer cache in a simple python list
        x = self.norm(x)
        # We only care about the last logits that generate the next token
        y = self.output(x[:, -1])
        y = sample(y)

        # y now has size [1]
        # Since MLX is lazily evaluated nothing is computed yet.
        # Calling y.item() would force the computation to happen at
        # this point but we can also choose not to do that and let the
        # user choose when to start the computation.
        yield y

        # Now we parsed the prompt and generated the first token we
        # need to feed it back into the model and loop to generate the
        # rest.
        while True:
            # Unsqueezing the last dimension to add a sequence length
            # dimension of 1
            x = y[:, None]

            x = self.tok_embeddings(x)
            for i in range(len(cache)):
                # We are overwriting the arrays in the cache list. When
                # the computation will happen, MLX will be discarding the
                # old cache the moment it is not needed anymore.
                x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
            x = self.norm(x)
            y = sample(self.output(x[:, -1]))

            yield y

def tic():
    return time.time()

def toc(msg, start):
    end = time.time()
    return f"[INFO] {msg}: {end - start:.3f} s"

def generate(args):
    input("Press enter to start generation")
    x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
    skip = 0
    prompt_processing = None
    tokens = []
    start = tic()
    for token in model.generate(x, args.temp):

        if len(tokens) == 1:
            # Actually perform the computation to measure the prompt processing time
            prompt_processing = toc("Prompt processing", start)

        if len(tokens) >= args.max_tokens:

        elif (len(tokens) % args.write_every) == 0:
            # It is perfectly ok to eval things we have already eval-ed.
            s = tokenizer.decode([t.item() for t in tokens])
            print(s[skip:], end="", flush=True)
            skip = len(s)

    full_gen = toc("Full generation", start)
    s = tokenizer.decode([t.item() for t in tokens])
    print(s[skip:], flush=True)

def few_shot_generate(args):
    def possible_end(s):
        word = "[Instruction]"
        for i in range(len(word) - 1, 0, -1):
            if s[-i:] == word[:i]:
                return 0
        if s[-len(word) :] == word:
            return 1
        return -1

    def generate(question):
        x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(question)])
        skip = 0
        prompt_processing = None
        tokens = []
        start = tic()
        for token in model.generate(x, args.temp):

            if len(tokens) == 1:
                # Actually perform the computation to measure the prompt processing time
                prompt_processing = toc("Prompt processing", start)

            if len(tokens) >= args.max_tokens:

            token_list = [t.item() for t in tokens]
            s = tokenizer.decode(token_list)

            end = possible_end(s)
            if end == 0:
            if end == 1:
                skip = len(s)

            print(s[skip:], end="", flush=True)
            skip = len(s)
            if token_list[-1] == tokenizer.eos_id():

        full_gen = toc("Full generation", start)
        s = tokenizer.decode([t.item() for t in tokens])
        print(s[skip:], end="", flush=True)

    print("[INFO] Loading few-shot examples from: {}".format(args.few_shot))
    prompt = open(args.few_shot).read().strip()
    while True:
        question = input("Ask a question: ")
        generate(prompt.replace("{}", question))

def sanitize_config(config, weights):
    config.pop("model_type", None)
    n_heads = config["n_heads"]
    if "n_kv_heads" not in config:
        config["n_kv_heads"] = n_heads
    if "head_dim" not in config:
        config["head_dim"] = config["dim"] // n_heads
    if "hidden_dim" not in config:
        config["hidden_dim"] = weights["layers.0.feed_forward.w1.weight"].shape[0]
    if config.get("vocab_size", -1) < 0:
        config["vocab_size"] = weights["output.weight"].shape[-1]
    if "rope_theta" not in config:
        config["rope_theta"] = 10000
    unused = ["multiple_of", "ffn_dim_multiplier"]
    for k in unused:
        config.pop(k, None)
    return config

def load_model(model_path):
    model_path = Path(model_path)

    unsharded_weights_path = Path(model_path / "weights.npz")
    if unsharded_weights_path.is_file():
        print("[INFO] Loading model from {}.".format(unsharded_weights_path))
        weights = mx.load(str(unsharded_weights_path))
        sharded_weights_glob = str(model_path / "weights.*.npz")
        weight_files = glob.glob(sharded_weights_glob)
        print("[INFO] Loading model from {}.".format(sharded_weights_glob))

        if len(weight_files) == 0:
            raise FileNotFoundError("No weights found in {}".format(model_path))

        weights = {}
        for wf in weight_files:

    with open(model_path / "config.json", "r") as f:
        config = sanitize_config(json.loads(, weights)
        quantization = config.pop("quantization", None)
    model = Llama(ModelArgs(**config))
    if quantization is not None:
        nn.QuantizedLinear.quantize_module(model, **quantization)
    tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))
    return model, tokenizer

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Llama inference script")
        help="Path to the model weights and tokenizer",
        help="The message to be processed by the model. Ignored when --few-shot is provided.",
        default="In the beginning the Universe was created.",
        help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
        "--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
        "--write-every", type=int, default=1, help="After how many tokens to detokenize"
        "--temp", type=float, default=0.0, help="The sampling temperature"
    parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")

    args = parser.parse_args()


    model, tokenizer = load_model(args.model_path)
    if args.few_shot:

You then need to add the respective tokenizer.model file to the model’s folder. For example, if you downloaded the llama-2–7B model from meta, you should have a llama-2–7B directory within your mlx directory. You then need to place the llama tokeniser into that model directory. To do so, visit HuggingFace and download the tokenizer.model file. Paste it into the model directory. The same tokeniser can be used regardless of model size – i.e. you can use the file from the above link even if you downloaded the 13B version of the model.


Let’s run some inference.

Open your terminal at the mlx folder and enter the following command:

python --torch-path ./llama-2-model-folder -q

Replace llama-2-model-folder with the name of your downloaded model folder eg llama-2–7B. The -q parameter applies 4-bit quantisation to speed-up inference. Again make sure you’ve pasted the tokenizer.model file from the previous section into the respective model folder.

The output should look like this:

Your model tensors have now been converted to MxNet and the model has been quantised.

Now enter:

python --prompt "A great hobby to have is" -m 30

Here we execute the script, input our prompt and set a maxtoken value with the -m_ parameter. Have a play with these inputs. Further details on the parameters can be found when entering:

python --help

Below is the output of our prompt after pressing enter to start generation:

Notice the output is truncated due to our max_token limit.

Let’s benchmark this against the un-quantised equivalent:

As expected (and as we learned from my previous article!), inference takes longer on the un-quantised version.

If you run into any errors when swapping between quantised/unquantized models or models of different sizes, simply deleted the mlxmodel folder that is created when running the script. Then re-run the convert.py_ script with your chosen parameters.

GPU usage:

Finally, as evidence that we are actually using the Apple hardware – let’s have a look at the GPU read-out when running inference. Press cmd + spacebar and search for "Activity Monitor".

Then, along the top of your screen you will see various menu options. Select Window > GPU History. Then start running some inference. You will see it spiking whenever you run inference on your LLM:

GPU Usage Spike
GPU Usage Spike


So here are my closing thoughts: If you are heavily invested in developing, training and fine-tuning Llm solutions, I still believe you’ll be better off using NVIDIA hardware – mainly due to software/library support.

Nevertheless, the development work from the Apple’s MLX team is very welcome and really makes LLMs much more approachable to hobbyists and professionals of all abilities and backgrounds. I’m excited to see how the project develops.

As stated previously, I will be creating a further article on LLM fine-tuning using the MLX library. It will build upon what I’ve covered in this article so make sure to follow me and keep an eye out for that.

As always, let me know if you have any questions.

All images belong to the author unless otherwise stated.

Related Articles