Skip to content

Commit

Permalink
adding a decomposition and test for triu
Browse files Browse the repository at this point in the history
  • Loading branch information
protonu committed Jan 10, 2025
1 parent 19f3d5c commit bd2e5f6
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
24 changes: 24 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2913,6 +2913,30 @@ def tril_sample_generator(op, device, dtype, requires_grad, **kwargs):
)
conditional_and_mask_ops.append(tril_opinfo)

triu_opinfo = OpInfo(
ltorch.triu,
sample_input_generator=tril_sample_generator,
torch_reference=torch.triu,
test_directives=(
# Not all PyTorch versions support complex32 tril
DecorateInfo(
pytest.mark.xfail,
"test_core_vs_torch_consistency",
dtypes=(datatypes.complex32,),
),
# PyTorch 2.0 doesn't support CUDA bfloat16 tril
DecorateInfo(
pytest.mark.xfail,
"test_core_vs_torch_consistency",
devicetypes=(devices.DeviceType.CUDA,),
dtypes=(datatypes.bfloat16,),
active_if=(LooseVersion(torch.__version__) < "2.1"),
),
),
)

conditional_and_mask_ops.append(triu_opinfo)

# Puts all elementwise ternary opinfos into the "opinfos" list
opinfos.extend(conditional_and_mask_ops)

Expand Down
23 changes: 23 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2409,6 +2409,29 @@ def tril_(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = No
return prims.copy_(tril(a, diagonal, fill_value=fill_value), a)


# NOTE triu is the same as tril except that we modify the inequality to return the upper triangluar
# NOTE matrix instead of the lower triangular matrix.
@torchsymbol(torch.triu, is_method=True)
def triu(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = None) -> TensorLike:
utils.check(a.ndim >= 2, lambda: f"triu: a ({a.ndim=}) must have at least two dimensions")

nrows, ncols = a.shape[-2:]
row_numbers = arange(nrows, device=a.device).unsqueeze(-1)
col_numbers = arange(ncols, device=a.device).unsqueeze(-2)

mask = (col_numbers - row_numbers) >= diagonal

if fill_value is None:
fill_value = 0

return _mask_tensor(a, mask, fill_value)


@torchsymbol(torch.Tensor.triu_, is_method=True, tags=(prims.OpTags.IN_PLACE,))
def triu_(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = None) -> TensorLike:
return prims.copy_(triu(a, diagonal, fill_value=fill_value), a)


@torchsymbol(torch.where, is_method=True)
def where(
pred: TensorLike,
Expand Down
2 changes: 0 additions & 2 deletions thunder/torch/default_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@
torch.trapz,
torch.triangular_solve,
torch.triplet_margin_loss,
torch.triu,
torch.unbind_copy,
torch.unfold_copy,
torch.unique_consecutive,
Expand Down Expand Up @@ -609,7 +608,6 @@
torch.Tensor.tolist,
torch.Tensor.trace,
torch.Tensor.triangular_solve,
torch.Tensor.triu,
torch.Tensor.unique,
torch.Tensor.unique_consecutive,
torch.Tensor.unsafe_chunk,
Expand Down

0 comments on commit bd2e5f6

Please sign in to comment.