-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MSELoss #435
base: master
Are you sure you want to change the base?
MSELoss #435
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good performance!
if reduction == Reduction.NONE.value: | ||
return func(inp, target) | ||
|
||
M = inp.numel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest making inp and target contiguous
src/flag_gems/ops/mse_loss.py
Outdated
M = inp.numel() | ||
dtype = inp.dtype | ||
if dtype is torch.bool: | ||
inp = inp.to(torch.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about target? does it need to be upcasted unto torch.int64?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fould that torch mse do not support bool and int.
RuntimeError: "mse_cuda" not implemented for 'Bool'
RuntimeError: "mse_cuda" not implemented for 'Bool'
src/flag_gems/ops/mse_loss.py
Outdated
@triton.jit | ||
def kernel_1(inp, target, mid, M, BLOCK_SIZE: tl.constexpr, reduction: tl.constexpr): | ||
if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( | ||
inp.dtype.element_ty == tl.bfloat16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype of input tensor will be fixed as constant when compiling. it's not necessary to specify it explicitly.
src/flag_gems/ops/mse_loss.py
Outdated
): | ||
cdtype = tl.float32 | ||
else: | ||
cdtype = inp.dtype.element_ty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assigning cdtype as tl.float32 is more simple.
src/flag_gems/ops/mse_loss.py
Outdated
): | ||
cdtype = tl.float32 | ||
else: | ||
cdtype = mid.dtype.element_ty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
src/flag_gems/ops/mse_loss.py
Outdated
mid_size = triton.cdiv(M, block_size) | ||
block_mid = triton.next_power_of_2(mid_size) | ||
|
||
mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initializing mid as torch.float32 might improve the precision.
tests/test_reduction_ops.py
Outdated
def test_accuracy_mse_loss(shape, dtype, reduction): | ||
dim = 1 | ||
inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) | ||
target = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May requires_grad be set to false?
inp = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) | ||
target = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) | ||
|
||
ref_inp = to_reference(inp, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why upcast=True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the dtype is float and the operation involves reduction, setting upcast=True is necessary to obtain a higher precision for reference ?
@StrongSpoon @0x45f Hi, I don not understand why the converage CI failed. |
PR Category
Type of Change
Description
Issue
Closes: #396
Progress
Performance
On NVIDIA A10
![image](https://private-user-images.githubusercontent.com/38181615/409803972-ca9d1b7b-4a4f-4b5e-b738-43f0c621926b.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkyMzcyMTYsIm5iZiI6MTczOTIzNjkxNiwicGF0aCI6Ii8zODE4MTYxNS80MDk4MDM5NzItY2E5ZDFiN2ItNGE0Zi00YjVlLWI3MzgtNDNmMGM2MjE5MjZiLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTElMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjExVDAxMjE1NlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWE1OWQ2NGZlZTExOGI5YzFhZTRiY2EzZjlkZWU1Nzk1YmU2OTMwNzJkZDA1YmZjNTQwYzQzYTAyMjA3ODFhNmQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.g42_TdhKZ_7Pru8pWHpNEakc9TQhf7wmSKbzmm1iMA4)