Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions benchmarks/test_collectors_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,29 @@ def test_async(benchmark):
benchmark(execute_collector, c)


@pytest.mark.parametrize("buffer_depth", [None, 2])
def test_async_buffer_depth(benchmark, buffer_depth):
device = "cuda:0" if torch.cuda.device_count() else "cpu"
env = EnvCreator(
lambda: TransformedEnv(
DMControlEnv("cheetah", "run", device=device), StepCounter(50)
)
)
c = MultiAsyncCollector(
[env, env],
RandomPolicy(env().action_spec),
total_frames=-1,
frames_per_batch=100,
device=device,
buffer_depth=buffer_depth,
)
c = iter(c)
for i, _ in enumerate(c):
if i == 10:
break
benchmark(execute_collector, c)


@pytest.mark.skipif(not torch.cuda.device_count(), reason="no rendering without cuda")
def test_single_pixels(benchmark):
(c,), _ = single_collector_setup_pixels()
Expand Down
18 changes: 18 additions & 0 deletions benchmarks/test_envs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ def test_parallel(benchmark):
benchmark(execute_env, c)


@pytest.mark.parametrize("worker_wait", ["block", "adaptive", "spin"])
def test_parallel_worker_wait(benchmark, worker_wait):
device = "cuda:0" if torch.cuda.device_count() else "cpu"
env = ParallelEnv(
3,
lambda: DMControlEnv("cheetah", "run", device=device),
worker_wait=worker_wait,
)
try:
env.rollout(3)
benchmark(execute_env, env)
finally:
# Close eagerly: in "spin" mode idle workers busy-wait at 100% CPU and
# would otherwise keep burning cores for the rest of the pytest session,
# skewing every subsequent benchmark.
env.close()


@pytest.mark.parametrize("nested", [True, False])
@pytest.mark.parametrize("keep_other", [True, False])
@pytest.mark.parametrize("exclude_reward", [True, False])
Expand Down
87 changes: 87 additions & 0 deletions test/envs/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ContinuousActionVecMockEnv,
CountingEnv,
CountingEnvCountPolicy,
CountingEnvWithString,
DiscreteActionConvMockEnv,
DiscreteActionVecMockEnv,
MockBatchedLockedEnv,
Expand Down Expand Up @@ -1334,3 +1335,89 @@ def test_stackable():
assert not _stackable(*stack)
stack = [TensorDict({"a": "a string"}, []), TensorDict({"a": "another string"}, [])]
assert _stackable(*stack)


class TestWorkerWait:
"""Tests for the shared-memory command signaling of ParallelEnv (worker_wait)."""

@staticmethod
def _rollout(env_cls, parallel_env_cls, n_steps=20, **kwargs):
env = parallel_env_cls(2, lambda: env_cls(max_steps=10), **kwargs)
try:
env.set_seed(0)
torch.manual_seed(0)
return env.rollout(n_steps, break_when_any_done=False)
finally:
env.close()
del env

@pytest.mark.parametrize(
"worker_wait,spin_for",
[
("adaptive", 1e-3),
# tiny spin window: forces the sleep/wake fallback path
("adaptive", 1e-6),
("spin", 1e-3),
],
)
@pytest.mark.parametrize("env_cls", [CountingEnv, NestedCountingEnv])
def test_worker_wait_rollout_parity(
self, worker_wait, spin_for, env_cls, maybe_fork_ParallelEnv
):
r_block = self._rollout(env_cls, maybe_fork_ParallelEnv)
r_fast = self._rollout(
env_cls,
maybe_fork_ParallelEnv,
worker_wait=worker_wait,
spin_for=spin_for,
)
assert_allclose_td(r_block, r_fast)

def test_worker_wait_payload_fallback(self, maybe_fork_ParallelEnv):
# Non-tensor keys put a payload on every step command, which must
# transparently fall back to the pipe path.
env = maybe_fork_ParallelEnv(2, CountingEnvWithString, worker_wait="adaptive")
try:
r = env.rollout(10, break_when_any_done=False)
assert r["string"] is not None
assert r["observation"].shape == (2, 10, 1)
finally:
env.close()
del env

def test_worker_wait_no_buffers_fallback(self, maybe_fork_ParallelEnv):
env = maybe_fork_ParallelEnv(
2,
lambda: CountingEnv(max_steps=10),
use_buffers=False,
worker_wait="adaptive",
)
try:
with pytest.warns(UserWarning, match="requires use_buffers=True"):
r = env.rollout(10, break_when_any_done=False)
assert r["observation"].shape == (2, 10, 1)
finally:
env.close()
del env

def test_worker_wait_validation(self):
with pytest.raises(ValueError, match="worker_wait must be one of"):
ParallelEnv(2, lambda: CountingEnv(), worker_wait="bogus")
with pytest.raises(ValueError, match="spin_for must be a positive float"):
ParallelEnv(2, lambda: CountingEnv(), spin_for=0.0)
with pytest.raises(TypeError, match="Cannot use worker_wait"):
SerialEnv(2, lambda: CountingEnv(), worker_wait="spin")

def test_worker_wait_configure_parallel(self, maybe_fork_ParallelEnv):
env = maybe_fork_ParallelEnv(2, lambda: CountingEnv(max_steps=10))
env.configure_parallel(worker_wait="adaptive", spin_for=1e-4)
assert env.worker_wait == "adaptive"
assert env.spin_for == 1e-4
try:
env.set_seed(0)
torch.manual_seed(0)
r = env.rollout(10, break_when_any_done=False)
assert r["observation"].shape == (2, 10, 1)
finally:
env.close()
del env
98 changes: 98 additions & 0 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7534,6 +7534,104 @@ def env_fn():
assert ("collector", "traj_ids") in rb.sample(1).keys(True)


class TestBufferDepth:
"""Tests for the ring-buffer transport (buffer_depth > 1) of MultiAsyncCollector."""

def test_ring_rotation_and_validity_window(self):
# A single worker makes slot rotation deterministic: batch j lands in
# slot j % buffer_depth.
collector = MultiAsyncCollector(
[functools.partial(CountingEnv, max_steps=20)],
policy=None,
frames_per_batch=16,
total_frames=16 * 6,
buffer_depth=2,
)
try:
views = []
clones = []
for data in collector:
views.append(data)
clones.append(data.clone())
finally:
collector.shutdown()
assert len(views) == 6

# Batches 0, 2, 4 share storage (slot 0); 1, 3, 5 share storage (slot 1).
ptrs = [v["observation"].data_ptr() for v in views]
assert ptrs[0] == ptrs[2] == ptrs[4]
assert ptrs[1] == ptrs[3] == ptrs[5]
assert ptrs[0] != ptrs[1]

# Validity window: slot 0 was last rewritten by batch 4, so the view of
# batch 0 now holds batch 4's data; the most recent batch is intact.
assert (views[0]["observation"] == clones[4]["observation"]).all()
assert (views[-1]["observation"] == clones[-1]["observation"]).all()

@pytest.mark.parametrize("env_cls", [CountingEnv, NestedCountingEnv])
def test_buffer_depth_content_parity(self, env_cls):
# With a single worker and a deterministic policy the batch stream is
# deterministic, so depth-1 (cloned yields) and depth-2 (view yields)
# must coincide. NestedCountingEnv exercises nested action keys.
probe = env_cls(max_steps=20)
policy = CountingEnvCountPolicy(
probe.full_action_spec[probe.action_key], probe.action_key
)

def collect(buffer_depth):
collector = MultiAsyncCollector(
[functools.partial(env_cls, max_steps=20)],
policy=policy,
frames_per_batch=16,
total_frames=64,
buffer_depth=buffer_depth,
)
try:
return [d.clone() for d in collector]
finally:
collector.shutdown()

for d1, d2 in zip(collect(1), collect(2), strict=True):
assert_allclose_td(d1, d2)

def test_buffer_depth_validation(self):
env_fn = functools.partial(CountingEnv, max_steps=20)
with pytest.raises(ValueError, match="buffer_depth must be >= 1"):
MultiAsyncCollector(
[env_fn],
policy=None,
frames_per_batch=16,
total_frames=64,
buffer_depth=0,
)
with pytest.raises(ValueError, match="only supported by MultiAsyncCollector"):
MultiSyncCollector(
[env_fn],
policy=None,
frames_per_batch=16,
total_frames=64,
buffer_depth=2,
)
with pytest.raises(ValueError, match="replay_buffer"):
MultiAsyncCollector(
[env_fn],
policy=None,
frames_per_batch=16,
total_frames=64,
buffer_depth=2,
replay_buffer=ReplayBuffer(storage=LazyTensorStorage(100)),
)
with pytest.raises(ValueError, match="use_buffers=True"):
MultiAsyncCollector(
[env_fn],
policy=None,
frames_per_batch=16,
total_frames=64,
buffer_depth=2,
use_buffers=False,
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main(
Expand Down
22 changes: 22 additions & 0 deletions torchrl/collectors/_multi_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class MultiAsyncCollector(MultiCollector):

__doc__ += MultiCollector.__doc__

_supports_buffer_depth = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.out_tensordicts = defaultdict(lambda: None)
Expand Down Expand Up @@ -195,6 +197,20 @@ def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int:
def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]:
new_data, j = self.queue_out.get(timeout=timeout)
use_buffers = self._use_buffers
if self.buffer_depth > 1:
# Ring-buffer transport: the first message for each (worker, slot)
# pair carries the shared-memory buffer itself, later messages only
# carry the indices. The returned tensordict is a view of the slot:
# it is NOT cloned, since the worker rotates across ``buffer_depth``
# slots and will not rewrite this one before ``buffer_depth - 1``
# further rollouts.
if len(new_data) == 3:
data, idx, slot = new_data
self.out_tensordicts[(idx, slot)] = data
else:
idx, slot = new_data
out = self.out_tensordicts[(idx, slot)]
return idx, j, out
if self.replay_buffer is not None:
idx = new_data
elif j == 0 or not use_buffers:
Expand Down Expand Up @@ -294,6 +310,12 @@ def _shutdown_main(self, *args, **kwargs) -> None:
return super()._shutdown_main(*args, **kwargs)

def reset(self, reset_idx: Sequence[bool] | None = None) -> None:
if self.running and self.buffer_depth == 2:
warnings.warn(
"Calling reset() while iterating grants workers one extra rollout "
"of lookahead: with buffer_depth=2, tensordicts yielded before this "
"call may be overwritten. Clone them if needed, or use buffer_depth=3."
)
super().reset(reset_idx)
if self.queue_out.full():
time.sleep(_TIMEOUT) # wait until queue is empty
Expand Down
45 changes: 45 additions & 0 deletions torchrl/collectors/_multi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,24 @@ class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta):
use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data.
This isn't compatible with environments with dynamic specs. Defaults to ``True``
for envs without dynamic specs, ``False`` for others.
buffer_depth (int, optional): number of rotating shared-memory buffers
per worker. Currently only supported by
:class:`~torchrl.collectors.MultiAsyncCollector`, and incompatible
with ``replay_buffer`` (which bypasses buffer transport entirely)
and ``use_buffers=False``.

With the default (``None``, i.e. ``1``), each yielded batch is cloned
on the main process so the single per-worker buffer can be reused.
With ``buffer_depth=K > 1``, workers write each rollout into one of
``K`` rotating shared-memory slots and the main process yields
zero-copy views instead: a yielded batch remains valid until the
same worker has collected ``K - 1`` further batches. In the steady
state a worker collects at most one batch ahead, so ``K=2`` is safe
for the standard ``for data in collector: ...`` pattern; calling
:meth:`reset` while iterating grants workers one extra rollout of
lookahead, so use ``K=3`` if you reset mid-collection. Clone the
yielded tensordict to keep it indefinitely, and treat it as
read-only.
replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts
but populate the buffer instead. Defaults to ``None``.
extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not
Expand Down Expand Up @@ -393,6 +411,9 @@ class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta):

"""

# Whether this collector supports buffer_depth > 1 (ring-buffer transport).
_supports_buffer_depth = False

def __init__(
self,
create_env_fn: Sequence[Callable[[], EnvBase]],
Expand Down Expand Up @@ -425,6 +446,7 @@ def __init__(
cat_results: str | int | None = None,
set_truncated: bool = False,
use_buffers: bool | None = None,
buffer_depth: int | None = None,
replay_buffer: ReplayBuffer | None = None,
extend_buffer: bool = True,
trust_policy: bool | None = None,
Expand Down Expand Up @@ -505,6 +527,28 @@ def __init__(
self._use_buffers = use_buffers
self.replay_buffer = replay_buffer

# Set up ring-buffer transport depth
buffer_depth = 1 if buffer_depth is None else int(buffer_depth)
if buffer_depth < 1:
raise ValueError(f"buffer_depth must be >= 1, got {buffer_depth}.")
if buffer_depth > 1:
if not self._supports_buffer_depth:
raise ValueError(
f"buffer_depth > 1 is currently only supported by MultiAsyncCollector, "
f"not {type(self).__name__}."
)
if replay_buffer is not None:
raise ValueError(
"buffer_depth > 1 has no effect when a replay_buffer is provided: "
"workers write directly into the buffer. Remove one of the two options."
)
if use_buffers is False:
raise ValueError(
"buffer_depth > 1 requires buffer-based transport (use_buffers=True)."
)
self._use_buffers = True
self.buffer_depth = buffer_depth

# Set up policy and weights
if trust_policy is None:
trust_policy = policy is not None and isinstance(policy, CudaGraphModule)
Expand Down Expand Up @@ -1345,6 +1389,7 @@ def _run_processes(self) -> None:
"interruptor": self.interruptor,
"set_truncated": self.set_truncated,
"use_buffers": self._use_buffers,
"buffer_depth": self.buffer_depth,
"replay_buffer": self.replay_buffer,
"extend_buffer": self.extend_buffer,
"traj_pool": self._traj_pool,
Expand Down
Loading
Loading