Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions test/envs/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_td_to_device_mps_safe,
_to_device_mps_safe,
)
from torchrl.envs.env_creator import get_env_metadata
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.transforms.transforms import Tokenizer
from torchrl.envs.utils import check_env_specs
Expand Down Expand Up @@ -999,3 +1000,160 @@ def test_parallel_env_no_buffers_mps_rollout(self):
assert td["observation"].dtype == torch.float32
finally:
env.close(raise_if_closed=False)


_MPS_USE_BUFFERS_WARNING = (
"The environment specs have leaves on an MPS device, which cannot be placed "
"in shared memory"
)
_MPS_USE_BUFFERS_ERROR = (
"use_buffers=True is incompatible with environments whose specs have leaves "
"on an MPS device"
)


class TestParallelEnvMPSBuffers:
"""ParallelEnv use_buffers checks for MPS sub-envs (issue #3066).

These tests fake the device map reported by the env metadata, so they run
on CPU-only machines too.
"""

@staticmethod
def _patch_device_map_to_mps(monkeypatch):
def get_env_metadata_mps(*args, **kwargs):
meta_data = get_env_metadata(*args, **kwargs)
meta_data.device_map = {
key: torch.device("mps") for key in meta_data.device_map
}
return meta_data

monkeypatch.setattr(
"torchrl.envs.batched_envs.get_env_metadata", get_env_metadata_mps
)

def test_parallel_env_mps_leaves_default_use_buffers_false(self, monkeypatch):
self._patch_device_map_to_mps(monkeypatch)
with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING):
env = ParallelEnv(2, ContinuousActionVecMockEnv)
assert env._use_buffers is False

def test_parallel_env_mps_leaves_use_buffers_true_raises(self, monkeypatch):
self._patch_device_map_to_mps(monkeypatch)
with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR):
ParallelEnv(2, ContinuousActionVecMockEnv, use_buffers=True)

def test_parallel_env_mps_leaves_configure_parallel_raises(self, monkeypatch):
self._patch_device_map_to_mps(monkeypatch)
with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING):
env = ParallelEnv(2, ContinuousActionVecMockEnv)
with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR):
env.configure_parallel(use_buffers=True)

def test_parallel_env_mps_leaves_explicit_use_buffers_false(self, monkeypatch):
self._patch_device_map_to_mps(monkeypatch)
env = ParallelEnv(2, ContinuousActionVecMockEnv, use_buffers=False)
assert env._use_buffers is False

def test_serial_env_mps_leaves_keeps_buffers(self, monkeypatch):
# SerialEnv runs in-process, so MPS buffers are fine there
self._patch_device_map_to_mps(monkeypatch)
env = SerialEnv(2, ContinuousActionVecMockEnv)
assert env._use_buffers is True


@pytest.mark.skipif(not _has_mps(), reason="MPS device not available")
class TestMPSSubEnvs:
"""ParallelEnv and collectors over sub-envs living on MPS (issue #3066)."""

class _MPSObsEnv(EnvBase):
"""Minimal env with all spec leaves on MPS.

The observation mirrors the last action so that the parent-worker
round-trip can be checked end-to-end.
"""

def __init__(self, device="mps"):
super().__init__(device=device)
self.observation_spec = Composite(
observation=Unbounded(shape=(3,), device=device), device=device
)
self.action_spec = Unbounded(shape=(1,), device=device)
self.reward_spec = Unbounded(shape=(1,), device=device)

def _reset(self, tensordict):
return TensorDict(
{"observation": torch.zeros(3, device=self.device)},
batch_size=[],
device=self.device,
)

def _step(self, tensordict):
return TensorDict(
{
"observation": tensordict["action"].expand(3).clone(),
"reward": torch.zeros(1, device=self.device),
"done": torch.zeros(1, dtype=torch.bool, device=self.device),
},
batch_size=[],
device=self.device,
)

def _set_seed(self, seed):
return seed

def test_parallel_env_mps_sub_envs_default_warns_and_runs(self):
with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING):
env = ParallelEnv(2, self._MPSObsEnv)
try:
assert env._use_buffers is False
td = env.reset()
assert td.device.type == "mps"
assert td["observation"].device.type == "mps"
policy = RandomPolicy(env.action_spec)
rollout = env.rollout(max_steps=3, policy=policy)
assert rollout.device.type == "mps"
# the worker must have seen the actions sampled in the parent
assert (rollout["next", "observation"] == rollout["action"]).all()
finally:
env.close(raise_if_closed=False)

def test_parallel_env_mps_sub_envs_use_buffers_true_raises(self):
with pytest.raises(RuntimeError, match=_MPS_USE_BUFFERS_ERROR):
ParallelEnv(2, self._MPSObsEnv, use_buffers=True)

@pytest.mark.parametrize("consolidate", [True, False])
def test_parallel_env_mps_sub_envs_no_buffers_rollout(self, consolidate):
env = ParallelEnv(
2, self._MPSObsEnv, use_buffers=False, consolidate=consolidate
)
try:
policy = RandomPolicy(env.action_spec)
rollout = env.rollout(max_steps=3, policy=policy)
assert rollout.device.type == "mps"
assert (rollout["next", "observation"] == rollout["action"]).all()
finally:
env.close(raise_if_closed=False)

def test_collector_parallel_env_mps_sub_envs(self):
# the setup reported in issue #3066
with pytest.warns(UserWarning, match=_MPS_USE_BUFFERS_WARNING):
collector = Collector(
lambda: ParallelEnv(2, self._MPSObsEnv),
frames_per_batch=4,
total_frames=8,
)
try:
for data in collector:
assert data.numel() == 4
finally:
collector.shutdown()

def test_serial_env_mps_sub_envs_buffers(self):
env = SerialEnv(2, self._MPSObsEnv)
try:
assert env._use_buffers is True
rollout = env.rollout(max_steps=3)
assert rollout.device.type == "mps"
finally:
env.close(raise_if_closed=False)
1 change: 1 addition & 0 deletions test/transforms/test_action_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2515,6 +2515,7 @@ def test_trailing_dim_enforced(self):
with pytest.raises(ValueError, match="immediately follow"):
t(TensorDict({"action": torch.randn(2, 4, 3)}, batch_size=[2, 4]))

@pytest.mark.skipif(IS_WIN, reason="windows tests do not support compile")
def test_compile_build_chunk(self):
t = ActionChunkTransform(chunk_size=3)
action = torch.randn(2, 5, 2)
Expand Down
Loading
Loading