This repository has been archived by the owner on Feb 7, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix some bugs in CPU version of BooleanMask and add GPU version
Reviewed By: akyrola Differential Revision: D5397208 fbshipit-source-id: 0314cc181e315f3b6cda846292b2e2ea73bb015b
- Loading branch information
1 parent
c340c20
commit 3a0ad3f
Showing
4 changed files
with
243 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
#include "caffe2/core/context_gpu.h" | ||
#include "caffe2/operators/boolean_mask_ops.h" | ||
|
||
#include <cub/cub.cuh> | ||
|
||
namespace caffe2 { | ||
|
||
namespace { | ||
template <typename T> | ||
__global__ void BooleanMaskCopyKernel( | ||
const TIndex numOfOutput, | ||
const TIndex numBytes, | ||
const TIndex* indices, | ||
const T* src, | ||
T* dest) { | ||
for (TIndex i = blockIdx.x; i < numOfOutput; i += gridDim.x) { | ||
const auto srcBase = indices[i] * numBytes; | ||
const auto destBase = i * numBytes; | ||
for (TIndex j = threadIdx.x; j < numBytes; j += blockDim.x) { | ||
dest[destBase + j] = src[srcBase + j]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
template <> | ||
class BooleanMaskOp<CUDAContext> final : public Operator<CUDAContext> { | ||
public: | ||
BooleanMaskOp(const OperatorDef& operator_def, Workspace* ws) | ||
: Operator<CUDAContext>(operator_def, ws) {} | ||
|
||
bool RunOnDevice() override { | ||
const auto& src = Input(0); | ||
const auto& mask = Input(1); | ||
auto* dest = Output(0); | ||
|
||
CAFFE_ENFORCE(src.ndim() >= 1); | ||
CAFFE_ENFORCE_EQ(mask.ndim(), 1); | ||
CAFFE_ENFORCE(src.dims()[0] == mask.dims()[0]); | ||
|
||
const auto* maskData = mask.template data<bool>(); | ||
const auto outerSize = mask.dims()[0]; | ||
indices_.Resize(outerSize); | ||
auto* indicesData = indices_.template mutable_data<TIndex>(); | ||
|
||
size_t numBytes = 0; | ||
cub::CountingInputIterator<int> itr(0); | ||
cub::DeviceSelect::Flagged( | ||
nullptr, | ||
numBytes, | ||
itr, | ||
maskData, | ||
indicesData, | ||
static_cast<TIndex*>(nullptr), | ||
outerSize, | ||
context_.cuda_stream()); | ||
|
||
auto numTIndex = | ||
static_cast<TIndex>((numBytes + sizeof(TIndex) - 1) / sizeof(TIndex)); | ||
// allocate one more TIndex at the end of scratch for storing numOfOutput | ||
scratch_.Resize(numTIndex + 1); | ||
auto* scratchData = scratch_.template mutable_data<TIndex>(); | ||
auto* numOfOutputData = scratchData + numTIndex; | ||
|
||
cub::DeviceSelect::Flagged( | ||
static_cast<void*>(scratchData), | ||
numBytes, | ||
itr, | ||
maskData, | ||
indicesData, | ||
numOfOutputData, | ||
outerSize, | ||
context_.cuda_stream()); | ||
|
||
// Copy numOfOutput from gpu to cpu | ||
TIndex numOfOutput; | ||
context_.Copy<TIndex, CUDAContext, CPUContext>( | ||
1, numOfOutputData, &numOfOutput); | ||
|
||
indices_.Resize(numOfOutput); | ||
std::vector<TIndex> dims = src.dims(); | ||
dims[0] = numOfOutput; | ||
dest->Resize(dims); | ||
auto* destData = (char*)dest->raw_mutable_data(src.meta()); | ||
const auto* srcData = (char*)src.raw_data(); | ||
if (OutputSize() == 2) { | ||
auto* indicesOut = Output(1); | ||
indicesOut->Resize(numOfOutput); | ||
indicesOut->template mutable_data<TIndex>(); | ||
} | ||
|
||
if (numOfOutput > 0) { | ||
BooleanMaskCopyKernel<<< | ||
min(numOfOutput, static_cast<TIndex>(CAFFE_MAXIMUM_NUM_BLOCKS)), | ||
CAFFE_CUDA_NUM_THREADS, | ||
0, | ||
context_.cuda_stream()>>>( | ||
numOfOutput, | ||
src.size_from_dim(1) * src.meta().itemsize(), | ||
indicesData, | ||
srcData, | ||
destData); | ||
if (OutputSize() == 2) { | ||
Output(1)->CopyFrom(indices_, &context_); | ||
} | ||
} | ||
return true; | ||
} | ||
private: | ||
Tensor<CUDAContext> indices_; | ||
Tensor<CUDAContext> scratch_; | ||
}; | ||
REGISTER_CUDA_OPERATOR(BooleanMask, BooleanMaskOp<CUDAContext>); | ||
} // caffe2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#ifndef BOOLEAN_MASK_OPS_H | ||
#define BOOLEAN_MASK_OPS_H | ||
|
||
#include "caffe2/core/context.h" | ||
#include "caffe2/core/operator.h" | ||
#include "caffe2/core/tensor.h" | ||
|
||
namespace caffe2 { | ||
|
||
template <class Context> | ||
class BooleanMaskOp final : public Operator<Context> { | ||
public: | ||
USE_OPERATOR_CONTEXT_FUNCTIONS; | ||
BooleanMaskOp(const OperatorDef& operator_def, Workspace* ws) | ||
: Operator<Context>(operator_def, ws) {} | ||
|
||
bool RunOnDevice() override; | ||
}; | ||
} // caffe2 | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
from hypothesis import given | ||
import hypothesis.strategies as st | ||
from caffe2.python import core | ||
import caffe2.python.hypothesis_test_util as hu | ||
|
||
|
||
class TestBooleanMaskOp(hu.HypothesisTestCase): | ||
|
||
@given(x=hu.tensor(min_dim=1, | ||
max_dim=5, | ||
elements=st.floats(min_value=0.5, max_value=1.0)), | ||
**hu.gcs) | ||
def test_boolean_mask(self, x, gc, dc): | ||
op = core.CreateOperator("BooleanMask", | ||
["data", "mask"], | ||
"masked_data") | ||
mask = np.random.choice(a=[True, False], size=x.shape[0]) | ||
|
||
def ref(x, mask): | ||
return (x[mask],) | ||
|
||
self.assertReferenceChecks(gc, op, [x, mask], ref) | ||
self.assertDeviceChecks(dc, op, [x, mask], [0]) | ||
|
||
@given(x=hu.tensor(min_dim=1, | ||
max_dim=5, | ||
elements=st.floats(min_value=0.5, max_value=1.0)), | ||
**hu.gcs) | ||
def test_boolean_mask_indices(self, x, gc, dc): | ||
op = core.CreateOperator("BooleanMask", | ||
["data", "mask"], | ||
["masked_data", "masked_indices"]) | ||
mask = np.random.choice(a=[True, False], size=x.shape[0]) | ||
|
||
def ref(x, mask): | ||
return (x[mask], np.where(mask)[0]) | ||
|
||
self.assertReferenceChecks(gc, op, [x, mask], ref) | ||
self.assertDeviceChecks(dc, op, [x, mask], [0]) |