Skip to content

Fix LayerNorm crash when model.half() is used#2729

Open
OWU-4f5755 wants to merge 1 commit into
openai:mainfrom
OWU-4f5755:fix/layernorm-dtype-defense
Open

Fix LayerNorm crash when model.half() is used#2729
OWU-4f5755 wants to merge 1 commit into
openai:mainfrom
OWU-4f5755:fix/layernorm-dtype-defense

Conversation

@OWU-4f5755

Copy link
Copy Markdown

LayerNorm.forward() casts the input to fp32 but doesn't cast its own weight/bias, so calling model.half() before transcription causes:

RuntimeError: expected scalar type Float but found Half

Linear and Conv1d in the same file already guard against this by casting their weights to match the input dtype. This PR does the same for LayerNorm.

# Before — weight/bias stay fp16 after model.half()
class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)

# After — explicitly cast weight/bias to fp32
class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return F.layer_norm(
            x.float(),
            self.normalized_shape,
            self.weight.float() if self.weight is not None else None,
            self.bias.float() if self.bias is not None else None,
            self.eps,
        ).type(x.dtype)

Observed no overhead in the normal case (.float() on an fp32 tensor is a no-op). Tested with model.half() + fp16=True and the standard path — both work.

…as to fp32 in LayerNorm.forward(), matching the pattern already used in Linear and Conv1d. Thus, RuntimeError is prevented ('expected scalar type Float but found Half') when you call model.half() prior to transcription. Tested: model.half() + fp16=True transcription works. Standard path no.half() also works and isn't affected.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant