Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix race in async copy #3438

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

fix race in async copy #3438

wants to merge 7 commits into from

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Nov 18, 2024

Fix #3428
What's in this PR?
Revise ReadAfterWriteSyncs to:

  • cpasync_wait_before_ only handles cpasync doesn't need thread syncs.
  • sync_before_ handles regular and cpasync requires thread syncs.

Why?
Before this fix, cpasync_wait_before_ also handles cpasync with thread syncs, it may lead to a case where cp.async.wait_all is inserted after __syncthreads() which leads to race condition as seen in #3428

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test --diff

@liqiangxl liqiangxl marked this pull request as ready for review November 18, 2024 21:05
@liqiangxl
Copy link
Collaborator Author

All code diffs are due to the newly added cp.async.wait_all before__syncthreads()

@liqiangxl
Copy link
Collaborator Author

all H100 failures are due to CI script/hardware issues, pending rerun of them.

@naoyam
Copy link
Collaborator

naoyam commented Nov 19, 2024

Can you please show what the typical code pattern looks like before and after this?

@liqiangxl
Copy link
Collaborator Author

Can you please show what the typical code pattern looks like before and after this?

The code change is simple, just adding a cp.async.wait_all before __syncthreads() when __syncthreads() is due to async copy.

image

Typical code is:

__shared__ float T2[blockDim.y * blockDim.x];
__shared__ float T3[blockDim.x];

// (1) Perform an async copy to shared memory T3, only for threads with threadIdx.y == 0
if (threadIdx.y == 0) {
  cp.async from gmem to T3
}

// Only ensures all threads are executing this sync, doesn't mean the async copy to T3 is returned.
/////////////// need a `cp.async.wait_all` before this __syncthreads /////////////////
__syncthreads();

for (int i11 = 0; i11 < 2; ++i11) {
  // (2) Perform an async copy to shared memory T2, for all threads
  cp.async from gmem to T2

  // (3) Wait for all async copies to complete. 
  // For threads not participating in the copy, no need to wait (my guess).
 // For example, threads with threadIdx.y != 0 don't participate the async copy to T3, so they don't need to wait for copy to T3 is done.
  asm volatile("cp.async.wait_all;\n");

  // (4) Read from T3 and T2
  float T4[1LL];
  T4[0] = T3[threadIdx.x];  // Potential race here
  T5[0] = T2[i7] + T4[0];
}

@liqiangxl
Copy link
Collaborator Author

Ref:
The mandatory .async qualifier indicates that the cp instruction will initiate the memory copy operation asynchronously and control will return to the executing thread before the copy operation is complete. The executing thread can then use cp.async.wait_all or cp.async.wait_group or to wait for completion of the asynchronous copy operation.

https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#instruction-set:~:text=The%20executing%20thread%20can%20then%20use%20cp.async.wait_all%20or%20cp.async.wait_group%20or%20mbarrier%20instructions%20to%20wait%20for%20completion%20of%20the%20asynchronous%20copy%20operation.

@naoyam
Copy link
Collaborator

naoyam commented Nov 19, 2024

why did we insert the wait instruction only for T2 but not for T3?

@liqiangxl
Copy link
Collaborator Author

liqiangxl commented Nov 19, 2024

why did we insert the wait instruction only for T2 but not for T3?

The orginal wait is for both T2 & T3, becuase it waits for the complete of all the previously issued load instructions.
However when one of the load instructions needs a syncthreads, we need to put the wait before the sync, then the load to T2 needs an additional sync before we can read from it.

Originally, we insert wait before the first read of T3, it waits for both T2 & T3.

cp.async from gmem to T3
syncthreads()
cp.async from gmem to T2

// Here we detected a read from T3, and will insert a wait, the wait ensures the load of both T3&T2 are complete. So there
// is no need to insert wait separately for both T2 & T3.
read from T3

After this PR, we insert wait before syncthreads(), then T2 needs an additional wait.

cp.async from gmem to T3
wait T3
syncthreads()

cp.async from gmem to T2

read from T3
wait T2
read from T2

Here is the full kernel for the newly added test NVFuserTestCpAsyncRace.isInlinedhasTIDy/inlined_false_hastidy_true

__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 1, 1> T1, Tensor<float, 2, 2> T5) {
  alignas(16) extern __shared__ char array[];
  const unsigned smem_offset = 0;
  nvfuser_index_t i0;
  i0 = ceilDiv((ceilDiv(T0.logical_size[0LL], 2LL)), 5LL);
  nvfuser_index_t i1;
  i1 = 4LL * ((nvfuser_index_t)threadIdx.x);
  nvfuser_index_t i2;
  i2 = T0.logical_size[1LL] * ((nvfuser_index_t)threadIdx.y);
  nvfuser_index_t i3;
  i3 = (2LL * T0.logical_size[1LL]) * ((nvfuser_index_t)blockIdx.y);
  float* ptr4;
  ptr4 = ((T0.data + ((nvfuser_index_t)threadIdx.x)) + i2) + i3;
  nvfuser_index_t i5;
  i5 = 10LL * T0.logical_size[1LL];
  float* T3 = reinterpret_cast<float*>(array + smem_offset + ((((2LL * T0.logical_size[1LL]) * 4LL) + 15LL) & -16LL));
  float* T2 = reinterpret_cast<float*>(array + smem_offset + 0LL);
  unsigned i6;
  i6 = (toSmem(T2) + i1) + ((4LL * T0.logical_size[1LL]) * ((nvfuser_index_t)threadIdx.y));
  nvfuser_index_t i7;
  i7 = ((nvfuser_index_t)threadIdx.x) + i2;
  nvfuser_index_t i8;
  i8 = i7 + i3;
  bool b9;
  b9 = ((nvfuser_index_t)threadIdx.y) == 0LL;
  nvfuser_index_t i10;
  i10 = ((nvfuser_index_t)threadIdx.y) + (2LL * ((nvfuser_index_t)blockIdx.y));
  if (b9) {
    asm volatile(
      "{\n"
      "  .reg .pred p0; \n"
      "  setp.ne.b32 p0, %3, 0;\n"
      "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
      "}\n"
      :
      :"r"((uint32_t)((toSmem(T3) + i1))),
       "l"((T1.data + ((nvfuser_index_t)threadIdx.x))),
       "n"(4LL),
       "r"((uint32_t)((!b9)))
    );
  }
  asm volatile("cp.async.wait_all;\n");
  __syncthreads();
  #pragma unroll 1
  for(nvfuser_index_t i11 = 0LL; i11 < i0; ++i11) {
    nvfuser_index_t i12;
    i12 = i5 * i11;
    if (((i10 + (10LL * i11)) < T0.logical_size[0LL])) {
      asm volatile(
        "{\n"
        "  .reg .pred p0; \n"
        "  setp.ne.b32 p0, %3, 0;\n"
        "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
        "}\n"
        :
        :"r"((uint32_t)(i6)),
         "l"((ptr4 + i12)),
         "n"(4LL),
         "n"((uint32_t)(false))
      );
      float T4[1LL];
      T4[0LL]
         = T3[((nvfuser_index_t)threadIdx.x)];
      asm volatile("cp.async.wait_all;\n");
      T5[(i8 + i12)]
        = T2[i7]
        + T4[0LL];
    }
  }

@liqiangxl
Copy link
Collaborator Author

For the newly added case which doesn't require thread predicate NVFuserTestCpAsyncRace.isInlinedhasTIDy/inlined_false_hastidy_false we only needs one wait, it waits for both T2 & T3.
The kernel is

__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 1, 1> T1, Tensor<float, 2, 2> T5) {
  alignas(16) extern __shared__ char array[];
  const unsigned smem_offset = 0;
  nvfuser_index_t i0;
  i0 = ceilDiv(T0.logical_size[0LL], 5LL);
  nvfuser_index_t i1;
  i1 = 4LL * ((nvfuser_index_t)threadIdx.x);
  nvfuser_index_t i2;
  i2 = T0.logical_size[1LL] * ((nvfuser_index_t)blockIdx.y);
  float* ptr3;
  ptr3 = (T0.data + ((nvfuser_index_t)threadIdx.x)) + i2;
  nvfuser_index_t i4;
  i4 = 5LL * T0.logical_size[1LL];
  float* T3 = reinterpret_cast<float*>(array + smem_offset + (((T0.logical_size[1LL] * 4LL) + 15LL) & -16LL));
  float* T2 = reinterpret_cast<float*>(array + smem_offset + 0LL);
  unsigned i5;
  i5 = toSmem(T2) + i1;
  nvfuser_index_t i6;
  i6 = ((nvfuser_index_t)threadIdx.x) + i2;
  asm volatile(
    "{\n"
    "  .reg .pred p0; \n"
    "  setp.ne.b32 p0, %3, 0;\n"
    "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
    "}\n"
    :
    :"r"((uint32_t)((toSmem(T3) + i1))),
     "l"((T1.data + ((nvfuser_index_t)threadIdx.x))),
     "n"(4LL),
     "n"((uint32_t)(false))
  );
  #pragma unroll 1
  for(nvfuser_index_t i7 = 0LL; i7 < i0; ++i7) {
    nvfuser_index_t i8;
    i8 = i4 * i7;
    if (((((nvfuser_index_t)blockIdx.y) + (5LL * i7)) < T0.logical_size[0LL])) {
      asm volatile(
        "{\n"
        "  .reg .pred p0; \n"
        "  setp.ne.b32 p0, %3, 0;\n"
        "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
        "}\n"
        :
        :"r"((uint32_t)(i5)),
         "l"((ptr3 + i8)),
         "n"(4LL),
         "n"((uint32_t)(false))
      );
      float T4[1LL];
      asm volatile("cp.async.wait_all;\n");
      T4[0LL]
         = T3[((nvfuser_index_t)threadIdx.x)];
      T5[(i6 + i8)]
        = T2[((nvfuser_index_t)threadIdx.x)]
        + T4[0LL];
    }
  }
}

liqiangxl added a commit that referenced this pull request Nov 20, 2024
when total_reduction_numel <= 1024, scheduler may use multiple
reductions per block with bdimy > 1, this leads to race condition in
shared memory when using async copy. Adding `cp.async.wait_all`after the
1st async copy can avoid the race, but needs to figure out the root
cause before we can safely use it. So, here we set bdimy = 1 as a WAR.
Should be reverted after #3438 is merged.

race detected with:
```
NVFUSER_DUMP=scheduler_params,cuda_to_file NVFUSER_ENABLE=kernel_debug PYTORCH_NO_CUDA_MEMORY_CACHING=1 compute-sanitizer --tool racecheck --racecheck-detect-level info  ./nvfuser_tests --gtest_filter='CombinedSchedulerTest.LayerNormBackward/dtype_double_batch_216_hidden_96'
```
Priya2698 pushed a commit that referenced this pull request Nov 20, 2024
when total_reduction_numel <= 1024, scheduler may use multiple
reductions per block with bdimy > 1, this leads to race condition in
shared memory when using async copy. Adding `cp.async.wait_all`after the
1st async copy can avoid the race, but needs to figure out the root
cause before we can safely use it. So, here we set bdimy = 1 as a WAR.
Should be reverted after #3438 is merged.

race detected with:
```
NVFUSER_DUMP=scheduler_params,cuda_to_file NVFUSER_ENABLE=kernel_debug PYTORCH_NO_CUDA_MEMORY_CACHING=1 compute-sanitizer --tool racecheck --racecheck-detect-level info  ./nvfuser_tests --gtest_filter='CombinedSchedulerTest.LayerNormBackward/dtype_double_batch_216_hidden_96'
```
@zasdfgbnm
Copy link
Collaborator

Why thread predicate is related here? I have a feeling that this PR "fixes" the issue only because for this specific example, the problematic schedule happen to have thread predicate.

For example, assume that there is a fusion where blockDim.x is 2. During cp.async load, threadIdx.x = 0 writes T1[1] and threadIdx.x = 1 writes T1[0]. But when T1 is being used, threadIdx.x = 0 loads T1[0] and threadIdx.x = 1 loads T1[1]. In this example, our codegen will also generate a cp.async.wait_all after the sync threads, and we will have the same problem, and because both threads are active for cp.async load, there should be no thread predicate.

@liqiangxl
Copy link
Collaborator Author

Why thread predicate is related here? I have a feeling that this PR "fixes" the issue only because for this specific example, the problematic schedule happen to have thread predicate.

For example, assume that there is a fusion where blockDim.x is 2. During cp.async load, threadIdx.x = 0 writes T1[1] and threadIdx.x = 1 writes T1[0]. But when T1 is being used, threadIdx.x = 0 loads T1[0] and threadIdx.x = 1 loads T1[1]. In this example, our codegen will also generate a cp.async.wait_all after the sync threads, and we will have the same problem, and because both threads are active for cp.async load, there should be no thread predicate.

More precisely, thread predicate should be updated to thread syncs. With this PR, any asynchronous copies requiring thread synchronization will have cp.async.wait_all inserted before sync threads.
For example the following fusion

  int m = 3, n = 2;

  TensorView* tv0 = makeContigTensor(2);
  fusion.addInput(tv0);

  auto tv1 = set(tv0);
  tv1->setMemoryType(MemoryType::Shared);
  tv1->definition()->as<LoadStoreOp>()->setOpType(LoadStoreOpType::CpAsync);
  tv1->definition()->as<LoadStoreOp>()->setCacheOp(CacheOp::Unspecified);

  auto tv2 = permute(tv1, {1, 0});
  auto tv3 = set(tv2);
  fusion.addOutput(tv3);

  for(auto tv: {tv0, tv1, tv2, tv3}) {
    tv->merge(0);
    tv->axis(0)->parallelize(ParallelType::TIDx);
  }

  inlineMost();

generates a kernel

__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T3) {
  alignas(16) extern __shared__ char array[];
  const unsigned smem_offset = 0;
  float* T1 = reinterpret_cast<float*>(array + smem_offset + 0LL);
  asm volatile(
    "{\n"
    "  .reg .pred p0; \n"
    "  setp.ne.b32 p0, %3, 0;\n"
    "  cp.async.ca.shared.global [%0], [%1], %2, p0;\n"
    "}\n"
    :
    :"r"((uint32_t)((toSmem(T1) + (4LL * ((nvfuser_index_t)threadIdx.x))))),
     "l"((T0.data + ((nvfuser_index_t)threadIdx.x))),
     "n"(4LL),
     "n"((uint32_t)(false))
  );
  asm volatile("cp.async.wait_all;\n");
  __syncthreads();
  float T2[1LL];
  T2[0LL]
     = T1[((T0.logical_size[1LL] * (((nvfuser_index_t)threadIdx.x) % T0.logical_size[0LL])) + (((nvfuser_index_t)threadIdx.x) / T0.logical_size[0LL]))];
  T3[((nvfuser_index_t)threadIdx.x)]
     = T2[0LL];
}

jacobhinkle pushed a commit that referenced this pull request Dec 3, 2024
when total_reduction_numel <= 1024, scheduler may use multiple
reductions per block with bdimy > 1, this leads to race condition in
shared memory when using async copy. Adding `cp.async.wait_all`after the
1st async copy can avoid the race, but needs to figure out the root
cause before we can safely use it. So, here we set bdimy = 1 as a WAR.
Should be reverted after #3438 is merged.

race detected with:
```
NVFUSER_DUMP=scheduler_params,cuda_to_file NVFUSER_ENABLE=kernel_debug PYTORCH_NO_CUDA_MEMORY_CACHING=1 compute-sanitizer --tool racecheck --racecheck-detect-level info  ./nvfuser_tests --gtest_filter='CombinedSchedulerTest.LayerNormBackward/dtype_double_batch_216_hidden_96'
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Race reported between Write access and Read access in fusion using async copy
3 participants