From f4e30c3bc220a1de0df1926fcfc6dbb0340b9001 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 11 Jun 2026 16:41:37 +0100 Subject: [PATCH 1/2] [Performance] Shared-memory signaling for ParallelEnv and ring-buffer transport for MultiAsyncCollector Two IPC optimizations targeting per-step syscalls and per-batch copies: 1. ParallelEnv worker_wait="adaptive"|"spin" (default "block", unchanged): payload-free hot-path commands (step, step_and_maybe_reset) are written as opcodes into a shared-memory RawArray that workers spin-poll instead of being pickled and sent through a pipe (one syscall per worker per step). This mirrors the existing shm done-flags in the other direction. In adaptive mode workers spin for spin_for seconds then fall back to a blocking pipe wait, advertising a sleep state so the parent wakes them through the pipe; a short poll recheck covers the theoretical lost-wake window. Payload-carrying commands (resets, seeds, non-tensor data, RNN passthrough) keep using the pipe, and no-buffer mode falls back to "block" with a warning. ~30% step-throughput gain on 4 workers with a fast env. 2. MultiAsyncCollector buffer_depth=K (default 1, unchanged): workers copy each rollout into one of K rotating shared-memory slots and the queue message shrinks to (idx, slot); the main process yields zero-copy views instead of cloning every batch. A yielded batch stays valid until the same worker has collected K-1 further batches; K=2 covers the standard iteration pattern since the continue-before-yield handshake keeps workers at most one rollout ahead. First use of a slot ships the buffer ref through the queue (re-sent if the put times out). Rejected for MultiSyncCollector, replay_buffer mode and use_buffers=False. Both knobs are exposed in configure_parallel and the Hydra configs (BatchedEnvConfig, MultiSyncCollectorConfig, MultiAsyncCollectorConfig), covered by tests (TestWorkerWait, TestBufferDepth) and parametrized benchmarks (test_parallel_worker_wait, test_async_buffer_depth). Co-Authored-By: Claude Fable 5 --- benchmarks/test_collectors_benchmark.py | 23 ++ benchmarks/test_envs_benchmark.py | 12 + test/envs/test_parallel.py | 87 +++++++ test/test_collectors.py | 98 ++++++++ torchrl/collectors/_multi_async.py | 22 ++ torchrl/collectors/_multi_base.py | 45 ++++ torchrl/collectors/_runner.py | 112 +++++++-- torchrl/envs/batched_envs.py | 226 ++++++++++++++++-- .../trainers/algorithms/configs/collectors.py | 2 + torchrl/trainers/algorithms/configs/envs.py | 10 + 10 files changed, 589 insertions(+), 48 deletions(-) diff --git a/benchmarks/test_collectors_benchmark.py b/benchmarks/test_collectors_benchmark.py index f15bf2bf1e5..3ea89756e5a 100644 --- a/benchmarks/test_collectors_benchmark.py +++ b/benchmarks/test_collectors_benchmark.py @@ -255,6 +255,29 @@ def test_async(benchmark): benchmark(execute_collector, c) +@pytest.mark.parametrize("buffer_depth", [None, 2]) +def test_async_buffer_depth(benchmark, buffer_depth): + device = "cuda:0" if torch.cuda.device_count() else "cpu" + env = EnvCreator( + lambda: TransformedEnv( + DMControlEnv("cheetah", "run", device=device), StepCounter(50) + ) + ) + c = MultiAsyncCollector( + [env, env], + RandomPolicy(env().action_spec), + total_frames=-1, + frames_per_batch=100, + device=device, + buffer_depth=buffer_depth, + ) + c = iter(c) + for i, _ in enumerate(c): + if i == 10: + break + benchmark(execute_collector, c) + + @pytest.mark.skipif(not torch.cuda.device_count(), reason="no rendering without cuda") def test_single_pixels(benchmark): (c,), _ = single_collector_setup_pixels() diff --git a/benchmarks/test_envs_benchmark.py b/benchmarks/test_envs_benchmark.py index ac27fc84190..9877ead1cfe 100644 --- a/benchmarks/test_envs_benchmark.py +++ b/benchmarks/test_envs_benchmark.py @@ -97,6 +97,18 @@ def test_parallel(benchmark): benchmark(execute_env, c) +@pytest.mark.parametrize("worker_wait", ["block", "adaptive", "spin"]) +def test_parallel_worker_wait(benchmark, worker_wait): + device = "cuda:0" if torch.cuda.device_count() else "cpu" + env = ParallelEnv( + 3, + lambda: DMControlEnv("cheetah", "run", device=device), + worker_wait=worker_wait, + ) + env.rollout(3) + benchmark(execute_env, env) + + @pytest.mark.parametrize("nested", [True, False]) @pytest.mark.parametrize("keep_other", [True, False]) @pytest.mark.parametrize("exclude_reward", [True, False]) diff --git a/test/envs/test_parallel.py b/test/envs/test_parallel.py index a9c08a78d27..5798237e8df 100644 --- a/test/envs/test_parallel.py +++ b/test/envs/test_parallel.py @@ -48,6 +48,7 @@ ContinuousActionVecMockEnv, CountingEnv, CountingEnvCountPolicy, + CountingEnvWithString, DiscreteActionConvMockEnv, DiscreteActionVecMockEnv, MockBatchedLockedEnv, @@ -1334,3 +1335,89 @@ def test_stackable(): assert not _stackable(*stack) stack = [TensorDict({"a": "a string"}, []), TensorDict({"a": "another string"}, [])] assert _stackable(*stack) + + +class TestWorkerWait: + """Tests for the shared-memory command signaling of ParallelEnv (worker_wait).""" + + @staticmethod + def _rollout(env_cls, parallel_env_cls, n_steps=20, **kwargs): + env = parallel_env_cls(2, lambda: env_cls(max_steps=10), **kwargs) + try: + env.set_seed(0) + torch.manual_seed(0) + return env.rollout(n_steps, break_when_any_done=False) + finally: + env.close() + del env + + @pytest.mark.parametrize( + "worker_wait,spin_for", + [ + ("adaptive", 1e-3), + # tiny spin window: forces the sleep/wake fallback path + ("adaptive", 1e-6), + ("spin", 1e-3), + ], + ) + @pytest.mark.parametrize("env_cls", [CountingEnv, NestedCountingEnv]) + def test_worker_wait_rollout_parity( + self, worker_wait, spin_for, env_cls, maybe_fork_ParallelEnv + ): + r_block = self._rollout(env_cls, maybe_fork_ParallelEnv) + r_fast = self._rollout( + env_cls, + maybe_fork_ParallelEnv, + worker_wait=worker_wait, + spin_for=spin_for, + ) + assert_allclose_td(r_block, r_fast) + + def test_worker_wait_payload_fallback(self, maybe_fork_ParallelEnv): + # Non-tensor keys put a payload on every step command, which must + # transparently fall back to the pipe path. + env = maybe_fork_ParallelEnv(2, CountingEnvWithString, worker_wait="adaptive") + try: + r = env.rollout(10, break_when_any_done=False) + assert r["string"] is not None + assert r["observation"].shape == (2, 10, 1) + finally: + env.close() + del env + + def test_worker_wait_no_buffers_fallback(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv( + 2, + lambda: CountingEnv(max_steps=10), + use_buffers=False, + worker_wait="adaptive", + ) + try: + with pytest.warns(UserWarning, match="requires use_buffers=True"): + r = env.rollout(10, break_when_any_done=False) + assert r["observation"].shape == (2, 10, 1) + finally: + env.close() + del env + + def test_worker_wait_validation(self): + with pytest.raises(ValueError, match="worker_wait must be one of"): + ParallelEnv(2, lambda: CountingEnv(), worker_wait="bogus") + with pytest.raises(ValueError, match="spin_for must be a positive float"): + ParallelEnv(2, lambda: CountingEnv(), spin_for=0.0) + with pytest.raises(TypeError, match="Cannot use worker_wait"): + SerialEnv(2, lambda: CountingEnv(), worker_wait="spin") + + def test_worker_wait_configure_parallel(self, maybe_fork_ParallelEnv): + env = maybe_fork_ParallelEnv(2, lambda: CountingEnv(max_steps=10)) + env.configure_parallel(worker_wait="adaptive", spin_for=1e-4) + assert env.worker_wait == "adaptive" + assert env.spin_for == 1e-4 + try: + env.set_seed(0) + torch.manual_seed(0) + r = env.rollout(10, break_when_any_done=False) + assert r["observation"].shape == (2, 10, 1) + finally: + env.close() + del env diff --git a/test/test_collectors.py b/test/test_collectors.py index f57b4e6f70c..02d2ba2b34e 100644 --- a/test/test_collectors.py +++ b/test/test_collectors.py @@ -7534,6 +7534,104 @@ def env_fn(): assert ("collector", "traj_ids") in rb.sample(1).keys(True) +class TestBufferDepth: + """Tests for the ring-buffer transport (buffer_depth > 1) of MultiAsyncCollector.""" + + def test_ring_rotation_and_validity_window(self): + # A single worker makes slot rotation deterministic: batch j lands in + # slot j % buffer_depth. + collector = MultiAsyncCollector( + [functools.partial(CountingEnv, max_steps=20)], + policy=None, + frames_per_batch=16, + total_frames=16 * 6, + buffer_depth=2, + ) + try: + views = [] + clones = [] + for data in collector: + views.append(data) + clones.append(data.clone()) + finally: + collector.shutdown() + assert len(views) == 6 + + # Batches 0, 2, 4 share storage (slot 0); 1, 3, 5 share storage (slot 1). + ptrs = [v["observation"].data_ptr() for v in views] + assert ptrs[0] == ptrs[2] == ptrs[4] + assert ptrs[1] == ptrs[3] == ptrs[5] + assert ptrs[0] != ptrs[1] + + # Validity window: slot 0 was last rewritten by batch 4, so the view of + # batch 0 now holds batch 4's data; the most recent batch is intact. + assert (views[0]["observation"] == clones[4]["observation"]).all() + assert (views[-1]["observation"] == clones[-1]["observation"]).all() + + @pytest.mark.parametrize("env_cls", [CountingEnv, NestedCountingEnv]) + def test_buffer_depth_content_parity(self, env_cls): + # With a single worker and a deterministic policy the batch stream is + # deterministic, so depth-1 (cloned yields) and depth-2 (view yields) + # must coincide. NestedCountingEnv exercises nested action keys. + probe = env_cls(max_steps=20) + policy = CountingEnvCountPolicy( + probe.full_action_spec[probe.action_key], probe.action_key + ) + + def collect(buffer_depth): + collector = MultiAsyncCollector( + [functools.partial(env_cls, max_steps=20)], + policy=policy, + frames_per_batch=16, + total_frames=64, + buffer_depth=buffer_depth, + ) + try: + return [d.clone() for d in collector] + finally: + collector.shutdown() + + for d1, d2 in zip(collect(1), collect(2), strict=True): + assert_allclose_td(d1, d2) + + def test_buffer_depth_validation(self): + env_fn = functools.partial(CountingEnv, max_steps=20) + with pytest.raises(ValueError, match="buffer_depth must be >= 1"): + MultiAsyncCollector( + [env_fn], + policy=None, + frames_per_batch=16, + total_frames=64, + buffer_depth=0, + ) + with pytest.raises(ValueError, match="only supported by MultiAsyncCollector"): + MultiSyncCollector( + [env_fn], + policy=None, + frames_per_batch=16, + total_frames=64, + buffer_depth=2, + ) + with pytest.raises(ValueError, match="replay_buffer"): + MultiAsyncCollector( + [env_fn], + policy=None, + frames_per_batch=16, + total_frames=64, + buffer_depth=2, + replay_buffer=ReplayBuffer(storage=LazyTensorStorage(100)), + ) + with pytest.raises(ValueError, match="use_buffers=True"): + MultiAsyncCollector( + [env_fn], + policy=None, + frames_per_batch=16, + total_frames=64, + buffer_depth=2, + use_buffers=False, + ) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main( diff --git a/torchrl/collectors/_multi_async.py b/torchrl/collectors/_multi_async.py index db77c4367aa..bbed314155a 100644 --- a/torchrl/collectors/_multi_async.py +++ b/torchrl/collectors/_multi_async.py @@ -124,6 +124,8 @@ class MultiAsyncCollector(MultiCollector): __doc__ += MultiCollector.__doc__ + _supports_buffer_depth = True + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.out_tensordicts = defaultdict(lambda: None) @@ -195,6 +197,20 @@ def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int: def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: new_data, j = self.queue_out.get(timeout=timeout) use_buffers = self._use_buffers + if self.buffer_depth > 1: + # Ring-buffer transport: the first message for each (worker, slot) + # pair carries the shared-memory buffer itself, later messages only + # carry the indices. The returned tensordict is a view of the slot: + # it is NOT cloned, since the worker rotates across ``buffer_depth`` + # slots and will not rewrite this one before ``buffer_depth - 1`` + # further rollouts. + if len(new_data) == 3: + data, idx, slot = new_data + self.out_tensordicts[(idx, slot)] = data + else: + idx, slot = new_data + out = self.out_tensordicts[(idx, slot)] + return idx, j, out if self.replay_buffer is not None: idx = new_data elif j == 0 or not use_buffers: @@ -294,6 +310,12 @@ def _shutdown_main(self, *args, **kwargs) -> None: return super()._shutdown_main(*args, **kwargs) def reset(self, reset_idx: Sequence[bool] | None = None) -> None: + if self.running and self.buffer_depth == 2: + warnings.warn( + "Calling reset() while iterating grants workers one extra rollout " + "of lookahead: with buffer_depth=2, tensordicts yielded before this " + "call may be overwritten. Clone them if needed, or use buffer_depth=3." + ) super().reset(reset_idx) if self.queue_out.full(): time.sleep(_TIMEOUT) # wait until queue is empty diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index b1fa70c00ce..1e9ae9189f7 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -276,6 +276,24 @@ class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta): use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. This isn't compatible with environments with dynamic specs. Defaults to ``True`` for envs without dynamic specs, ``False`` for others. + buffer_depth (int, optional): number of rotating shared-memory buffers + per worker. Currently only supported by + :class:`~torchrl.collectors.MultiAsyncCollector`, and incompatible + with ``replay_buffer`` (which bypasses buffer transport entirely) + and ``use_buffers=False``. + + With the default (``None``, i.e. ``1``), each yielded batch is cloned + on the main process so the single per-worker buffer can be reused. + With ``buffer_depth=K > 1``, workers write each rollout into one of + ``K`` rotating shared-memory slots and the main process yields + zero-copy views instead: a yielded batch remains valid until the + same worker has collected ``K - 1`` further batches. In the steady + state a worker collects at most one batch ahead, so ``K=2`` is safe + for the standard ``for data in collector: ...`` pattern; calling + :meth:`reset` while iterating grants workers one extra rollout of + lookahead, so use ``K=3`` if you reset mid-collection. Clone the + yielded tensordict to keep it indefinitely, and treat it as + read-only. replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts but populate the buffer instead. Defaults to ``None``. extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not @@ -393,6 +411,9 @@ class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta): """ + # Whether this collector supports buffer_depth > 1 (ring-buffer transport). + _supports_buffer_depth = False + def __init__( self, create_env_fn: Sequence[Callable[[], EnvBase]], @@ -425,6 +446,7 @@ def __init__( cat_results: str | int | None = None, set_truncated: bool = False, use_buffers: bool | None = None, + buffer_depth: int | None = None, replay_buffer: ReplayBuffer | None = None, extend_buffer: bool = True, trust_policy: bool | None = None, @@ -505,6 +527,28 @@ def __init__( self._use_buffers = use_buffers self.replay_buffer = replay_buffer + # Set up ring-buffer transport depth + buffer_depth = 1 if buffer_depth is None else int(buffer_depth) + if buffer_depth < 1: + raise ValueError(f"buffer_depth must be >= 1, got {buffer_depth}.") + if buffer_depth > 1: + if not self._supports_buffer_depth: + raise ValueError( + f"buffer_depth > 1 is currently only supported by MultiAsyncCollector, " + f"not {type(self).__name__}." + ) + if replay_buffer is not None: + raise ValueError( + "buffer_depth > 1 has no effect when a replay_buffer is provided: " + "workers write directly into the buffer. Remove one of the two options." + ) + if use_buffers is False: + raise ValueError( + "buffer_depth > 1 requires buffer-based transport (use_buffers=True)." + ) + self._use_buffers = True + self.buffer_depth = buffer_depth + # Set up policy and weights if trust_policy is None: trust_policy = policy is not None and isinstance(policy, CudaGraphModule) @@ -1345,6 +1389,7 @@ def _run_processes(self) -> None: "interruptor": self.interruptor, "set_truncated": self.set_truncated, "use_buffers": self._use_buffers, + "buffer_depth": self.buffer_depth, "replay_buffer": self.replay_buffer, "extend_buffer": self.extend_buffer, "traj_pool": self._traj_pool, diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index cc97f19ac68..21304b74cc8 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -31,6 +31,52 @@ from torchrl.weight_update import WeightSyncScheme from torchrl.weight_update.utils import _resolve_model +_MPS_SHARE_ERROR = ( + "tensors on mps device cannot be put in shared memory. Make sure " + "the shared device (aka storing_device) is set to CPU." +) + + +def _share_tensordict_for_transport(td: TensorDictBase) -> None: + """Place a collected tensordict in shared memory for queue transport. + + CPU and CUDA tensordicts are shared wholesale; device-less tensordicts are + shared per-tensor (non-CPU leaves are assumed shareable already). MPS + tensors cannot be shared and raise an error. + """ + if td.device is not None: + # placeholder in case we need different behaviors + if td.device.type in ("cpu",): + td.share_memory_() + elif td.device.type in ("mps",): + raise RuntimeError(_MPS_SHARE_ERROR) + elif td.device.type == "cuda": + td.share_memory_() + else: + raise NotImplementedError( + f"Device {td.device} is not supported in multi-collectors yet." + ) + else: + # make sure each cpu tensor is shared - assuming non-cpu devices are shared + def cast_tensor(x): + if x.device.type in ("cpu",): + x.share_memory_() + if x.device.type in ("mps",): + raise RuntimeError(_MPS_SHARE_ERROR) + + td.apply(cast_tensor, filter_empty=True) + + +def _td_has_cuda(td: TensorDictBase) -> bool: + """Whether any leaf tensor of ``td`` lives on a CUDA device.""" + has_cuda = [False] + + def look_for_cuda(tensor, has_cuda=has_cuda): + has_cuda[0] = has_cuda[0] or tensor.is_cuda + + td.apply(look_for_cuda, filter_empty=True) + return has_cuda[0] + def _main_async_collector( pipe_child: connection.Connection, @@ -51,6 +97,7 @@ def _main_async_collector( interruptor=None, set_truncated: bool = False, use_buffers: bool | None = None, + buffer_depth: int = 1, replay_buffer: ReplayBuffer | None = None, extend_buffer: bool = True, traj_pool: _TrajectoryPool = None, @@ -160,6 +207,11 @@ def _main_async_collector( scheme.model = actual_model use_buffers = inner_collector._use_buffers + if buffer_depth > 1 and not use_buffers: + raise RuntimeError( + "buffer_depth > 1 requires buffer-based transport, but the inner " + "collector resolved use_buffers=False (e.g. because of dynamic specs)." + ) if verbose: torchrl_logger.debug("Sync data collector created") @@ -188,6 +240,15 @@ def _main_async_collector( has_timed_out = False counter = 0 run_free = False + # Ring-buffer transport (buffer_depth > 1): rollouts are copied into + # rotating shared-memory slots so the main process can yield views without + # cloning, and so the worker can collect ahead without overwriting data + # the main process still holds. + ring_buffers = [None] * buffer_depth + # Whether the slot's buffer ref has reached the main process (a put can + # time out, in which case the ref must be re-sent on the next attempt). + ring_shipped = [False] * buffer_depth + ring_has_cuda = False while True: _timeout = _TIMEOUT if not has_timed_out else 1e-3 if not run_free and pipe_child.poll(_timeout): @@ -329,7 +390,28 @@ def _main_async_collector( has_timed_out = True continue - if j == 0 or not use_buffers: + if buffer_depth > 1: + if storing_device is not None and next_data.device != storing_device: + raise RuntimeError( + f"expected device to be {storing_device} but got {next_data.device}" + ) + slot = j % buffer_depth + buf = ring_buffers[slot] + if buf is None: + # First use of this slot: copy the rollout into a fresh + # shared-memory buffer and ship the buffer itself. Later + # uses only send (idx, slot). + buf = next_data.clone() + _share_tensordict_for_transport(buf) + ring_buffers[slot] = buf + ring_has_cuda = ring_has_cuda or _td_has_cuda(buf) + else: + buf.update_(next_data, non_blocking=True) + if ring_has_cuda and not no_cuda_sync: + # Make the slot contents visible to the main process. + torch.cuda.synchronize() + data = (buf, idx, slot) if not ring_shipped[slot] else (idx, slot) + elif j == 0 or not use_buffers: collected_tensordict = next_data if ( storing_device is not None @@ -343,31 +425,7 @@ def _main_async_collector( # if policy is on cuda and env on cuda, we are fine with this # If policy is on cuda and env on cpu (or opposite) we put tensors that # are on cpu in shared mem. - MPS_ERROR = ( - "tensors on mps device cannot be put in shared memory. Make sure " - "the shared device (aka storing_device) is set to CPU." - ) - if collected_tensordict.device is not None: - # placeholder in case we need different behaviors - if collected_tensordict.device.type in ("cpu",): - collected_tensordict.share_memory_() - elif collected_tensordict.device.type in ("mps",): - raise RuntimeError(MPS_ERROR) - elif collected_tensordict.device.type == "cuda": - collected_tensordict.share_memory_() - else: - raise NotImplementedError( - f"Device {collected_tensordict.device} is not supported in multi-collectors yet." - ) - else: - # make sure each cpu tensor is shared - assuming non-cpu devices are shared - def cast_tensor(x, MPS_ERROR=MPS_ERROR): - if x.device.type in ("cpu",): - x.share_memory_() - if x.device.type in ("mps",): - raise RuntimeError(MPS_ERROR) - - collected_tensordict.apply(cast_tensor, filter_empty=True) + _share_tensordict_for_transport(collected_tensordict) data = (collected_tensordict, idx) else: if next_data is not collected_tensordict: @@ -377,6 +435,8 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): data = idx # flag the worker that has sent its data try: queue_out.put((data, j), timeout=_TIMEOUT) + if buffer_depth > 1: + ring_shipped[slot] = True if verbose: torchrl_logger.debug(f"mp worker {idx} successfully sent data") j += 1 diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b1ba4d29bf8..5b5f37a545f 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -19,7 +19,7 @@ from multiprocessing import connection from multiprocessing.connection import wait as connection_wait from multiprocessing.synchronize import Lock as MpLock -from typing import Any +from typing import Any, Literal from warnings import warn import torch @@ -62,6 +62,21 @@ "`consolidate` keyword argument of the ParallelEnv constructor." ) +# Opcodes written by the parent into the shared command flags when +# ``worker_wait`` is "adaptive" or "spin". Only payload-free hot-path commands +# are dispatched this way; everything else goes through the pipe. +_CMD_NONE = 0 +_CMD_TO_STR = {1: "step", 2: "step_and_maybe_reset"} +_STR_TO_CMD = {v: k for k, v in _CMD_TO_STR.items()} +# Worker sleep states for the "adaptive" wait strategy. +_WORKER_AWAKE = 0 +_WORKER_SLEEPING = 1 +# How long a sleeping worker blocks on the pipe before re-checking the shared +# command flag. This only matters in the (theoretical) lost-wake race where the +# parent reads the sleep state before it becomes visible: the worker then picks +# the command up at the next recheck instead of waiting forever. +_WAKE_RECHECK_INTERVAL = 0.05 + def _to_device_mps_safe( tensor: torch.Tensor, @@ -328,6 +343,28 @@ class BatchedEnvBase(EnvBase): daemon (bool, optional): whether the processes should be daemonized. This is only applicable to parallel environments such as :class:`~torchrl.envs.ParallelEnv`. Defaults to ``False``. + worker_wait (str, optional): how workers wait for the next command. + To be used only with :class:`~torchrl.envs.ParallelEnv` subclasses, + and only effective when ``use_buffers=True``. + + - ``"block"`` (default): workers block on the pipe. Each step incurs + one pipe message (a syscall and a small pickle) per worker. + - ``"adaptive"``: payload-free hot-path commands (``step`` and + ``step_and_maybe_reset``) are written as opcodes to a shared-memory + flag that workers spin-poll, eliminating the per-step syscalls. + After ``spin_for`` seconds without a command, workers fall back to + blocking on the pipe (and the parent wakes them through it), so + long gaps between steps (e.g. a slow policy) don't burn CPU. + - ``"spin"``: workers spin indefinitely. Lowest latency, but each + worker keeps one CPU core busy at all times; only use when workers + do not outnumber the available cores. + + Commands that carry a payload (resets, seeds, non-tensor data, RNN + key passthrough) always travel through the pipe, whatever the mode. + spin_for (float, optional): how long (in seconds) workers spin before + falling back to a blocking pipe wait when ``worker_wait="adaptive"``. + Should roughly cover the typical time between two consecutive + commands (e.g. the policy forward pass). Defaults to ``1e-3``. auto_wrap_envs (bool, optional): if ``True`` (default), lambda functions passed as ``create_env_fn`` will be automatically wrapped in an :class:`~torchrl.envs.EnvCreator` to enable pickling for multiprocessing with the ``spawn`` start method. @@ -458,6 +495,8 @@ def __init__( use_buffers: bool | None = None, consolidate: bool = True, daemon: bool = False, + worker_wait: Literal["block", "adaptive", "spin"] = "block", + spin_for: float = 1e-3, ): super().__init__(device=device) self.serial_for_single = serial_for_single @@ -468,6 +507,14 @@ def __init__( self._use_buffers = use_buffers self.consolidate = consolidate self.daemon = daemon + if worker_wait not in ("block", "adaptive", "spin"): + raise ValueError( + f"worker_wait must be one of 'block', 'adaptive' or 'spin', got {worker_wait!r}." + ) + if spin_for <= 0: + raise ValueError(f"spin_for must be a positive float, got {spin_for}.") + self.worker_wait = worker_wait + self.spin_for = spin_for self._single_task = callable(create_env_fn) or (len(set(create_env_fn)) == 1) if callable(create_env_fn): @@ -525,6 +572,10 @@ def __init__( raise TypeError( f"Cannot use mp_start_method={mp_start_method} with envs of type {type(self)}." ) + if worker_wait != "block" and not isinstance(self, ParallelEnv): + raise TypeError( + f"Cannot use worker_wait={worker_wait} with envs of type {type(self)}." + ) self._mp_start_method = mp_start_method is_spec_locked = EnvBase.is_spec_locked @@ -540,6 +591,8 @@ def configure_parallel( num_sub_threads: int | None = None, non_blocking: bool | None = None, daemon: bool | None = None, + worker_wait: Literal["block", "adaptive", "spin"] | None = None, + spin_for: float | None = None, ) -> BatchedEnvBase: """Configure parallel execution parameters before the environment starts. @@ -560,6 +613,13 @@ def configure_parallel( non_blocking (bool, optional): if ``True``, device moves will be done using the ``non_blocking=True`` option. daemon (bool, optional): whether the processes should be daemonized. + worker_wait (str, optional): how workers wait for the next command. + One of ``"block"`` (wait on the pipe, default), ``"adaptive"`` + (spin on a shared-memory flag for ``spin_for`` seconds, then fall + back to the pipe) or ``"spin"`` (always spin). See the + :class:`~torchrl.envs.ParallelEnv` documentation for details. + spin_for (float, optional): how long (in seconds) workers spin before + falling back to a blocking pipe wait when ``worker_wait="adaptive"``. Returns: self: Returns self for method chaining. @@ -594,6 +654,20 @@ def configure_parallel( self._non_blocking = non_blocking if daemon is not None: self.daemon = daemon + if worker_wait is not None: + if worker_wait not in ("block", "adaptive", "spin"): + raise ValueError( + f"worker_wait must be one of 'block', 'adaptive' or 'spin', got {worker_wait!r}." + ) + if worker_wait != "block" and not isinstance(self, ParallelEnv): + raise TypeError( + f"Cannot use worker_wait={worker_wait} with envs of type {type(self)}." + ) + self.worker_wait = worker_wait + if spin_for is not None: + if spin_for <= 0: + raise ValueError(f"spin_for must be a positive float, got {spin_for}.") + self.spin_for = spin_for return self def select_and_clone(self, name, tensor, selected_keys=None): @@ -1770,6 +1844,24 @@ def look_for_cuda(tensor, has_cuda=has_cuda): # Eliminates futex syscalls from mp.Event on the critical path. self._shm_done_flags = mp.RawArray("b", _num_workers) + # Shared-memory command flags (parent -> worker direction): the parent + # writes an opcode that workers spin-poll, removing the per-step pipe + # send. Requires buffers since the opcode carries no payload. + worker_wait = self.worker_wait + if worker_wait != "block" and not self._use_buffers: + warn( + f"worker_wait={worker_wait!r} requires use_buffers=True; " + "falling back to worker_wait='block'." + ) + worker_wait = "block" + self._worker_wait = worker_wait + if worker_wait != "block": + self._shm_cmd_flags = mp.RawArray("b", _num_workers) + self._shm_worker_states = mp.RawArray("b", _num_workers) + else: + self._shm_cmd_flags = None + self._shm_worker_states = None + kwargs = [{"mp_event": self._events[i]} for i in range(_num_workers)] if self._use_buffers: for i in range(_num_workers): @@ -1777,6 +1869,10 @@ def look_for_cuda(tensor, has_cuda=has_cuda): { "worker_idx": i, "shm_done_flags": self._shm_done_flags, + "shm_cmd_flags": self._shm_cmd_flags, + "shm_worker_states": self._shm_worker_states, + "worker_wait": worker_wait, + "spin_for": self.spin_for, } ) with clear_mpi_env_vars(): @@ -2022,7 +2118,7 @@ def step_and_maybe_reset( self._sync_m2w() for i, _data in zip(workers_range, data): - self.parent_channels[i].send(("step_and_maybe_reset", _data)) + self._send_hot_cmd(i, "step_and_maybe_reset", _data) self._wait_for_workers(workers_range) if self._non_tensor_keys: @@ -2132,6 +2228,24 @@ def select_and_transfer(x, y): return tensordict, tensordict_ + def _send_hot_cmd(self, i: int, cmd: str, data: dict) -> None: + """Dispatch a hot-path command to worker ``i``. + + When shared-memory command flags are enabled (``worker_wait`` is + "adaptive" or "spin") and the command carries no payload, the command + is published as an opcode in shared memory (no syscall, no pickling). + If the worker advertised that it fell back to a blocking pipe wait, a + "wake" message is additionally sent through the pipe. + + Commands with a payload always go through the pipe. + """ + if self._shm_cmd_flags is not None and not data: + self._shm_cmd_flags[i] = _STR_TO_CMD[cmd] + if self._shm_worker_states[i] == _WORKER_SLEEPING: + self.parent_channels[i].send(("wake", None)) + else: + self.parent_channels[i].send((cmd, data)) + def _wait_for_workers(self, workers_range): """Wait for all workers to signal completion. @@ -2392,7 +2506,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: self.event.synchronize() for i in workers_range: - self.parent_channels[i].send(("step", data[i])) + self._send_hot_cmd(i, "step", data[i]) self._wait_for_workers(workers_range) @@ -2652,6 +2766,9 @@ def _shutdown_workers(self) -> None: self._cuda_events = None self._events = None self.event = None + self._shm_done_flags = None + self._shm_cmd_flags = None + self._shm_worker_states = None @_check_start def set_seed( @@ -2765,6 +2882,10 @@ def _run_worker_pipe_shared_mem( filter_warnings: bool = False, worker_idx: int | None = None, shm_done_flags=None, + shm_cmd_flags=None, + shm_worker_states=None, + worker_wait: Literal["block", "adaptive", "spin"] = "block", + spin_for: float = 1e-3, ) -> None: pid = os.getpid() # Handle warning filtering (moved from _ProcessNoWarn) @@ -2813,7 +2934,6 @@ def look_for_cuda(tensor, has_cuda=has_cuda): _last_cmd = "N/A" # Create a timeit instance to track elapsed time since worker start # Use shared memory for done signaling (avoids futex syscalls). - # Command delivery still goes through pipes (kernel wakeup is efficient). if shm_done_flags is not None and worker_idx is not None: def _signal_done(): @@ -2824,28 +2944,90 @@ def _signal_done(): def _signal_done(): mp_event.set() + def _raise_timeout(): + torchrl_logger.debug( + f"batched_env worker {pid}: TIMEOUT after {_timeout}s waiting for cmd, " + f"elapsed_since_start={_worker_timer.elapsed():.1f}s, " + f"last_cmd={_last_cmd}, cmd_count={_cmd_count}" + ) + raise TimeoutError( + f"Worker timed out after {_timeout}s, " + f"increase timeout if needed through the BATCHED_PIPE_TIMEOUT environment variable." + ) + + # Command delivery: with worker_wait="block" (default), commands arrive + # through the pipe (kernel wakeup). With "adaptive"/"spin", payload-free + # hot-path commands arrive as opcodes in shared memory that we spin-poll, + # removing the per-step syscall + pickling; the pipe is still checked + # periodically for cold-path commands (reset, seed, close, attributes...). + if shm_cmd_flags is None or worker_idx is None or worker_wait == "block": + + def _wait_cmd(): + if child_pipe.poll(_timeout): + return child_pipe.recv() + _raise_timeout() + + else: + _adaptive = worker_wait == "adaptive" + + def _sleep_until_cmd(t_start): + # Advertise that we are about to block on the pipe so the parent + # sends a "wake" message in addition to the shared-memory opcode. + # The short poll timeout covers the (theoretical) window where the + # parent reads the state before our write becomes visible. + shm_worker_states[worker_idx] = _WORKER_SLEEPING + try: + while True: + op = shm_cmd_flags[worker_idx] + if op != _CMD_NONE: + shm_cmd_flags[worker_idx] = _CMD_NONE + return _CMD_TO_STR[op], None + if child_pipe.poll(_WAKE_RECHECK_INTERVAL): + cmd_data = child_pipe.recv() + if cmd_data[0] != "wake": + return cmd_data + # "wake": loop back and consume the opcode. + continue + if time.time() - t_start > _timeout: + _raise_timeout() + finally: + shm_worker_states[worker_idx] = _WORKER_AWAKE + + def _wait_cmd(): + t_start = time.time() + n_spins = 0 + while True: + op = shm_cmd_flags[worker_idx] + if op != _CMD_NONE: + shm_cmd_flags[worker_idx] = _CMD_NONE + return _CMD_TO_STR[op], None + n_spins += 1 + # Periodically check the pipe (cold-path commands, stale + # wakes, EOF on parent death) and the timeout. + if (n_spins & 0x3FF) == 0: + if child_pipe.poll(0): + cmd_data = child_pipe.recv() + if cmd_data[0] != "wake": + return cmd_data + # Stale wake from an already-consumed opcode: ignore. + continue + elapsed = time.time() - t_start + if _adaptive and elapsed > spin_for: + return _sleep_until_cmd(t_start) + if elapsed > _timeout: + _raise_timeout() + _worker_timer = timeit(f"batched_env_worker/{pid}/lifetime").start() while True: try: - if child_pipe.poll(_timeout): - cmd, data = child_pipe.recv() - _cmd_count += 1 - _last_cmd = cmd - # Log every 1000 commands - if _cmd_count % 1000 == 0: - torchrl_logger.debug( - f"batched_env worker {pid}: cmd_count={_cmd_count}, " - f"elapsed={_worker_timer.elapsed():.1f}s, last_cmd={cmd}" - ) - else: + cmd, data = _wait_cmd() + _cmd_count += 1 + _last_cmd = cmd + # Log every 1000 commands + if _cmd_count % 1000 == 0: torchrl_logger.debug( - f"batched_env worker {pid}: TIMEOUT after {_timeout}s waiting for cmd, " - f"elapsed_since_start={_worker_timer.elapsed():.1f}s, " - f"last_cmd={_last_cmd}, cmd_count={_cmd_count}" - ) - raise TimeoutError( - f"Worker timed out after {_timeout}s, " - f"increase timeout if needed through the BATCHED_PIPE_TIMEOUT environment variable." + f"batched_env worker {pid}: cmd_count={_cmd_count}, " + f"elapsed={_worker_timer.elapsed():.1f}s, last_cmd={cmd}" ) except EOFError as err: torchrl_logger.debug( diff --git a/torchrl/trainers/algorithms/configs/collectors.py b/torchrl/trainers/algorithms/configs/collectors.py index dac802a0e67..5ac080f7acc 100644 --- a/torchrl/trainers/algorithms/configs/collectors.py +++ b/torchrl/trainers/algorithms/configs/collectors.py @@ -161,6 +161,7 @@ class MultiSyncCollectorConfig(BaseCollectorConfig): cat_results: Any = None set_truncated: bool = False use_buffers: bool = False + buffer_depth: int | None = None replay_buffer: ConfigBase | None = None extend_buffer: bool = False trust_policy: bool = True @@ -223,6 +224,7 @@ class MultiAsyncCollectorConfig(BaseCollectorConfig): cat_results: Any = None set_truncated: bool = False use_buffers: bool = False + buffer_depth: int | None = None replay_buffer: ConfigBase | None = None extend_buffer: bool = False trust_policy: bool = True diff --git a/torchrl/trainers/algorithms/configs/envs.py b/torchrl/trainers/algorithms/configs/envs.py index 2d325f557d0..3830d6eb4ef 100644 --- a/torchrl/trainers/algorithms/configs/envs.py +++ b/torchrl/trainers/algorithms/configs/envs.py @@ -34,6 +34,10 @@ class BatchedEnvConfig(EnvConfig): create_env_kwargs: dict = field(default_factory=dict) batched_env_type: str = "parallel" device: str | None = None + # How ParallelEnv workers wait for commands ("block", "adaptive" or "spin"); + # only forwarded when set. See torchrl.envs.ParallelEnv. + worker_wait: str | None = None + spin_for: float | None = None # batched_env_type: Literal["parallel", "serial", "async"] = "parallel" _target_: str = "torchrl.trainers.algorithms.configs.envs.make_batched_env" @@ -94,6 +98,12 @@ def env_fn(env_instance=env_instance): if device is not None: kwargs["device"] = device + # worker_wait / spin_for are ParallelEnv-only; drop the unset defaults so + # SerialEnv / AsyncEnvPool don't receive unexpected kwargs. + for key in ("worker_wait", "spin_for"): + if kwargs.get(key, None) is None: + kwargs.pop(key, None) + if batched_env_type == "parallel": return ParallelEnv(num_workers, env_fn, **kwargs) elif batched_env_type == "serial": From 5363e83ceb8dd62acbb0c83afcd7697065ba9829 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 11 Jun 2026 17:36:57 +0100 Subject: [PATCH 2/2] [Performance] Close env in worker_wait benchmark to avoid leaking spinning workers The benchmark files share one pytest session; without an explicit close the "spin" case leaves three busy-waiting workers burning cores for the rest of the session, which can skew subsequent benchmarks. Co-Authored-By: Claude Fable 5 --- benchmarks/test_envs_benchmark.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/benchmarks/test_envs_benchmark.py b/benchmarks/test_envs_benchmark.py index 9877ead1cfe..ff028938161 100644 --- a/benchmarks/test_envs_benchmark.py +++ b/benchmarks/test_envs_benchmark.py @@ -105,8 +105,14 @@ def test_parallel_worker_wait(benchmark, worker_wait): lambda: DMControlEnv("cheetah", "run", device=device), worker_wait=worker_wait, ) - env.rollout(3) - benchmark(execute_env, env) + try: + env.rollout(3) + benchmark(execute_env, env) + finally: + # Close eagerly: in "spin" mode idle workers busy-wait at 100% CPU and + # would otherwise keep burning cores for the rest of the pytest session, + # skewing every subsequent benchmark. + env.close() @pytest.mark.parametrize("nested", [True, False])