From bed0f68cc3837e31d1966b7b32373f6235a8bb34 Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Sun, 29 Dec 2024 17:38:28 -0800 Subject: [PATCH] Fix custom modules pyre fix me issues Differential Revision: D67706756 --- captum/attr/_utils/custom_modules.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/captum/attr/_utils/custom_modules.py b/captum/attr/_utils/custom_modules.py index a666cfce6a..6593bc33c8 100644 --- a/captum/attr/_utils/custom_modules.py +++ b/captum/attr/_utils/custom_modules.py @@ -2,6 +2,7 @@ # pyre-strict import torch.nn as nn +from torch import Tensor class Addition_Module(nn.Module): @@ -12,7 +13,5 @@ class Addition_Module(nn.Module): def __init__(self) -> None: super().__init__() - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def forward(self, x1, x2): + def forward(self, x1: Tensor, x2: Tensor) -> Tensor: return x1 + x2