Skip to content

[Experiment] ROCm backend#2300

Open
NripeshN wants to merge 313 commits into
ml-explore:mainfrom
NripeshN:rocm-support
Open

[Experiment] ROCm backend#2300
NripeshN wants to merge 313 commits into
ml-explore:mainfrom
NripeshN:rocm-support

Conversation

@NripeshN

@NripeshN NripeshN commented Jun 16, 2025

Copy link
Copy Markdown
Contributor

Experiment with ROCm backend.

install MLX with ROCm backend using:

mkdir build && cd build
cmake -DMLX_BUILD_ROCM=ON \
      -DCMAKE_PREFIX_PATH=/opt/rocm \
      -DCMAKE_HIP_ARCHITECTURES="gfx90a;gfx1100" \
      ..
make -j$(nproc)

closes #2556

Inspired by @zcbenz

@NripeshN NripeshN changed the title [Experiment] ROCm backend initial push [Experiment] ROCm backend Jun 16, 2025
@lin72h

lin72h commented Jun 17, 2025

Copy link
Copy Markdown

What an unexpected and amazing surprise! I'm absolutely thrilled.

@NripeshN

Copy link
Copy Markdown
Contributor Author

@awni
What do you think of this PR? Does this have the potential to be merged into main? I can turn this PR from experimental to WIP if so.

@angeloskath

Copy link
Copy Markdown
Member

I think this is good to stay as an experiment branch for some time while we work on core and CUDA. I don't think we have the bandwidth to merge this for a few months at least. Sorry if this is disappointing @NripeshN I don't mean to discourage you working on it.

@akshat2602

Copy link
Copy Markdown

I would love to see the ROCm backend get more traction. The new AI series of processors by AMD have a similar advantage to Apple Silicon with unified memory and getting MLX to run on those processors would be neat.

@countradooku

Copy link
Copy Markdown

Stole my idea :(

@goniz

goniz commented Jan 22, 2026

Copy link
Copy Markdown

How is this even possible for such an awesome PR to be left like this?

Copilot AI review requested due to automatic review settings January 24, 2026 17:08

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds experimental ROCm backend support to MLX, enabling execution on AMD GPUs. The implementation mirrors the CUDA backend structure, providing HIP-based implementations of core operations, memory management, and device handling.

Changes:

  • Added ROCm backend infrastructure with device management, memory allocation, and stream handling
  • Implemented HIP kernels for unary, binary, ternary operations, reductions, normalization (softmax, layer_norm, rms_norm), RoPE, and sorting
  • Updated build system (CMake) to support ROCm compilation with configurable GPU architectures

Reviewed changes

Copilot reviewed 59 out of 59 changed files in this pull request and generated 13 comments.

Show a summary per file
File Description
CMakeLists.txt Added MLX_BUILD_ROCM option and ROCm library detection
mlx/CMakeLists.txt Integrated ROCm backend build configuration
mlx/device.cpp Added ROCm device availability checks
mlx/backend/rocm/*.hip HIP kernel implementations for various operations
mlx/backend/rocm/device.* ROCm device and stream management
mlx/backend/rocm/allocator.* ROCm-specific memory allocator using HIP unified memory
mlx/backend/rocm/worker.* Async task execution worker for stream synchronization
mlx/backend/rocm/utils.* HIP utility functions and error handling
mlx/backend/rocm/jit_module.* JIT compilation support using HIPRTC
mlx/backend/rocm/device/*.hpp Device-side utility functions and type definitions
mlx/backend/rocm/CMakeLists.txt ROCm backend build configuration

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlx/backend/rocm/softmax.hip Outdated
Comment thread mlx/backend/rocm/device.cpp Outdated
Comment thread mlx/backend/rocm/layer_norm.hip Outdated
Comment thread mlx/backend/rocm/rope.hip Outdated
Comment thread mlx/backend/rocm/softmax.hip Outdated
Comment thread mlx/backend/rocm/allocator.cpp Outdated
Comment thread CMakeLists.txt Outdated
Comment thread mlx/backend/rocm/binary.hip Outdated
Comment thread mlx/backend/rocm/rms_norm.hip Outdated
Comment thread mlx/backend/rocm/layer_norm.hip Outdated
@goniz

goniz commented Jan 24, 2026

Copy link
Copy Markdown

👑👑👑

@NripeshN

Copy link
Copy Markdown
Contributor Author

Can anyone run

CMAKE_ARGS="-DMLX_BUILD_ROCM=ON" pip install -e .
CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES={based on your GPU}" pip install -e .

Replace {based on your GPU} with your GPU architecture

You can run

rocm-smi

to get your GPU information

@goniz

goniz commented Jan 24, 2026

Copy link
Copy Markdown

I'm getting this CMake error:

CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151" pip install -e .

      -- Configuring done (4.8s)
      CMake Error: The following variables are used in this project, but they are set to NOTFOUND.
      Please set them or make sure they are set and tested correctly in the CMake files:
      /home/goniz/Work/mlx/LAPACK_INCLUDE_DIRS
         used as include directory in directory /home/goniz/Work/mlx
      
      CMake Error in CMakeLists.txt:
        HIP_ARCHITECTURES is empty for target "mlx".
      
      
      CMake Error in CMakeLists.txt:
        HIP_ARCHITECTURES is empty for target "mlx".
      
      
      -- Generating done (0.0s)
      CMake Generate step failed.  Build files cannot be regene
rated correctly.

Running on Strix Halo (gfx1151)

@NripeshN

Copy link
Copy Markdown
Contributor Author

I'm getting this CMake error:

CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151" pip install -e .
     -- Configuring done (4.8s)
     CMake Error: The following variables are used in this project, but they are set to NOTFOUND.
     Please set them or make sure they are set and tested correctly in the CMake files:
     /home/goniz/Work/mlx/LAPACK_INCLUDE_DIRS
        used as include directory in directory /home/goniz/Work/mlx
     
     CMake Error in CMakeLists.txt:
       HIP_ARCHITECTURES is empty for target "mlx".
     
     
     CMake Error in CMakeLists.txt:
       HIP_ARCHITECTURES is empty for target "mlx".
     
     
     -- Generating done (0.0s)
     CMake Generate step failed.  Build files cannot be regene
rated correctly.

Running on Strix Halo (gfx1151)

Could you retry with the latest push please (p.s. keep your fingers crossed while it compiles, worked for me 138th time)😅

@goniz

goniz commented Jan 25, 2026

Copy link
Copy Markdown
  Created wheel for mlx: filename=mlx-0.30.4.dev20260125+cadf18c1-0.editable-cp314-cp314-linux_x86_64.whl size=4722 sha256=72c664adbfc4fb9ec317522a8d83b84f85d599d08bd691d7fec3abfdb6f3a5e9
  Stored in directory: /tmp/pip-ephem-wheel-cache-nt7w6bq0/wheels/8a/63/d1/d7d629a5ff73457822bb71aa527c083674bb19ca314735cd05
Successfully built mlx
Installing collected packages: mlx
Successfully installed mlx-0.30.4.dev20260125+cadf18c1

Now what can I test? 😍

@goniz

goniz commented Jan 25, 2026

Copy link
Copy Markdown

I'm getting this:

ImportError: /home/goniz/Work/mlx/python/mlx/lib/libmlx.so: undefined symbol: _ZN3mlx4core11Convolution8eval_gpuERKSt6vectorINS0_5arrayESaIS3_EERS3_

@NripeshN

Copy link
Copy Markdown
Contributor Author

I'm getting this:

ImportError: /home/goniz/Work/mlx/python/mlx/lib/libmlx.so: undefined symbol: _ZN3mlx4core11Convolution8eval_gpuERKSt6vectorINS0_5arrayESaIS3_EERS3_

I forgot to test the Python build my bad, can you try it now?

Unfortunately I might not be able to help after it compiles, I don't have an AMD GPU to run tests😔 I've tried replicating most things from cuda, so hopefully it works

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown

Now fails on load with this:

>>> import mlx.core
Traceback (most recent call last):
  File "<python-input-0>", line 1, in <module>
    import mlx.core
ImportError: /home/goniz/Work/mlx/python/mlx/lib/libmlx.so: undefined symbol: hiprtcCompileProgram

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown

Unfortunately I might not be able to help after it compiles, I don't have an AMD GPU to run tests😔 I've tried replicating most things from cuda, so hopefully it works

Omg I don't believe you did it without AMD card 😱😱

@NripeshN

NripeshN commented Jan 26, 2026

Copy link
Copy Markdown
Contributor Author

Now fails on load with this:

The latest push hopefully fixes the undefined symbol error Found the issue, working on the fix😩

Omg I don't believe you did it without AMD card 😱😱

Haha docker literally saves me and humbles me at the same time

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown
image

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown

I might got over excited:
image

@NripeshN

Copy link
Copy Markdown
Contributor Author

Wait it works?😅

Ah unfortunately unless a magic fairy sends me a PC with AMD GPU I cannot help after this😭 With the ram prices I doubt the magic fairy has the funds either🥲

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown

Latest commit broke something:
image

@NripeshN

Copy link
Copy Markdown
Contributor Author

Lemme try adding a fix for both the issues above actually. I had just made a stub implementation earlier.

@NripeshN

Copy link
Copy Markdown
Contributor Author

@goniz give the last push a try maybe. It might not work but you will definitely not have the same error atleast☺️

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown

mlx rocm-support ? ❯︎ python3 qwen3.py 
Fetching 9 files: 100%|██████| 9/9 [00:00<00:00, 201864.90it/s]
Download complete: : 0.00B [00:00, ?B/s]              ?, ?it/s]
==========
Traceback (most recent call last):
  File "/home/goniz/Work/mlx/qwen3.py", line 15, in <module>
    text = generate(model, tokenizer, prompt=prompt, verbose=True)
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 762, in generate
    for response in stream_generate(model, tokenizer, prompt, **kwargs):
                    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 699, in stream_generate
    for n, (token, logprobs, from_draft) in enumerate(token_generator):
                                            ~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 689, in <genexpr>
    (token, logprobs, False) for token, logprobs in token_generator
                                                    ^^^^^^^^^^^^^^^
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 432, in generate_step
    mx.eval([c.state for c in prompt_cache])
    ~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Unsupported dtype for affine_dequantize

@NripeshN

Copy link
Copy Markdown
Contributor Author

Might fix it(????)

@goniz

goniz commented Jan 26, 2026

Copy link
Copy Markdown

mlx rocm-support ? ❯︎ python3 qwen3.py 
Fetching 9 files: 100%|███████| 9/9 [00:00<00:00, 28575.88it/s]
Download complete: : 0.00B [00:00, ?B/s]              ?, ?it/s]
==========
Traceback (most recent call last):
  File "/home/goniz/Work/mlx/qwen3.py", line 15, in <module>
    text = generate(model, tokenizer, prompt=prompt, verbose=True)
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 762, in generate
    for response in stream_generate(model, tokenizer, prompt, **kwargs):
                    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 699, in stream_generate
    for n, (token, logprobs, from_draft) in enumerate(token_generator):
                                            ~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 689, in <genexpr>
    (token, logprobs, False) for token, logprobs in token_generator
                                                    ^^^^^^^^^^^^^^^
  File "/home/goniz/Work/mlx/venv/lib/python3.14/site-packages/mlx_lm/generate.py", line 432, in generate_step
    mx.eval([c.state for c in prompt_cache])
    ~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: QuantizedMatmul has no ROCm implementation.

Geramy added 30 commits June 20, 2026 18:29
…idge

add_kernel_node_ex now copies arg VALUES into a heap pack kept alive through
commit() (HIP graph nodes reference kernelParams until instantiate/exec-update,
after which the pack is cleared) — fixes dangling kernelParams. The per-op
micro-capture bridge in launch_kernel is now behind MLX_HIP_GRAPH_BRIDGE.
graphs-OFF (default) unchanged.
Diagnostic: pure add_kernel_node kernel-node graphs launch correctly on this
ROCm build (model-load evals pass). Remaining graphs-ON blockers are the
non-kernel residuals only: library GEMM (aborts/crashes under graph) and the
child-graph bridge nodes. graphs-OFF (default) unaffected.
…lifetime

graphs-ON (MLX_USE_HIP_GRAPHS, default OFF) now RUNS end-to-end on the ROCm
7.13 runtime (7.12 segfaulted hipGraphLaunch). launch_kernel graph-splits
un-graphable residuals (JIT module kernels, GEMM, memsets): flush+launch the
accumulated kernel-node graph, run the residual immediately on the same stream,
start a fresh graph. hipBLASLt forced to rocBLAS in graph mode (its lazy init
aborts under graph activity). kernelParams arg-packs freed at synchronize (exec
references them through async launch). KNOWN WIP: graphs-ON output is incorrect
(incomplete set_input/output_array dependency edges -> races) and slower than
eager due to graph-split fragmentation. Default graphs-OFF unchanged (41 tok/s).
graphs-ON (default OFF): graph nodes serialized into a linear chain in
submission order (matches eager stream order; robust vs incomplete
set_input/output_array edges) and arg-packs freed at synchronize. Runs on the
7.13 runtime without crashing but output is still incorrect (an unisolated
race) and slower than eager due to graph-split fragmentation. Default
graphs-OFF eager unchanged (41 tok/s coherent).
Bisection of the graphs-ON correctness bug (all on 7.13 runtime, graphs-OFF
default unaffected, eager 41 tok/s coherent):
- 1 node/graph is ALSO wrong -> not multi-node dependency/race.
- exec-cache keyed by node-type-only collided distinct kernel sequences ->
  hipGraphExecUpdate mis-reused execs -> garbage. Now key by func ptr + dims.
- fresh hipGraphInstantiate per commit + destroy-at-synchronize (no reuse) ->
  segfaults; ExecUpdate-reuse -> runs but garbage. Both point to a deeper
  hipGraph instantiate/exec instability for this GDN+MoE workload on ROCm 7.13.
graphs-ON still not correct; eager + 7.13 is the working path.
…solated

Standalone repro proved hipGraphAddKernelNode + tuple-marshaling are correct on
7.13 (identical to hipLaunchKernel). Bisection of full-forward graphs-ON:
- BUG1 buffer lifetime: graph nodes execute at commit, but the allocator frees
  intermediates at eval time -> reused before the graph runs -> segfault.
  Deferring frees (graph_active) prevents the segfault but balloons memory.
- BUG2 computation: even with buffers kept alive/non-aliased, output is garbage
  -> a remaining error in the full multi-kernel forward not reproduced by the
  single-kernel repro. Needs per-kernel eager-vs-graph output bisection.
Default graphs-OFF eager unchanged (41 tok/s coherent on 7.12 and 7.13).
Root cause of the graphs-ON garbage: a ROCm CLR per-node kernarg corruption bug
(hip#3887 / clr#138) that produces WRONG results once one instantiated graph holds
>~3 heterogeneous kernel nodes. Verified: 3-node graphs match eager BIT-FOR-BIT;
4+ nodes -> garbage. Found via per-op eager-vs-graph checksum (identical for all
9636 ops when force-executed) + batch-size bisection + standalone HIP repros
(10-node chains, 4-node packs all correct in isolation -> not our code, not HIP
deps, not param marshaling).

Fix: cap max_ops_per_graph at 2 (graphs <=3 nodes, the verified-correct range), and
destroy each exec via a completion handler after its async launch (instead of
retaining until synchronize, which OOM'd over a long generation). Result:
MLX_USE_HIP_GRAPHS=1 generates a full coherent 1000-token story, 19.9 tok/s. Speed
is below eager (41) because the CLR bug forces tiny graphs, killing the batching
win — that ceiling lifts only when AMD fixes CLR. Default (graphs-OFF) eager
unchanged.
The graphs-ON garbage was ONE bad op, not a node-count limit: per-op force-execute
bisection (MLX_GRAPH_FORCE_FROM in eval.cpp) pinned it to Concatenate, whose
multi-copy kernels (one per input slice, all writing one output buffer at offsets)
corrupt when co-grouped with neighbors in a HIP graph. Fix: isolate Concatenate via
a commit-split in gpu::eval (is_graph_split_op). Graphs now match eager BIT-IDENTICAL
for 200 tokens at cap=50 (raised from the cap=2 CLR workaround).

Speed note: graphs are still ~slower than eager (21 vs 25 tok/s) because fresh
hipGraphInstantiate per commit + per-GDN-layer Concatenate splits outweigh the
launch-overhead savings — decode on this APU is not launch-bound enough for HIP
graphs to win. Default remains graphs-OFF (eager).
A/B'd graph perf levers: MLX_DISABLE_COMPILE (fewer JIT splits) 21->31 tok/s;
MLX_GRAPH_REUSE (hipGraphExecUpdate instead of fresh instantiate) 31->34.5 but
corrupts (in-flight ExecUpdate on a same-topology exec reused within one token —
needs a per-key exec pool to fix). CONCLUSION: HIP graphs do not beat eager on
gfx1151 for this MoE (best 34.5 vs eager 41) — decode is not launch-bound enough;
instantiate + forced JIT/library/Concatenate splits cost more than the launch
savings. All gated OFF by default (graphs-OFF eager is the path).
Auto graph-batching path (gated behind MLX_USE_HIP_GRAPHS) — experimental/WIP.

clr#138 root cause: ROCm hipGraphExecUpdate/SetParams store kernel-node
kernelParams BY POINTER and never deep-copy them (CUDA deep-copies — that is the
exact divergence). A cached exec reused via hipGraphExecUpdate keeps reading our
arg Packs + the source graph by address, so freeing them per-token (or destroying
build_graph_ per-commit) corrupts it -> garbage. Fresh-instantiate avoids it only
because it is never reused across the free.

Fix (eliminates the garbage): per-topology ExecSlot that OWNS its source
hipGraph_t + arg Packs for the exec's life; only a drained slot (inflight==0) is
reused, so the in-flight ones keep their memory. insert_graph_dependencies builds
real data-dependency edges (node_map_ buffer->producer), mirroring the CUDA
backend, which ExecUpdate needs to re-map node params. arange now registers its
output array (a missing graph edge).

128-bit (uint4 b128) q-weight loads: removes the "uint4 miscompiles 8-bit affine
on RDNA 3.5" workaround — verified bit-identical on gfx1151 for q4/q6/q8, it was a
misdiagnosis. load_weight_vec_streaming widened from scalar to vector transactions.

Enable (eager/default is UNCHANGED — the correct, fastest path at ~41 tok/s):
  LD_LIBRARY_PATH=/opt/rocm/core-7.13/lib   # 7.13 runtime REQUIRED
                                            # (7.12 segfaults hipGraphLaunch)
  MLX_USE_HIP_GRAPHS=1                       # turn on graph auto-batching
  MLX_DISABLE_COMPILE=1                      # avoid JIT-compiled graph splits
Guards / diagnostics:
  MLX_GRAPH_NO_REUSE=1       fresh hipGraphInstantiate each commit (no ExecUpdate)
  MLX_MAX_OPS_PER_BUFFER=N   cap kernel nodes per graph
  MLX_NO_CONCAT_SPLIT=1      keep Concatenate in-graph
  MLX_GRAPH_FORCE_FROM=N     force ops >= N to eager (bisection)
  MLX_HIP_GRAPH_DUMP=1 / MLX_GRAPH_DEBUG=1   node/edge dump + tracing

Status: graphs-ON reuse is WIP — the real-dep graph still needs complete kernel
I/O registration to be bit-correct under ExecUpdate. eager (graphs OFF) unaffected.
Two correctness/perf fixes for the HIP-graph auto-batch decode path:

- Submission-order chain edge: insert_graph_dependencies now serializes each
  node behind the previously-inserted one (deduped vs real data deps). The
  documented last_node_ backstop was never wired, so kernels that register no
  I/O raced their producers -> garbage. Graph output is now a superset of
  eager's serial-on-one-stream order.

- Module kernels as graph nodes: ROCm 7.13 hipGraphAddKernelNode accepts a
  hipFunction_t in hipKernelNodeParams.func, so Compiled* JIT-fused kernels and
  CustomKernel no longer have to graph-split. New add_module_kernel_node /
  launch_module_kernel route them through kernel nodes. This eliminates ~320
  inline graph-splits per token (all of which were JIT/Custom, none GEMMs),
  collapsing the graph count and taking coherent graph decode 22.7 -> 37.1 tps.

Reuse: deterministic per-node hipGraphExecKernelNodeSetParams refresh (replaces
hipGraphExecUpdate, which returns success but mis-maps params in the model's
complex DAG). Opt-in via MLX_GRAPH_SETPARAMS while a kernarg-corruption on
by-value-struct kernels is debugged; default is reinstantiate-into-slot.

Diagnostics (env-gated): MLX_GRAPH_REUSE_STATS (+inline-by-op breakdown),
MLX_GRAPH_REINST_SLOT, MLX_GRAPH_SETPARAMS, MLX_EVENT_BLOCKING.
…leak

- Enable HIP-graph batching by default (opt out with MLX_USE_HIP_GRAPHS=0). The
  win is collapsing thousands of per-token kernel-launch submissions into a few
  graph launches — launch/PCIe traffic and latency on a discrete GPU over a slow
  link (TB5 eGPU), independent of local APU throughput.

- Fix is_hipblaslt_available(): it forced the graph-safe rocBLAS path by reading
  getenv("MLX_USE_HIP_GRAPHS") directly. With graphs now default-on (env unset)
  that check flipped and let hipBLASLt run under graph mode, aborting the process
  ("operation not permitted when stream is capturing"). Use use_hip_graphs().

- Fix the deferred-free leak: graph_active() is true for the whole auto-batch
  session, so flush_graph_deferred_frees() -> free() just re-deferred every
  buffer forever (graph-mode load peaked 30.7GB vs eager 19.9GB). Add
  free(Buffer, bool force); the flush forces a real release.

- Per-generation deferred-free reclaim infrastructure (graph_current_gen /
  free_graph_generation) is in place but OFF by default (MLX_GRAPH_FREE_LAG):
  per-chunk reclaim during a forward still races a later chunk's reference. The
  correct fix is deterministic 100% buffer reuse, not freeing.

- Precise exec-pool key (full grid/block dims, not the product) so reuse only
  matches structurally identical graphs. Reuse mechanisms behind env flags
  (MLX_GRAPH_SETPARAMS / RELAUNCH / EXECUPDATE / REINST_SLOT) + diagnostics
  (MLX_GRAPH_REUSE_STATS). MLX_EVENT_BLOCKING for a blocking event wait.

Note: graph-mode still holds a full forward's intermediates (~30GB) -> OOMs a
32GB R9700 during generation; needs the 100%-reuse memory work. Eager unaffected.
Diagnostics for the graph-mode memory work, both env-gated and off by default:
- MLX_GRAPH_POISON_FREE: per-gen reclaim overwrites the buffer with a sentinel
  and leaks it (kept mapped) so a read-after-free shows as garbage, not a crash.
- MLX_GRAPH_NODEFER: skip the defer-all-frees and rely on add_temporary (which
  holds each graph input's array::Data to its chunk's completion), to test
  whether the blanket deferral is redundant for the auto-batch path.
MLX's scheduler runs a thread per stream, so the graph encoder/allocator are hit
concurrently (incl. cross-stream AtomicEvent::signal -> commit during weight
materialization, and worker-thread frees during async decode). default-on graphs
exposed two data races that crash on the discrete R9700 (gfx1201):

- Device::get_command_encoder: encoders_ map find/emplace/rehash was unguarded;
  a concurrent insert handed back a garbage encoder -> SIGSEGV in
  add_kernel_node_ex during materialize_weights. Serialize with encoders_mtx_.
  (This fixes the load-time crash; the R9700 now loads under graphs.)

- malloc_async/free_async did hipMallocAsync/hipFreeAsync OUTSIDE the allocator
  mutex, so the eval thread's alloc raced the worker thread's free on the pool.
  Serialize the pool ops with pool_mutex().

- MLX_ROCM_NO_ASYNC_POOL env to fall back off the async pool (diagnostic for the
  remaining gfx1201 hipMallocAsync/graph interaction).

APU (gfx1151) graph decode stable at ~38 tok/s. R9700 still has further
device-specific graph issues during decode (a hipMallocAsync fault inside the
driver with the async pool; an "invalid device function" in hipGraphAddKernelNode
without it) — not yet resolved. Eager unaffected on both.
Graph mode defers buffer frees until they're safe, but hoarding a whole forward's
intermediates pushed the discrete R9700 (34GB) to 34.17GB during decode -> the
async pool exhausted VRAM and faulted inside hipMallocAsync (looked like a crash;
was OOM).

Track the deferred-free backlog (g_deferred_bytes) and, when it exceeds a cap,
drain the stream (all in-flight graphs complete, so nothing still references the
freed buffers) and flush. Race-free, and the sync only costs latency when memory
is actually high. MLX_GRAPH_DEFER_MAX_MB (default 1024; 0 disables).

Result: R9700 graph decode peaks at 24GB (was 34.2/OOM), loads at 19.9GB (was
30.4), and runs coherent + stable at ~35.5 tok/s (3/3). APU unaffected (~38 tok/s).
Also add MLX_ROCM_NO_ASYNC_POOL diagnostic.
The deferred-free backlog was bounded by a blocking hipStreamSynchronize + flush,
which drains the pipeline. Replace it for the common (pool-buffer) case with the
clean approach: reclaim each chunk's stream-ordered pool buffers via hipFreeAsync
queued right after the chunk's hipGraphLaunch on the same stream. The free is
stream-ordered, so it retires after the graph that reads the buffer — no
use-after-free and no drain. Only the non-stream-ordered remainder (unified/slab,
rare on the discrete path) falls back to the sync+flush cap.

R9700 graph decode peak drops to 20.4GB (was 24 with sync, 34/OOM before),
coherent + stable, ~35.7 tok/s. APU unaffected.
…ral hack)

Two latent errors that a fresh build hits (the in-tree chat build masked them
with stale objects); surfaced building the Python mlx bindings from scratch.

- add_kernel_node_ex: std::make_shared<Pack>() default-constructs Pack, which
  requires every decayed kernel-arg type to be default-constructible — some
  aren't (deleted default ctor). Construct the value tuple in place instead.

- load_weight_vec_streaming: __builtin_nontemporal_load on uint2/uint4 vector
  types is rejected by the clang toolchain ("nontemporal builtin not valid on
  vector types"), making the function ill-formed and cascading into "no matching
  function" all over the tiled-quant kernels. Drop the vector-load hack and use
  scalar nontemporal loads (the compiler coalesces adjacent ones).
Add a strict per-node signature check before hipGraphExecKernelNodeSetParams
reuse: if any src_nodes[i] kernel/dims differ from this commit's, fall back to
reinstantiate instead of writing mismatched params (correctness backstop).
[sp-sig-mismatch] stat exposes how often it fires. Diagnosis: with the guard
in place the counter stays 0 yet copy_gg_byval still faults under SETPARAMS,
isolating the OOB to CLR by-value-struct kernarg marshaling (not our keying).
Capture an FNV-1a hash of each kernel node's arg values at build time: typed
nodes hash the arg tuple in add_kernel_node_ex; JIT/module nodes hash
KernelArgs::storage_ via arg_hash() plumbed through add_module_kernel_node.
On slot reuse, compare against the slot's previous hashes to count how many
nodes' kernargs actually change token-to-token ([arg-change] stat). Reveals
~98.7% of decode-graph nodes get fresh buffer addresses every token.
…s them

hipGraphExecUpdate returns success but does NOT refresh kernel params nested
inside child-graph nodes, so reusing an exec via ExecUpdate produced coherent-
but-wrong (non-bit-identical) decode output. QuantizedMatmul wraps each kernel in
a single-kernel child graph, so almost every decode-layer graph carried child
nodes and broke reuse.

Fix: add_child_graph_node now flattens the child's kernels into build_graph_ as
top-level kernel nodes (topologically ordered) via a new add_kernel_node_kp()
helper that preserves the FULL params struct (both kernelParams and extra — the
child kernels use the extra kernarg form). With no child nodes, ExecUpdate
refreshes every node: MLX_GRAPH_EXECUPDATE output is now bit-identical to eager
at temp 0, update=N reinst=0, in both default and single-graph configs.

Also: per-node kernarg change-tracking diagnostic (gated on MLX_GRAPH_REUSE_STATS)
and a SetParams per-node sig guard, used to isolate the root cause.
A single-token (decode) forward now accrues into ONE graph instead of being
chopped by the per-graph op/byte caps and Concatenate splits. needs_commit()
returns false in decode-mode so the forward commits once at finalize; concat is
kept as a graph node; the cached exec is refreshed via hipGraphExecUpdate (now
correct after the child-graph flatten) and launched once per token. Prefill
leaves decode-mode off so its large intermediates stay bounded by the caps.

set/get via set_graph_decode_mode(); the generation loop sets it per step.
MLX_GRAPH_DECODE=0 disables. Result on Qwen3.6-35B-A3B-4bit: decode = 1 launch/
token, bit-identical to eager, beats eager +2% (gfx1151) / +9% (gfx1201);
replaces the previously net-negative default graph mode.
…ll crash)

The fp8 dequant GEMM tuned algos by running each candidate and timing it with
hipStreamSynchronize/hipEventSynchronize. When the GEMM is recorded inside a HIP
graph capture (long-prompt prefill on gfx1201, where M>=64 takes the fp8 path),
synchronizing the stream is illegal: it invalidates the capture, after which
every subsequent hipblasLtMatmul fails (status 6) and the run segfaults.

Fix: when the stream is in capture, skip the timing benchmark and take the
heuristic's top algo (res[0]) — it records into the captured graph and executes
on replay. Non-capture (eager) path still benchmarks. Result on R9700: long-prompt
prefill no longer crashes, bit-identical to eager, and ~13% faster (332 vs 292 pp/s).
Drop the MLX_GRAPH_REUSE_STATS machinery (reuse/launch/timer counters, per-node
arg-hash change tracking incl. KernelArgs::arg_hash, inline-by-op, the stats
report), the dead alternate reuse mechanisms (SetParams + per-node sig guard,
relaunch, MLX_GRAPH_NO_REUSE, the eu_min/max bisection gate, MLX_HIP_GRAPH_DUMP,
MLX_GRAPH_FORCE_FROM/TRACE), and the MLX_GRAPH_EXECUPDATE env opt-in (ExecUpdate
is driven by decode-mode). Reuse path is now just: ExecUpdate (decode-mode) else
reinstantiate, else new instantiate. Bit-identical decode on gfx1151 + gfx1201.
Remove the unused stream-capture path now that its only callers are gone (engine
Phase 2): CommandEncoder begin_capture/end_capture/replay/reset_graph/has_graph/
capturing() + the capturing_/graph_/graph_exec_/capture_held_/graph_execs_ members
and the if(capturing_) blocks in commit()/synchronize(); launch_kernel drops the
!capturing_ guard; event.hip drops the capture branch; eval.cpp drops the
gpu_arena_*/gpu_graph_* shims (DecodeArena class + kv-pos kernels still present,
removed in a follow-up). Shipping graphs (hipGraphAddKernelNode + ExecUpdate) are
untouched. NOTE: under test post-commit.
The capture-once allocator arena (DecodeArena) and the in-place device-position
KV kernels (gpu_kv_pos_set/increment, gpu_kv_row_write) were only used by the
removed capture-once decode path. Drop the DecodeArena class + arena_ member +
the arena_.active() fast-paths in malloc/malloc_async/free (always false now),
and the kv-pos/kv-row-write kernels in indexing.hip. Verified bit-identical
decode on gfx1151 + gfx1201 and clean long-prompt prefill on gfx1201 (after a
clean backend rebuild — incremental .hip objects go stale on device.h changes).
qmm.hip: v_dot4_i32_iu8 integer-dot quantized matmul for 4-bit affine decode,
on both the dense QuantizedMatmul path and the MoE GatherQMM path. Activation
is quantized to int8 (per-group scale + exact sum) in a prepass; weight nibbles
feed __builtin_amdgcn_sudot4 with int32 accumulate, then affine dequant
(scale*dx*<dot> + bias*Sx). Gated by MLX_QMV_IDOT (dense) and
MLX_QMV_IDOT_GATHER (gather, experimental); both off by default. Coherent
(~0.27% L2 vs scalar). +7% on dense 4B; neutral on MoE (gather is latency-bound).

copy.hip: env-gated (MLX_COUNT_COPIES) histogram of general copies by shape,
to attribute copy_gg/copy_g kernels to model ops. No overhead when disabled.
gather_qmv_wide_kernel: flat per-(batch,col) warp with 128-bit (uint4 = 32
nibbles) weight loads instead of the per-group 32-bit loop, for 4-bit affine
MoE-expert decode. Per-group affine applied inline; activation read from global
(L2-cached). Requires K%32==0, group_size in {32,64,128}. Gated MLX_QMV_WIDE,
off by default. Microbench: scattered-expert DRAM read 130->295 GB/s on R9700
(2.3x), 98->150 on gfx1151. Coherent. In-engine TPS-neutral on the 35B (experts
are the smaller matmul share; dense qmv_tiled already uses wide loads + dot4),
kept as a measured, available path.
…% dense

qmv_wide_kernel: one warp per output column, flat full-K loop with 128-bit
(uint4 = 32 nibbles) weight loads, per-group affine applied inline, activation
from global. Replaces qmv_tiled on the 4-bit affine decode path under
MLX_QMV_WIDE. Root cause: qmv_tiled achieves only ~108 GB/s on the real 8MB
per-layer matmul while a simple wide-load streaming kernel sustains ~175 (a
kernel inefficiency, not matmul size or interleaving — verified by rotating-DRAM
microbench). Requires K%32==0, group_size in {32,64,128}.

Measured (gfx1151, greedy decode), MLX_QMV_WIDE on vs off:
  dense Qwen3.5-4B:   54.2 -> 61.9 tok/s (+14%)
  dense Qwen3.6-27B:  11.3 -> 13.5 tok/s (+20%)
  35B-A3B MoE:        38.4 -> 40.3 (+5%, dense proj is smaller share)
All coherent. Win scales with dense matmul share. MLX_QMV_WIDE now covers both
this dense path and the gather path.
The dense qmv_wide path is a measured +14-20% win on dense 4-bit decode and
never regresses, so enable it by default for the qualifying decode case (4-bit
affine, K%32==0, group_size in {32,64,128}, M<=8). Both the dense qmv_wide and
the gather_qmv_wide paths now default on; set MLX_QMV_NO_WIDE=1 to revert to
qmv_tiled / the warp-shared gather.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add ROCm Support for AMD GPUs