Skip to content

Instantly share code, notes, and snippets.

@yeahdongcn
Last active March 12, 2026 06:23
Show Gist options
  • Select an option

  • Save yeahdongcn/161f0718d55c7022791261e6d6a0b57d to your computer and use it in GitHub Desktop.

Select an option

Save yeahdongcn/161f0718d55c7022791261e6d6a0b57d to your computer and use it in GitHub Desktop.
Accelerating SGLang Inference on macOS: 5× Faster with Native MLX

SGLang already runs on macOS via PyTorch's MPS (Metal Performance Shaders) backend — you can launch a server, send requests, and get responses. But performance on Apple Silicon has been underwhelming. In this post, we describe how we integrated a native MLX execution path into SGLang that delivers up to 5.3× higher throughput while using significantly less memory.

The Problem: PyTorch MPS Overhead

When SGLang runs on macOS with PyTorch MPS, every operation — matrix multiplications, attention, normalization — goes through PyTorch's MPS backend, which translates PyTorch ops into Metal Performance Shaders. This translation layer adds substantial overhead:

  1. Op dispatch overhead: Each PyTorch operation is individually dispatched to MPS, missing optimization opportunities that come from fusing operations together.
  2. Memory duplication: PyTorch loads model weights into MPS memory and allocates a large KV cache, leaving less room for actual inference workloads.
  3. No fused kernels: Operations like RMSNorm + RoPE + Attention run as separate Metal dispatches instead of fused GPU kernels.

The Solution: Native MLX Execution

MLX is Apple's machine learning framework designed specifically for Apple Silicon. Unlike PyTorch MPS, MLX compiles operations directly to optimized Metal shaders through its mx.fast API:

Operation PyTorch MPS MLX
RMS Normalization Generic MPS kernel mx.fast.rms_norm (fused Metal shader)
Rotary Embeddings Multiple MPS dispatches mx.fast.rope (fused Metal shader)
Attention Separate Q·K, softmax, ·V mx.fast.scaled_dot_product_attention (fused)
Layer Normalization Generic MPS kernel mx.fast.layer_norm (fused Metal shader)

We integrated MLX as an alternative execution backend within SGLang's existing architecture. The key insight: keep SGLang's scheduler and serving infrastructure intact, but route the actual model computation through MLX instead of PyTorch MPS.

Architecture

+---------------------------------------------+
|             SGLang Scheduler                |
|   (request management, batching, routing)   |
+---------------------------------------------+
|            MlxTpModelWorker                 |
|  +-------------------+ +------------------+ |
|  | MlxModelRunnerStub| |  MlxModelRunner  | |
|  |  (zero-memory     | |  (actual MLX     | |
|  |   bookkeeping)    | |   inference)     | |
|  +-------------------+ +------------------+ |
+---------------------------------------------+
|  mlx-lm models -> mlx.nn -> mx.fast kernels |
|              | Metal GPU |                  |
+---------------------------------------------+

Three Key Components

1. MlxModelRunner — The actual inference engine. Loads the model via mlx-lm, manages per-request KV caches, and runs forward passes entirely in MLX. Only the final logits are bridged to PyTorch for compatibility with SGLang's sampling pipeline.

2. MlxModelRunnerStub — A lightweight ModelRunner subclass that satisfies SGLang's infrastructure requirements without consuming GPU memory. It provides:

  • A _DummyKVCache that reports zero memory usage
  • CPU-side bookkeeping pools (req_to_token_pool, token_to_kv_pool_allocator) for the scheduler
  • No PyTorch model weights loaded at all

3. MlxTpModelWorker — Routes forward_batch_generation calls to the MLX runner instead of the standard PyTorch path. Handles both prefill and decode phases, including batched prefill for same-length sequences.

Zero-Memory Stub Design

A critical design decision was the zero-memory stub. In the naive approach, SGLang loads PyTorch model weights (~1.2 GB for Qwen3-0.6B) and allocates a PyTorch KV cache (~2.5 GB), even though MLX manages its own weights and cache. This "double-load" wastes ~3.7 GB of unified memory.

Our stub eliminates this entirely:

Resource Before (double-load) After (stub)
PyTorch model weights ~1.2 GB 0 bytes
PyTorch KV cache ~2.5 GB 0 bytes
Scheduler bookkeeping CPU-side pools CPU-side pools
Available KV cache tokens 17,534 40,960

Benchmark Results

All benchmarks run on Apple M1 (16 GB unified memory) with Qwen3-0.6B, input length 60, output length 10.

Batch Size 1

Metric PyTorch MPS MLX Native Speedup
Prefill throughput 229.97 tok/s 902.99 tok/s 3.9×
Decode throughput (median) 7.87 tok/s 31.25 tok/s 4.0×
Total throughput 44.19 tok/s 201.89 tok/s 4.6×
Prefill latency 260.9 ms 66.5 ms 3.9×
Decode latency (median) 127.1 ms 32.0 ms 4.0×

Batch Size 4

Metric PyTorch MPS MLX Native Speedup
Prefill throughput 385.76 tok/s 1,019.48 tok/s 2.6×
Decode throughput (median) 18.22 tok/s 46.13 tok/s 2.5×
Total throughput 109.26 tok/s 275.00 tok/s 2.5×

Memory Efficiency

Metric PyTorch MPS MLX Native
PyTorch weight memory ~1.2 GB 0
PyTorch KV cache memory ~2.5 GB 0
Max KV cache tokens (scheduler) 23,290 40,960

How to Use It

Enable the MLX backend with a single environment variable:

# Benchmark
SGLANG_USE_MLX=1 python -m sglang.bench_one_batch \
  --model-path Qwen/Qwen3-0.6B \
  --trust-remote-code --disable-radix-cache --disable-cuda-graph \
  --tp-size 1 --batch-size 1 --input-len 60 --output-len 10

# Launch server (OpenAI-compatible API)
SGLANG_USE_MLX=1 python -m sglang.launch_server \
  --model-path Qwen/Qwen3-0.6B \
  --host 0.0.0.0 --port 30000 \
  --trust-remote-code --disable-radix-cache --disable-cuda-graph --tp-size 1

The server exposes the standard OpenAI-compatible API, so existing clients work without modification:

curl http://localhost:30000/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{"model": "Qwen/Qwen3-0.6B",
       "messages": [{"role": "user", "content": "What is the capital of France?"}],
       "temperature": 0, "max_tokens": 50}'

Current Limitations

  • Greedy sampling only: The MLX path currently uses argmax for token selection. Full sampling (top-k, top-p, temperature, penalties) is not yet integrated.
  • No KV cache reuse across requests: Each request maintains its own independent KV cache. SGLang's radix cache and prefix sharing are not yet connected to the MLX path.
  • No continuous batching with cache reuse: While we support batched prefill and decode, the sophisticated cache management that makes SGLang efficient on CUDA is not yet ported.
  • Single model support: Speculative decoding and draft models are not yet supported through the MLX path.

Future Plans

Short Term

  • Full sampling pipeline: Integrate SGLang's sampling parameters (temperature, top-k, top-p, repetition penalty) with the MLX forward pass.
  • Broader model support: Test and validate more model architectures beyond Qwen3 (Llama, Mistral, Gemma, etc.) through mlx-lm's extensive model library.

Medium Term

  • Paged KV cache for MLX: Implement a custom paged KV cache that integrates with SGLang's req_to_token_pool / token_to_kv_pool_allocator, enabling true cross-request cache sharing and radix cache on Apple Silicon.
  • Continuous batching with cache reuse: Connect SGLang's scheduler-level batching decisions to MLX's cache management for production-grade serving performance.

Long Term

  • Custom Metal kernels: Leverage mx.fast.metal_kernel for SGLang-specific fused operations that go beyond what mlx-lm provides out of the box.
  • Speculative decoding on Apple Silicon: Route both draft and target models through MLX for low-latency speculative decoding.
  • Tensor parallelism: Explore multi-GPU inference on Mac Studio / Mac Pro configurations using MLX's distributed primitives.

Acknowledgments

This work builds on the foundation laid by SGLang's existing macOS support for both LLM and diffusion model inference on Apple Silicon. The MLX integration extends this by providing a native execution path that eliminates the PyTorch MPS overhead entirely.

Special thanks to the MLX team at Apple for building a framework that makes this level of performance possible on Apple Silicon, and to the mlx-lm project for providing optimized model implementations that leverage mx.fast fused kernels.


To try it out, grab the latest SGLang from the main branch and set SGLANG_USE_MLX=1. Feedback and contributions welcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment