Skip to content

Commit

Permalink
Adds margin_ranking_loss opinfo
Browse files Browse the repository at this point in the history
per title
Pull Request resolved: pytorch#75887
Approved by: https://github.com/ngimel
  • Loading branch information
Mike Ruberry authored and pytorchmergebot committed Apr 15, 2022
1 parent 5dcbcc6 commit 2da43ec
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2799,6 +2799,24 @@ def sample_inputs_randint_like(self, device, dtype, requires_grad, **kwargs):
kwargs=sample.kwargs))
return tuple(samples)

# TODO: add reduction kwargs
def sample_inputs_margin_ranking_loss(op_info, device, dtype, requires_grad, **kwargs):
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

shapes = (
(),
(S,),
(S, S),
(S, S, S),
)

for shape in shapes:
for kwargs in [{}, {'margin': 1.0}]:
yield SampleInput(_make_tensor(shape),
args=(_make_tensor(shape, requires_grad=False),
_make_tensor(shape, requires_grad=False)),
kwargs=kwargs)

def sample_inputs_new_fns(self, device, dtype, requires_grad, **kwargs):
inputs = [
((), (), {}),
Expand Down Expand Up @@ -11899,6 +11917,19 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
),
supports_out=False),
OpInfo(
"nn.functional.margin_ranking_loss",
ref=_NOTHING,
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
supports_out=False,
sample_inputs_func=sample_inputs_margin_ranking_loss,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
# target doesn't require grad
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_floating_inputs_are_differentiable'),
)),
OpInfo(
"nn.functional.multi_margin_loss",
ref=_NOTHING,
Expand Down

0 comments on commit 2da43ec

Please sign in to comment.