Skip to content

Commit

Permalink
Adds binary_cross_entropy opinfo
Browse files Browse the repository at this point in the history
Per title.
Pull Request resolved: pytorch#75876
Approved by: https://github.com/ngimel
  • Loading branch information
Mike Ruberry authored and pytorchmergebot committed Apr 16, 2022
1 parent 452ebd0 commit 9a0d1c5
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8016,6 +8016,30 @@ def sample_inputs_pixel_unshuffle(op_info, device, dtype, requires_grad, **kwarg
for downscale_factor in (1, 3)
]

def sample_inputs_binary_cross_entropy(op_info, device, dtype, requires_grad, logits=False, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype)
make_prob = partial(make, low=0, high=1)

reductions = ("mean", "sum", "none")

shapes_and_kwargs = [
*[(shape, None) for shape in ((), (1,), (S,), (S, S), (S, S, S))],
*[((S, S), dict(reduction=reduction)) for reduction in reductions],
*[((S, S), dict(reduction=reduction, weight=make((S, S)))) for reduction in reductions],
]

if logits:
shapes_and_kwargs.extend(
[((S, S), dict(reduction=reduction, pos_weight=make((S,), low=0))) for reduction in reductions]
)

for shape, kwargs in shapes_and_kwargs:
yield SampleInput(
(make if logits else make_prob)(shape, requires_grad=requires_grad),
args=(make_prob(shape, requires_grad=requires_grad),),
kwargs=kwargs,
)

def sample_inputs_allclose(op_info, device, dtype, requires_grad, **kwargs):
samples = []
sample_shapes = [(), (S), (S, S, S)]
Expand Down Expand Up @@ -12640,6 +12664,55 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
DecorateInfo(unittest.skip("We don't want to differentiate wrt running mean / std"),
"TestCommon", "test_floating_inputs_are_differentiable"),),
sample_inputs_func=sample_inputs_batch_norm),
OpInfo(
"nn.functional.binary_cross_entropy",
sample_inputs_func=sample_inputs_binary_cross_entropy,
dtypes=floating_types(),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
supports_out=False,
gradcheck_fast_mode=False,
decorators=(
# RuntimeError: expected int at position 0, but got: Tensor
DecorateInfo(
unittest.skip("Skipped!"),
"TestCudaFuserOpInfo",
"test_nvfuser_correctness",
),
# RuntimeError: expected int at position 0, but got: Tensor
DecorateInfo(
unittest.skip("Skipped!"),
"TestNNCOpInfo",
"test_nnc_correctness",
),
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-3, rtol=1e-3)}),
"TestJit",
"test_variant_consistency_jit",
),
),
skips=(
# RuntimeError: expected int at position 0, but got: Tensor
DecorateInfo(
unittest.expectedFailure,
"TestJit",
"test_variant_consistency_jit",
),
# NotImplementedError: the derivative for 'binary_cross_entropy_backward wrt `target`' is not implemented.
DecorateInfo(
unittest.expectedFailure,
"TestGradients",
"test_fn_gradgrad",
),
# AssertionError: Found a sampled tensor of floating-point dtype torch.float32 sampled with
# requires_grad=False.
# `weight` input does not support gradient.
DecorateInfo(
unittest.expectedFailure,
"TestCommon",
"test_floating_inputs_are_differentiable",
),
),
),
# We have to add 2 OpInfo entry for `igamma` and `igammac`.First is the
# standard entry, second is to run gradcheck tests on the second argument.
BinaryUfuncInfo('igamma',
Expand Down

0 comments on commit 9a0d1c5

Please sign in to comment.