Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] How to avoid too many resources requested #1166

Closed
YSF-A opened this issue Nov 2, 2023 · 10 comments
Closed

[QST] How to avoid too many resources requested #1166

YSF-A opened this issue Nov 2, 2023 · 10 comments

Comments

@YSF-A
Copy link

YSF-A commented Nov 2, 2023

What is your question?

I try to use the cutlass::conv::device::Convolution with the fixed ThreadblockShape, WarpShape and InstructionShape. There is internal error which is too many resources requested actually. It may be useful to modify the ThreadblockShape or WarpShape. Is there any other solutions? For example, __launch_bounds__ may be useful in such case of cuda kernel.

@hwu36
Copy link
Collaborator

hwu36 commented Nov 2, 2023

what shapes do you want to use? we don't want to use too many resources which can hurt the performance.

@YSF-A
Copy link
Author

YSF-A commented Nov 3, 2023

@hwu36 hi, the ThreadblockShape is [128, 128, 32], the WarpShape is [32, 32, 32], the InstructionShape is [8, 8, 4].

@hwu36
Copy link
Collaborator

hwu36 commented Nov 3, 2023

What is data type and architecture?

@YSF-A
Copy link
Author

YSF-A commented Nov 3, 2023

What is data type and architecture?

Data type is float16. Data type of internal accumulation is float32. GPU is 2080ti.

@hwu36
Copy link
Collaborator

hwu36 commented Nov 3, 2023

your threadblock size is 128x128, your warp size is 32x32. so you need 128/32 x 128/32 = 16 warps. we usually use 4 or 8 warps. so you'd better use warp size 64x64. if i am not wrong, 2080ti is a turing card. it better use instruction shape 16x8x8.

so you could use threadblock shape 128x128x32, warp shape 64x64x32, and instruction shape 16x8x8. you could find more plausible tile sizes in our profilers https://github.com/NVIDIA/cutlass/blob/main/python/cutlass_library/generator.py#L1375-L1383

@wxthu
Copy link

wxthu commented Nov 13, 2023

how can I convert these tile descriptions to TVM tile size shape? thanks

@hwu36
Copy link
Collaborator

hwu36 commented Nov 13, 2023

how can I convert these tile descriptions to TVM tile size shape? thanks

@masahi

@masahi
Copy link
Contributor

masahi commented Nov 13, 2023

What do you mean by "TVM tile size shape"?

Copy link

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

Copy link

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants