Skip to content

Commit

Permalink
Adds OpInfos for l1_loss and smooth_l1_loss
Browse files Browse the repository at this point in the history
Per title
Pull Request resolved: pytorch#75877
Approved by: https://github.com/ngimel
  • Loading branch information
Mike Ruberry authored and pytorchmergebot committed Apr 15, 2022
1 parent 1c0a01e commit 9d17157
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8021,6 +8021,24 @@ def sample_inputs_allclose(op_info, device, dtype, requires_grad, **kwargs):

return samples

def sample_inputs_l1_loss(op_info, device, dtype, requires_grad, **kwargs):
yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs)

# In addition to the regular test cases, we add two for mixed floating point and complex inputs
if dtype.is_complex:
make = partial(make_tensor, (), device=device, requires_grad=requires_grad)
yield SampleInput(make(dtype=dtype), args=(make(dtype=torch.double),))
yield SampleInput(make(dtype=torch.double), args=(make(dtype=dtype),))

def sample_inputs_smooth_l1_loss(op_info, device, dtype, requires_grad, **kwargs):
yield from sample_inputs_loss(op_info, device, dtype, requires_grad, **kwargs)

make = partial(make_tensor, (S, S), device=device, dtype=dtype, requires_grad=requires_grad)

# This test case always triggers the smooth condition, since absolute difference of input and target
# is smaller than beta
yield SampleInput(make(low=0, high=2), args=(make(low=-2, high=0),), kwargs=dict(beta=5))
yield SampleInput(make(), args=(make(),), kwargs=dict(beta=0))

def sample_inputs_kl_div(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
Expand Down Expand Up @@ -15401,6 +15419,38 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_normal',
active_if=TEST_SCIPY and LooseVersion(scipy.__version__) < "1.4.0"),
)),
OpInfo("nn.functional.smooth_l1_loss",
ref=reference_smooth_l1_loss,
sample_inputs_func=sample_inputs_smooth_l1_loss,
dtypes=floating_types_and(torch.float16, torch.bfloat16),
backward_dtypesIfCPU=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16),
backward_dtypesIfCUDA=floating_types_and(torch.float16),
supports_out=False,
supports_forward_ad=True,
skips=(
# RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED
# at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch.
DecorateInfo(unittest.expectedFailure, "TestJit", "test_variant_consistency_jit"),)),
OpInfo(
"nn.functional.l1_loss",
ref=loss_reference_reduction_wrapper(lambda input, target: np.abs(input - target)),
sample_inputs_func=sample_inputs_l1_loss,
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
backward_dtypes=all_types_and(torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
skips=(
# RuntimeError: input->type()->kind() == TypeKind::OptionalTypeINTERNAL ASSERT FAILED
# at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":270, please report a bug to PyTorch.
DecorateInfo(
unittest.expectedFailure,
"TestJit",
"test_variant_consistency_jit",
dtypes=(torch.float32,),
),
),
),
UnaryUfuncInfo('lgamma',
ref=reference_lgamma if TEST_SCIPY else _NOTHING,
aliases=('special.gammaln', ),
Expand Down

0 comments on commit 9d17157

Please sign in to comment.