diff --git a/tests/helpers/influence/common.py b/tests/helpers/influence/common.py index 1369d96d4c..7ae1376e11 100644 --- a/tests/helpers/influence/common.py +++ b/tests/helpers/influence/common.py @@ -23,18 +23,16 @@ from torch.utils.data import DataLoader, Dataset -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. -def _isSorted(x, key=lambda x: x, descending=True): +def _isSorted(x, key=lambda x: x, descending=True) -> bool: if descending: - return all([key(x[i]) >= key(x[i + 1]) for i in range(len(x) - 1)]) + return all(key(x[i]) >= key(x[i + 1]) for i in range(len(x) - 1)) else: - return all([key(x[i]) <= key(x[i + 1]) for i in range(len(x) - 1)]) + return all(key(x[i]) <= key(x[i + 1]) for i in range(len(x) - 1)) -# pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. -def _wrap_model_in_dataparallel(net): +def _wrap_model_in_dataparallel(net) -> Module: alt_device_ids = [0] + [x for x in range(torch.cuda.device_count() - 1, 0, -1)] net = net.cuda() return torch.nn.DataParallel(net, device_ids=alt_device_ids) @@ -60,9 +58,7 @@ def __init__( def __len__(self) -> int: return len(self.samples) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]: return (self.samples[idx], self.labels[idx]) @@ -83,8 +79,7 @@ def __len__(self) -> int: return len(self.samples[0]) # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def __getitem__(self, idx): + def __getitem__(self, idx: int): """ The signature of the returning item is: List[List], where the contents are: [sample_0, sample_1, ...] + [labels] (two lists concacenated). @@ -98,10 +93,8 @@ def __init__( num_features: int, use_gpu: bool = False, ) -> None: - # pyre-fixme[4]: Attribute must be annotated. - self.samples = torch.diag(torch.ones(num_features)) - # pyre-fixme[4]: Attribute must be annotated. - self.labels = torch.zeros(num_features).unsqueeze(1) + self.samples: Tensor = torch.diag(torch.ones(num_features)) + self.labels: Tensor = torch.zeros(num_features).unsqueeze(1) if use_gpu: self.samples = self.samples.cuda() self.labels = self.labels.cuda() @@ -115,14 +108,14 @@ def __init__( num_features: int, use_gpu: bool = False, ) -> None: - # pyre-fixme[4]: Attribute must be annotated. - self.samples = ( + self.samples: Tensor = ( torch.arange(start=low, end=high, dtype=torch.float) .repeat(num_features, 1) .transpose(1, 0) ) - # pyre-fixme[4]: Attribute must be annotated. - self.labels = torch.arange(start=low, end=high, dtype=torch.float).unsqueeze(1) + self.labels: Tensor = torch.arange( + start=low, end=high, dtype=torch.float + ).unsqueeze(1) if use_gpu: self.samples = self.samples.cuda() self.labels = self.labels.cuda() @@ -130,8 +123,7 @@ def __init__( class BinaryDataset(ExplicitDataset): def __init__(self, use_gpu: bool = False) -> None: - # pyre-fixme[4]: Attribute must be annotated. - self.samples = F.normalize( + self.samples: Tensor = F.normalize( torch.stack( ( torch.Tensor([1, 1]), @@ -161,8 +153,7 @@ def __init__(self, use_gpu: bool = False) -> None: ) ) ) - # pyre-fixme[4]: Attribute must be annotated. - self.labels = torch.cat( + self.labels: Tensor = torch.cat( ( torch.Tensor([1]).repeat(12, 1), torch.Tensor([-1]).repeat(12, 1), @@ -350,13 +341,10 @@ def get_random_model_and_data( tmpdir, # pyre-fixme[2]: Parameter must be annotated. unpack_inputs, - # pyre-fixme[2]: Parameter must be annotated. - return_test_data=True, + return_test_data: bool = True, gpu_setting: Optional[str] = None, - # pyre-fixme[2]: Parameter must be annotated. - return_hessian_data=False, - # pyre-fixme[2]: Parameter must be annotated. - model_type="random", + return_hessian_data: bool = False, + model_type: str = "random", ): """ returns a model, training data, and optionally data for computing the hessian @@ -534,10 +522,9 @@ def generate_symmetric_matrix_given_eigenvalues( return torch.matmul(Q, torch.matmul(torch.diag(torch.tensor(eigenvalues)), Q.T)) -# pyre-fixme[3]: Return type must be annotated. def generate_assymetric_matrix_given_eigenvalues( eigenvalues: Union[Tensor, List[float]] -): +) -> Tensor: """ following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L105 # noqa: E501 generate assymetric random matrix with specified eigenvalues. this is used in