carrot/tinygrad_repo/extra/gemm/max_kernels/nv.fp16_fp16_fp16.max.cu
2025-04-19 08:05:49 +09:00

483 lines
30 KiB
Plaintext

#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
#include <cuda_fp16.h>
#include <cuda_pipeline.h>
#define SMEM_N_WIDTH 136
struct __align__(8) half4 { half x, y, z, w; };
__device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; };
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
__device__ void __ldmatrix_a_elems(half8 *regs, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr = reinterpret_cast<uint32_t*>(regs);
addr[0] = reg0;
addr[1] = reg1;
addr[2] = reg2;
addr[3] = reg3;
}
__device__ void __ldmatrix_b_elems(half4 *regs_lo, half4 *regs_hi, half *smem) {
uint32_t reg0, reg1, reg2, reg3;
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];"
: "=r"(reg0), "=r"(reg1), "=r"(reg2), "=r"(reg3)
: "l"(__cvta_generic_to_shared(smem))
);
uint32_t *addr_lo = reinterpret_cast<uint32_t*>(regs_lo);
uint32_t *addr_hi = reinterpret_cast<uint32_t*>(regs_hi);
addr_lo[0] = reg0;
addr_lo[1] = reg1;
addr_hi[0] = reg2;
addr_hi[1] = reg3;
}
__device__ half4 __WMMA_8_16_16_half_half(half8 a, half4 b, half4 c) {
int *a_pk = (int *) (&a), *b_pk = (int *) (&b), *c_pk = (int *) (&c);
asm( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 { %0, %1 }, { %2, %3, %4, %5 }, { %6, %7 }, { %0, %1 };"
: "+r"(c_pk[0]), "+r"(c_pk[1]): "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;
}
extern "C" __global__ void __launch_bounds__(256) wmma_example(half* data0, const half* data1, const half* data2, int N, int K) {
extern __shared__ char smem[];
half *smem_a_0 = (half *)(smem);
half *smem_a_1 = (half *)(smem + 16384);
half *smem_a_2 = (half *)(smem + 32768);
half *smem_b_0 = (half *)(smem + 49152);
half *smem_b_1 = (half *)(smem + 57344);
half *smem_b_2 = (half *)(smem + 65536);
int grid_m = blockIdx.x; /* M//256 */
int grid_n = blockIdx.y; /* N//128 */
int wg_threads = threadIdx.x; // 32
int wg_m = threadIdx.y; // 4
int wg_n = threadIdx.z; // 2
int threads = threadIdx.x + (threadIdx.y * 32) + (threadIdx.z * 128); /* 256 */
int num_k_blocks = K / 32;
// ldmatrix indices - 4x loads of 8x8 matrices by 32 threads
// threads 0-7 are row starts for A, 8-15 for B, 16-23 for C, 24-31 for D
// [ A | C ]
// [ - + - ]
// [ B | D ]
// swizzled A - SMEM_A is 128 rows x 64 cols
size_t global_a_off = ((grid_m * 256) * K) + ((threads % 4) * 8) + (((threads / 4) % 2) * 8 * 16 * K) + ((threads / 8) * K);
size_t store_smem_a_off = ((threads / 8) * 64) + (((threads * 8) ^ threads) & 56); // 32 rows / 64 cols per copy
size_t load_smem_a_row = ((wg_m * 16) + (threads % 16)) * 64;
size_t load_smem_a_phase = (threads / 16) % 2;
size_t load_smem_a_0_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_a_1_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_a_2_k_0 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_a_3_k_0 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_a_0_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
size_t load_smem_a_1_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 2) ^ (threads % 8)) * 8);
size_t load_smem_a_2_k_1 = load_smem_a_row + ( 0 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
size_t load_smem_a_3_k_1 = load_smem_a_row + (64 * 64) + (((load_smem_a_phase + 6) ^ (threads % 8)) * 8);
// swizzled B - SMEM_B is 32 rows x 128 cols
size_t global_b_off = (grid_n * 128) + ((threads % 16) * 8) + ((threads / 16) * N);
size_t store_smem_b_off = ((threads / 16) * 128) + ((((threads / 16) % 8) * 8) ^ ((threads % 16) * 8)); // 16 rows / 128 cols per copy
size_t load_smem_b_row = (threads % 16) * 128;
size_t load_smem_b_phase = (wg_n * 2) + (wg_threads / 16);
size_t load_smem_b_0_k_0 = load_smem_b_row + (((load_smem_b_phase + 0) ^ (threads % 8)) * 8);
size_t load_smem_b_1_k_0 = load_smem_b_row + (((load_smem_b_phase + 4) ^ (threads % 8)) * 8);
size_t load_smem_b_2_k_0 = load_smem_b_row + (((load_smem_b_phase + 8) ^ (threads % 8)) * 8);
size_t load_smem_b_3_k_0 = load_smem_b_row + (((load_smem_b_phase + 12) ^ (threads % 8)) * 8);
size_t load_smem_b_0_k_1 = load_smem_b_0_k_0 + (16 * 128);
size_t load_smem_b_1_k_1 = load_smem_b_1_k_0 + (16 * 128);
size_t load_smem_b_2_k_1 = load_smem_b_2_k_0 + (16 * 128);
size_t load_smem_b_3_k_1 = load_smem_b_3_k_0 + (16 * 128);
// create accs (M=4, N=8)
half4 acc_frag_0_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_0_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_1_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_2_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_0 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_1 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_2 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_3 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_4 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_5 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_6 = make_half4(0.0f,0.0f,0.0f,0.0f);
half4 acc_frag_3_7 = make_half4(0.0f,0.0f,0.0f,0.0f);
// create registers for block A elements
half8 a_frag_0_k_0;
half8 a_frag_1_k_0;
half8 a_frag_2_k_0;
half8 a_frag_3_k_0;
half8 a_frag_0_k_1;
half8 a_frag_1_k_1;
half8 a_frag_2_k_1;
half8 a_frag_3_k_1;
// create register for block B elements
half4 b_frag_0_k_0;
half4 b_frag_1_k_0;
half4 b_frag_2_k_0;
half4 b_frag_3_k_0;
half4 b_frag_4_k_0;
half4 b_frag_5_k_0;
half4 b_frag_6_k_0;
half4 b_frag_7_k_0;
half4 b_frag_0_k_1;
half4 b_frag_1_k_1;
half4 b_frag_2_k_1;
half4 b_frag_3_k_1;
half4 b_frag_4_k_1;
half4 b_frag_5_k_1;
half4 b_frag_6_k_1;
half4 b_frag_7_k_1;
__syncthreads();
// load first tile
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_0[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_0[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
__pipeline_commit();
global_a_off += 32;
global_b_off += 32 * N;
// load second tile
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_1[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_1[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
__pipeline_commit();
global_a_off += 32;
global_b_off += 32 * N;
// wait on first pre-fetch load
__pipeline_wait_prior(1);
__syncthreads();
// load K=0 elements for the first tile
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_0[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_0[load_smem_a_1_k_0]);
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_0[load_smem_a_2_k_0]);
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_0[load_smem_a_3_k_0]);
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_0[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_0[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_0[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_0[load_smem_b_3_k_0]);
for (int block_k = 0; block_k < num_k_blocks; block_k++) {
int phase_k = block_k % 3;
half *smem_a_curr = (phase_k == 0) ? smem_a_0 : ((phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_curr = (phase_k == 0) ? smem_b_0 : ((phase_k == 1) ? smem_b_1 : smem_b_2);
int next_phase_k = (block_k+1) % 3;
half *smem_a_next = (next_phase_k == 0) ? smem_a_0 : ((next_phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_next = (next_phase_k == 0) ? smem_b_0 : ((next_phase_k == 1) ? smem_b_1 : smem_b_2);
int store_phase_k = (block_k+2) % 3;
half *smem_a_store = (store_phase_k == 0) ? smem_a_0 : ((store_phase_k == 1) ? smem_a_1 : smem_a_2);
half *smem_b_store = (store_phase_k == 0) ? smem_b_0 : ((store_phase_k == 1) ? smem_b_1 : smem_b_2);
// load K=1 elements for the current tile
__ldmatrix_a_elems(&a_frag_0_k_1, &smem_a_curr[load_smem_a_0_k_1]);
__ldmatrix_a_elems(&a_frag_1_k_1, &smem_a_curr[load_smem_a_1_k_1]);
__ldmatrix_a_elems(&a_frag_2_k_1, &smem_a_curr[load_smem_a_2_k_1]);
__ldmatrix_a_elems(&a_frag_3_k_1, &smem_a_curr[load_smem_a_3_k_1]);
__ldmatrix_b_elems(&b_frag_0_k_1, &b_frag_1_k_1, &smem_b_curr[load_smem_b_0_k_1]);
__ldmatrix_b_elems(&b_frag_2_k_1, &b_frag_3_k_1, &smem_b_curr[load_smem_b_1_k_1]);
__ldmatrix_b_elems(&b_frag_4_k_1, &b_frag_5_k_1, &smem_b_curr[load_smem_b_2_k_1]);
__ldmatrix_b_elems(&b_frag_6_k_1, &b_frag_7_k_1, &smem_b_curr[load_smem_b_3_k_1]);
// MMA K=0, (M=4 x N=8)
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_0_k_0, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_1_k_0, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_2_k_0, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_3_k_0, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_4_k_0, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_5_k_0, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_6_k_0, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_0, b_frag_7_k_0, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_0_k_0, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_1_k_0, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_2_k_0, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_3_k_0, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_4_k_0, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_5_k_0, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_6_k_0, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_0, b_frag_7_k_0, acc_frag_1_7);
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_0_k_0, acc_frag_2_0);
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_1_k_0, acc_frag_2_1);
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_2_k_0, acc_frag_2_2);
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_3_k_0, acc_frag_2_3);
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_4_k_0, acc_frag_2_4);
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_5_k_0, acc_frag_2_5);
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_6_k_0, acc_frag_2_6);
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_0, b_frag_7_k_0, acc_frag_2_7);
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_0_k_0, acc_frag_3_0);
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_1_k_0, acc_frag_3_1);
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_2_k_0, acc_frag_3_2);
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_3_k_0, acc_frag_3_3);
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_4_k_0, acc_frag_3_4);
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_5_k_0, acc_frag_3_5);
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_6_k_0, acc_frag_3_6);
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_0, b_frag_7_k_0, acc_frag_3_7);
// load next tile if needed
if (block_k < (num_k_blocks-2)) {
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 0)], &data1[global_a_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 32*64)], &data1[global_a_off + ( 32*K)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 64*64)], &data1[global_a_off + ( 64*K)], 16);
__pipeline_memcpy_async(&smem_a_store[store_smem_a_off + ( 96*64)], &data1[global_a_off + ( 96*K)], 16);
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + ( 0)], &data2[global_b_off + ( 0)], 16);
__pipeline_memcpy_async(&smem_b_store[store_smem_b_off + (16*128)], &data2[global_b_off + ( 16*N)], 16);
global_a_off += 32;
global_b_off += 32 * N;
}
__pipeline_commit();
// wait next tile
__pipeline_wait_prior(1);
__syncthreads();
// load K=0 elements for the next tile
__ldmatrix_a_elems(&a_frag_0_k_0, &smem_a_next[load_smem_a_0_k_0]);
__ldmatrix_a_elems(&a_frag_1_k_0, &smem_a_next[load_smem_a_1_k_0]);
__ldmatrix_a_elems(&a_frag_2_k_0, &smem_a_next[load_smem_a_2_k_0]);
__ldmatrix_a_elems(&a_frag_3_k_0, &smem_a_next[load_smem_a_3_k_0]);
__ldmatrix_b_elems(&b_frag_0_k_0, &b_frag_1_k_0, &smem_b_next[load_smem_b_0_k_0]);
__ldmatrix_b_elems(&b_frag_2_k_0, &b_frag_3_k_0, &smem_b_next[load_smem_b_1_k_0]);
__ldmatrix_b_elems(&b_frag_4_k_0, &b_frag_5_k_0, &smem_b_next[load_smem_b_2_k_0]);
__ldmatrix_b_elems(&b_frag_6_k_0, &b_frag_7_k_0, &smem_b_next[load_smem_b_3_k_0]);
// MMA K=1, (M=4 x N=8)
acc_frag_0_0 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_0_k_1, acc_frag_0_0);
acc_frag_0_1 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_1_k_1, acc_frag_0_1);
acc_frag_0_2 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_2_k_1, acc_frag_0_2);
acc_frag_0_3 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_3_k_1, acc_frag_0_3);
acc_frag_0_4 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_4_k_1, acc_frag_0_4);
acc_frag_0_5 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_5_k_1, acc_frag_0_5);
acc_frag_0_6 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_6_k_1, acc_frag_0_6);
acc_frag_0_7 = __WMMA_8_16_16_half_half(a_frag_0_k_1, b_frag_7_k_1, acc_frag_0_7);
acc_frag_1_0 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_0_k_1, acc_frag_1_0);
acc_frag_1_1 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_1_k_1, acc_frag_1_1);
acc_frag_1_2 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_2_k_1, acc_frag_1_2);
acc_frag_1_3 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_3_k_1, acc_frag_1_3);
acc_frag_1_4 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_4_k_1, acc_frag_1_4);
acc_frag_1_5 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_5_k_1, acc_frag_1_5);
acc_frag_1_6 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_6_k_1, acc_frag_1_6);
acc_frag_1_7 = __WMMA_8_16_16_half_half(a_frag_1_k_1, b_frag_7_k_1, acc_frag_1_7);
acc_frag_2_0 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_0_k_1, acc_frag_2_0);
acc_frag_2_1 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_1_k_1, acc_frag_2_1);
acc_frag_2_2 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_2_k_1, acc_frag_2_2);
acc_frag_2_3 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_3_k_1, acc_frag_2_3);
acc_frag_2_4 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_4_k_1, acc_frag_2_4);
acc_frag_2_5 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_5_k_1, acc_frag_2_5);
acc_frag_2_6 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_6_k_1, acc_frag_2_6);
acc_frag_2_7 = __WMMA_8_16_16_half_half(a_frag_2_k_1, b_frag_7_k_1, acc_frag_2_7);
acc_frag_3_0 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_0_k_1, acc_frag_3_0);
acc_frag_3_1 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_1_k_1, acc_frag_3_1);
acc_frag_3_2 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_2_k_1, acc_frag_3_2);
acc_frag_3_3 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_3_k_1, acc_frag_3_3);
acc_frag_3_4 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_4_k_1, acc_frag_3_4);
acc_frag_3_5 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_5_k_1, acc_frag_3_5);
acc_frag_3_6 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_6_k_1, acc_frag_3_6);
acc_frag_3_7 = __WMMA_8_16_16_half_half(a_frag_3_k_1, b_frag_7_k_1, acc_frag_3_7);
}
// write accumulators to output
__pipeline_wait_prior(0);
__syncthreads();
// faster epilogue: write each 8x8 TC accs to SMEM first
// - SMEM_N_WIDTH 8 larger than 128 required to deconflict bank access
// - around 14 micros
// - check bank conflict with in sudo with: "PYTHONPATH=. CUDA=1 GEMM_VARIATION="max" DTYPE_IN=half DTYPE_OUT=half DTYPE_ACC=half CNT=8 INPUT=ONES /usr/local/cuda/bin/ncu --section MemoryWorkloadAnalysis --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum,l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum python3 ./extra/gemm/max_matmul.py"
// epilogue chunk with 256 threads / WG_M=4 / WG_N=2: split into 8 chunks (hi/lo for each in TC M)
// 1) write 32 rows of 128 cols (rows 0-7, 16-23, 32-39, 48-53 in acc_frag_0.lo, then acc_frag_0.hi, etc.)
// 2) read/write 16 rows of 128 elements in 8 elem (16B) chunks
half2 *smem32_d = (half2 *)(smem);
half8 *smem128_d = (half8 *)(smem);
half8 *out128_d = (half8 *)(data0);
size_t smem32_d_write_off = (wg_m * 8 * (SMEM_N_WIDTH / 2)) + (wg_n * (16 / 2));
size_t smem32_d_thread_off = ((wg_threads / 4) * (SMEM_N_WIDTH / 2)) + (wg_threads % 4);
size_t smem128_d_read_off = ((threads / 16) * (SMEM_N_WIDTH / 8)) + (threads % 16);
size_t out128_d_off = ((grid_m * 256) * (N / 8)) + (grid_n * (128 / 8)) +
((threads / 128) * 16 * (N / 8)) + (((threads / 16) % 8) * (N / 8)) + (threads % 16);
// write acc_frag_0_*
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.x, acc_frag_0_0.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.x, acc_frag_0_1.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.x, acc_frag_0_2.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.x, acc_frag_0_3.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.x, acc_frag_0_4.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.x, acc_frag_0_5.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.x, acc_frag_0_6.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.x, acc_frag_0_7.y);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_0_0.z, acc_frag_0_0.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_0_1.z, acc_frag_0_1.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_0_2.z, acc_frag_0_2.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_0_3.z, acc_frag_0_3.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_0_4.z, acc_frag_0_4.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_0_5.z, acc_frag_0_5.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_0_6.z, acc_frag_0_6.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_0_7.z, acc_frag_0_7.w);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write acc_frag_1_*
out128_d_off += (64 * (N / 8));
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.x, acc_frag_1_0.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.x, acc_frag_1_1.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.x, acc_frag_1_2.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.x, acc_frag_1_3.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.x, acc_frag_1_4.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.x, acc_frag_1_5.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.x, acc_frag_1_6.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.x, acc_frag_1_7.y);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_1_0.z, acc_frag_1_0.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_1_1.z, acc_frag_1_1.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_1_2.z, acc_frag_1_2.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_1_3.z, acc_frag_1_3.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_1_4.z, acc_frag_1_4.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_1_5.z, acc_frag_1_5.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_1_6.z, acc_frag_1_6.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_1_7.z, acc_frag_1_7.w);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write acc_frag_2_*
out128_d_off += (64 * (N / 8));
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.x, acc_frag_2_0.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.x, acc_frag_2_1.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.x, acc_frag_2_2.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.x, acc_frag_2_3.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.x, acc_frag_2_4.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.x, acc_frag_2_5.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.x, acc_frag_2_6.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.x, acc_frag_2_7.y);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_2_0.z, acc_frag_2_0.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_2_1.z, acc_frag_2_1.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_2_2.z, acc_frag_2_2.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_2_3.z, acc_frag_2_3.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_2_4.z, acc_frag_2_4.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_2_5.z, acc_frag_2_5.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_2_6.z, acc_frag_2_6.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_2_7.z, acc_frag_2_7.w);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write acc_frag_3_*
out128_d_off += (64 * (N / 8));
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.x, acc_frag_3_0.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.x, acc_frag_3_1.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.x, acc_frag_3_2.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.x, acc_frag_3_3.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.x, acc_frag_3_4.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.x, acc_frag_3_5.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.x, acc_frag_3_6.y);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.x, acc_frag_3_7.y);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 0 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (32 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
// write 32 rows of 128 N elements to SMEM
__syncthreads();
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 0*4)] = half2(acc_frag_3_0.z, acc_frag_3_0.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 1*4)] = half2(acc_frag_3_1.z, acc_frag_3_1.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 4*4)] = half2(acc_frag_3_2.z, acc_frag_3_2.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 5*4)] = half2(acc_frag_3_3.z, acc_frag_3_3.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 8*4)] = half2(acc_frag_3_4.z, acc_frag_3_4.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + ( 9*4)] = half2(acc_frag_3_5.z, acc_frag_3_5.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (12*4)] = half2(acc_frag_3_6.z, acc_frag_3_6.w);
smem32_d[smem32_d_write_off + smem32_d_thread_off + (13*4)] = half2(acc_frag_3_7.z, acc_frag_3_7.w);
// each thread reads and writes two 8 element chunks
__syncthreads();
out128_d[out128_d_off + ( 8 * (N / 8))] = smem128_d[smem128_d_read_off];
out128_d[out128_d_off + (40 * (N / 8))] = smem128_d[smem128_d_read_off + (16 * (SMEM_N_WIDTH / 8))];
__syncthreads();
}