From acf9fb4f15416568de02dc8085a5e451fadf9b6f Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Fri, 7 Feb 2025 21:51:21 -0500 Subject: [PATCH] Fix PT2.6 linting issues (#1002) --- src/fairseq2/models/transformer/_factory.py | 8 +++++++- src/fairseq2/models/wav2vec2/_position_encoder.py | 8 +++++--- src/fairseq2/nn/transformer/_decoder.py | 4 ++-- src/fairseq2/nn/transformer/_encoder.py | 4 ++-- src/fairseq2/recipes/lm/_preference_finetune/_common.py | 2 ++ src/fairseq2/recipes/lm/_preference_finetune/_dpo.py | 1 + src/fairseq2/recipes/trainer.py | 2 +- src/fairseq2/recipes/wav2vec2/asr/_train.py | 4 ++-- 8 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/fairseq2/models/transformer/_factory.py b/src/fairseq2/models/transformer/_factory.py index b99ed905e..70c8cdd0b 100644 --- a/src/fairseq2/models/transformer/_factory.py +++ b/src/fairseq2/models/transformer/_factory.py @@ -14,6 +14,7 @@ from fairseq2.models.transformer._model import TransformerModel from fairseq2.nn import ( Embedding, + Linear, PositionEncoder, Projection, SinusoidalPositionEncoder, @@ -163,4 +164,9 @@ def create_decoder_layer(self) -> TransformerDecoderLayer: ) def create_final_proj(self, embed: Embedding) -> Projection: - return TiedProjection(embed.weight, bias=None) + config = self._config + + if isinstance(embed, StandardEmbedding): + return TiedProjection(embed.weight, bias=None) + + return Linear(config.model_dim, config.vocab_info.size, bias=False) diff --git a/src/fairseq2/models/wav2vec2/_position_encoder.py b/src/fairseq2/models/wav2vec2/_position_encoder.py index 13cdd7e25..c44bcc030 100644 --- a/src/fairseq2/models/wav2vec2/_position_encoder.py +++ b/src/fairseq2/models/wav2vec2/_position_encoder.py @@ -7,7 +7,7 @@ from __future__ import annotations import warnings -from typing import final +from typing import cast, final from warnings import catch_warnings import torch.nn as nn @@ -130,8 +130,10 @@ def reset_parameters(self) -> None: weight_norm(self, dim=2) - self.weight_v.requires_grad_(weight.requires_grad) - self.weight_g.requires_grad_(weight.requires_grad) + requires_grad = cast(bool, weight.requires_grad) + + self.weight_v.requires_grad_(requires_grad) + self.weight_g.requires_grad_(requires_grad) if self.bias is not None: nn.init.constant_(self.bias, 0.0) diff --git a/src/fairseq2/nn/transformer/_decoder.py b/src/fairseq2/nn/transformer/_decoder.py index 1cca32c7e..9840cf2ae 100644 --- a/src/fairseq2/nn/transformer/_decoder.py +++ b/src/fairseq2/nn/transformer/_decoder.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from collections.abc import Iterable, Iterator -from typing import Protocol, final +from typing import Protocol, cast, final import torch from torch import Generator, Tensor @@ -193,7 +193,7 @@ def __init__( if not layer_list: raise ValueError("`layers` must be non-empty.") - model_dim = layer_list[0].model_dim + model_dim = cast(int, layer_list[0].model_dim) super().__init__(model_dim) diff --git a/src/fairseq2/nn/transformer/_encoder.py b/src/fairseq2/nn/transformer/_encoder.py index 4354fbe6a..14c6682fc 100644 --- a/src/fairseq2/nn/transformer/_encoder.py +++ b/src/fairseq2/nn/transformer/_encoder.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from collections.abc import Iterable, Iterator -from typing import Any, Protocol, final +from typing import Any, Protocol, cast, final import torch from torch import Generator, Tensor @@ -169,7 +169,7 @@ def __init__( if not layer_list: raise ValueError("`layers` must be non-empty.") - model_dim = layer_list[0].model_dim + model_dim = cast(int, layer_list[0].model_dim) super().__init__(model_dim) diff --git a/src/fairseq2/recipes/lm/_preference_finetune/_common.py b/src/fairseq2/recipes/lm/_preference_finetune/_common.py index d94644caa..f9771f626 100644 --- a/src/fairseq2/recipes/lm/_preference_finetune/_common.py +++ b/src/fairseq2/recipes/lm/_preference_finetune/_common.py @@ -19,6 +19,7 @@ def _gather_lprobs(output: SequenceModelOutput, target: SequenceBatch) -> Tensor: + assert target.target_mask is not None logprobs = torch.log_softmax(output.logits, dim=-1) chosen_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze(-1) chosen_logps = (chosen_logps * target.target_mask).sum(dim=-1) # [Batch, 1] @@ -29,6 +30,7 @@ def _gather_lprobs(output: SequenceModelOutput, target: SequenceBatch) -> Tensor def _gather_lprobs_avg( output: SequenceModelOutput, target: SequenceBatch ) -> tuple[Tensor, Tensor]: + assert target.target_mask is not None logprobs = torch.log_softmax(output.logits, dim=-1) per_token_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze(-1) total_logps = (per_token_logps * target.target_mask).sum(dim=-1) # [Batch, 1] diff --git a/src/fairseq2/recipes/lm/_preference_finetune/_dpo.py b/src/fairseq2/recipes/lm/_preference_finetune/_dpo.py index 33e37e266..6364e7814 100644 --- a/src/fairseq2/recipes/lm/_preference_finetune/_dpo.py +++ b/src/fairseq2/recipes/lm/_preference_finetune/_dpo.py @@ -168,6 +168,7 @@ def __call__(self, batch: PreferenceBatch) -> tuple[Tensor, int]: def _gather_lprobs( self, output: SequenceModelOutput, target: SequenceBatch ) -> tuple[Tensor, Tensor]: + assert target.target_mask is not None logprobs = torch.log_softmax(output.logits, dim=-1) per_token_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze( -1 diff --git a/src/fairseq2/recipes/trainer.py b/src/fairseq2/recipes/trainer.py index 038a05191..4e2d6c46a 100644 --- a/src/fairseq2/recipes/trainer.py +++ b/src/fairseq2/recipes/trainer.py @@ -804,7 +804,7 @@ def _maybe_no_sync( self, batch_nr: int, num_batches: int ) -> AbstractContextManager[None]: if batch_nr < num_batches - 1 and self._gangs.dp.size > 1: - return self._model.no_sync() # type: ignore[no-any-return] + return self._model.no_sync() # type: ignore[no-any-return, operator] return nullcontext() diff --git a/src/fairseq2/recipes/wav2vec2/asr/_train.py b/src/fairseq2/recipes/wav2vec2/asr/_train.py index e47606bd8..5cb499543 100644 --- a/src/fairseq2/recipes/wav2vec2/asr/_train.py +++ b/src/fairseq2/recipes/wav2vec2/asr/_train.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import final +from typing import cast, final import torch from torch import Tensor @@ -399,7 +399,7 @@ def set_step_nr(self, step_nr: int) -> None: if isinstance(self._model, Wav2Vec2AsrModel): model = self._model else: - model = self._model.module # DDP or FSDP + model = cast(Wav2Vec2AsrModel, self._model.module) # DDP or FSDP if step_nr <= self._freeze_encoder_for_n_steps: if step_nr == 1: