Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MSELoss #435

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open

MSELoss #435

wants to merge 4 commits into from

Conversation

awayzjj
Copy link
Collaborator

@awayzjj awayzjj commented Jan 29, 2025

PR Category

Type of Change

Description

Issue

Closes: #396

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

On NVIDIA A10
image

@0x45f 0x45f self-assigned this Feb 7, 2025
Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good performance!

if reduction == Reduction.NONE.value:
return func(inp, target)

M = inp.numel()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest making inp and target contiguous

M = inp.numel()
dtype = inp.dtype
if dtype is torch.bool:
inp = inp.to(torch.int64)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about target? does it need to be upcasted unto torch.int64?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fould that torch mse do not support bool and int.

RuntimeError: "mse_cuda" not implemented for 'Bool'
RuntimeError: "mse_cuda" not implemented for 'Bool'

@triton.jit
def kernel_1(inp, target, mid, M, BLOCK_SIZE: tl.constexpr, reduction: tl.constexpr):
if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
inp.dtype.element_ty == tl.bfloat16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype of input tensor will be fixed as constant when compiling. it's not necessary to specify it explicitly.

):
cdtype = tl.float32
else:
cdtype = inp.dtype.element_ty
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assigning cdtype as tl.float32 is more simple.

):
cdtype = tl.float32
else:
cdtype = mid.dtype.element_ty
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

mid_size = triton.cdiv(M, block_size)
block_mid = triton.next_power_of_2(mid_size)

mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initializing mid as torch.float32 might improve the precision.

def test_accuracy_mse_loss(shape, dtype, reduction):
dim = 1
inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
target = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May requires_grad be set to false?

inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)
target = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True)

ref_inp = to_reference(inp, True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why upcast=True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the dtype is float and the operation involves reduction, setting upcast=True is necessary to obtain a higher precision for reference ?

@awayzjj awayzjj requested review from StrongSpoon and 0x45f February 8, 2025 07:31
@awayzjj
Copy link
Collaborator Author

awayzjj commented Feb 8, 2025

@StrongSpoon @0x45f Hi, I don not understand why the converage CI failed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Code Contribution: [Hard] [Operator Development] mse_loss
3 participants