IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /docs/manual/basics.md). For the complete Mojo documentation index, see llms.txt.
Skip to main content
Version: 1.0
For the complete Mojo documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /docs/manual/basics.md).

twophase_reduce_kernel

twophase_reduce_kernel[rank: Int, axis: Int, num_reductions: Int, BLOCK_SIZE: Int, input_fn: def[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: def[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None, reduce_fn: def[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width], dtype: DType, simd_width: Int, accum_type: DType = get_accum_type[dtype]()](shape: IndexList[rank], init: StaticTuple[Scalar[dtype], num_reductions], partials: UnsafePointer[Scalar[accum_type], MutAnyOrigin], counters: UnsafePointer[Int32, MutAnyOrigin], blocks_per_row: Int)

GPU kernel for reductions when there are too few rows to saturate the device at one block per row. Assigns multiple blocks per row and uses a two-phase approach: each block reduces a chunk via cooperative block-level reduction, then the last block to finish (detected via a per-row atomic counter) reduces all partial results for its row.

Parameters:

  • rank (Int): The tensor rank.
  • axis (Int): The axis along which to reduce.
  • num_reductions (Int): The number of fused reductions to perform.
  • BLOCK_SIZE (Int): The number of threads per block.
  • input_fn (def[dtype: DType, width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]): The lambda to load input elements.
  • output_fn (def[dtype: DType, width: Int, rank: Int](IndexList[rank], StaticTuple[SIMD[dtype, width], num_reductions]) capturing -> None): The lambda to store output elements.
  • reduce_fn (def[ty: DType, width: Int, reduction_idx: Int](SIMD[ty, width], SIMD[ty, width]) capturing -> SIMD[ty, width]): The binary reduction function.
  • dtype (DType): The data type of the elements.
  • simd_width (Int): The SIMD vector width.
  • accum_type (DType): The accumulator data type.

Args: