Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Incorrect copy behavior from 3.3.0 #1202

Closed
cloudhan opened this issue Nov 19, 2023 · 5 comments
Closed

[BUG] Incorrect copy behavior from 3.3.0 #1202

cloudhan opened this issue Nov 19, 2023 · 5 comments
Labels
bug Something isn't working CuTe CuTe Functionality
Milestone

Comments

@cloudhan
Copy link

Describe the bug

A behavior change is observed in commit c008b4a (cutlass 3.3.0). And should be regarded as a (critical) bug.

Steps/Code to reproduce bug

#include <vector>

#include <cute/pointer.hpp>
#include <cute/tensor.hpp>

using namespace cute;

__global__ void kernel(int m, int n, float* c, int ldc) {
  auto mC = make_tensor(make_gmem_ptr(c), make_layout(make_shape(m, n), make_stride(_1{}, ldc)));

  const auto CtaShape = make_shape(Int<128>{}, Int<128>{}, Int<8>{});
  const auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _);

  auto ctaC = local_tile(mC, CtaShape, cta_coord, make_step(_1{}, _1{}, _));

  auto acc = make_fragment_like<float>(make_layout(make_shape(_8{}, _8{})));
  for (int i=0; i<size(acc); i++) {
    acc(i) = float(i) * 0.01;
  }

  auto threadC = local_tile(ctaC, make_tile(Int<8>{}, Int<8>{}), make_coord(0, 0));
  // NOTE: only 1 thread, this thread produce a block of 8x8 output. The fringe will not be touched.
  copy(acc, threadC);
}

int main() {
  float* dev_buffer;
  cudaMalloc(&dev_buffer, sizeof(float) * 9 * 9);
  cudaMemset(dev_buffer, 0, sizeof(float) * 9 * 9);
  kernel<<<1,1>>>(9, 9, dev_buffer, 9);
  std::vector<int> host_buffer(9 * 9);
  cudaMemcpy(host_buffer.data(), dev_buffer, sizeof(float) * 9 * 9, cudaMemcpyDeviceToHost);
  cudaDeviceSynchronize();

  auto t = make_tensor((float*)host_buffer.data(), make_layout(make_shape(9, 9)));
  print_tensor(t);
}

Expected behavior
Before the commit, the program outputs:

raw_ptr_32b(0x55e35e0d95f0) o (9,9):(_1,9):
  0.00e+00  8.00e-02  1.60e-01  2.40e-01  3.20e-01  4.00e-01  4.80e-01  5.60e-01  0.00e+00
  1.00e-02  9.00e-02  1.70e-01  2.50e-01  3.30e-01  4.10e-01  4.90e-01  5.70e-01  0.00e+00
  2.00e-02  1.00e-01  1.80e-01  2.60e-01  3.40e-01  4.20e-01  5.00e-01  5.80e-01  0.00e+00
  3.00e-02  1.10e-01  1.90e-01  2.70e-01  3.50e-01  4.30e-01  5.10e-01  5.90e-01  0.00e+00
  4.00e-02  1.20e-01  2.00e-01  2.80e-01  3.60e-01  4.40e-01  5.20e-01  6.00e-01  0.00e+00
  5.00e-02  1.30e-01  2.10e-01  2.90e-01  3.70e-01  4.50e-01  5.30e-01  6.10e-01  0.00e+00
  6.00e-02  1.40e-01  2.20e-01  3.00e-01  3.80e-01  4.60e-01  5.40e-01  6.20e-01  0.00e+00
  7.00e-02  1.50e-01  2.30e-01  3.10e-01  3.90e-01  4.70e-01  5.50e-01  6.30e-01  0.00e+00
  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00  0.00e+00

With the commit:

ptr[32b](0x56284f8183b0) o (9,9):(_1,9):
  0.00e+00  9.00e-02  1.80e-01  2.70e-01  3.60e-01  4.50e-01  5.40e-01  6.30e-01  0.00e+00
  1.00e-02  1.00e-01  1.90e-01  2.80e-01  3.70e-01  4.60e-01  5.50e-01  0.00e+00  0.00e+00
  2.00e-02  1.10e-01  2.00e-01  2.90e-01  3.80e-01  4.70e-01  5.60e-01  0.00e+00  0.00e+00
  3.00e-02  1.20e-01  2.10e-01  3.00e-01  3.90e-01  4.80e-01  5.70e-01  0.00e+00  0.00e+00
  4.00e-02  1.30e-01  2.20e-01  3.10e-01  4.00e-01  4.90e-01  5.80e-01  0.00e+00  0.00e+00
  5.00e-02  1.40e-01  2.30e-01  3.20e-01  4.10e-01  5.00e-01  5.90e-01  0.00e+00  0.00e+00
  6.00e-02  1.50e-01  2.40e-01  3.30e-01  4.20e-01  5.10e-01  6.00e-01  0.00e+00  0.00e+00
  7.00e-02  1.60e-01  2.50e-01  3.40e-01  4.30e-01  5.20e-01  6.10e-01  0.00e+00  0.00e+00
  8.00e-02  1.70e-01  2.60e-01  3.50e-01  4.40e-01  5.30e-01  6.20e-01  0.00e+00  0.00e+00

It is obvious that the old behavior is correct and should be maintained.

Environment details (please complete the following information):

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0
@cloudhan cloudhan added ? - Needs Triage bug Something isn't working labels Nov 19, 2023
@cloudhan
Copy link
Author

cloudhan commented Nov 19, 2023

@ccecka I am quite sure this time it is a bug not a feature =)

@ccecka
Copy link

ccecka commented Nov 19, 2023

Thanks, that does looks concerning.

I suspect the Tensors acc and threadC have the same layouts between updates, correct? I have a hunch and suspect this works when all of those 9s are 10s (or maybe 12), correct?

I'll run this as soon as I can and also get it into our unit tests.

As an intermediate workaround, you should be able to replace

copy(acc, threadC);

with

copy_vec<float>(acc, threadC);
// or
copy_if(TrivialPredTensor{}, acc, threadC);

to disable auto-vectorization.

@cloudhan
Copy link
Author

I suspect the Tensors acc and threadC have the same layouts between updates, correct?

Yeah, they are the same, the code snippet is extracted from a complete kernel.

I have a hunch and suspect this works when all of those 9s are 10s (or maybe 12), correct?

If the acc is 8x8, then copy works correct when the threadC's layout is a multiple of 8. From the complete kernel, if acc is 4x4, then the layout should be a multiple of 4.

@ccecka
Copy link

ccecka commented Nov 21, 2023

I've got a fix for this now internally. Not sure about when it can be merged/released, but cute::copy should be much safer soon! Thanks for finding this.

@mnicely mnicely added this to the CUTLASS 3.3 milestone Nov 27, 2023
@mnicely mnicely added the CuTe CuTe Functionality label Dec 6, 2023
@mnicely mnicely modified the milestones: CUTLASS 3.3, CUTLASS 3.4 Dec 6, 2023
@mnicely
Copy link
Collaborator

mnicely commented Dec 7, 2023

Fixed in 3.3 tagging

@mnicely mnicely closed this as completed Dec 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CuTe CuTe Functionality
Projects
None yet
Development

No branches or pull requests

3 participants