IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /docs/manual/basics.md). For the complete Mojo documentation index, see llms.txt.
Skip to main content
Version: Nightly
For the complete Mojo documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /docs/manual/basics.md).

TMATensorTile

struct TMATensorTile[dtype: DType, rank: Int, tile_shape: IndexList[rank], desc_shape: IndexList[rank] = tile_shape, is_k_major: Bool = True]

A hardware-accelerated tensor memory access (TMA) tile for efficient asynchronous data movement.

The TMATensorTile struct provides a high-performance interface for asynchronous data transfers between global memory and shared memory in GPU tensor operations. It encapsulates a TMA descriptor that defines the memory access pattern and provides methods for various asynchronous operations.

Performance:

  • Hardware-accelerated memory transfers using TMA instructions
  • Supports prefetching of descriptors for latency hiding
  • Enforces 128-byte alignment requirements for optimal memory access

Parameters

  • dtype (DType): DType The data type of the tensor elements.
  • rank (Int): Int The dimensionality of the tile (2, 3, 4, or 5).
  • tile_shape (IndexList[rank]): IndexList[rank] The shape of the tile in shared memory.
  • desc_shape (IndexList[rank]): IndexList[rank] = tile_shape The shape of the descriptor, which can be different from the tile shape to accommodate hardware requirements like WGMMA.
  • is_k_major (Bool): Bool = True Whether the shared memory is k-major.

Fields

  • descriptor (TMADescriptor): The TMA descriptor that defines the memory access pattern. This field stores the hardware descriptor that encodes information about:

    • The source tensor's memory layout and dimensions
    • The tile shape and access pattern
    • Swizzling configuration for optimal memory access

    The descriptor is used by the GPU's Tensor Memory Accelerator hardware to efficiently transfer data between global and shared memory.

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

comptime members

device_type

comptime device_type = TMATensorTile[dtype, rank, tile_shape, desc_shape, is_k_major]

The device-side type representation.

Methods

__init__

@implicit __init__(out self, descriptor: TMADescriptor)

Initializes a new TMATensorTile with the provided TMA descriptor.

Args:

  • descriptor (TMADescriptor): The TMA descriptor that defines the memory access pattern.

__init__(out self, *, copy: Self)

Copy initializes this TMATensorTile from another instance.

Args:

  • copy (Self): The other TMATensorTile instance to copy from.

get_type_name

static get_type_name() -> String

Gets this type's name, for use in error messages when handing arguments to kernels.

Returns:

String: This type's name.

prefetch_descriptor

prefetch_descriptor(self)

Prefetches the TMA descriptor into cache to reduce latency.

This method helps hide memory access latency by prefetching the descriptor before it's needed for actual data transfers.

async_copy

async_copy[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int])

Schedules an asynchronous copy from global memory to shared memory at specified coordinates.

This method initiates a hardware-accelerated asynchronous transfer of data from global memory to the specified destination in shared memory. The transfer is tracked by the provided memory barrier.

Constraints:

  • The destination tensor must be 128-byte aligned in shared memory.
  • The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements.

Parameters:

  • cta_group (Int): Int If the TMA is issued with cta_group == 2, only the leader CTA needs to be notified upon completion.
  • eviction_policy (CacheEviction): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT_NORMAL.

Args:

async_copy[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int])

TileTensor overload for 2D async copy from global to shared memory.

Parameters:

  • cta_group (Int): If the TMA is issued with cta_group == 2, only the leader CTA needs to be notified upon completion.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy[coord_rank: Int, //, cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: StaticTuple[UInt32, coord_rank])

Schedules an asynchronous copy from global memory to shared memory for N-dimensional tensors.

This is a generic dispatcher that selects the appropriate rank-specific async copy method based on the tensor rank. It provides a unified interface for initiating TMA transfers across 2D, 3D, 4D, and 5D tensors using StaticTuple coordinates.

Constraints:

  • The coord_rank must be 2, 3, 4, or 5.
  • The destination tensor must be 128-byte aligned in shared memory.

Parameters:

  • coord_rank (Int): The dimensionality of the tensor (must be 2, 3, 4, or 5).
  • cta_group (Int): If set to 2, only the leader CTA needs to be notified upon completion. Defaults to 1.
  • eviction_policy (CacheEviction): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT_NORMAL.

Args:

async_copy[coord_rank: Int, //, cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: StaticTuple[UInt32, coord_rank])

TileTensor overload of the generic rank-dispatched async_copy. Dispatches to the rank-specific TileTensor async_copy methods.

Parameters:

  • coord_rank (Int): The dimensionality (must be >=2 and <= 5).
  • cta_group (Int): CTA group configuration. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy_elect

async_copy_elect[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int], elect: Int32)

Elect-predicated overload of async_copy (2D).

Each unrolled cp_async_bulk_tensor_shared_cluster_global issue is predicated in-PTX on elect: the TMA fires only on the elected lane. All lanes follow the same PTX control flow — no warp-divergent if elect != 0: is needed at the call site.

Parameters:

  • cta_group (Int): If the TMA is issued with cta_group == 2, only the leader CTA is notified on completion. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy_elect[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int], elect: Int32)

Elect-predicated TileTensor overload of async_copy (2D).

See the LayoutTensor overload of async_copy_elect for semantics.

Parameters:

  • cta_group (Int): If the TMA is issued with cta_group == 2, only the leader CTA is notified on completion. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy_elect[coord_rank: Int, //, cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: StaticTuple[UInt32, coord_rank], elect: Int32)

Elect-predicated rank-dispatched overload of async_copy.

Dispatches to the rank-specific async_copy_*_elect methods. Each underlying TMA issue is predicated in-PTX on elect.

Parameters:

  • coord_rank (Int): The dimensionality (must be 2, 3, 4, or 5).
  • cta_group (Int): If set to 2, the TMA emits cta_group::2 PTX. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy_elect[coord_rank: Int, //, cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: StaticTuple[UInt32, coord_rank], elect: Int32)

Elect-predicated TileTensor rank-dispatched overload of async_copy. Dispatches to the rank-specific _elect methods.

Parameters:

  • coord_rank (Int): The dimensionality (must be 2, 3, 4, or 5).
  • cta_group (Int): If set to 2, the TMA emits cta_group::2 PTX. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy_3d

async_copy_3d[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int])

Schedules an asynchronous copy from global memory to shared memory at specified 3D coordinates.

This method initiates a hardware-accelerated asynchronous transfer of data from global memory to the specified destination in shared memory for 3D tensors. The transfer is tracked by the provided memory barrier.

Constraints:

  • The destination tensor must be 128-byte aligned in shared memory.
  • The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements.

Parameters:

  • cta_group (Int): If set to 2, the TMA emits cta_group::2 PTX so the mbarrier arrival routes to the leader CTA's barrier — required for pair-CTA kernels that share one barrier across the pair. Defaults to 1.
  • eviction_policy (CacheEviction): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT_NORMAL.

Args:

async_copy_3d[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int])

TileTensor overload for 3D async copy from global to shared memory.

Assumes 128B alignment (TileTensor tiles are allocated with proper alignment by the caller's SMEM layout).

Parameters:

  • cta_group (Int): If set to 2, the TMA emits cta_group::2 PTX so the mbarrier arrival routes to the leader CTA's barrier — required for pair-CTA kernels that share one barrier across the pair. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy_3d_elect

async_copy_3d_elect[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int], elect: Int32)

Elect-predicated overload of async_copy_3d.

See async_copy_elect for semantics — each unrolled TMA issue is predicated in-PTX on elect.

Parameters:

  • cta_group (Int): If set to 2, the TMA emits cta_group::2 PTX so the mbarrier arrival routes to the leader CTA's barrier. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy_3d_elect[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int], elect: Int32)

Elect-predicated TileTensor overload of async_copy_3d.

See the LayoutTensor overload of async_copy_3d_elect for semantics.

Parameters:

  • cta_group (Int): If set to 2, the TMA emits cta_group::2 PTX. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy_4d

async_copy_4d[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int])

Schedules an asynchronous copy from global memory to shared memory at specified 4D coordinates.

This method initiates a hardware-accelerated asynchronous transfer of data from global memory to the specified destination in shared memory for 4D tensors. The transfer is tracked by the provided memory barrier.

Constraints:

  • The destination tensor must be 128-byte aligned in shared memory.
  • The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements.

Parameters:

  • cta_group (Int): Int If the TMA is issued with cta_group == 2, only the leader CTA needs to be notified upon completion.
  • eviction_policy (CacheEviction): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT_NORMAL.

Args:

async_copy_4d[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int])

Schedules an asynchronous copy from global memory to shared memory at specified 4D coordinates.

TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment).

Parameters:

  • cta_group (Int): If the TMA is issued with cta_group == 2, only the leader CTA needs to be notified upon completion.
  • eviction_policy (CacheEviction): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT_NORMAL.

Args:

async_copy_4d_elect

async_copy_4d_elect[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int], elect: Int32)

Elect-predicated overload of async_copy_4d.

See async_copy_elect for semantics — each unrolled TMA issue is predicated in-PTX on elect.

Parameters:

  • cta_group (Int): If set to 2, the TMA emits cta_group::2 PTX. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy_4d_elect[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int], elect: Int32)

Elect-predicated TileTensor overload of async_copy_4d.

See the LayoutTensor overload of async_copy_4d_elect for semantics.

Parameters:

  • cta_group (Int): If set to 2, the TMA emits cta_group::2 PTX. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy_5d

async_copy_5d[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int, Int])

Schedules an asynchronous copy from global memory to shared memory at specified 5D coordinates.

This method initiates a hardware-accelerated asynchronous transfer of data from global memory to the specified destination in shared memory for 5D tensors. The transfer is tracked by the provided memory barrier.

Constraints:

  • The destination tensor must be 128-byte aligned in shared memory.
  • The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements.

Parameters:

  • cta_group (Int): Int If the TMA is issued with cta_group == 2, only the leader CTA needs to be notified upon completion.
  • eviction_policy (CacheEviction): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT_NORMAL.

Args:

async_copy_5d[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int, Int])

Schedules an asynchronous copy from global memory to shared memory at specified 5D coordinates.

TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment).

Parameters:

  • cta_group (Int): If the TMA is issued with cta_group == 2, only the leader CTA needs to be notified upon completion.
  • eviction_policy (CacheEviction): Optional cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT_NORMAL.

Args:

async_copy_5d_elect

async_copy_5d_elect[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int, Int], elect: Int32)

Elect-predicated overload of async_copy_5d.

See async_copy_elect for semantics — each unrolled TMA issue is predicated in-PTX on elect.

Parameters:

  • cta_group (Int): If set to 2, the TMA emits cta_group::2 PTX. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy_5d_elect[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int, Int], elect: Int32)

Elect-predicated TileTensor overload of async_copy_5d.

See the LayoutTensor overload of async_copy_5d_elect for semantics.

Parameters:

  • cta_group (Int): If set to 2, the TMA emits cta_group::2 PTX. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.

Args:

async_copy_gather4

async_copy_gather4[cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, dst: LayoutTensor[address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, col_idx: Int32, row0: Int32, row1: Int32, row2: Int32, row3: Int32)

Schedules an asynchronous gather4 copy of 4 non-contiguous rows from global memory to shared memory.

This method uses the TMA gather4 hardware instruction (SM100/Blackwell) to load 4 rows at arbitrary row indices from a 2D tensor in global memory, placing them contiguously in shared memory. The TMA descriptor must be configured with box dim1=1 (one row per tile).

Constraints:

  • Requires rank == 2 (gather4 is 2D only).
  • Requires desc_shape[0] == 1 (gather4 hardware requirement: one row per tile).
  • The destination tensor must be 128-byte aligned in shared memory.
  • Requires SM100 (Blackwell) or newer GPU architecture.

Parameters:

  • cta_group (Int): If the TMA is issued with cta_group == 2, only the leader CTA needs to be notified upon completion. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy that controls how the data is handled in the cache hierarchy. Defaults to EVICT_NORMAL.

Args:

gather4_tile_bytes

gather4_tile_bytes[tile_width: Int](self) -> Int32

Returns total expected bytes for a full gather4 tile load.

Computes tile_height * tile_width * sizeof(dtype) which is the number of bytes that async_copy_gather4_tile will transfer into shared memory. Pass this value to SharedMemBarrier.expect_bytes before issuing the tile load.

Parameters:

  • tile_width (Int): Total number of elements per row in global memory.

Returns:

Int32: The total expected transfer size in bytes as Int32.

async_copy_gather4_tile

async_copy_gather4_tile[tile_width: Int, cta_group: Int = 1, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL, d_indices_addr_space: AddressSpace = AddressSpace.GENERIC](self, smem_base: UnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, d_indices: UnsafePointer[Int32, address_space=d_indices_addr_space], start_idx: Int = 0)

Loads a full tile of tile_height rows via gather4 in 4-row chunks.

Internally loops over column groups and 4-row chunks, issuing one async_copy_gather4 call per chunk per column group. The SMEM destination layout matches the bulk TMA async_copy ordering: column groups are stored contiguously (each group holds tile_height rows of box_width elements), and within each group 4-row chunks are contiguous.

The caller must call mem_barrier.expect_bytes(self.gather4_tile_bytes[tile_width]()) before invoking this method.

Parameters:

  • tile_width (Int): Total number of elements per row in global memory.
  • cta_group (Int): CTA group configuration. Defaults to 1.
  • eviction_policy (CacheEviction): Cache eviction policy. Defaults to EVICT_NORMAL.
  • d_indices_addr_space (AddressSpace): Address space of the d_indices pointer. Defaults to GENERIC, but callers may pass SHARED pointers directly.

Args:

async_store

async_store[coord_rank: Int, //, cta_group: Int = 1](self, dst: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], coords: StaticTuple[UInt32, coord_rank])

Schedules an asynchronous store from shared memory to global memory for N-dimensional tensors.

This is a generic dispatcher that selects the appropriate rank-specific async store method based on the tensor rank. It provides a unified interface for initiating TMA store operations across 2D, 3D, 4D, and 5D tensors using StaticTuple coordinates.

Constraints:

  • The coord_rank must be 2, 3, 4, or 5.
  • The source tensor must be 128-byte aligned in shared memory.

Parameters:

  • coord_rank (Int): The dimensionality of the tensor (must be 2, 3, 4, or 5).
  • cta_group (Int): CTA group configuration for the store operation. Defaults to 1.

Args:

async_store[coord_rank: Int, //, cta_group: Int = 1](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], coords: StaticTuple[UInt32, coord_rank])

Schedules an asynchronous store from shared memory to global memory.

TileTensor overload of the generic rank-dispatched async_store. Dispatches to the rank-specific TileTensor async_store methods.

Parameters:

  • coord_rank (Int): The dimensionality of the tensor (must be 2 or 3).
  • cta_group (Int): CTA group configuration. Defaults to 1.

Args:

async_store(self, src: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=src.element_layout, layout_int_type=src.layout_int_type, linear_idx_type=src.linear_idx_type, masked=src.masked, alignment=src.alignment], coords: Tuple[Int, Int])

Schedules an asynchronous store from shared memory to global memory.

This method initiates a hardware-accelerated asynchronous transfer of data from shared memory to global memory at the specified coordinates.

Constraints:

The source tensor must be 128-byte aligned in shared memory.

Args:

async_store(self, src: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=src.linear_idx_type, element_size=src.element_size], coords: Tuple[Int, Int])

Schedules an asynchronous store from shared memory to global memory.

TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment).

Args:

async_multicast_load

async_multicast_load[cta_group: Int = 1](self, dst: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int], multicast_mask: UInt16)

Schedules an asynchronous multicast load from global memory to multiple shared memory locations.

This method initiates a hardware-accelerated asynchronous transfer of data from global memory to multiple destination locations in shared memory across different CTAs (Cooperative Thread Arrays) as specified by the multicast mask.

Constraints:

The destination tensor must be 128-byte aligned in shared memory.

Parameters:

  • cta_group (Int): Int If the TMA is issued with cta_group == 2, only the leader CTA needs to be notified upon completion.

Args:

async_multicast_load[cta_group: Int = 1](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int], multicast_mask: UInt16)

Schedules an asynchronous 2D multicast load from global to shared memory.

TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment).

Parameters:

  • cta_group (Int): If issued with cta_group == 2, only the leader CTA needs to be notified upon completion.

Args:

async_multicast_load_3d

async_multicast_load_3d[cta_group: Int = 1](self, dst: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=dst.alignment], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int], multicast_mask: UInt16)

Schedules an asynchronous 3D multicast load from global memory to multiple shared memory locations.

This method initiates a hardware-accelerated asynchronous transfer of data from global memory to multiple destination locations in shared memory across different CTAs (Cooperative Thread Arrays) as specified by the multicast mask.

Constraints:

The destination tensor must be 128-byte aligned in shared memory.

Parameters:

  • cta_group (Int): Int If the TMA is issued with cta_group == 2, only the leader CTA needs to be notified upon completion.

Args:

async_multicast_load_3d[cta_group: Int = 1](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int], multicast_mask: UInt16)

Schedules an asynchronous 3D multicast load from global to shared memory.

TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment).

Parameters:

  • cta_group (Int): If issued with cta_group == 2, only the leader CTA needs to be notified upon completion.

Args:

async_multicast_load_4d

async_multicast_load_4d[cta_group: Int = 1](self, dst: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, coords: Tuple[Int, Int, Int, Int], multicast_mask: UInt16)

Schedules an asynchronous 4D multicast load from global to shared memory.

TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment).

Parameters:

  • cta_group (Int): If issued with cta_group == 2, only the leader CTA needs to be notified upon completion.

Args:

async_multicast_load_partitioned

async_multicast_load_partitioned[tma_rows: Int, tma_load_size: Int](self, dst: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=dst.element_layout, layout_int_type=dst.layout_int_type, linear_idx_type=dst.linear_idx_type, masked=dst.masked, alignment=128], ref[AddressSpace._value] mem_barrier: SharedMemBarrier, cta_rank: Int, coords: Tuple[Int, Int], multicast_mask: UInt16)

Performs a partitioned multicast load where each rank loads a distinct slice of data.

This method is designed for clustered execution where different ranks (CTAs) load different, contiguous slices of the source tensor. Each rank's slice is offset by cta_rank * tma_rows in the second dimension and stored at offset cta_rank * tma_load_size in shared memory.

Note: This is typically used in matrix multiplication kernels where the input matrices are partitioned across multiple CTAs for parallel processing.

Parameters:

  • tma_rows (Int): The number of rows each rank is responsible for loading.
  • tma_load_size (Int): The size in elements of each rank's slice in shared memory.

Args:

async_store_3d

async_store_3d(self, src: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=src.element_layout, layout_int_type=src.layout_int_type, linear_idx_type=src.linear_idx_type, masked=src.masked, alignment=src.alignment], coords: Tuple[Int, Int, Int])

Schedules an asynchronous store from shared memory to global memory at specified 3D coordinates.

This method initiates a hardware-accelerated asynchronous transfer of data from shared memory to the specified destination in global memory for 3D tensors.

Constraints:

  • The source tensor must be 128-byte aligned in shared memory.
  • The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements.

Args:

async_store_3d(self, src: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=src.linear_idx_type, element_size=src.element_size], coords: Tuple[Int, Int, Int])

Schedules an asynchronous store from shared memory to global memory at 3D coordinates.

TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment).

Args:

async_store_4d

async_store_4d(self, src: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=src.element_layout, layout_int_type=src.layout_int_type, linear_idx_type=src.linear_idx_type, masked=src.masked, alignment=src.alignment], coords: Tuple[Int, Int, Int, Int])

Schedules an asynchronous store from shared memory to global memory at specified 4D coordinates.

This method initiates a hardware-accelerated asynchronous transfer of data from shared memory to the specified destination in global memory for 4D tensors.

Constraints:

  • The source tensor must be 128-byte aligned in shared memory.
  • The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements.

Args:

async_store_4d(self, src: TileTensor[dtype, address_space=AddressSpace.SHARED, linear_idx_type=src.linear_idx_type, element_size=src.element_size], coords: Tuple[Int, Int, Int, Int])

Schedules an asynchronous store from shared memory to global memory at 4D coordinates.

TileTensor overload - accepts TileTensor instead of LayoutTensor. Assumes 128B alignment (TileTensor tiles are allocated with proper alignment).

Args:

async_store_5d

async_store_5d(self, src: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=src.element_layout, layout_int_type=src.layout_int_type, linear_idx_type=src.linear_idx_type, masked=src.masked, alignment=src.alignment], coords: Tuple[Int, Int, Int, Int, Int])

Schedules an asynchronous store from shared memory to global memory at specified 5D coordinates.

This method initiates a hardware-accelerated asynchronous transfer of data from shared memory to the specified destination in global memory for 5D tensors.

Constraints:

  • The source tensor must be 128-byte aligned in shared memory.
  • The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements.

Args:

async_reduce

async_reduce[reduction_kind: ReduceOp](self, src: LayoutTensor[dtype, address_space=AddressSpace.SHARED, element_layout=src.element_layout, layout_int_type=src.layout_int_type, linear_idx_type=src.linear_idx_type, masked=src.masked, alignment=src.alignment], coords: Tuple[Int, Int])

Schedules an asynchronous reduction operation from shared memory to global memory.

This method initiates a hardware-accelerated asynchronous reduction operation that combines data from shared memory with data in global memory using the specified reduction operation. The reduction is performed element-wise at the specified coordinates in the global tensor.

Constraints:

The source tensor must be 128-byte aligned in shared memory.

Parameters:

  • reduction_kind (ReduceOp): The type of reduction operation to perform (e.g., ADD, MIN, MAX). This determines how values are combined during the reduction.

Args:

commit_group

commit_group(self)

Commits all prior initiated but uncommitted TMA instructions into a group.

This function behaves the same as cp_async_bulk_commit_group, which creates a synchronization point for bulk TMA transfer.

wait_group

wait_group[n: Int = 0](self)

Wait for the completion of asynchronous copy until a specified number of groups are waiting.

This function behaves the same as cp_async_bulk_wait_group, which causes the executing thread to wait until a specified number of the most recent TMA copy are pending.

Parameters:

  • n (Int): The number of pending groups left.

smem_tensormap_init

smem_tensormap_init(self, smem_tma_descriptor_ptr: UnsafePointer[TMADescriptor, address_space=AddressSpace.SHARED])

Initializes a TMA descriptor in shared memory from this tensor tile's descriptor.

This method copies the TMA descriptor from global memory to shared memory, allowing for faster access during kernel execution. The descriptor is copied in 16-byte chunks using asynchronous copy operations for efficiency.

Note:

  • Only one thread should call this method to avoid race conditions
  • The descriptor is copied in 8 chunks of 16 bytes each (total 128 bytes)

Args:

replace_tensormap_global_address_in_gmem

replace_tensormap_global_address_in_gmem[_dtype: DType](self, src_ptr: UnsafePointer[Scalar[_dtype]])

Replaces the global memory address in the TMA descriptor stored in global memory.

This method allows dynamically changing the source tensor for TMA operations without recreating the entire descriptor, which is useful for reusing descriptors with different data sources. The operation modifies the descriptor in global memory directly.

Note: A memory fence may be required after this operation to ensure visibility of the changes to other threads.

Parameters:

  • _dtype (DType): The data type of the new source tensor.

Args:

  • src_ptr (UnsafePointer[Scalar[_dtype]]): The new source tensor whose address will replace the current one in the descriptor. Must have compatible layout with the original tensor.

tensormap_fence_acquire

tensormap_fence_acquire(self)

Establishes a memory fence for TMA operations with acquire semantics.

This method ensures proper ordering of memory operations by creating a barrier that prevents subsequent TMA operations from executing before prior operations have completed. It is particularly important when reading from a descriptor that might have been modified by other threads or processes.

The acquire semantics ensure that all memory operations after this fence will observe any modifications made to the descriptor before the fence.

Notes:

  • The entire warp must call this function as the instruction is warp-aligned.
  • Typically used in pairs with tensormap_fence_release for proper synchronization.

tensormap_fence_release

tensormap_fence_release(self)

Establishes a memory fence for TMA operations with release semantics.

This method ensures proper ordering of memory operations by creating a barrier that ensures all prior memory operations are visible before subsequent operations can proceed. It is particularly important when modifying a TMA descriptor in global memory that might be read by other threads or processes.

The release semantics ensure that all memory operations before this fence will be visible to any thread that observes operations after the fence.

Notes:

  • Typically used after modifying a tensormap descriptor in global memory.
  • Often paired with tensormap_fence_acquire for proper synchronization.

replace_tensormap_global_address_in_shared_mem

replace_tensormap_global_address_in_shared_mem[_dtype: DType](self, smem_tma_descriptor_ptr: UnsafePointer[TMADescriptor, address_space=AddressSpace.SHARED], src_ptr: UnsafePointer[Scalar[_dtype]])

Replaces the global memory address in the TMA descriptor stored in shared memory.

This method allows dynamically changing the source tensor for TMA operations without recreating the entire descriptor, which is useful for reusing descriptors with different data sources. The operation modifies a descriptor that has been previously copied to shared memory.

Notes:

  • Only one thread should call this method to avoid race conditions.
  • A memory fence may be required after this operation to ensure visibility of the changes to other threads.
  • Typically used with descriptors previously initialized with smem_tensormap_init.

Parameters:

  • _dtype (DType): The data type of the new source tensor.

Args:

tensormap_cp_fence_release

tensormap_cp_fence_release(self, smem_tma_descriptor_ptr: UnsafePointer[TMADescriptor, address_space=AddressSpace.SHARED])

Establishes a memory fence for TMA operations with release semantics for shared memory descriptors.

This method ensures proper ordering of memory operations by creating a barrier that ensures all prior memory operations are visible before subsequent operations can proceed. It is specifically designed for synchronizing between global memory and shared memory TMA descriptors.

The release semantics ensure that all memory operations before this fence will be visible to any thread that observes operations after the fence.

Notes:

  • The entire warp must call this function as the instruction is warp-aligned
  • Typically used after modifying a tensormap descriptor in shared memory
  • More specialized than the general tensormap_fence_release for cross-memory space synchronization

Args:

replace_tensormap_global_dim_strides_in_shared_mem

replace_tensormap_global_dim_strides_in_shared_mem[_dtype: DType, only_update_dim_0: Bool, /, *, tensor_rank: Int](self, smem_tma_descriptor_ptr: UnsafePointer[TMADescriptor, address_space=AddressSpace.SHARED], gmem_dims: IndexList[tensor_rank], gmem_strides: IndexList[tensor_rank])

Replaces dimensions and strides in a TMA descriptor stored in shared memory. Note: This function is only supported for CUDA versions >= 12.5.

This function allows dynamically modifying the dimensions and strides of a TMA descriptor that has been previously initialized in shared memory. If only the first dimension (dim 0) is updated, then updating strides can be skipped.

Notes:

  • Only one thread should call this method to avoid race conditions.
  • A memory fence may be required after this operation to ensure visibility of the changes to other threads.

Parameters:

  • _dtype (DType): The data type of the new source tensor.
  • only_update_dim_0 (Bool): If true, only the first dimension (dim 0) is updated with updating strides.
  • tensor_rank (Int): The rank of the tensor.

Args:

replace_tensormap_global_dim_strides_in_shared_mem[_dtype: DType, tensor_rank: Int, dim_idx: Int](self, smem_tma_descriptor_ptr: UnsafePointer[TMADescriptor, address_space=AddressSpace.SHARED], dim_value: UInt32, dim_stride: Optional[UInt64] = None)

Replaces dimensions and strides in a TMA descriptor stored in shared memory. Note: This function is only supported for CUDA versions >= 12.5. This function allows dynamically modifying the dimensions and strides of a TMA descriptor that has been previously initialized in shared memory. If only the first dimension is updated, then updating strides can be skipped.

Notes:

  • Only one thread should call this method to avoid race conditions.
  • A memory fence may be required after this operation to ensure visibility of the changes to other threads.

Parameters:

  • _dtype (DType): The data type of the source tensor in GMEM.
  • tensor_rank (Int): The rank of the source tensor in GMEM.
  • dim_idx (Int): The index of the dimension to be updated in the TMA descriptor with the provided dimension and stride values at runtime.

Args: