carrot/tinygrad_repo/extra/gemm/max_kernels/nv.fp16_fp32_fp16.hcopt.cu

157 lines
7.9 KiB
Plaintext
Raw Normal View History

2025-04-18 20:38:55 +09:00
#define INFINITY (__int_as_float(0x7f800000))
#define NAN (__int_as_float(0x7fffffff))
#include <cuda_fp16.h>
struct __align__(8) half4 { half x, y, z, w; }; __device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; }
struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; }
__device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };"
: "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) );
return c;}
extern "C" __global__ void __launch_bounds__(128) wmma_example(half* data0, const half* data1, const half* data2) {
int gidx0 = blockIdx.x; /* 32 */
int gidx1 = blockIdx.y; /* 64 */
int lidx0 = threadIdx.x; /* 16 */
int lidx1 = threadIdx.y; /* 2 */
int lidx2 = threadIdx.z; /* 4 */
float4 cast0 = make_float4(0.0f,0.0f,0.0f,0.0f);
int alu0 = (gidx0*128);
int alu1 = (gidx1*262144);
int alu2 = (lidx1*32768);
int alu3 = (lidx2*32);
int alu4 = (lidx0/8);
int alu5 = (alu4*16384);
int alu6 = (lidx0%2);
int alu7 = (alu6*2);
int alu8 = ((lidx0/2)%2);
int alu9 = (alu8*4);
int alu10 = ((lidx0/4)%2);
int alu11 = (alu10*8192);
int alu12 = (alu1+alu0+alu7+alu9+alu11+alu5+alu2+alu3);
int alu13 = (alu1+alu7+alu9+alu11+alu5+alu2);
float4 acc0 = cast0;
float4 acc1 = cast0;
float4 acc2 = cast0;
float4 acc3 = cast0;
float4 acc4 = cast0;
float4 acc5 = cast0;
float4 acc6 = cast0;
float4 acc7 = cast0;
float4 acc8 = cast0;
float4 acc9 = cast0;
float4 acc10 = cast0;
float4 acc11 = cast0;
float4 acc12 = cast0;
float4 acc13 = cast0;
float4 acc14 = cast0;
float4 acc15 = cast0;
for (int ridx0 = 0; ridx0 < 256; ridx0++) {
int alu14 = (ridx0*16);
int alu15 = (alu13+alu14);
int alu16 = (alu14+alu13);
int alu17 = (alu0+(alu6*8192)+(alu8*16384)+alu10+(alu4*2)+(lidx1*4)+alu3+(ridx0*65536));
half val0 = data2[alu17+8];
half val1 = data2[alu17+16];
half val2 = data2[alu17+24];
half val3 = data2[alu17+4096];
half val4 = data2[alu17+4104];
half val5 = data2[alu17+4112];
half val6 = data2[alu17+4120];
half val7 = data2[alu17+32768];
half val8 = data2[alu17+32776];
half val9 = data2[alu17+32784];
half val10 = data2[alu17+32792];
half val11 = data2[alu17+36864];
half val12 = data2[alu17+36872];
half4 cast1 = make_half4(val0,val4,val8,val12);
half val13 = data2[alu17+36880];
half4 cast2 = make_half4(val1,val5,val9,val13);
half val14 = data2[alu17+36888];
half4 cast3 = make_half4(val2,val6,val10,val14);
half val15 = data2[alu17];
half4 cast4 = make_half4(val15,val3,val7,val11);
half2 val16 = *((half2*)(data1+alu15+4096));
half2 val17 = *((half2*)(data1+alu15+65536));
half2 val18 = *((half2*)(data1+alu15+69632));
half2 val19 = *((half2*)(data1+alu15+131072));
half2 val20 = *((half2*)(data1+alu15+135168));
half2 val21 = *((half2*)(data1+alu15+196608));
half2 val22 = *((half2*)(data1+alu15+200704));
half2 val23 = *((half2*)(data1+alu15));
half2 val24 = *((half2*)(data1+alu16+8));
half2 val25 = *((half2*)(data1+alu16+4104));
half8 cast5 = make_half8(val23.x,val23.y,val16.x,val16.y,val24.x,val24.y,val25.x,val25.y);
float4 wmma0 = __WMMA_8_16_16_half_float(cast5, cast1, acc1);
float4 wmma1 = __WMMA_8_16_16_half_float(cast5, cast2, acc2);
float4 wmma2 = __WMMA_8_16_16_half_float(cast5, cast3, acc3);
float4 wmma3 = __WMMA_8_16_16_half_float(cast5, cast4, acc0);
half2 val26 = *((half2*)(data1+alu16+65544));
half2 val27 = *((half2*)(data1+alu16+69640));
half8 cast6 = make_half8(val17.x,val17.y,val18.x,val18.y,val26.x,val26.y,val27.x,val27.y);
float4 wmma4 = __WMMA_8_16_16_half_float(cast6, cast1, acc5);
float4 wmma5 = __WMMA_8_16_16_half_float(cast6, cast2, acc6);
float4 wmma6 = __WMMA_8_16_16_half_float(cast6, cast3, acc7);
float4 wmma7 = __WMMA_8_16_16_half_float(cast6, cast4, acc4);
half2 val28 = *((half2*)(data1+alu16+131080));
half2 val29 = *((half2*)(data1+alu16+135176));
half8 cast7 = make_half8(val19.x,val19.y,val20.x,val20.y,val28.x,val28.y,val29.x,val29.y);
float4 wmma8 = __WMMA_8_16_16_half_float(cast7, cast1, acc9);
float4 wmma9 = __WMMA_8_16_16_half_float(cast7, cast2, acc10);
float4 wmma10 = __WMMA_8_16_16_half_float(cast7, cast3, acc11);
float4 wmma11 = __WMMA_8_16_16_half_float(cast7, cast4, acc8);
half2 val30 = *((half2*)(data1+alu16+196616));
half2 val31 = *((half2*)(data1+alu16+200712));
half8 cast8 = make_half8(val21.x,val21.y,val22.x,val22.y,val30.x,val30.y,val31.x,val31.y);
float4 wmma12 = __WMMA_8_16_16_half_float(cast8, cast1, acc13);
float4 wmma13 = __WMMA_8_16_16_half_float(cast8, cast2, acc14);
float4 wmma14 = __WMMA_8_16_16_half_float(cast8, cast3, acc15);
float4 wmma15 = __WMMA_8_16_16_half_float(cast8, cast4, acc12);
acc0 = wmma3;
acc1 = wmma0;
acc2 = wmma1;
acc3 = wmma2;
acc4 = wmma7;
acc5 = wmma4;
acc6 = wmma5;
acc7 = wmma6;
acc8 = wmma11;
acc9 = wmma8;
acc10 = wmma9;
acc11 = wmma10;
acc12 = wmma15;
acc13 = wmma12;
acc14 = wmma13;
acc15 = wmma14;
}
*((half2*)(data0+alu12+8)) = make_half2((half)(acc1.x),(half)(acc1.y));
*((half2*)(data0+alu12+16)) = make_half2((half)(acc2.x),(half)(acc2.y));
*((half2*)(data0+alu12+24)) = make_half2((half)(acc3.x),(half)(acc3.y));
*((half2*)(data0+alu12+4096)) = make_half2((half)(acc0.z),(half)(acc0.w));
*((half2*)(data0+alu12+4104)) = make_half2((half)(acc1.z),(half)(acc1.w));
*((half2*)(data0+alu12+4112)) = make_half2((half)(acc2.z),(half)(acc2.w));
*((half2*)(data0+alu12+4120)) = make_half2((half)(acc3.z),(half)(acc3.w));
*((half2*)(data0+alu12+65536)) = make_half2((half)(acc4.x),(half)(acc4.y));
*((half2*)(data0+alu12+65544)) = make_half2((half)(acc5.x),(half)(acc5.y));
*((half2*)(data0+alu12+65552)) = make_half2((half)(acc6.x),(half)(acc6.y));
*((half2*)(data0+alu12+65560)) = make_half2((half)(acc7.x),(half)(acc7.y));
*((half2*)(data0+alu12+69632)) = make_half2((half)(acc4.z),(half)(acc4.w));
*((half2*)(data0+alu12+69640)) = make_half2((half)(acc5.z),(half)(acc5.w));
*((half2*)(data0+alu12+69648)) = make_half2((half)(acc6.z),(half)(acc6.w));
*((half2*)(data0+alu12+69656)) = make_half2((half)(acc7.z),(half)(acc7.w));
*((half2*)(data0+alu12+131072)) = make_half2((half)(acc8.x),(half)(acc8.y));
*((half2*)(data0+alu12+131080)) = make_half2((half)(acc9.x),(half)(acc9.y));
*((half2*)(data0+alu12+131088)) = make_half2((half)(acc10.x),(half)(acc10.y));
*((half2*)(data0+alu12+131096)) = make_half2((half)(acc11.x),(half)(acc11.y));
*((half2*)(data0+alu12+135168)) = make_half2((half)(acc8.z),(half)(acc8.w));
*((half2*)(data0+alu12+135176)) = make_half2((half)(acc9.z),(half)(acc9.w));
*((half2*)(data0+alu12+135184)) = make_half2((half)(acc10.z),(half)(acc10.w));
*((half2*)(data0+alu12+135192)) = make_half2((half)(acc11.z),(half)(acc11.w));
*((half2*)(data0+alu12+196608)) = make_half2((half)(acc12.x),(half)(acc12.y));
*((half2*)(data0+alu12+196616)) = make_half2((half)(acc13.x),(half)(acc13.y));
*((half2*)(data0+alu12+196624)) = make_half2((half)(acc14.x),(half)(acc14.y));
*((half2*)(data0+alu12+196632)) = make_half2((half)(acc15.x),(half)(acc15.y));
*((half2*)(data0+alu12+200704)) = make_half2((half)(acc12.z),(half)(acc12.w));
*((half2*)(data0+alu12+200712)) = make_half2((half)(acc13.z),(half)(acc13.w));
*((half2*)(data0+alu12+200720)) = make_half2((half)(acc14.z),(half)(acc14.w));
*((half2*)(data0+alu12+200728)) = make_half2((half)(acc15.z),(half)(acc15.w));
*((half2*)(data0+alu12)) = make_half2((half)(acc0.x),(half)(acc0.y));
}