diff --git a/cub/agent/agent_merge_sort.cuh b/cub/agent/agent_merge_sort.cuh new file mode 100644 index 0000000000..42d5d8976b --- /dev/null +++ b/cub/agent/agent_merge_sort.cuh @@ -0,0 +1,752 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "../config.cuh" +#include "../util_type.cuh" +#include "../block/block_load.cuh" +#include "../block/block_store.cuh" +#include "../block/block_merge_sort.cuh" + +#include + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +template < + int _BLOCK_THREADS, + int _ITEMS_PER_THREAD = 1, + cub::BlockLoadAlgorithm _LOAD_ALGORITHM = cub::BLOCK_LOAD_DIRECT, + cub::CacheLoadModifier _LOAD_MODIFIER = cub::LOAD_LDG, + cub::BlockStoreAlgorithm _STORE_ALGORITHM = cub::BLOCK_STORE_DIRECT> +struct AgentMergeSortPolicy +{ + static constexpr int BLOCK_THREADS = _BLOCK_THREADS; + static constexpr int ITEMS_PER_THREAD = _ITEMS_PER_THREAD; + static constexpr int ITEMS_PER_TILE = BLOCK_THREADS * ITEMS_PER_THREAD; + + static constexpr cub::BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; + static constexpr cub::CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; + static constexpr cub::BlockStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM; +}; + +/// \brief This agent is responsible for the initial in-tile sorting. +template +struct AgentBlockSort +{ + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + static constexpr bool KEYS_ONLY = Equals::VALUE; + + using BlockMergeSortT = + BlockMergeSort; + + using KeysLoadIt = typename thrust::cuda_cub::core::LoadIterator::type; + using ItemsLoadIt = typename thrust::cuda_cub::core::LoadIterator::type; + + using BlockLoadKeys = typename cub::BlockLoadType::type; + using BlockLoadItems = typename cub::BlockLoadType::type; + + using BlockStoreKeysIt = typename cub::BlockStoreType::type; + using BlockStoreItemsIt = typename cub::BlockStoreType::type; + using BlockStoreKeysRaw = typename cub::BlockStoreType::type; + using BlockStoreItemsRaw = typename cub::BlockStoreType::type; + + union _TempStorage + { + typename BlockLoadKeys::TempStorage load_keys; + typename BlockLoadItems::TempStorage load_items; + typename BlockStoreKeysIt::TempStorage store_keys_it; + typename BlockStoreItemsIt::TempStorage store_items_it; + typename BlockStoreKeysRaw::TempStorage store_keys_raw; + typename BlockStoreItemsRaw::TempStorage store_items_raw; + typename BlockMergeSortT::TempStorage block_merge; + }; + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + static constexpr int BLOCK_THREADS = Policy::BLOCK_THREADS; + static constexpr int ITEMS_PER_THREAD = Policy::ITEMS_PER_THREAD; + static constexpr int ITEMS_PER_TILE = Policy::ITEMS_PER_TILE; + static constexpr int SHARED_MEMORY_SIZE = + static_cast(sizeof(TempStorage)); + + //--------------------------------------------------------------------- + // Per thread data + //--------------------------------------------------------------------- + + bool ping; + _TempStorage &storage; + KeysLoadIt keys_in; + ItemsLoadIt items_in; + OffsetT keys_count; + KeyIteratorT keys_out_it; + ValueIteratorT items_out_it; + KeyT *keys_out_raw; + ValueT *items_out_raw; + CompareOpT compare_op; + + __device__ __forceinline__ AgentBlockSort(bool ping_, + TempStorage &storage_, + KeysLoadIt keys_in_, + ItemsLoadIt items_in_, + OffsetT keys_count_, + KeyIteratorT keys_out_it_, + ValueIteratorT items_out_it_, + KeyT *keys_out_raw_, + ValueT *items_out_raw_, + CompareOpT compare_op_) + : ping(ping_) + , storage(storage_.Alias()) + , keys_in(keys_in_) + , items_in(items_in_) + , keys_count(keys_count_) + , keys_out_it(keys_out_it_) + , items_out_it(items_out_it_) + , keys_out_raw(keys_out_raw_) + , items_out_raw(items_out_raw_) + , compare_op(compare_op_) + { + } + + __device__ __forceinline__ void Process() + { + auto tile_idx = static_cast(blockIdx.x); + auto num_tiles = static_cast(gridDim.x); + auto tile_base = tile_idx * ITEMS_PER_TILE; + int items_in_tile = (cub::min)(keys_count - tile_base, int{ITEMS_PER_TILE}); + + if (tile_idx < num_tiles - 1) + { + consume_tile(tile_base, ITEMS_PER_TILE); + } + else + { + consume_tile(tile_base, items_in_tile); + } + } + + template + __device__ __forceinline__ void consume_tile(OffsetT tile_base, + int num_remaining) + { + ValueT items_local[ITEMS_PER_THREAD]; + if (!KEYS_ONLY) + { + if (IS_LAST_TILE) + { + BlockLoadItems(storage.load_items) + .Load(items_in + tile_base, + items_local, + num_remaining, + *(items_in + tile_base)); + } + else + { + BlockLoadItems(storage.load_items).Load(items_in + tile_base, items_local); + } + + CTA_SYNC(); + } + + KeyT keys_local[ITEMS_PER_THREAD]; + if (IS_LAST_TILE) + { + BlockLoadKeys(storage.load_keys) + .Load(keys_in + tile_base, + keys_local, + num_remaining, + *(keys_in + tile_base)); + } + else + { + BlockLoadKeys(storage.load_keys) + .Load(keys_in + tile_base, keys_local); + } + + CTA_SYNC(); + + if (IS_LAST_TILE) + { + BlockMergeSortT(storage.block_merge) + .Sort(keys_local, items_local, compare_op, num_remaining, keys_local[0]); + } + else + { + BlockMergeSortT(storage.block_merge).Sort(keys_local, items_local, compare_op); + } + + CTA_SYNC(); + + if (ping) + { + if (IS_LAST_TILE) + { + BlockStoreKeysIt(storage.store_keys_it) + .Store(keys_out_it + tile_base, keys_local, num_remaining); + } + else + { + BlockStoreKeysIt(storage.store_keys_it) + .Store(keys_out_it + tile_base, keys_local); + } + + if (!KEYS_ONLY) + { + CTA_SYNC(); + + if (IS_LAST_TILE) + { + BlockStoreItemsIt(storage.store_items_it) + .Store(items_out_it + tile_base, items_local, num_remaining); + } + else + { + BlockStoreItemsIt(storage.store_items_it) + .Store(items_out_it + tile_base, items_local); + } + } + } + else + { + if (IS_LAST_TILE) + { + BlockStoreKeysRaw(storage.store_keys_raw) + .Store(keys_out_raw + tile_base, keys_local, num_remaining); + } + else + { + BlockStoreKeysRaw(storage.store_keys_raw) + .Store(keys_out_raw + tile_base, keys_local); + } + + if (!KEYS_ONLY) + { + CTA_SYNC(); + + if (IS_LAST_TILE) + { + BlockStoreItemsRaw(storage.store_items_raw) + .Store(items_out_raw + tile_base, items_local, num_remaining); + } + else + { + BlockStoreItemsRaw(storage.store_items_raw) + .Store(items_out_raw + tile_base, items_local); + } + } + } + } +}; + +/** + * \brief This agent is responsible for partitioning a merge path into equal segments + * + * There are two sorted arrays to be merged into one array. If the first array + * is partitioned between parallel workers by slicing it into ranges of equal + * size, there could be a significant workload imbalance. The imbalance is + * caused by the fact that the distribution of elements from the second array + * is unknown beforehand. Instead, the MergePath is partitioned between workers. + * This approach guarantees an equal amount of work being assigned to each worker. + * + * This approach is outlined in the paper: + * Odeh et al, "Merge Path - Parallel Merging Made Simple" + * doi:10.1109/IPDPSW.2012.202 + */ +template < + typename KeyIteratorT, + typename OffsetT, + typename CompareOpT, + typename KeyT> +struct AgentPartition +{ + bool ping; + KeyIteratorT keys_ping; + KeyT *keys_pong; + OffsetT keys_count; + OffsetT partition_idx; + OffsetT *merge_partitions; + CompareOpT compare_op; + OffsetT target_merged_tiles_number; + int items_per_tile; + + __device__ __forceinline__ AgentPartition(bool ping, + KeyIteratorT keys_ping, + KeyT *keys_pong, + OffsetT keys_count, + OffsetT partition_idx, + OffsetT *merge_partitions, + CompareOpT compare_op, + OffsetT target_merged_tiles_number, + int items_per_tile) + : ping(ping) + , keys_ping(keys_ping) + , keys_pong(keys_pong) + , keys_count(keys_count) + , partition_idx(partition_idx) + , merge_partitions(merge_partitions) + , compare_op(compare_op) + , target_merged_tiles_number(target_merged_tiles_number) + , items_per_tile(items_per_tile) + {} + + __device__ __forceinline__ void Process() + { + OffsetT merged_tiles_number = target_merged_tiles_number / 2; + + // target_merged_tiles_number is a power of two. + OffsetT mask = target_merged_tiles_number - 1; + + // The first tile number in the tiles group being merged, equal to: + // target_merged_tiles_number * (partition_idx / target_merged_tiles_number) + OffsetT list = ~mask & partition_idx; + OffsetT start = items_per_tile * list; + OffsetT size = items_per_tile * merged_tiles_number; + + // Tile number within the tile group being merged, equal to: + // partition_idx / target_merged_tiles_number + OffsetT local_tile_idx = mask & partition_idx; + + OffsetT keys1_beg = (cub::min)(keys_count, start); + OffsetT keys1_end = (cub::min)(keys_count, start + size); + OffsetT keys2_beg = keys1_end; + OffsetT keys2_end = (cub::min)(keys_count, keys2_beg + size); + + OffsetT partition_at = (cub::min)(keys2_end - keys1_beg, + items_per_tile * local_tile_idx); + + OffsetT partition_diag = ping ? MergePath(keys_ping + keys1_beg, + keys_ping + keys2_beg, + keys1_end - keys1_beg, + keys2_end - keys2_beg, + partition_at, + compare_op) + : MergePath(keys_pong + keys1_beg, + keys_pong + keys2_beg, + keys1_end - keys1_beg, + keys2_end - keys2_beg, + partition_at, + compare_op); + + merge_partitions[partition_idx] = keys1_beg + partition_diag; + } +}; + +/// \brief The agent is responsible for merging N consecutive sorted arrays into N/2 sorted arrays. +template < + typename Policy, + typename KeyIteratorT, + typename ValueIteratorT, + typename OffsetT, + typename CompareOpT, + typename KeyT, + typename ValueT> +struct AgentMerge +{ + + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + using KeysLoadPingIt = typename thrust::cuda_cub::core::LoadIterator::type; + using ItemsLoadPingIt = typename thrust::cuda_cub::core::LoadIterator::type; + using KeysLoadPongIt = typename thrust::cuda_cub::core::LoadIterator::type; + using ItemsLoadPongIt = typename thrust::cuda_cub::core::LoadIterator::type; + + using KeysOutputPongIt = KeyIteratorT; + using ItemsOutputPongIt = ValueIteratorT; + using KeysOutputPingIt = KeyT*; + using ItemsOutputPingIt = ValueT*; + + using BlockStoreKeysPong = typename BlockStoreType::type; + using BlockStoreItemsPong = typename BlockStoreType::type; + using BlockStoreKeysPing = typename BlockStoreType::type; + using BlockStoreItemsPing = typename BlockStoreType::type; + + /// Parameterized BlockReduce primitive + + union _TempStorage + { + typename BlockStoreKeysPing::TempStorage store_keys_ping; + typename BlockStoreItemsPing::TempStorage store_items_ping; + typename BlockStoreKeysPong::TempStorage store_keys_pong; + typename BlockStoreItemsPong::TempStorage store_items_pong; + + KeyT keys_shared[Policy::ITEMS_PER_TILE + 1]; + ValueT items_shared[Policy::ITEMS_PER_TILE + 1]; + }; + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : Uninitialized<_TempStorage> {}; + + static constexpr bool KEYS_ONLY = Equals::VALUE; + static constexpr int BLOCK_THREADS = Policy::BLOCK_THREADS; + static constexpr int ITEMS_PER_THREAD = Policy::ITEMS_PER_THREAD; + static constexpr int ITEMS_PER_TILE = Policy::ITEMS_PER_TILE; + static constexpr int SHARED_MEMORY_SIZE = + static_cast(sizeof(TempStorage)); + + //--------------------------------------------------------------------- + // Per thread data + //--------------------------------------------------------------------- + + bool ping; + _TempStorage& storage; + + KeysLoadPingIt keys_in_ping; + ItemsLoadPingIt items_in_ping; + KeysLoadPongIt keys_in_pong; + ItemsLoadPongIt items_in_pong; + + OffsetT keys_count; + + KeysOutputPongIt keys_out_pong; + ItemsOutputPongIt items_out_pong; + KeysOutputPingIt keys_out_ping; + ItemsOutputPingIt items_out_ping; + + CompareOpT compare_op; + OffsetT *merge_partitions; + OffsetT target_merged_tiles_number; + + //--------------------------------------------------------------------- + // Utility functions + //--------------------------------------------------------------------- + + /** + * \brief Concatenates up to ITEMS_PER_THREAD elements from input{1,2} into output array + * + * Reads data in a coalesced fashion [BLOCK_THREADS * item + tid] and + * stores the result in output[item]. + */ + template + __device__ __forceinline__ void + gmem_to_reg(T (&output)[ITEMS_PER_THREAD], + It1 input1, + It2 input2, + int count1, + int count2) + { + if (IS_FULL_TILE) + { +#pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD; ++item) + { + int idx = BLOCK_THREADS * item + threadIdx.x; + output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; + } + } + else + { +#pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD; ++item) + { + int idx = BLOCK_THREADS * item + threadIdx.x; + if (idx < count1 + count2) + { + output[item] = (idx < count1) ? input1[idx] : input2[idx - count1]; + } + } + } + } + + /// \brief Stores data in a coalesced fashion in[item] -> out[BLOCK_THREADS * item + tid] + template + __device__ __forceinline__ void + reg_to_shared(It output, + T (&input)[ITEMS_PER_THREAD]) + { +#pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD; ++item) + { + int idx = BLOCK_THREADS * item + threadIdx.x; + output[idx] = input[item]; + } + } + + template + __device__ __forceinline__ void + consume_tile(int tid, OffsetT tile_idx, OffsetT tile_base, int count) + { + OffsetT partition_beg = merge_partitions[tile_idx + 0]; + OffsetT partition_end = merge_partitions[tile_idx + 1]; + + // target_merged_tiles_number is a power of two. + OffsetT merged_tiles_number = target_merged_tiles_number / 2; + + OffsetT mask = target_merged_tiles_number - 1; + + // The first tile number in the tiles group being merged, equal to: + // target_merged_tiles_number * (tile_idx / target_merged_tiles_number) + OffsetT list = ~mask & tile_idx; + OffsetT start = ITEMS_PER_TILE * list; + OffsetT size = ITEMS_PER_TILE * merged_tiles_number; + + OffsetT diag = ITEMS_PER_TILE * tile_idx - start; + + OffsetT keys1_beg = partition_beg; + OffsetT keys1_end = partition_end; + OffsetT keys2_beg = (cub::min)(keys_count, 2 * start + size + diag - partition_beg); + OffsetT keys2_end = (cub::min)(keys_count, 2 * start + size + diag + ITEMS_PER_TILE - partition_end); + + // Check if it's the last tile in the tile group being merged + if (mask == (mask & tile_idx)) + { + keys1_end = (cub::min)(keys_count, start + size); + keys2_end = (cub::min)(keys_count, start + size * 2); + } + + // number of keys per tile + // + int num_keys1 = static_cast(keys1_end - keys1_beg); + int num_keys2 = static_cast(keys2_end - keys2_beg); + + // load keys1 & keys2 + KeyT keys_local[ITEMS_PER_THREAD]; + if (ping) + { + gmem_to_reg(keys_local, + keys_in_ping + keys1_beg, + keys_in_ping + keys2_beg, + num_keys1, + num_keys2); + } + else + { + gmem_to_reg(keys_local, + keys_in_pong + keys1_beg, + keys_in_pong + keys2_beg, + num_keys1, + num_keys2); + } + reg_to_shared(&storage.keys_shared[0], keys_local); + + // preload items into registers already + // + ValueT items_local[ITEMS_PER_THREAD]; + if (!KEYS_ONLY) + { + if (ping) + { + gmem_to_reg(items_local, + items_in_ping + keys1_beg, + items_in_ping + keys2_beg, + num_keys1, + num_keys2); + } + else + { + gmem_to_reg(items_local, + items_in_pong + keys1_beg, + items_in_pong + keys2_beg, + num_keys1, + num_keys2); + } + } + + CTA_SYNC(); + + // use binary search in shared memory + // to find merge path for each of thread + // we can use int type here, because the number of + // items in shared memory is limited + // + int diag0_local = (cub::min)(num_keys1 + num_keys2, ITEMS_PER_THREAD * tid); + + int keys1_beg_local = MergePath(&storage.keys_shared[0], + &storage.keys_shared[num_keys1], + num_keys1, + num_keys2, + diag0_local, + compare_op); + int keys1_end_local = num_keys1; + int keys2_beg_local = diag0_local - keys1_beg_local; + int keys2_end_local = num_keys2; + + int num_keys1_local = keys1_end_local - keys1_beg_local; + int num_keys2_local = keys2_end_local - keys2_beg_local; + + // perform serial merge + // + int indices[ITEMS_PER_THREAD]; + + SerialMerge(&storage.keys_shared[0], + keys1_beg_local, + keys2_beg_local + num_keys1, + num_keys1_local, + num_keys2_local, + keys_local, + indices, + compare_op); + + CTA_SYNC(); + + // write keys + // + if (ping) + { + if (IS_FULL_TILE) + { + BlockStoreKeysPing(storage.store_keys_ping) + .Store(keys_out_ping + tile_base, keys_local); + } + else + { + BlockStoreKeysPing(storage.store_keys_ping) + .Store(keys_out_ping + tile_base, keys_local, num_keys1 + num_keys2); + } + } + else + { + if (IS_FULL_TILE) + { + BlockStoreKeysPong(storage.store_keys_pong) + .Store(keys_out_pong + tile_base, keys_local); + } + else + { + BlockStoreKeysPong(storage.store_keys_pong) + .Store(keys_out_pong + tile_base, keys_local, num_keys1 + num_keys2); + } + } + + // if items are provided, merge them + if (!KEYS_ONLY) + { + CTA_SYNC(); + + reg_to_shared(&storage.items_shared[0], items_local); + + CTA_SYNC(); + + // gather items from shared mem + // +#pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD; ++item) + { + items_local[item] = storage.items_shared[indices[item]]; + } + + CTA_SYNC(); + + // write from reg to gmem + // + if (ping) + { + if (IS_FULL_TILE) + { + BlockStoreItemsPing(storage.store_items_ping) + .Store(items_out_ping + tile_base, items_local); + } + else + { + BlockStoreItemsPing(storage.store_items_ping) + .Store(items_out_ping + tile_base, items_local, count); + } + } + else + { + if (IS_FULL_TILE) + { + BlockStoreItemsPong(storage.store_items_pong) + .Store(items_out_pong + tile_base, items_local); + } + else + { + BlockStoreItemsPong(storage.store_items_pong) + .Store(items_out_pong + tile_base, items_local, count); + } + } + } + } + + __device__ __forceinline__ AgentMerge(bool ping_, + TempStorage &storage_, + KeysLoadPingIt keys_in_ping_, + ItemsLoadPingIt items_in_ping_, + KeysLoadPongIt keys_in_pong_, + ItemsLoadPongIt items_in_pong_, + OffsetT keys_count_, + KeysOutputPingIt keys_out_ping_, + ItemsOutputPingIt items_out_ping_, + KeysOutputPongIt keys_out_pong_, + ItemsOutputPongIt items_out_pong_, + CompareOpT compare_op_, + OffsetT *merge_partitions_, + OffsetT target_merged_tiles_number_) + : ping(ping_) + , storage(storage_.Alias()) + , keys_in_ping(keys_in_ping_) + , items_in_ping(items_in_ping_) + , keys_in_pong(keys_in_pong_) + , items_in_pong(items_in_pong_) + , keys_count(keys_count_) + , keys_out_pong(keys_out_pong_) + , items_out_pong(items_out_pong_) + , keys_out_ping(keys_out_ping_) + , items_out_ping(items_out_ping_) + , compare_op(compare_op_) + , merge_partitions(merge_partitions_) + , target_merged_tiles_number(target_merged_tiles_number_) + {} + + __device__ __forceinline__ void Process() + { + int tile_idx = static_cast(blockIdx.x); + int num_tiles = static_cast(gridDim.x); + OffsetT tile_base = OffsetT(tile_idx) * ITEMS_PER_TILE; + int tid = static_cast(threadIdx.x); + int items_in_tile = static_cast( + (cub::min)(static_cast(ITEMS_PER_TILE), keys_count - tile_base)); + + if (tile_idx < num_tiles - 1) + { + consume_tile(tid, tile_idx, tile_base, ITEMS_PER_TILE); + } + else + { + consume_tile(tid, tile_idx, tile_base, items_in_tile); + } + } +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/cub/block/block_load.cuh b/cub/block/block_load.cuh index d689954e0b..96cf32e494 100644 --- a/cub/block/block_load.cuh +++ b/cub/block/block_load.cuh @@ -34,6 +34,7 @@ #pragma once #include +#include #include "block_exchange.cuh" #include "../iterator/cache_modified_input_iterator.cuh" @@ -1288,6 +1289,17 @@ public: }; +template ::value_type> +struct BlockLoadType +{ + using type = cub::BlockLoad; +}; + } // CUB namespace CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/cub/block/block_merge_sort.cuh b/cub/block/block_merge_sort.cuh new file mode 100644 index 0000000000..f9bf979fee --- /dev/null +++ b/cub/block/block_merge_sort.cuh @@ -0,0 +1,582 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "../util_ptx.cuh" +#include "../util_type.cuh" +#include "../util_math.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +// Implementation of the MergePath algorithm, as described in: +// Odeh et al, "Merge Path - Parallel Merging Made Simple" +// doi:10.1109/IPDPSW.2012.202 +template +__device__ __forceinline__ OffsetT MergePath(KeyIteratorT keys1, + KeyIteratorT keys2, + OffsetT keys1_count, + OffsetT keys2_count, + OffsetT diag, + BinaryPred binary_pred) +{ + OffsetT keys1_begin = diag < keys2_count ? 0 : diag - keys2_count; + OffsetT keys1_end = (cub::min)(diag, keys1_count); + + while (keys1_begin < keys1_end) + { + OffsetT mid = cub::MidPoint(keys1_begin, keys1_end); + KeyT key1 = keys1[mid]; + KeyT key2 = keys2[diag - 1 - mid]; + bool pred = binary_pred(key2, key1); + + if (pred) + { + keys1_end = mid; + } + else + { + keys1_begin = mid + 1; + } + } + return keys1_begin; +} + +template +__device__ __forceinline__ void SerialMerge(KeyT *keys_shared, + int keys1_beg, + int keys2_beg, + int keys1_count, + int keys2_count, + KeyT (&output)[ITEMS_PER_THREAD], + int (&indices)[ITEMS_PER_THREAD], + CompareOp compare_op) +{ + int keys1_end = keys1_beg + keys1_count; + int keys2_end = keys2_beg + keys2_count; + + KeyT key1 = keys_shared[keys1_beg]; + KeyT key2 = keys_shared[keys2_beg]; + +#pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD; ++item) + { + bool p = (keys2_beg < keys2_end) && + ((keys1_beg >= keys1_end) + || compare_op(key2, key1)); + + output[item] = p ? key2 : key1; + indices[item] = p ? keys2_beg++ : keys1_beg++; + + if (p) + { + key2 = keys_shared[keys2_beg]; + } + else + { + key1 = keys_shared[keys1_beg]; + } + } +} + +/** + * \brief The BlockMergeSort class provides methods for sorting items partitioned across a CUDA thread block using a merge sorting method. + * \ingroup BlockModule + * + * \tparam KeyT KeyT type + * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension + * \tparam ITEMS_PER_THREAD The number of items per thread + * \tparam ValueT [optional] ValueT type (default: cub::NullType, which indicates a keys-only sort) + * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) + * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) + * + * \par Overview + * BlockMergeSort arranges items into ascending order using a comparison + * functor with less-than semantics. Merge sort can handle arbitrary types + * and comparison functors, but is slower than BlockRadixSort when sorting + * arithmetic types into ascending/descending order. + * + * \par A Simple Example + * \blockcollective{BlockMergeSort} + * \par + * The code snippet below illustrates a sort of 512 integer keys that are + * partitioned across 128 threads * where each thread owns 4 consecutive items. + * \par + * \code + * #include // or equivalently + * + * struct CustomLess + * { + * template + * __device__ bool operator()(const DataType &lhs, const DataType &rhs) + * { + * return lhs < rhs; + * } + * }; + * + * __global__ void ExampleKernel(...) + * { + * // Specialize BlockMergeSort for a 1D block of 128 threads owning 4 integer items each + * typedef cub::BlockMergeSort BlockMergeSort; + * + * // Allocate shared memory for BlockMergeSort + * __shared__ typename BlockMergeSort::TempStorage temp_storage_shuffle; + * + * // Obtain a segment of consecutive items that are blocked across threads + * int thread_keys[4]; + * ... + * + * BlockMergeSort(temp_storage_shuffle).Sort(thread_data, CustomLess()); + * ... + * } + * \endcode + * \par + * Suppose the set of input \p thread_keys across the block of threads is + * { [0,511,1,510], [2,509,3,508], [4,507,5,506], ..., [254,257,255,256] }. + * The corresponding output \p thread_keys in those threads will be + * { [0,1,2,3], [4,5,6,7], [8,9,10,11], ..., [508,509,510,511] }. + * + */ +template < + typename KeyT, + int BLOCK_DIM_X, + int ITEMS_PER_THREAD, + typename ValueT = NullType, + int BLOCK_DIM_Y = 1, + int BLOCK_DIM_Z = 1> +class BlockMergeSort +{ +private: + + // The thread block size in threads + static constexpr int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z; + static constexpr int ITEMS_PER_TILE = ITEMS_PER_THREAD * BLOCK_THREADS; + + // Whether or not there are values to be trucked along with keys + static constexpr bool KEYS_ONLY = Equals::VALUE; + + /// Shared memory type required by this thread block + union _TempStorage + { + KeyT keys_shared[ITEMS_PER_TILE + 1]; + ValueT items_shared[ITEMS_PER_TILE + 1]; + }; // union TempStorage + + /// Internal storage allocator + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + /// Shared storage reference + _TempStorage &temp_storage; + + /// Linear thread-id + unsigned int linear_tid; + +public: + + /// \smemstorage{BlockMergeSort} + struct TempStorage : Uninitialized<_TempStorage> {}; + + __device__ __forceinline__ BlockMergeSort() + : temp_storage(PrivateStorage()) + , linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + + __device__ __forceinline__ BlockMergeSort(TempStorage &temp_storage) + : temp_storage(temp_storage.Alias()) + , linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) + {} + +private: + + template + __device__ __forceinline__ void Swap(T &lhs, T &rhs) + { + T temp = lhs; + lhs = rhs; + rhs = temp; + } + + template + __device__ __forceinline__ void + StableOddEvenSort(KeyT (&keys)[ITEMS_PER_THREAD], + ValueT (&items)[ITEMS_PER_THREAD], + CompareOp compare_op) + { +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) + { +#pragma unroll + for (int j = 1 & i; j < ITEMS_PER_THREAD - 1; j += 2) + { + if (compare_op(keys[j + 1], keys[j])) + { + Swap(keys[j], keys[j + 1]); + if (!KEYS_ONLY) + { + Swap(items[j], items[j + 1]); + } + } + } // inner loop + } // outer loop + } + +public: + + /** + * \brief Sorts items partitioned across a CUDA thread block using a merge sorting method. + * + * \par + * - Sort is not guaranteed to be stable. That is, suppose that i and j are + * equivalent: neither one is less than the other. It is not guaranteed + * that the relative order of these two elements will be preserved by sort. + * + * \tparam CompareOp functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOp is a model of Strict Weak Ordering. + */ + template + __device__ __forceinline__ void + Sort(KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + CompareOp compare_op) ///< [in] Comparison function object which returns + ///< true if the first argument is ordered before + ///< the second + { + ValueT items[ITEMS_PER_THREAD]; + Sort(keys, items, compare_op, ITEMS_PER_TILE, keys[0]); + } + + /** + * \brief Sorts items partitioned across a CUDA thread block using a merge sorting method. + * + * \par + * - Sort is not guaranteed to be stable. That is, suppose that i and j are + * equivalent: neither one is less than the other. It is not guaranteed + * that the relative order of these two elements will be preserved by sort. + * - The value of \p oob_default is assigned to all elements that are out of + * \p valid_items boundaries. It's expected that \p oob_default is ordered + * after any value in the \p valid_items boundaries. The algorithm always + * sorts a fixed amount of elements, which is equal to ITEMS_PER_THREAD * BLOCK_THREADS. + * If there is a value that is ordered after \p oob_default, it won't be + * placed within \p valid_items boundaries. + * + * \tparam CompareOp functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOp is a model of Strict Weak Ordering. + */ + template + __device__ __forceinline__ void + Sort(KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + CompareOp compare_op, ///< [in] Comparison function object which returns true if the first argument is ordered before the second + int valid_items, ///< [in] Number of valid items to sort + KeyT oob_default) ///< [in] Default value to assign out-of-bound items + { + ValueT items[ITEMS_PER_THREAD]; + Sort(keys, items, compare_op, valid_items, oob_default); + } + + /** + * \brief Sorts items partitioned across a CUDA thread block using a merge sorting method. + * + * \par + * - Sort is not guaranteed to be stable. That is, suppose that i and j are + * equivalent: neither one is less than the other. It is not guaranteed + * that the relative order of these two elements will be preserved by sort. + * + * \tparam CompareOp functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOp is a model of Strict Weak Ordering. + */ + template + __device__ __forceinline__ void + Sort(KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + ValueT (&items)[ITEMS_PER_THREAD], ///< [in-out] Values to sort + CompareOp compare_op) ///< [in] Comparison function object which returns true if the first argument is ordered before the second + { + Sort(keys, items, compare_op, ITEMS_PER_TILE, keys[0]); + } + + /** + * \brief Sorts items partitioned across a CUDA thread block using a merge sorting method. + * + * \par + * - Sort is not guaranteed to be stable. That is, suppose that i and j are + * equivalent: neither one is less than the other. It is not guaranteed + * that the relative order of these two elements will be preserved by sort. + * - The value of \p oob_default is assigned to all elements that are out of + * \p valid_items boundaries. It's expected that \p oob_default is ordered + * after any value in the \p valid_items boundaries. The algorithm always + * sorts a fixed amount of elements, which is equal to ITEMS_PER_THREAD * BLOCK_THREADS. + * If there is a value that is ordered after \p oob_default, it won't be + * placed within \p valid_items boundaries. + * + * \tparam CompareOp functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOp is a model of Strict Weak Ordering. + * \tparam IS_LAST_TILE True if valid_items isn't equal to the ITEMS_PER_TILE + */ + template + __device__ __forceinline__ void + Sort(KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + ValueT (&items)[ITEMS_PER_THREAD], ///< [in-out] Values to sort + CompareOp compare_op, ///< [in] Comparison function object which returns true if the first argument is ordered before the second + int valid_items, ///< [in] Number of valid items to sort + KeyT oob_default) ///< [in] Default value to assign out-of-bound items + { + if (IS_LAST_TILE) + { + // if last tile, find valid max_key + // and fill the remaining keys with it + // + KeyT max_key = oob_default; +#pragma unroll + for (int item = 1; item < ITEMS_PER_THREAD; ++item) + { + if (ITEMS_PER_THREAD * linear_tid + item < valid_items) + { + max_key = compare_op(max_key, keys[item]) ? keys[item] : max_key; + } + else + { + keys[item] = max_key; + } + } + } + + // if first element of thread is in input range, stable sort items + // + if (!IS_LAST_TILE || ITEMS_PER_THREAD * linear_tid < valid_items) + { + StableOddEvenSort(keys, items, compare_op); + } + + // each thread has sorted keys + // merge sort keys in shared memory + // +#pragma unroll + for (int target_merged_threads_number = 2; + target_merged_threads_number <= BLOCK_THREADS; + target_merged_threads_number *= 2) + { + int merged_threads_number = target_merged_threads_number / 2; + int mask = target_merged_threads_number - 1; + + CTA_SYNC(); + + // store keys in shmem + // +#pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD; ++item) + { + int idx = ITEMS_PER_THREAD * linear_tid + item; + temp_storage.keys_shared[idx] = keys[item]; + } + + CTA_SYNC(); + + int indices[ITEMS_PER_THREAD]; + + int first_thread_idx_in_thread_group_being_merged = ~mask & linear_tid; + int start = ITEMS_PER_THREAD * first_thread_idx_in_thread_group_being_merged; + int size = ITEMS_PER_THREAD * merged_threads_number; + + int thread_idx_in_thread_group_being_merged = mask & linear_tid; + + int diag = + (cub::min)(valid_items, + ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged); + + int keys1_beg = (cub::min)(valid_items, start); + int keys1_end = (cub::min)(valid_items, keys1_beg + size); + int keys2_beg = keys1_end; + int keys2_end = (cub::min)(valid_items, keys2_beg + size); + + int keys1_count = keys1_end - keys1_beg; + int keys2_count = keys2_end - keys2_beg; + + int partition_diag = MergePath(&temp_storage.keys_shared[keys1_beg], + &temp_storage.keys_shared[keys2_beg], + keys1_count, + keys2_count, + diag, + compare_op); + + int keys1_beg_loc = keys1_beg + partition_diag; + int keys1_end_loc = keys1_end; + int keys2_beg_loc = keys2_beg + diag - partition_diag; + int keys2_end_loc = keys2_end; + int keys1_count_loc = keys1_end_loc - keys1_beg_loc; + int keys2_count_loc = keys2_end_loc - keys2_beg_loc; + SerialMerge(&temp_storage.keys_shared[0], + keys1_beg_loc, + keys2_beg_loc, + keys1_count_loc, + keys2_count_loc, + keys, + indices, + compare_op); + + if (!KEYS_ONLY) + { + CTA_SYNC(); + + // store keys in shmem + // +#pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD; ++item) + { + int idx = ITEMS_PER_THREAD * linear_tid + item; + temp_storage.items_shared[idx] = items[item]; + } + + CTA_SYNC(); + + // gather items from shmem + // +#pragma unroll + for (int item = 0; item < ITEMS_PER_THREAD; ++item) + { + items[item] = temp_storage.items_shared[indices[item]]; + } + } + } + } // func block_merge_sort + + /** + * \brief Sorts items partitioned across a CUDA thread block using a merge sorting method. + * + * \par + * - StableSort is stable: it preserves the relative ordering of equivalent + * elements. That is, if x and y are elements such that x precedes y, + * and if the two elements are equivalent (neither x < y nor y < x) then + * a postcondition of stable_sort is that x still precedes y. + * + * \tparam CompareOp functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOp is a model of Strict Weak Ordering. + */ + template + __device__ __forceinline__ void + StableSort(KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + CompareOp compare_op) ///< [in] Comparison function object which returns true if the first argument is ordered before the second + { + Sort(keys, compare_op); + } + + /** + * \brief Sorts items partitioned across a CUDA thread block using a merge sorting method. + * + * \par + * - StableSort is stable: it preserves the relative ordering of equivalent + * elements. That is, if x and y are elements such that x precedes y, + * and if the two elements are equivalent (neither x < y nor y < x) then + * a postcondition of stable_sort is that x still precedes y. + * + * \tparam CompareOp functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOp is a model of Strict Weak Ordering. + */ + template + __device__ __forceinline__ void + StableSort(KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + ValueT (&items)[ITEMS_PER_THREAD], ///< [in-out] Values to sort + CompareOp compare_op) ///< [in] Comparison function object which returns true if the first argument is ordered before the second + { + Sort(keys, items, compare_op); + } + + /** + * \brief Sorts items partitioned across a CUDA thread block using a merge sorting method. + * + * \par + * - StableSort is stable: it preserves the relative ordering of equivalent + * elements. That is, if x and y are elements such that x precedes y, + * and if the two elements are equivalent (neither x < y nor y < x) then + * a postcondition of stable_sort is that x still precedes y. + * - The value of \p oob_default is assigned to all elements that are out of + * \p valid_items boundaries. It's expected that \p oob_default is ordered + * after any value in the \p valid_items boundaries. The algorithm always + * sorts a fixed amount of elements, which is equal to ITEMS_PER_THREAD * BLOCK_THREADS. + * If there is a value that is ordered after \p oob_default, it won't be + * placed within \p valid_items boundaries. + * + * \tparam CompareOp functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOp is a model of Strict Weak Ordering. + */ + template + __device__ __forceinline__ void + StableSort(KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + CompareOp compare_op, ///< [in] Comparison function object which returns true if the first argument is ordered before the second + int valid_items, ///< [in] Number of valid items to sort + KeyT oob_default) ///< [in] Default value to assign out-of-bound items + { + Sort(keys, compare_op, valid_items, oob_default); + } + + /** + * \brief Sorts items partitioned across a CUDA thread block using a merge sorting method. + * + * \par + * - StableSort is stable: it preserves the relative ordering of equivalent + * elements. That is, if x and y are elements such that x precedes y, + * and if the two elements are equivalent (neither x < y nor y < x) then + * a postcondition of stable_sort is that x still precedes y. + * - The value of \p oob_default is assigned to all elements that are out of + * \p valid_items boundaries. It's expected that \p oob_default is ordered + * after any value in the \p valid_items boundaries. The algorithm always + * sorts a fixed amount of elements, which is equal to ITEMS_PER_THREAD * BLOCK_THREADS. + * If there is a value that is ordered after \p oob_default, it won't be + * placed within \p valid_items boundaries. + * + * \tparam CompareOp functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOp is a model of Strict Weak Ordering. + * \tparam IS_LAST_TILE True if valid_items isn't equal to the ITEMS_PER_TILE + */ + template + __device__ __forceinline__ void + StableSort(KeyT (&keys)[ITEMS_PER_THREAD], ///< [in-out] Keys to sort + ValueT (&items)[ITEMS_PER_THREAD], ///< [in-out] Values to sort + CompareOp compare_op, ///< [in] Comparison function object which returns true if the first argument is ordered before the second + int valid_items, ///< [in] Number of valid items to sort + KeyT oob_default) ///< [in] Default value to assign out-of-bound items + { + Sort(keys, + items, + compare_op, + valid_items, + oob_default); + } +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/cub/block/block_store.cuh b/cub/block/block_store.cuh index cb00ec7287..ddfd22f061 100644 --- a/cub/block/block_store.cuh +++ b/cub/block/block_store.cuh @@ -34,6 +34,7 @@ #pragma once #include +#include #include "block_exchange.cuh" #include "../config.cuh" @@ -1050,6 +1051,16 @@ public: } }; +template ::value_type> +struct BlockStoreType +{ + using type = cub::BlockStore; +}; } // CUB namespace CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/cub/cub.cuh b/cub/cub.cuh index a71d78fe0d..2d2f7b10c9 100644 --- a/cub/cub.cuh +++ b/cub/cub.cuh @@ -43,12 +43,14 @@ #include "block/block_load.cuh" #include "block/block_radix_rank.cuh" #include "block/block_radix_sort.cuh" +#include "block/block_merge_sort.cuh" #include "block/block_reduce.cuh" #include "block/block_scan.cuh" #include "block/block_store.cuh" //#include "block/block_shift.cuh" // Device +#include "device/device_merge_sort.cuh" #include "device/device_histogram.cuh" #include "device/device_partition.cuh" #include "device/device_radix_sort.cuh" diff --git a/cub/device/device_merge_sort.cuh b/cub/device/device_merge_sort.cuh new file mode 100644 index 0000000000..d27c03ea34 --- /dev/null +++ b/cub/device/device_merge_sort.cuh @@ -0,0 +1,602 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "../config.cuh" +#include "dispatch/dispatch_merge_sort.cuh" + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +/** + * \brief DeviceMergeSort provides device-wide, parallel operations for computing a merge sort across a sequence of data items residing within device-accessible memory. + * \ingroup SingleModule + * + * \par Overview + * - DeviceMergeSort arranges items into ascending order using a comparison + * functor with less-than semantics. Merge sort can handle arbitrary types (as + * long as a value of these types is a model of + * LessThan Comparable) + * and comparison functors, but is slower than DeviceRadixSort when sorting + * arithmetic types into ascending/descending order. + * - Another difference from RadixSort is the fact that DeviceMergeSort can + * handle arbitrary random-access iterators, as shown below. + * + * \par A Simple Example + * \par + * The code snippet below illustrates a thrust reverse iterator usage. + * \par + * \code + * #include // or equivalently + * + * struct CustomLess + * { + * template + * __device__ bool operator()(const DataType &lhs, const DataType &rhs) + * { + * return lhs < rhs; + * } + * }; + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * thrust::device_vector d_keys(num_items); + * thrust::device_vector d_values(num_items); + * // ... + * + * // Initialize iterator + * using KeyIterator = typename thrust::device_vector::iterator; + * thrust::reverse_iterator reverse_iter(d_keys.end()); + * + * // Determine temporary device storage requirements + * std::size_t temp_storage_bytes = 0; + * cub::DeviceMergeSort::SortPairs( + * nullptr, + * temp_storage_bytes, + * reverse_iter, + * thrust::raw_pointer_cast(d_values.data()), + * num_items, + * CustomLess()); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceMergeSort::SortPairs( + * d_temp_storage, + * temp_storage_bytes, + * reverse_iter, + * thrust::raw_pointer_cast(d_values.data()), + * num_items, + * CustomLess()); + * \endcode + */ +struct DeviceMergeSort +{ + + /** + * \brief Sorts items using a merge sorting method. + * + * \par + * SortPairs is not guaranteed to be stable. That is, suppose that i and j are + * equivalent: neither one is less than the other. It is not guaranteed + * that the relative order of these two elements will be preserved by sort. + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys + * with associated vector of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_keys; // e.g., [8, 6, 6, 5, 3, 0, 9] + * int *d_values; // e.g., [0, 1, 2, 3, 4, 5, 6] + * ... + * + * // Initialize comparator + * CustomOpT custom_op; + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceMergeSort::SortPairs(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, custom_op); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceMergeSort::SortPairs(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, custom_op); + * + * // d_keys <-- [0, 3, 5, 6, 6, 8, 9] + * // d_values <-- [5, 4, 3, 2, 1, 0, 6] + * + * \endcode + * + * \tparam KeyIteratorT is a model of Random Access Iterator, + * \p KeyIteratorT is mutable, and \p KeyIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p KeyIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam ValueIteratorT is a model of Random Access Iterator, + * \p ValueIteratorT is mutable, and \p ValueIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p ValueIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam OffsetT is an integer type for global offsets. + * \tparam CompareOpT functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOpT is a model of Strict Weak Ordering. + */ + template + CUB_RUNTIME_FUNCTION static cudaError_t SortPairs(void *d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + std::size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + KeyIteratorT d_keys, ///< [in,out] Pointer to the input sequence of unsorted input keys + ValueIteratorT d_items, ///< [in,out] Pointer to the input sequence of unsorted input values + OffsetT num_items, ///< [in] Number of items to sort + CompareOpT compare_op, ///< [in] Comparison function object which returns true if the first argument is ordered before the second + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + using DispatchMergeSortT = DispatchMergeSort; + + return DispatchMergeSortT::Dispatch(d_temp_storage, + temp_storage_bytes, + d_keys, + d_items, + d_keys, + d_items, + num_items, + compare_op, + stream, + debug_synchronous); + } + + /** + * \brief Sorts items using a merge sorting method. + * + * \par + * - SortPairs is not guaranteed to be stable. That is, suppose that i and j are + * equivalent: neither one is less than the other. It is not guaranteed + * that the relative order of these two elements will be preserved by sort. + * - Input arrays d_input_keys and d_input_items are not modified. + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys + * with associated vector of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_keys; // e.g., [8, 6, 6, 5, 3, 0, 9] + * int *d_values; // e.g., [0, 1, 2, 3, 4, 5, 6] + * ... + * + * // Initialize comparator + * CustomOpT custom_op; + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceMergeSort::SortPairsCopy(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, custom_op); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceMergeSort::SortPairsCopy(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, custom_op); + * + * // d_keys <-- [0, 3, 5, 6, 6, 8, 9] + * // d_values <-- [5, 4, 3, 2, 1, 0, 6] + * + * \endcode + * + * \tparam KeyInputIteratorT is a model of Random Access Iterator, + * \p KeyInputIteratorT is mutable, and \p KeyInputIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p KeyInputIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam ValueInputIteratorT is a model of Random Access Iterator, + * \p ValueInputIteratorT is mutable, and \p ValueInputIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p ValueInputIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam KeyIteratorT is a model of Random Access Iterator, + * \p KeyIteratorT is mutable, and \p KeyIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p KeyIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam ValueIteratorT is a model of Random Access Iterator, + * \p ValueIteratorT is mutable, and \p ValueIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p ValueIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam OffsetT is an integer type for global offsets. + * \tparam CompareOpT functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOpT is a model of Strict Weak Ordering. + */ + template + CUB_RUNTIME_FUNCTION static cudaError_t SortPairsCopy(void *d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + std::size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + KeyInputIteratorT d_input_keys, ///< [in] Pointer to the input sequence of unsorted input keys + ValueInputIteratorT d_input_items, ///< [in] Pointer to the input sequence of unsorted input values + KeyIteratorT d_output_keys, ///< [out] Pointer to the output sequence of sorted input keys + ValueIteratorT d_output_items, ///< [out] Pointer to the output sequence of sorted input values + OffsetT num_items, ///< [in] Number of items to sort + CompareOpT compare_op, ///< [in] Comparison function object which returns true if the first argument is ordered before the second + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + using DispatchMergeSortT = DispatchMergeSort; + + return DispatchMergeSortT::Dispatch(d_temp_storage, + temp_storage_bytes, + d_input_keys, + d_input_items, + d_output_keys, + d_output_items, + num_items, + compare_op, + stream, + debug_synchronous); + } + + /** + * \brief Sorts items using a merge sorting method. + * + * \par + * - SortKeys is not guaranteed to be stable. That is, suppose that i and j are + * equivalent: neither one is less than the other. It is not guaranteed + * that the relative order of these two elements will be preserved by sort. + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_keys; // e.g., [8, 6, 7, 5, 3, 0, 9] + * ... + * + * // Initialize comparator + * CustomOpT custom_op; + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceMergeSort::SortKeys(d_temp_storage, temp_storage_bytes, d_keys, num_items, custom_op); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceMergeSort::SortKeys(d_temp_storage, temp_storage_bytes, d_keys, num_items, custom_op); + * + * // d_keys <-- [0, 3, 5, 6, 7, 8, 9] + * + * \endcode + * + * \tparam KeyIteratorT is a model of Random Access Iterator, + * \p KeyIteratorT is mutable, and \p KeyIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p KeyIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam OffsetT is an integer type for global offsets. + * \tparam CompareOpT functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOpT is a model of Strict Weak Ordering. + */ + template + CUB_RUNTIME_FUNCTION static cudaError_t SortKeys(void *d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + std::size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + KeyIteratorT d_keys, ///< [in,out] Pointer to the input sequence of unsorted input keys + OffsetT num_items, ///< [in] Number of items to sort + CompareOpT compare_op, ///< [in] Comparison function object which returns true if the first argument is ordered before the second + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + using DispatchMergeSortT = DispatchMergeSort; + + return DispatchMergeSortT::Dispatch(d_temp_storage, + temp_storage_bytes, + d_keys, + static_cast(nullptr), + d_keys, + static_cast(nullptr), + num_items, + compare_op, + stream, + debug_synchronous); + } + + /** + * \brief Sorts items using a merge sorting method. + * + * \par + * - SortKeys is not guaranteed to be stable. That is, suppose that i and j are + * equivalent: neither one is less than the other. It is not guaranteed + * that the relative order of these two elements will be preserved by sort. + * - Input array d_input_keys is not modified. + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_keys; // e.g., [8, 6, 7, 5, 3, 0, 9] + * ... + * + * // Initialize comparator + * CustomOpT custom_op; + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceMergeSort::SortKeysCopy(d_temp_storage, temp_storage_bytes, d_keys, num_items, custom_op); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceMergeSort::SortKeysCopy(d_temp_storage, temp_storage_bytes, d_keys, num_items, custom_op); + * + * // d_keys <-- [0, 3, 5, 6, 7, 8, 9] + * + * \endcode + * + * \tparam KeyInputIteratorT is a model of Random Access Iterator, + * \p KeyInputIteratorT is mutable, and \p KeyInputIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p KeyInputIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam KeyIteratorT is a model of Random Access Iterator, + * \p KeyIteratorT is mutable, and \p KeyIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p KeyIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam OffsetT is an integer type for global offsets. + * \tparam CompareOpT functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOpT is a model of Strict Weak Ordering. + */ + template + CUB_RUNTIME_FUNCTION static cudaError_t SortKeysCopy(void *d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + std::size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + KeyInputIteratorT d_input_keys, ///< [in] Pointer to the input sequence of unsorted input keys + KeyIteratorT d_output_keys, ///< [out] Pointer to the output sequence of sorted input keys + OffsetT num_items, ///< [in] Number of items to sort + CompareOpT compare_op, ///< [in] Comparison function object which returns true if the first argument is ordered before the second + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + using DispatchMergeSortT = DispatchMergeSort; + + return DispatchMergeSortT::Dispatch(d_temp_storage, + temp_storage_bytes, + d_input_keys, + static_cast(nullptr), + d_output_keys, + static_cast(nullptr), + num_items, + compare_op, + stream, + debug_synchronous); + } + + /** + * \brief Sorts items using a merge sorting method. + * + * \par + * StableSortPairs is stable: it preserves the relative ordering of equivalent + * elements. That is, if x and y are elements such that x precedes y, + * and if the two elements are equivalent (neither x < y nor y < x) then + * a postcondition of stable_sort is that x still precedes y. + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys + * with associated vector of \p int values. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_keys; // e.g., [8, 6, 6, 5, 3, 0, 9] + * int *d_values; // e.g., [0, 1, 2, 3, 4, 5, 6] + * ... + * + * // Initialize comparator + * CustomOpT custom_op; + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceMergeSort::StableSortPairs(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, custom_op); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceMergeSort::StableSortPairs(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, custom_op); + * + * // d_keys <-- [0, 3, 5, 6, 6, 8, 9] + * // d_values <-- [5, 4, 3, 1, 2, 0, 6] + * + * \endcode + * + * \tparam KeyIteratorT is a model of Random Access Iterator, + * \p KeyIteratorT is mutable, and \p KeyIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p KeyIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam ValueIteratorT is a model of Random Access Iterator, + * \p ValueIteratorT is mutable, and \p ValueIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p ValueIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam OffsetT is an integer type for global offsets. + * \tparam CompareOpT functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOpT is a model of Strict Weak Ordering. + */ + template + CUB_RUNTIME_FUNCTION static cudaError_t + StableSortPairs(void *d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + std::size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + KeyIteratorT d_keys, ///< [in,out] Pointer to the input sequence of unsorted input keys + ValueIteratorT d_items, ///< [in,out] Pointer to the input sequence of unsorted input values + OffsetT num_items, ///< [in] Number of items to sort + CompareOpT compare_op, ///< [in] Comparison function object which returns true if the first argument is ordered before the second + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + return SortPairs( + d_temp_storage, + temp_storage_bytes, + d_keys, + d_items, + num_items, + compare_op, + stream, + debug_synchronous); + } + + /** + * \brief Sorts items using a merge sorting method. + * + * \par + * StableSortKeys is stable: it preserves the relative ordering of equivalent + * elements. That is, if x and y are elements such that x precedes y, + * and if the two elements are equivalent (neither x < y nor y < x) then + * a postcondition of stable_sort is that x still precedes y. + * + * \par Snippet + * The code snippet below illustrates the sorting of a device vector of \p int keys. + * \par + * \code + * #include // or equivalently + * + * // Declare, allocate, and initialize device-accessible pointers for sorting data + * int num_items; // e.g., 7 + * int *d_keys; // e.g., [8, 6, 7, 5, 3, 0, 9] + * ... + * + * // Initialize comparator + * CustomOpT custom_op; + * + * // Determine temporary device storage requirements + * void *d_temp_storage = NULL; + * size_t temp_storage_bytes = 0; + * cub::DeviceMergeSort::StableSortKeys(d_temp_storage, temp_storage_bytes, d_keys, num_items, custom_op); + * + * // Allocate temporary storage + * cudaMalloc(&d_temp_storage, temp_storage_bytes); + * + * // Run sorting operation + * cub::DeviceMergeSort::StableSortKeys(d_temp_storage, temp_storage_bytes, d_keys, num_items, custom_op); + * + * // d_keys <-- [0, 3, 5, 6, 7, 8, 9] + * + * \endcode + * + * \tparam KeyIteratorT is a model of Random Access Iterator, + * \p KeyIteratorT is mutable, and \p KeyIteratorT's \c value_type is + * a model of LessThan Comparable, + * and the ordering relation on \p KeyIteratorT's \c value_type is a strict weak ordering, as defined in the + * LessThan Comparable requirements. + * \tparam OffsetT is an integer type for global offsets. + * \tparam CompareOpT functor type having member bool operator()(KeyT lhs, KeyT rhs) + * CompareOpT is a model of Strict Weak Ordering. + */ + template + CUB_RUNTIME_FUNCTION static cudaError_t + StableSortKeys(void *d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + std::size_t &temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + KeyIteratorT d_keys, ///< [in,out] Pointer to the input sequence of unsorted input keys + OffsetT num_items, ///< [in] Number of items to sort + CompareOpT compare_op, ///< [in] Comparison function object which returns true if the first argument is ordered before the second + cudaStream_t stream = 0, ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous = false) ///< [in] [optional] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + { + return SortKeys(d_temp_storage, + temp_storage_bytes, + d_keys, + num_items, + compare_op, + stream, + debug_synchronous); + } +}; + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) + diff --git a/cub/device/dispatch/dispatch_merge_sort.cuh b/cub/device/dispatch/dispatch_merge_sort.cuh new file mode 100644 index 0000000000..2cd727950c --- /dev/null +++ b/cub/device/dispatch/dispatch_merge_sort.cuh @@ -0,0 +1,790 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "../../util_math.cuh" +#include "../../util_device.cuh" +#include "../../agent/agent_merge_sort.cuh" + +#include +#include + +/// Optional outer namespace(s) +CUB_NS_PREFIX + +/// CUB namespace +namespace cub { + +template +void __global__ __launch_bounds__(ActivePolicyT::BLOCK_THREADS) +DeviceMergeSortBlockSortKernel(bool ping, + KeyInputIteratorT keys_in, + ValueInputIteratorT items_in, + KeyIteratorT keys_out, + ValueIteratorT items_out, + OffsetT keys_count, + KeyT *tmp_keys_out, + ValueT *tmp_items_out, + CompareOpT compare_op, + char *vshmem) +{ + extern __shared__ char shmem[]; + + using AgentBlockSortT = AgentBlockSort; + + const OffsetT vshmem_offset = blockIdx.x * + AgentBlockSortT::SHARED_MEMORY_SIZE; + + typename AgentBlockSortT::TempStorage &storage = + *reinterpret_cast( + UseVShmem ? vshmem + vshmem_offset : shmem); + + AgentBlockSortT agent(ping, + storage, + thrust::cuda_cub::core::make_load_iterator(ActivePolicyT(), keys_in), + thrust::cuda_cub::core::make_load_iterator(ActivePolicyT(), items_in), + keys_count, + keys_out, + items_out, + tmp_keys_out, + tmp_items_out, + compare_op); + + agent.Process(); +} + +template +__global__ void DeviceMergeSortPartitionKernel(bool ping, + KeyIteratorT keys_ping, + KeyT *keys_pong, + OffsetT keys_count, + OffsetT num_partitions, + OffsetT *merge_partitions, + CompareOpT compare_op, + OffsetT target_merged_tiles_number, + int items_per_tile) +{ + OffsetT partition_idx = blockDim.x * blockIdx.x + threadIdx.x; + + if (partition_idx < num_partitions) + { + AgentPartition agent( + ping, + keys_ping, + keys_pong, + keys_count, + partition_idx, + merge_partitions, + compare_op, + target_merged_tiles_number, + items_per_tile); + + agent.Process(); + } +} + +template < + bool UseVShmem, + typename ActivePolicyT, + typename KeyIteratorT, + typename ValueIteratorT, + typename OffsetT, + typename CompareOpT, + typename KeyT, + typename ValueT> +void __global__ __launch_bounds__(ActivePolicyT::BLOCK_THREADS) +DeviceMergeSortMergeKernel(bool ping, + KeyIteratorT keys_ping, + ValueIteratorT items_ping, + OffsetT keys_count, + KeyT *keys_pong, + ValueT *items_pong, + CompareOpT compare_op, + OffsetT *merge_partitions, + OffsetT target_merged_tiles_number, + char *vshmem + ) +{ + extern __shared__ char shmem[]; + + using AgentMergeT = AgentMerge; + + const OffsetT vshmem_offset = blockIdx.x * AgentMergeT::SHARED_MEMORY_SIZE; + + typename AgentMergeT::TempStorage &storage = + *reinterpret_cast( + UseVShmem ? vshmem + vshmem_offset : shmem); + + AgentMergeT agent( + ping, + storage, + thrust::cuda_cub::core::make_load_iterator(ActivePolicyT(), keys_ping), + thrust::cuda_cub::core::make_load_iterator(ActivePolicyT(), items_ping), + thrust::cuda_cub::core::make_load_iterator(ActivePolicyT(), keys_pong), + thrust::cuda_cub::core::make_load_iterator(ActivePolicyT(), items_pong), + keys_count, + keys_pong, + items_pong, + keys_ping, + items_ping, + compare_op, + merge_partitions, + target_merged_tiles_number); + + agent.Process(); +} + +/****************************************************************************** + * Policy + ******************************************************************************/ + +template +struct DeviceMergeSortPolicy +{ + using KeyT = typename std::iterator_traits::value_type; + + //------------------------------------------------------------------------------ + // Architecture-specific tuning policies + //------------------------------------------------------------------------------ + + struct Policy300 : ChainedPolicy<300, Policy300, Policy300> + { + using MergeSortPolicy = + AgentMergeSortPolicy<128, + Nominal4BItemsToItems(7), + cub::BLOCK_LOAD_WARP_TRANSPOSE, + cub::LOAD_DEFAULT, + cub::BLOCK_STORE_WARP_TRANSPOSE>; + }; + + struct Policy350 : ChainedPolicy<350, Policy350, Policy300> + { + using MergeSortPolicy = + AgentMergeSortPolicy<256, + Nominal4BItemsToItems(11), + cub::BLOCK_LOAD_WARP_TRANSPOSE, + cub::LOAD_LDG, + cub::BLOCK_STORE_WARP_TRANSPOSE>; + }; + + struct Policy520 : ChainedPolicy<520, Policy520, Policy350> + { + using MergeSortPolicy = + AgentMergeSortPolicy<512, + Nominal4BItemsToItems(15), + cub::BLOCK_LOAD_WARP_TRANSPOSE, + cub::LOAD_LDG, + cub::BLOCK_STORE_WARP_TRANSPOSE>; + }; + + struct Policy600 : ChainedPolicy<600, Policy600, Policy520> + { + using MergeSortPolicy = + AgentMergeSortPolicy<256, + Nominal4BItemsToItems(17), + cub::BLOCK_LOAD_WARP_TRANSPOSE, + cub::LOAD_DEFAULT, + cub::BLOCK_STORE_WARP_TRANSPOSE>; + }; + + + /// MaxPolicy + using MaxPolicy = Policy600; +}; + +template +struct BlockSortLauncher +{ + OffsetT num_tiles; + int block_sort_shmem_size; + bool ping; + + KeyInputIteratorT d_input_keys; + ValueInputIteratorT d_input_items; + KeyIteratorT d_output_keys; + ValueIteratorT d_output_items; + OffsetT num_items; + CompareOpT compare_op; + cudaStream_t stream; + + KeyT *keys_buffer; + ValueT *items_buffer; + char* vshmem_ptr; + + CUB_RUNTIME_FUNCTION __forceinline__ + BlockSortLauncher(OffsetT num_tiles, + int block_sort_shmem_size, + bool ping, + KeyInputIteratorT d_input_keys, + ValueInputIteratorT d_input_items, + KeyIteratorT d_output_keys, + ValueIteratorT d_output_items, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream, + KeyT *keys_buffer, + ValueT *items_buffer, + char *vshmem_ptr) + : num_tiles(num_tiles) + , block_sort_shmem_size(block_sort_shmem_size) + , ping(ping) + , d_input_keys(d_input_keys) + , d_input_items(d_input_items) + , d_output_keys(d_output_keys) + , d_output_items(d_output_items) + , num_items(num_items) + , compare_op(compare_op) + , stream(stream) + , keys_buffer(keys_buffer) + , items_buffer(items_buffer) + , vshmem_ptr(vshmem_ptr) + {} + + CUB_RUNTIME_FUNCTION __forceinline__ + void launch() const + { + if (vshmem_ptr) + { + launch_impl(); + } + else + { + launch_impl(); + } + } + + template + CUB_RUNTIME_FUNCTION __forceinline__ void launch_impl() const + { + thrust::cuda_cub::launcher::triple_chevron(num_tiles, + MergePolicyT::BLOCK_THREADS, + block_sort_shmem_size, + stream) + .doit(DeviceMergeSortBlockSortKernel, + ping, + d_input_keys, + d_input_items, + d_output_keys, + d_output_items, + num_items, + keys_buffer, + items_buffer, + compare_op, + vshmem_ptr); + } +}; + +template < + typename KeyIteratorT, + typename ValueIteratorT, + typename OffsetT, + typename MergePolicyT, + typename CompareOpT, + typename KeyT, + typename ValueT> +struct MergeLauncher +{ + OffsetT num_tiles; + int merge_shmem_size; + + KeyIteratorT d_keys; + ValueIteratorT d_items; + OffsetT num_items; + CompareOpT compare_op; + OffsetT *merge_partitions; + cudaStream_t stream; + + KeyT *keys_buffer; + ValueT *items_buffer; + char *vshmem_ptr; + + CUB_RUNTIME_FUNCTION __forceinline__ MergeLauncher(OffsetT num_tiles, + int merge_shmem_size, + KeyIteratorT d_keys, + ValueIteratorT d_items, + OffsetT num_items, + CompareOpT compare_op, + OffsetT *merge_partitions, + cudaStream_t stream, + KeyT *keys_buffer, + ValueT *items_buffer, + char *vshmem_ptr) + : num_tiles(num_tiles) + , merge_shmem_size(merge_shmem_size) + , d_keys(d_keys) + , d_items(d_items) + , num_items(num_items) + , compare_op(compare_op) + , merge_partitions(merge_partitions) + , stream(stream) + , keys_buffer(keys_buffer) + , items_buffer(items_buffer) + , vshmem_ptr(vshmem_ptr) + {} + + CUB_RUNTIME_FUNCTION __forceinline__ void + launch(bool ping, OffsetT target_merged_tiles_number) const + { + if (vshmem_ptr) + { + launch_impl(ping, target_merged_tiles_number); + } + else + { + launch_impl(ping, target_merged_tiles_number); + } + } + + template + CUB_RUNTIME_FUNCTION __forceinline__ void + launch_impl(bool ping, OffsetT target_merged_tiles_number) const + { + thrust::cuda_cub::launcher::triple_chevron(num_tiles, + MergePolicyT::BLOCK_THREADS, + merge_shmem_size, + stream) + .doit(DeviceMergeSortMergeKernel, + ping, + d_keys, + d_items, + num_items, + keys_buffer, + items_buffer, + compare_op, + merge_partitions, + target_merged_tiles_number, + vshmem_ptr); + } +}; + +template > +struct DispatchMergeSort : SelectedPolicy +{ + using KeyT = typename std::iterator_traits::value_type; + using ValueT = typename std::iterator_traits::value_type; + + // Whether or not there are values to be trucked along with keys + static constexpr bool KEYS_ONLY = Equals::VALUE; + + //------------------------------------------------------------------------------ + // Problem state + //------------------------------------------------------------------------------ + + void *d_temp_storage; ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. + std::size_t &temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + KeyInputIteratorT d_input_keys; ///< [in] Pointer to the input sequence of unsorted input keys + ValueInputIteratorT d_input_items;///< [in] Pointer to the input sequence of unsorted input values + KeyIteratorT d_output_keys; ///< [out] Pointer to the output sequence of sorted input keys + ValueIteratorT d_output_items; ///< [out] Pointer to the output sequence of sorted input values + OffsetT num_items; ///< [in] Number of items to sort + CompareOpT compare_op; ///< [in] Comparison function object which returns true if the first argument is ordered before the second + cudaStream_t stream; ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. + bool debug_synchronous; ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false. + int ptx_version; + + //------------------------------------------------------------------------------ + // Constructor + //------------------------------------------------------------------------------ + + CUB_RUNTIME_FUNCTION __forceinline__ std::size_t + vshmem_size(std::size_t max_shmem, + std::size_t shmem_per_block, + std::size_t num_blocks) + { + if (shmem_per_block > max_shmem) + { + return shmem_per_block * num_blocks; + } + else + { + return 0; + } + } + + /// Constructor + CUB_RUNTIME_FUNCTION __forceinline__ + DispatchMergeSort(void *d_temp_storage, + std::size_t &temp_storage_bytes, + KeyInputIteratorT d_input_keys, + ValueInputIteratorT d_input_items, + KeyIteratorT d_output_keys, + ValueIteratorT d_output_items, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream, + bool debug_synchronous, + int ptx_version) + : d_temp_storage(d_temp_storage) + , temp_storage_bytes(temp_storage_bytes) + , d_input_keys(d_input_keys) + , d_input_items(d_input_items) + , d_output_keys(d_output_keys) + , d_output_items(d_output_items) + , num_items(num_items) + , compare_op(compare_op) + , stream(stream) + , debug_synchronous(debug_synchronous) + , ptx_version(ptx_version) + {} + + /// Invocation + template + CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t Invoke() + { + using MergePolicyT = typename ActivePolicyT::MergeSortPolicy; + + using BlockSortAgentT = AgentBlockSort; + + using MergeAgentT = AgentMerge; + + cudaError error = cudaSuccess; + + if (num_items == 0) + return error; + + do + { + // Get device ordinal + int device_ordinal = 0; + if (CubDebug(error = cudaGetDevice(&device_ordinal))) + { + break; + } + + // Get shared memory size + int max_shmem = 0; + if (CubDebug(error = cudaDeviceGetAttribute(&max_shmem, + cudaDevAttrMaxSharedMemoryPerBlock, + device_ordinal))) + { + break; + } + + int tile_size = MergePolicyT::ITEMS_PER_TILE; + OffsetT num_tiles = cub::DivideAndRoundUp(num_items, tile_size); + + std::size_t block_sort_shmem_size = BlockSortAgentT::SHARED_MEMORY_SIZE; + std::size_t merge_shmem_size = MergeAgentT::SHARED_MEMORY_SIZE; + + std::size_t merge_partitions_size = (1 + num_tiles) * sizeof(OffsetT); + std::size_t temporary_keys_storage_size = num_items * sizeof(KeyT); + std::size_t temporary_values_storage_size = num_items * sizeof(ValueT) * !KEYS_ONLY; + std::size_t virtual_shared_memory_size = + vshmem_size(max_shmem, + (cub::max)(block_sort_shmem_size, merge_shmem_size), + num_tiles); + + void* allocations[4] = {nullptr, nullptr, nullptr, nullptr}; + std::size_t allocation_sizes[4] = {merge_partitions_size, + temporary_keys_storage_size, + temporary_values_storage_size, + virtual_shared_memory_size}; + + if (CubDebug(error = AliasTemporaries(d_temp_storage, + temp_storage_bytes, + allocations, + allocation_sizes))) + { + break; + } + + if (d_temp_storage == nullptr) + { + // Return if the caller is simply requesting the size of the storage allocation + break; + } + + int num_passes = static_cast(thrust::detail::log2_ri(num_tiles)); + + /* + * The algorithm consists of stages. At each stage, there are input and + * output arrays. There are two pairs of arrays allocated (keys and items). + * One pair is from function arguments and another from temporary storage. + * Ping is a helper variable that controls which of these two pairs of + * arrays is an input and which is an output for a current stage. If the + * ping is true - the current stage stores its result in the temporary + * storage. The temporary storage acts as input data otherwise. + * + * Block sort is executed before the main loop. It stores its result in + * the pair of arrays that will be an input of the next stage. The initial + * value of the ping variable is selected so that the result of the final + * stage is stored in the input arrays. + */ + bool ping = num_passes % 2 == 0; + + auto merge_partitions = (OffsetT *)allocations[0]; + auto keys_buffer = (KeyT *)allocations[1]; + auto items_buffer = (ValueT *)allocations[2]; + + char *vshmem_ptr = virtual_shared_memory_size > 0 ? (char *)allocations[3] + : nullptr; + + // Invoke DeviceReduceKernel + BlockSortLauncher + block_sort_launcher(num_tiles, + virtual_shared_memory_size > 0 + ? 0 + : block_sort_shmem_size, + ping, + d_input_keys, + d_input_items, + d_output_keys, + d_output_items, + num_items, + compare_op, + stream, + keys_buffer, + items_buffer, + vshmem_ptr); + + block_sort_launcher.launch(); + + if (debug_synchronous) + { + if (CubDebug(error = SyncStream(stream))) + { + break; + } + } + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) break; + + std::size_t num_partitions = num_tiles + 1; + const int threads_per_partition_block = 256; + const std::size_t partition_grid_size = + cub::DivideAndRoundUp(num_partitions, threads_per_partition_block); + + MergeLauncher + merge_launcher(num_tiles, + virtual_shared_memory_size > 0 ? 0 : merge_shmem_size, + d_output_keys, + d_output_items, + num_items, + compare_op, + merge_partitions, + stream, + keys_buffer, + items_buffer, + vshmem_ptr); + + for (int pass = 0; pass < num_passes; ++pass, ping = !ping) + { + OffsetT target_merged_tiles_number = OffsetT(2) << pass; + + // Partition + thrust::cuda_cub::launcher::triple_chevron(partition_grid_size, + threads_per_partition_block, + 0, + stream) + .doit(DeviceMergeSortPartitionKernel, + ping, + d_output_keys, + keys_buffer, + num_items, + num_partitions, + merge_partitions, + compare_op, + target_merged_tiles_number, + tile_size); + + if (debug_synchronous) + { + if (CubDebug(error = SyncStream(stream))) + { + break; + } + } + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) + { + break; + } + + // Merge + merge_launcher.launch(ping, target_merged_tiles_number); + + if (debug_synchronous) + { + if (CubDebug(error = SyncStream(stream))) + { + break; + } + } + + // Check for failure to launch + if (CubDebug(error = cudaPeekAtLastError())) + { + break; + } + } + } + while (0); + + return error; + } + + CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t + Dispatch(void *d_temp_storage, + std::size_t &temp_storage_bytes, + KeyInputIteratorT d_input_keys, + ValueInputIteratorT d_input_items, + KeyIteratorT d_output_keys, + ValueIteratorT d_output_items, + OffsetT num_items, + CompareOpT compare_op, + cudaStream_t stream, + bool debug_synchronous) + { + using MaxPolicyT = typename DispatchMergeSort::MaxPolicy; + + cudaError error = cudaSuccess; + do + { + // Get PTX version + int ptx_version = 0; + if (CubDebug(error = PtxVersion(ptx_version))) + { + break; + } + + // Create dispatch functor + DispatchMergeSort dispatch(d_temp_storage, + temp_storage_bytes, + d_input_keys, + d_input_items, + d_output_keys, + d_output_items, + num_items, + compare_op, + stream, + debug_synchronous, + ptx_version); + + // Dispatch to chained policy + if (CubDebug(error = MaxPolicyT::Invoke(ptx_version, dispatch))) + { + break; + } + } while (0); + + return error; + } +}; + + +} // CUB namespace +CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/cub/util_macro.cuh b/cub/util_macro.cuh index ff86365422..697944fed8 100644 --- a/cub/util_macro.cuh +++ b/cub/util_macro.cuh @@ -34,6 +34,8 @@ #include "util_namespace.cuh" +#include + /// Optional outer namespace(s) CUB_NS_PREFIX @@ -56,6 +58,24 @@ namespace cub { #endif #endif +#define CUB_PREVENT_MACRO_SUBSTITUTION + +template +constexpr __host__ __device__ auto min CUB_PREVENT_MACRO_SUBSTITUTION(T &&t, + U &&u) + -> decltype(t < u ? std::forward(t) : std::forward(u)) +{ + return t < u ? std::forward(t) : std::forward(u); +} + +template +constexpr __host__ __device__ auto max CUB_PREVENT_MACRO_SUBSTITUTION(T &&t, + U &&u) + -> decltype(t < u ? std::forward(u) : std::forward(t)) +{ + return t < u ? std::forward(u) : std::forward(t); +} + #ifndef CUB_MAX /// Select maximum(a, b) #define CUB_MAX(a, b) (((b) > (a)) ? (b) : (a)) diff --git a/cub/util_math.cuh b/cub/util_math.cuh index 21bf843e12..61b0932ea0 100644 --- a/cub/util_math.cuh +++ b/cub/util_math.cuh @@ -71,5 +71,43 @@ DivideAndRoundUp(NumeratorT n, DenominatorT d) return static_cast(n / d + (n % d != 0 ? 1 : 0)); } +template +constexpr __device__ __host__ int +Nominal4BItemsToItems(int nominal_4b_items_per_thread) +{ + constexpr int type_size = static_cast(sizeof(T)); + + return (cub::min)( + nominal_4b_items_per_thread, + (cub::max)(1, (nominal_4b_items_per_thread * 4 / type_size))); +} + +template +constexpr __device__ __host__ int +Nominal8BItemsToItems(int nominal_8b_items_per_thread) +{ + constexpr int input_size = sizeof(ItemT); + return input_size <= 8 + ? nominal_8b_items_per_thread + : (cub::min)(nominal_8b_items_per_thread, + (cub::max)(1, + ((nominal_8b_items_per_thread * 8) + + input_size - 1) / + input_size)); +} + +/** + * \brief Computes the midpoint of the integers + * + * Extra operation is performed in order to prevent overflow. + * + * \return Half the sum of \p begin and \p end + */ +template +constexpr __device__ __host__ T MidPoint(T begin, T end) +{ + return begin + (end - begin) / 2; +} + } // namespace cub CUB_NS_POSTFIX // Optional outer namespace(s) diff --git a/test/test_block_merge_sort.cu b/test/test_block_merge_sort.cu new file mode 100644 index 0000000000..5022d3dbf1 --- /dev/null +++ b/test/test_block_merge_sort.cu @@ -0,0 +1,444 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/****************************************************************************** + * Test of BlockMergeSort utilities + ******************************************************************************/ + +// Ensure printing of CUDA runtime errors to console +#define CUB_STDERR + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "test_util.h" + +using namespace cub; + +struct CustomType +{ + std::uint8_t key; + std::uint64_t count; + + __device__ __host__ CustomType() + : key(0) + , count(0) + {} + + __device__ __host__ CustomType(std::uint64_t value) + : key(value) // overflow + , count(value) + {} + + __device__ __host__ void operator=(std::uint64_t value) + { + key = value; // overflow + count = value; + } +}; + + +struct CustomLess +{ + template + __device__ bool operator()(DataType &lhs, DataType &rhs) + { + return lhs < rhs; + } + + __device__ bool operator()(CustomType &lhs, CustomType &rhs) + { + return lhs.key < rhs.key; + } +}; + +template < + typename DataType, + unsigned int ThreadsInBlock, + unsigned int ItemsPerThread, + bool Stable = false> +__global__ void BlockMergeSortTestKernel(DataType *data, unsigned int valid_items) +{ + using BlockMergeSort = + cub::BlockMergeSort; + + __shared__ typename BlockMergeSort::TempStorage temp_storage_shuffle; + + DataType thread_data[ItemsPerThread]; + + const unsigned int thread_offset = threadIdx.x * ItemsPerThread; + + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + const unsigned int idx = thread_offset + item; + thread_data[item] = idx < valid_items ? data[idx] : DataType(); + } + __syncthreads(); + + // Tests below use sequence to fill the data. + // Therefore the following value should be greater than any that + // is present in the input data. + const DataType oob_default = + static_cast(ThreadsInBlock * ItemsPerThread + 1); + + if (Stable) + { + if (valid_items == ThreadsInBlock * ItemsPerThread) + { + BlockMergeSort(temp_storage_shuffle).StableSort( + thread_data, + CustomLess()); + } + else + { + BlockMergeSort(temp_storage_shuffle).StableSort( + thread_data, + CustomLess(), + valid_items, + oob_default); + } + } + else + { + if (valid_items == ThreadsInBlock * ItemsPerThread) + { + BlockMergeSort(temp_storage_shuffle).Sort( + thread_data, + CustomLess()); + } + else + { + BlockMergeSort(temp_storage_shuffle).Sort( + thread_data, + CustomLess(), + valid_items, + oob_default); + } + } + + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + const unsigned int idx = thread_offset + item; + + if (idx >= valid_items) + break; + + data[idx] = thread_data[item]; + } +} + +template < + typename KeyType, + typename ValueType, + unsigned int ThreadsInBlock, + unsigned int ItemsPerThread, + bool Stable = false> +__global__ void BlockMergeSortTestKernel(KeyType *keys, + ValueType *values, + unsigned int valid_items) +{ + using BlockMergeSort = + cub::BlockMergeSort; + + __shared__ typename BlockMergeSort::TempStorage temp_storage_shuffle; + + KeyType thread_keys[ItemsPerThread]; + ValueType thread_values[ItemsPerThread]; + + const unsigned int thread_offset = threadIdx.x * ItemsPerThread; + + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + const unsigned int idx = thread_offset + item; + thread_keys[item] = idx < valid_items ? keys[idx] : KeyType(); + thread_values[item] = idx < valid_items ? values[idx] : ValueType(); + } + __syncthreads(); + + // Tests below use sequence to fill the data. + // Therefore the following value should be greater than any that + // is present in the input data. + const KeyType oob_default = ThreadsInBlock * ItemsPerThread + 1; + + if (Stable) + { + if (valid_items == ThreadsInBlock * ItemsPerThread) + { + BlockMergeSort(temp_storage_shuffle).StableSort( + thread_keys, + thread_values, + CustomLess()); + } + else + { + BlockMergeSort(temp_storage_shuffle).StableSort( + thread_keys, + thread_values, + CustomLess(), + valid_items, + oob_default); + } + } + else + { + if (valid_items == ThreadsInBlock * ItemsPerThread) + { + BlockMergeSort(temp_storage_shuffle).Sort( + thread_keys, + thread_values, + CustomLess()); + } + else + { + BlockMergeSort(temp_storage_shuffle).Sort( + thread_keys, + thread_values, + CustomLess(), + valid_items, + oob_default); + } + } + + for (unsigned int item = 0; item < ItemsPerThread; item++) + { + const unsigned int idx = thread_offset + item; + + if (idx >= valid_items) + break; + + keys[idx] = thread_keys[item]; + values[idx] = thread_values[item]; + } +} + +template< + typename DataType, + unsigned int ItemsPerThread, + unsigned int ThreadsInBlock, + bool Stable = false> +void BlockMergeSortTest(DataType *data, unsigned int valid_items) +{ + BlockMergeSortTestKernel + <<<1, ThreadsInBlock>>>(data, valid_items); + + CubDebugExit(cudaPeekAtLastError()); + CubDebugExit(cudaDeviceSynchronize()); +} + +template< + typename KeyType, + typename ValueType, + unsigned int ItemsPerThread, + unsigned int ThreadsInBlock> +void BlockMergeSortTest(KeyType *keys, ValueType *values, unsigned int valid_items) +{ + BlockMergeSortTestKernel + <<<1, ThreadsInBlock>>>(keys, values, valid_items); + + CubDebugExit(cudaPeekAtLastError()); + CubDebugExit(cudaDeviceSynchronize()); +} + +template +bool CheckResult(int num_items, + thrust::device_vector &d_data, + thrust::host_vector &h_data) +{ + thrust::copy_n(d_data.begin(), num_items, h_data.begin()); + + for (int i = 0; i < num_items; i++) + { + if (h_data[i] != i) + { + return false; + } + } + + return true; +} + +template < + typename DataType, + unsigned int ItemsPerThread, + unsigned int ThreadsInBlock> +void Test(unsigned int num_items, + thrust::default_random_engine &rng, + thrust::device_vector &d_data, + thrust::host_vector &h_data) +{ + thrust::sequence(d_data.begin(), d_data.end()); + thrust::shuffle(d_data.begin(), d_data.end(), rng); + + BlockMergeSortTest( + thrust::raw_pointer_cast(d_data.data()), num_items); + + AssertTrue(CheckResult(num_items, d_data, h_data)); +} + +template < + typename KeyType, + typename ValueType, + unsigned int ItemsPerThread, + unsigned int ThreadsInBlock> +void Test(unsigned int num_items, + thrust::default_random_engine &rng, + thrust::device_vector &d_keys, + thrust::device_vector &d_values, + thrust::host_vector &h_data) +{ + thrust::sequence(d_keys.begin(), d_keys.end()); + thrust::shuffle(d_keys.begin(), d_keys.end(), rng); + thrust::copy_n(d_keys.begin(), num_items, d_values.begin()); + + BlockMergeSortTest( + thrust::raw_pointer_cast(d_keys.data()), + thrust::raw_pointer_cast(d_values.data()), + num_items); + + AssertTrue(CheckResult(num_items, d_values, h_data)); +} + +template < + typename KeyType, + typename ValueType, + unsigned int ItemsPerThread, + unsigned int ThreadsInBlock> +void Test(thrust::default_random_engine &rng) +{ + for (unsigned int num_items = ItemsPerThread * ThreadsInBlock; + num_items > 1; + num_items /= 2) + { + thrust::device_vector d_keys(num_items); + thrust::device_vector d_values(num_items); + thrust::host_vector h_keys(num_items); + thrust::host_vector h_values(num_items); + + Test(num_items, + rng, + d_keys, + h_keys); + + Test(num_items, + rng, + d_keys, + d_values, + h_values); + } +} + +template +void Test(thrust::default_random_engine &rng) +{ + Test(rng); + Test(rng); + + // Mixed types + Test(rng); + Test(rng); +} + +template +void Test(thrust::default_random_engine &rng) +{ + Test(rng); + Test(rng); +} + +struct CountToType +{ + __device__ __host__ CustomType operator()(std::uint64_t val) + { + return { val }; + } +}; + +struct CountComparator +{ + __device__ __host__ bool operator()(const CustomType &lhs, const CustomType &rhs) + { + if (lhs.key == rhs.key) + return lhs.count < rhs.count; + + return lhs.key < rhs.key; + } +}; + +void TestStability() +{ + constexpr unsigned int items_per_thread = 10; + constexpr unsigned int threads_per_block = 128; + constexpr unsigned int elements = items_per_thread * threads_per_block; + constexpr bool stable = true; + + thrust::device_vector d_keys(elements); + thrust::device_vector d_counts(elements); + thrust::sequence(d_counts.begin(), d_counts.end()); + thrust::transform(d_counts.begin(), d_counts.end(), d_keys.begin(), CountToType{}); + + // Sort keys + BlockMergeSortTest( + thrust::raw_pointer_cast(d_keys.data()), + elements); + + // Check counts + AssertTrue(thrust::is_sorted(d_keys.begin(), d_keys.end(), CountComparator{})); +} + +int main(int argc, char** argv) +{ + CommandLineArgs args(argc, argv); + + // Initialize device + CubDebugExit(args.DeviceInit()); + + thrust::default_random_engine rng; + + Test<1>(rng); + Test<2>(rng); + Test<10>(rng); + Test<15>(rng); + + Test(rng); + Test(rng); + + TestStability(); + + return 0; +} diff --git a/test/test_device_merge_sort.cu b/test/test_device_merge_sort.cu new file mode 100644 index 0000000000..895ce4c74c --- /dev/null +++ b/test/test_device_merge_sort.cu @@ -0,0 +1,344 @@ +/****************************************************************************** + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/****************************************************************************** + * Test of DeviceMergeSort utilities + ******************************************************************************/ + +// Ensure printing of CUDA runtime errors to console +#define CUB_STDERR + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "test_util.h" + +using namespace cub; + +struct CustomLess +{ + template + __device__ bool operator()(DataType &lhs, DataType &rhs) + { + return lhs < rhs; + } +}; + +template +bool CheckResult(thrust::device_vector &d_data) +{ + const bool is_sorted = thrust::is_sorted(d_data.begin(), d_data.end(), CustomLess()); + return is_sorted; +} + +template +struct ValueToKey +{ + __device__ __host__ KeyType operator()(const ValueType &val) + { + return val; + } +}; + +template +struct ValueToKey +{ + __device__ __host__ HugeDataType operator()(const ValueType &val) + { + return HugeDataType(val); + } +}; + +template +void Test(std::int64_t num_items, + thrust::default_random_engine &rng, + thrust::device_vector &d_keys, + thrust::device_vector &d_values) +{ + thrust::sequence(d_values.begin(), d_values.end()); + thrust::shuffle(d_values.begin(), d_values.end(), rng); + + thrust::transform(d_values.begin(), + d_values.end(), + d_keys.begin(), + ValueToKey()); + + thrust::device_vector d_keys_before_sort(d_keys); + thrust::device_vector d_values_before_sort(d_values); + + thrust::device_vector d_keys_before_sort_copy(d_keys); + thrust::device_vector d_values_before_sort_copy(d_values); + + size_t temp_size = 0; + CubDebugExit(cub::DeviceMergeSort::SortPairs( + nullptr, + temp_size, + thrust::raw_pointer_cast(d_keys.data()), + thrust::raw_pointer_cast(d_values.data()), + num_items, + CustomLess(), + 0, + true)); + + thrust::device_vector tmp(temp_size); + + CubDebugExit(cub::DeviceMergeSort::SortPairs( + thrust::raw_pointer_cast(tmp.data()), + temp_size, + thrust::raw_pointer_cast(d_keys.data()), + thrust::raw_pointer_cast(d_values.data()), + num_items, + CustomLess(), + 0, + true)); + + thrust::device_vector d_keys_after_sort_copy(d_keys); + thrust::device_vector d_values_after_sort_copy(d_values); + + AssertTrue(CheckResult(d_values)); + + CubDebugExit(cub::DeviceMergeSort::SortPairsCopy( + thrust::raw_pointer_cast(tmp.data()), + temp_size, + thrust::raw_pointer_cast(d_keys_before_sort.data()), + thrust::raw_pointer_cast(d_values_before_sort.data()), + thrust::raw_pointer_cast(d_keys.data()), + thrust::raw_pointer_cast(d_values.data()), + num_items, + CustomLess(), + 0, + true)); + + AssertEquals(d_keys, d_keys_after_sort_copy); + AssertEquals(d_values, d_values_after_sort_copy); + AssertEquals(d_keys_before_sort, d_keys_before_sort_copy); + AssertEquals(d_values_before_sort, d_values_before_sort_copy); + + // At the moment stable sort is an alias to sort, so it's safe to use + // temp_size storage allocated before + CubDebugExit(cub::DeviceMergeSort::StableSortPairs( + thrust::raw_pointer_cast(tmp.data()), + temp_size, + thrust::raw_pointer_cast(d_keys.data()), + thrust::raw_pointer_cast(d_values.data()), + num_items, + CustomLess(), + 0, + true)); + + AssertTrue(CheckResult(d_values)); + + CubDebugExit(cub::DeviceMergeSort::SortPairsCopy( + thrust::raw_pointer_cast(tmp.data()), + temp_size, + thrust::constant_iterator(KeyType(42)), + thrust::counting_iterator(DataType(0)), + thrust::raw_pointer_cast(d_keys.data()), + thrust::raw_pointer_cast(d_values.data()), + num_items, + CustomLess(), + 0, + true)); + + thrust::sequence(d_values_before_sort.begin(), d_values_before_sort.end()); + + AssertEquals(d_values, d_values_before_sort); +} + +template +void TestKeys(std::int64_t num_items, + thrust::default_random_engine &rng, + thrust::device_vector &d_keys, + thrust::device_vector &d_values) +{ + thrust::sequence(d_values.begin(), d_values.end()); + thrust::shuffle(d_values.begin(), d_values.end(), rng); + + thrust::transform(d_values.begin(), + d_values.end(), + d_keys.begin(), + ValueToKey()); + + thrust::device_vector d_before_sort(d_keys); + thrust::device_vector d_before_sort_copy(d_keys); + + size_t temp_size = 0; + cub::DeviceMergeSort::SortKeys( + nullptr, + temp_size, + thrust::raw_pointer_cast(d_keys.data()), + num_items, + CustomLess()); + + thrust::device_vector tmp(temp_size); + + CubDebugExit(cub::DeviceMergeSort::SortKeys( + thrust::raw_pointer_cast(tmp.data()), + temp_size, + thrust::raw_pointer_cast(d_keys.data()), + num_items, + CustomLess(), + 0, + true)); + + thrust::device_vector d_after_sort(d_keys); + + AssertTrue(CheckResult(d_keys)); + + CubDebugExit(cub::DeviceMergeSort::SortKeysCopy( + thrust::raw_pointer_cast(tmp.data()), + temp_size, + thrust::raw_pointer_cast(d_before_sort.data()), + thrust::raw_pointer_cast(d_keys.data()), + num_items, + CustomLess(), + 0, + true)); + + AssertTrue(d_keys == d_after_sort); + AssertTrue(d_before_sort == d_before_sort_copy); + + // At the moment stable sort is an alias to sort, so it's safe to use + // temp_size storage allocated before + CubDebugExit(cub::DeviceMergeSort::StableSortKeys( + thrust::raw_pointer_cast(tmp.data()), + temp_size, + thrust::raw_pointer_cast(d_keys.data()), + num_items, + CustomLess(), + 0, + true)); + + AssertTrue(CheckResult(d_keys)); +} + +template +struct TestHelper +{ + template + static void AllocateAndTest(thrust::default_random_engine &rng, unsigned int num_items) + { + thrust::device_vector d_keys(num_items); + thrust::device_vector d_values(num_items); + + Test(num_items, rng, d_keys, d_values); + TestKeys(num_items, rng, d_keys, d_values); + } +}; + +template <> +struct TestHelper +{ + template + static void AllocateAndTest(thrust::default_random_engine &, unsigned int) + {} +}; + +template +void Test(thrust::default_random_engine &rng, unsigned int num_items) +{ + TestHelper::template AllocateAndTest(rng, num_items); + TestHelper::template AllocateAndTest(rng, num_items); + TestHelper::template AllocateAndTest(rng, num_items); +} + +template +void AllocateAndTestIterators(unsigned int num_items) +{ + thrust::device_vector d_keys(num_items); + thrust::device_vector d_values(num_items); + + thrust::sequence(d_keys.begin(), d_keys.end()); + thrust::sequence(d_values.begin(), d_values.end()); + + thrust::reverse(d_values.begin(), d_values.end()); + + using KeyIterator = typename thrust::device_vector::iterator; + thrust::reverse_iterator reverse_iter(d_keys.end()); + + size_t temp_size = 0; + cub::DeviceMergeSort::SortPairs( + nullptr, + temp_size, + reverse_iter, + thrust::raw_pointer_cast(d_values.data()), + num_items, + CustomLess()); + + thrust::device_vector tmp(temp_size); + + cub::DeviceMergeSort::SortPairs( + thrust::raw_pointer_cast(tmp.data()), + temp_size, + reverse_iter, + thrust::raw_pointer_cast(d_values.data()), + num_items, + CustomLess()); + + AssertTrue(CheckResult(d_values)); +} + +template +void Test(thrust::default_random_engine &rng) +{ + for (int pow2 = 9; pow2 < 22; pow2 += 2) + { + const int num_items = std::pow(2, pow2); + AllocateAndTestIterators(num_items); + TestHelper::AllocateAndTest(rng, num_items); + Test(rng, num_items); + } +} + +int main(int argc, char** argv) +{ + CommandLineArgs args(argc, argv); + + // Initialize device + CubDebugExit(args.DeviceInit()); + + thrust::default_random_engine rng; + + Test(rng); + Test(rng); + + return 0; +} diff --git a/test/test_util.h b/test/test_util.h index 609ef6c744..d6a3f1a85f 100644 --- a/test/test_util.h +++ b/test/test_util.h @@ -85,6 +85,8 @@ T SafeBitCast(const U& in) */ #define AssertEquals(a, b) if ((a) != (b)) { std::cerr << "\n(" << __FILE__ << ": " << __LINE__ << ")\n"; exit(1);} +#define AssertTrue(a) if (!(a)) { std::cerr << "\n(" << __FILE__ << ": " << __LINE__ << ")\n"; exit(1);} + /****************************************************************************** * Command-line parsing functionality @@ -1718,3 +1720,77 @@ struct GpuTimer return elapsed; } }; + +struct HugeDataType +{ + static constexpr int ELEMENTS_PER_OBJECT = 128; + + __device__ __host__ HugeDataType() + { + for (int i = 0; i < ELEMENTS_PER_OBJECT; i++) + { + data[i] = 0; + } + } + + __device__ __host__ HugeDataType(const HugeDataType&rhs) + { + for (int i = 0; i < ELEMENTS_PER_OBJECT; i++) + { + data[i] = rhs.data[i]; + } + } + + explicit __device__ __host__ HugeDataType(int val) + { + for (int i = 0; i < ELEMENTS_PER_OBJECT; i++) + { + data[i] = val; + } + } + + int data[ELEMENTS_PER_OBJECT]; +}; + +inline __device__ __host__ bool operator==(const HugeDataType &lhs, + const HugeDataType &rhs) +{ + for (int i = 0; i < HugeDataType::ELEMENTS_PER_OBJECT; i++) + { + if (lhs.data[i] != rhs.data[i]) + { + return false; + } + } + + return true; +} + +inline __device__ __host__ bool operator<(const HugeDataType &lhs, + const HugeDataType &rhs) +{ + for (int i = 0; i < HugeDataType::ELEMENTS_PER_OBJECT; i++) + { + if (lhs.data[i] < rhs.data[i]) + { + return true; + } + } + + return false; +} + +template +__device__ __host__ bool operator!=(const HugeDataType &lhs, + const DataType &rhs) +{ + for (int i = 0; i < HugeDataType::ELEMENTS_PER_OBJECT; i++) + { + if (lhs.data[i] != rhs) + { + return true; + } + } + + return false; +}