Skip to content

implement jagged_unique_indices_cpu #4651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -909,8 +909,209 @@ Tensor jagged_1d_to_truncated_values_cpu(
return truncated_values;
}

// CPU version of linearize_index_wo_infos_kernel
template <typename index_t>
void linearize_index_wo_infos_cpu(
const at::TensorAccessor<index_t, 1>& hash_size_cumsum,
const at::TensorAccessor<index_t, 1>& indices,
const at::TensorAccessor<index_t, 1>& offsets,
at::TensorAccessor<index_t, 1> linear_indices,
const int64_t total_B,
const int64_t T) {
at::parallel_for(0, total_B, 0, [&](int64_t start, int64_t end) {
for (int64_t b_t = start; b_t < end; ++b_t) {
const int32_t t = b_t / (total_B / T);

const auto hash_offset = hash_size_cumsum[t];
const auto indices_start = offsets[b_t];
const int32_t L = offsets[b_t + 1] - indices_start;

for (int32_t i = 0; i < L; ++i) {
const auto idx = indices[indices_start + i];
linear_indices[indices_start + i] = hash_offset + idx;
}
}
});
}

// CPU version of delinearize_unique_index_kernel
template <typename index_t>
void delinearize_unique_index_cpu(
const at::TensorAccessor<index_t, 1>& indices,
const at::TensorAccessor<index_t, 1>& reverse_index,
at::TensorAccessor<index_t, 1> unique_indices) {
at::parallel_for(0, indices.size(0), 0, [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; ++i) {
const auto original_index = indices[i];
const auto pos = reverse_index[i];
unique_indices[pos] = original_index;
}
});
}

// CPU version of unique_indices_length_kernel
template <typename index_t>
void unique_indices_length_cpu(
const at::TensorAccessor<index_t, 1>& hash_size_offsets,
const at::TensorAccessor<index_t, 1>& reverse_index,
const at::TensorAccessor<index_t, 1>& offsets,
at::TensorAccessor<index_t, 1> lengths,
const int64_t T) {
const int32_t batch_size = (offsets.size(0) - 1) / T;

at::parallel_for(0, T, 0, [&](int64_t start, int64_t end) {
for (int64_t bid = start; bid < end; ++bid) {
const auto offset_begin = hash_size_offsets[bid] * batch_size;
const auto offset_end = hash_size_offsets[bid + 1] * batch_size;
const auto num_lengths = (offset_end - offset_begin);

const auto reverse_index_begin = offsets[offset_begin];
const auto reverse_index_end = offsets[offset_end];

if (reverse_index_begin == reverse_index_end) {
continue;
}

index_t t_max = std::numeric_limits<index_t>::min();
index_t t_min = std::numeric_limits<index_t>::max();

for (index_t i = reverse_index_begin; i < reverse_index_end; ++i) {
const index_t value = reverse_index[i];
t_max = std::max(t_max, value);
t_min = std::min(t_min, value);
}

const index_t total_length = (t_max - t_min) + 1;
const index_t div_length = total_length / num_lengths;
const index_t r_length = total_length % num_lengths;

for (int32_t i = 0; i < num_lengths; ++i) {
index_t seg_length = (i < r_length) ? (div_length + 1) : div_length;
lengths[offset_begin + i] = seg_length;
}
}
});
}

// CPU version of compute_hash_size_kernel
template <typename index_t>
void compute_hash_size_cpu(
const at::TensorAccessor<index_t, 1>& offsets,
const at::TensorAccessor<index_t, 1>& indices,
const int64_t batch_size,
at::TensorAccessor<index_t, 1> hash_size,
const int64_t T) {
at::parallel_for(0, T, 0, [&](int64_t start, int64_t end) {
for (int64_t bid = start; bid < end; ++bid) {
const auto offset_begin = bid * batch_size;
const auto offset_end = (bid + 1) * batch_size;
const auto index_begin = offsets[offset_begin];
const auto index_end = offsets[offset_end];

if (index_begin == index_end) {
hash_size[bid] = 0;
continue;
}

index_t t_max = std::numeric_limits<index_t>::min();
for (index_t i = index_begin; i < index_end; ++i) {
const index_t value = indices[i];
t_max = std::max(t_max, value);
}

hash_size[bid] = t_max + 1;
}
});
}

} // namespace

std::tuple<Tensor, Tensor, Tensor, Tensor> jagged_unique_indices_cpu(
const Tensor& hash_size_cumsum,
const Tensor& hash_size_offsets,
const Tensor& offsets,
const Tensor& indices) {
TENSOR_ON_CPU(hash_size_cumsum);
TENSOR_ON_CPU(hash_size_offsets);
TENSOR_ON_CPU(offsets);
TENSOR_ON_CPU(indices);

const auto total_B = offsets.size(0) - 1;
const auto T = hash_size_cumsum.size(0) - 1;

Tensor linear_indices = at::empty_like(indices);

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "linearize_index_cpu", [&] {
linearize_index_wo_infos_cpu<index_t>(
hash_size_cumsum.accessor<index_t, 1>(),
indices.accessor<index_t, 1>(),
offsets.accessor<index_t, 1>(),
linear_indices.accessor<index_t, 1>(),
total_B,
T);
});

Tensor linear_unique_indices;
Tensor reverse_index;

std::tie(linear_unique_indices, reverse_index) =
at::_unique(linear_indices, true, true);

Tensor unique_indices = at::empty_like(linear_unique_indices);

AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "delinearize_unique_index_cpu", [&] {
delinearize_unique_index_cpu<index_t>(
indices.accessor<index_t, 1>(),
reverse_index.accessor<index_t, 1>(),
unique_indices.accessor<index_t, 1>());
});

Tensor output_lengths = at::zeros({total_B}, offsets.options());

AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "unique_indices_length_cpu", [&] {
unique_indices_length_cpu<index_t>(
hash_size_offsets.accessor<index_t, 1>(),
reverse_index.accessor<index_t, 1>(),
offsets.accessor<index_t, 1>(),
output_lengths.accessor<index_t, 1>(),
T);
});

Tensor output_offsets = asynchronous_complete_cumsum_cpu(output_lengths);

return {output_lengths, output_offsets, unique_indices, reverse_index};
}

std::tuple<Tensor, Tensor> jagged_hash_size_cumsum_cpu(
const Tensor& offsets,
const Tensor& indices,
const int64_t batch_size) {
TENSOR_ON_CPU(offsets);
TENSOR_ON_CPU(indices);

const auto T = (offsets.size(0) - 1) / batch_size;
Tensor hash_size = at::zeros({T}, offsets.options());

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "compute_hash_size_cpu", [&] {
compute_hash_size_cpu<index_t>(
offsets.accessor<index_t, 1>(),
indices.accessor<index_t, 1>(),
batch_size,
hash_size.accessor<index_t, 1>(),
T);
});

Tensor hash_size_cumsum = asynchronous_complete_cumsum_cpu(hash_size);

Tensor hash_size_lengths = at::ones_like(hash_size);
Tensor hash_size_offsets =
asynchronous_complete_cumsum_cpu(hash_size_lengths);

return {hash_size_cumsum, hash_size_offsets};
}

std::tuple<Tensor, Tensor> masked_select_jagged_1d(
const Tensor& values,
const Tensor& lengths,
Expand Down Expand Up @@ -1823,6 +2024,10 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
DISPATCH_TO_CPU(
"jagged_dense_bmm_forward", fbgemm_gpu::jagged_dense_bmm_forward);
DISPATCH_TO_CPU("jagged_slice_forward", fbgemm_gpu::jagged_slice_forward_cpu);
DISPATCH_TO_CPU(
"jagged_unique_indices", fbgemm_gpu::jagged_unique_indices_cpu);
DISPATCH_TO_CPU(
"jagged_hash_size_cumsum", fbgemm_gpu::jagged_hash_size_cumsum_cpu);
}

TORCH_LIBRARY_IMPL(fbgemm, CompositeExplicitAutograd, m) {
Expand Down
16 changes: 8 additions & 8 deletions fbgemm_gpu/test/jagged/unique_indices_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

if open_source:
# pyre-ignore[21]
from test_utils import gpu_unavailable, optests, symint_vector_unsupported
from test_utils import cpu_and_maybe_gpu, optests, symint_vector_unsupported
else:
from fbgemm_gpu.test.test_utils import (
gpu_unavailable,
cpu_and_maybe_gpu,
optests,
symint_vector_unsupported,
)
Expand Down Expand Up @@ -58,18 +58,19 @@ def setUp(self) -> None:
# Turn off static assumption for auto-dynamic
torch._dynamo.config.assume_static_by_default = False

@unittest.skipIf(*gpu_unavailable)
@given(
B=st.integers(min_value=100, max_value=200),
F=st.integers(min_value=50, max_value=100),
max_length=st.integers(min_value=5, max_value=10),
device=cpu_and_maybe_gpu(),
)
@settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
def test_jagged_unique_indices(
self,
B: int, # Batch size
F: int, # The number of features
max_length: int, # The maximum value of pooling factor
device: torch.device,
) -> None:
hash_size_list = []
lengths_list = []
Expand All @@ -91,7 +92,6 @@ def test_jagged_unique_indices(
indices_list.extend(indices)
linearized_indices_list.extend(linearized_indices)

device = torch.device("cuda")
dtype = torch.int64
hash_size = torch.as_tensor(hash_size_list, dtype=dtype, device=device)
hash_size_offsets = torch.as_tensor(
Expand Down Expand Up @@ -162,18 +162,19 @@ def test_jagged_unique_indices(
pos = reverse_index_list[each_offset]
self.assertTrue((output_start <= pos) and (pos < output_end))

@unittest.skipIf(*gpu_unavailable)
@given(
B=st.integers(min_value=100, max_value=200),
F=st.integers(min_value=50, max_value=100),
max_length=st.integers(min_value=5, max_value=10),
device=cpu_and_maybe_gpu(),
)
@settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
def test_jagged_unique_indices_multi_keys(
self,
B: int, # Batch size
F: int, # The number of features
max_length: int, # The maximum value of pooling factor
device: torch.device,
) -> None:
hash_size_list = []
lengths_list = []
Expand All @@ -195,7 +196,6 @@ def test_jagged_unique_indices_multi_keys(
indices_list.extend(indices)
linearized_indices_list.extend(linearized_indices)

device = torch.device("cuda")
dtype = torch.int64
hash_size = torch.as_tensor(hash_size_list, dtype=dtype, device=device)
lengths = torch.as_tensor(lengths_list, dtype=dtype, device=device)
Expand Down Expand Up @@ -235,23 +235,23 @@ def test_jagged_unique_indices_multi_keys(
pos = reverse_index_list[i]
self.assertTrue(unique_indices_list[pos] == indices_list[i])

@unittest.skipIf(*gpu_unavailable)
@given(
B=st.integers(min_value=100, max_value=200),
F=st.integers(min_value=50, max_value=100),
device=cpu_and_maybe_gpu(),
)
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
def test_jagged_unique_indices_empty(
self,
B: int, # Batch size
F: int, # The number of features
device: torch.device,
) -> None:
hash_size_cumsum_list = [0] + list(itertools.accumulate([10] * F))
hash_size_offsets_list = [0] + list(itertools.accumulate([1] * F))
offsets_list = [0] * (B * F + 1)
indices_list = []

device = torch.device("cuda")
dtype = torch.int64
hash_size_cumsum = torch.as_tensor(
hash_size_cumsum_list, device=device, dtype=dtype
Expand Down
Loading