Skip to content

[RFC] Mask-aware reduction across LossModule subclasses #3866

@theap06

Description

@theap06

Background

[#3695 added SliceSampler(pad_output=True), which returns sequence batches with a ("collector", "mask") boolean key flagging real positions (True) from duplicated-last-step padding (False). [#3850 wired the LossModule base class to honor this mask in its reduction helper, and adopted it in BCLoss as the reference case.

Most other loss modules still call _reduce(loss, reduction=self.reduction) directly, which silently averages padded positions into the gradient. Each remaining loss needs a near-mechanical one-line migration to route through self._reduce_loss(loss, tensordict=tensordict) instead.

Behavior is byte-identical when no mask key is present, so this is a safe back-compatible change. Together these migrations close the loss-side gap from the sequence-RL composability work in [#3695

The pattern

The reference diff is [PR #3850] . Open it side-by-side with your target loss while migrating.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions