Skip to content

Commit

Permalink
Fix PT2.6 linting issues (#1002)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Feb 8, 2025
1 parent 89104df commit acf9fb4
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 11 deletions.
8 changes: 7 additions & 1 deletion src/fairseq2/models/transformer/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from fairseq2.models.transformer._model import TransformerModel
from fairseq2.nn import (
Embedding,
Linear,
PositionEncoder,
Projection,
SinusoidalPositionEncoder,
Expand Down Expand Up @@ -163,4 +164,9 @@ def create_decoder_layer(self) -> TransformerDecoderLayer:
)

def create_final_proj(self, embed: Embedding) -> Projection:
return TiedProjection(embed.weight, bias=None)
config = self._config

if isinstance(embed, StandardEmbedding):
return TiedProjection(embed.weight, bias=None)

return Linear(config.model_dim, config.vocab_info.size, bias=False)
8 changes: 5 additions & 3 deletions src/fairseq2/models/wav2vec2/_position_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

import warnings
from typing import final
from typing import cast, final
from warnings import catch_warnings

import torch.nn as nn
Expand Down Expand Up @@ -130,8 +130,10 @@ def reset_parameters(self) -> None:

weight_norm(self, dim=2)

self.weight_v.requires_grad_(weight.requires_grad)
self.weight_g.requires_grad_(weight.requires_grad)
requires_grad = cast(bool, weight.requires_grad)

self.weight_v.requires_grad_(requires_grad)
self.weight_g.requires_grad_(requires_grad)

if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/nn/transformer/_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import Iterable, Iterator
from typing import Protocol, final
from typing import Protocol, cast, final

import torch
from torch import Generator, Tensor
Expand Down Expand Up @@ -193,7 +193,7 @@ def __init__(
if not layer_list:
raise ValueError("`layers` must be non-empty.")

model_dim = layer_list[0].model_dim
model_dim = cast(int, layer_list[0].model_dim)

super().__init__(model_dim)

Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/nn/transformer/_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import Iterable, Iterator
from typing import Any, Protocol, final
from typing import Any, Protocol, cast, final

import torch
from torch import Generator, Tensor
Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(
if not layer_list:
raise ValueError("`layers` must be non-empty.")

model_dim = layer_list[0].model_dim
model_dim = cast(int, layer_list[0].model_dim)

super().__init__(model_dim)

Expand Down
2 changes: 2 additions & 0 deletions src/fairseq2/recipes/lm/_preference_finetune/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


def _gather_lprobs(output: SequenceModelOutput, target: SequenceBatch) -> Tensor:
assert target.target_mask is not None
logprobs = torch.log_softmax(output.logits, dim=-1)
chosen_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze(-1)
chosen_logps = (chosen_logps * target.target_mask).sum(dim=-1) # [Batch, 1]
Expand All @@ -29,6 +30,7 @@ def _gather_lprobs(output: SequenceModelOutput, target: SequenceBatch) -> Tensor
def _gather_lprobs_avg(
output: SequenceModelOutput, target: SequenceBatch
) -> tuple[Tensor, Tensor]:
assert target.target_mask is not None
logprobs = torch.log_softmax(output.logits, dim=-1)
per_token_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze(-1)
total_logps = (per_token_logps * target.target_mask).sum(dim=-1) # [Batch, 1]
Expand Down
1 change: 1 addition & 0 deletions src/fairseq2/recipes/lm/_preference_finetune/_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __call__(self, batch: PreferenceBatch) -> tuple[Tensor, int]:
def _gather_lprobs(
self, output: SequenceModelOutput, target: SequenceBatch
) -> tuple[Tensor, Tensor]:
assert target.target_mask is not None
logprobs = torch.log_softmax(output.logits, dim=-1)
per_token_logps = torch.gather(logprobs, -1, target.seqs.unsqueeze(-1)).squeeze(
-1
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/recipes/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def _maybe_no_sync(
self, batch_nr: int, num_batches: int
) -> AbstractContextManager[None]:
if batch_nr < num_batches - 1 and self._gangs.dp.size > 1:
return self._model.no_sync() # type: ignore[no-any-return]
return self._model.no_sync() # type: ignore[no-any-return, operator]

return nullcontext()

Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/recipes/wav2vec2/asr/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dataclasses import dataclass, field
from pathlib import Path
from typing import final
from typing import cast, final

import torch
from torch import Tensor
Expand Down Expand Up @@ -399,7 +399,7 @@ def set_step_nr(self, step_nr: int) -> None:
if isinstance(self._model, Wav2Vec2AsrModel):
model = self._model
else:
model = self._model.module # DDP or FSDP
model = cast(Wav2Vec2AsrModel, self._model.module) # DDP or FSDP

if step_nr <= self._freeze_encoder_for_n_steps:
if step_nr == 1:
Expand Down

0 comments on commit acf9fb4

Please sign in to comment.