#define INFINITY (__int_as_float(0x7f800000)) #define NAN (__int_as_float(0x7fffffff)) #include struct __align__(8) half4 { half x, y, z, w; }; __device__ half4 make_half4(half x, half y, half z, half w) { half4 r={x, y, z, w}; return r; } struct __align__(16) half8 { half x, y, z, w, a, b, c, d; }; __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { half8 r={x, y, z, w, a, b, c, d}; return r; } __device__ float4 __WMMA_8_16_16_half_float(half8 a, half4 b, float4 c) { int *a_pk = (int *) (&a), *b_pk = (int *) (&b); asm( "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 { %0, %1, %2, %3 }, { %4, %5, %6, %7 }, { %8, %9 }, { %0, %1, %2, %3 };" : "+f"(c.x), "+f"(c.y), "+f"(c.z), "+f"(c.w) : "r"(a_pk[0]), "r"(a_pk[1]), "r"(a_pk[2]), "r"(a_pk[3]), "r"(b_pk[0]), "r"(b_pk[1]) ); return c;} extern "C" __global__ void __launch_bounds__(128) wmma_example(half* data0, const half* data1, const half* data2) { int gidx0 = blockIdx.x; /* 32 */ int gidx1 = blockIdx.y; /* 64 */ int lidx0 = threadIdx.x; /* 16 */ int lidx1 = threadIdx.y; /* 2 */ int lidx2 = threadIdx.z; /* 4 */ float4 cast0 = make_float4(0.0f,0.0f,0.0f,0.0f); int alu0 = (gidx0*128); int alu1 = (gidx1*262144); int alu2 = (lidx1*32768); int alu3 = (lidx2*32); int alu4 = (lidx0/8); int alu5 = (alu4*16384); int alu6 = (lidx0%2); int alu7 = (alu6*2); int alu8 = ((lidx0/2)%2); int alu9 = (alu8*4); int alu10 = ((lidx0/4)%2); int alu11 = (alu10*8192); int alu12 = (alu1+alu0+alu7+alu9+alu11+alu5+alu2+alu3); int alu13 = (alu1+alu7+alu9+alu11+alu5+alu2); float4 acc0 = cast0; float4 acc1 = cast0; float4 acc2 = cast0; float4 acc3 = cast0; float4 acc4 = cast0; float4 acc5 = cast0; float4 acc6 = cast0; float4 acc7 = cast0; float4 acc8 = cast0; float4 acc9 = cast0; float4 acc10 = cast0; float4 acc11 = cast0; float4 acc12 = cast0; float4 acc13 = cast0; float4 acc14 = cast0; float4 acc15 = cast0; for (int ridx0 = 0; ridx0 < 256; ridx0++) { int alu14 = (ridx0*16); int alu15 = (alu13+alu14); int alu16 = (alu14+alu13); int alu17 = (alu0+(alu6*8192)+(alu8*16384)+alu10+(alu4*2)+(lidx1*4)+alu3+(ridx0*65536)); half val0 = data2[alu17+8]; half val1 = data2[alu17+16]; half val2 = data2[alu17+24]; half val3 = data2[alu17+4096]; half val4 = data2[alu17+4104]; half val5 = data2[alu17+4112]; half val6 = data2[alu17+4120]; half val7 = data2[alu17+32768]; half val8 = data2[alu17+32776]; half val9 = data2[alu17+32784]; half val10 = data2[alu17+32792]; half val11 = data2[alu17+36864]; half val12 = data2[alu17+36872]; half4 cast1 = make_half4(val0,val4,val8,val12); half val13 = data2[alu17+36880]; half4 cast2 = make_half4(val1,val5,val9,val13); half val14 = data2[alu17+36888]; half4 cast3 = make_half4(val2,val6,val10,val14); half val15 = data2[alu17]; half4 cast4 = make_half4(val15,val3,val7,val11); half2 val16 = *((half2*)(data1+alu15+4096)); half2 val17 = *((half2*)(data1+alu15+65536)); half2 val18 = *((half2*)(data1+alu15+69632)); half2 val19 = *((half2*)(data1+alu15+131072)); half2 val20 = *((half2*)(data1+alu15+135168)); half2 val21 = *((half2*)(data1+alu15+196608)); half2 val22 = *((half2*)(data1+alu15+200704)); half2 val23 = *((half2*)(data1+alu15)); half2 val24 = *((half2*)(data1+alu16+8)); half2 val25 = *((half2*)(data1+alu16+4104)); half8 cast5 = make_half8(val23.x,val23.y,val16.x,val16.y,val24.x,val24.y,val25.x,val25.y); float4 wmma0 = __WMMA_8_16_16_half_float(cast5, cast1, acc1); float4 wmma1 = __WMMA_8_16_16_half_float(cast5, cast2, acc2); float4 wmma2 = __WMMA_8_16_16_half_float(cast5, cast3, acc3); float4 wmma3 = __WMMA_8_16_16_half_float(cast5, cast4, acc0); half2 val26 = *((half2*)(data1+alu16+65544)); half2 val27 = *((half2*)(data1+alu16+69640)); half8 cast6 = make_half8(val17.x,val17.y,val18.x,val18.y,val26.x,val26.y,val27.x,val27.y); float4 wmma4 = __WMMA_8_16_16_half_float(cast6, cast1, acc5); float4 wmma5 = __WMMA_8_16_16_half_float(cast6, cast2, acc6); float4 wmma6 = __WMMA_8_16_16_half_float(cast6, cast3, acc7); float4 wmma7 = __WMMA_8_16_16_half_float(cast6, cast4, acc4); half2 val28 = *((half2*)(data1+alu16+131080)); half2 val29 = *((half2*)(data1+alu16+135176)); half8 cast7 = make_half8(val19.x,val19.y,val20.x,val20.y,val28.x,val28.y,val29.x,val29.y); float4 wmma8 = __WMMA_8_16_16_half_float(cast7, cast1, acc9); float4 wmma9 = __WMMA_8_16_16_half_float(cast7, cast2, acc10); float4 wmma10 = __WMMA_8_16_16_half_float(cast7, cast3, acc11); float4 wmma11 = __WMMA_8_16_16_half_float(cast7, cast4, acc8); half2 val30 = *((half2*)(data1+alu16+196616)); half2 val31 = *((half2*)(data1+alu16+200712)); half8 cast8 = make_half8(val21.x,val21.y,val22.x,val22.y,val30.x,val30.y,val31.x,val31.y); float4 wmma12 = __WMMA_8_16_16_half_float(cast8, cast1, acc13); float4 wmma13 = __WMMA_8_16_16_half_float(cast8, cast2, acc14); float4 wmma14 = __WMMA_8_16_16_half_float(cast8, cast3, acc15); float4 wmma15 = __WMMA_8_16_16_half_float(cast8, cast4, acc12); acc0 = wmma3; acc1 = wmma0; acc2 = wmma1; acc3 = wmma2; acc4 = wmma7; acc5 = wmma4; acc6 = wmma5; acc7 = wmma6; acc8 = wmma11; acc9 = wmma8; acc10 = wmma9; acc11 = wmma10; acc12 = wmma15; acc13 = wmma12; acc14 = wmma13; acc15 = wmma14; } *((half2*)(data0+alu12+8)) = make_half2((half)(acc1.x),(half)(acc1.y)); *((half2*)(data0+alu12+16)) = make_half2((half)(acc2.x),(half)(acc2.y)); *((half2*)(data0+alu12+24)) = make_half2((half)(acc3.x),(half)(acc3.y)); *((half2*)(data0+alu12+4096)) = make_half2((half)(acc0.z),(half)(acc0.w)); *((half2*)(data0+alu12+4104)) = make_half2((half)(acc1.z),(half)(acc1.w)); *((half2*)(data0+alu12+4112)) = make_half2((half)(acc2.z),(half)(acc2.w)); *((half2*)(data0+alu12+4120)) = make_half2((half)(acc3.z),(half)(acc3.w)); *((half2*)(data0+alu12+65536)) = make_half2((half)(acc4.x),(half)(acc4.y)); *((half2*)(data0+alu12+65544)) = make_half2((half)(acc5.x),(half)(acc5.y)); *((half2*)(data0+alu12+65552)) = make_half2((half)(acc6.x),(half)(acc6.y)); *((half2*)(data0+alu12+65560)) = make_half2((half)(acc7.x),(half)(acc7.y)); *((half2*)(data0+alu12+69632)) = make_half2((half)(acc4.z),(half)(acc4.w)); *((half2*)(data0+alu12+69640)) = make_half2((half)(acc5.z),(half)(acc5.w)); *((half2*)(data0+alu12+69648)) = make_half2((half)(acc6.z),(half)(acc6.w)); *((half2*)(data0+alu12+69656)) = make_half2((half)(acc7.z),(half)(acc7.w)); *((half2*)(data0+alu12+131072)) = make_half2((half)(acc8.x),(half)(acc8.y)); *((half2*)(data0+alu12+131080)) = make_half2((half)(acc9.x),(half)(acc9.y)); *((half2*)(data0+alu12+131088)) = make_half2((half)(acc10.x),(half)(acc10.y)); *((half2*)(data0+alu12+131096)) = make_half2((half)(acc11.x),(half)(acc11.y)); *((half2*)(data0+alu12+135168)) = make_half2((half)(acc8.z),(half)(acc8.w)); *((half2*)(data0+alu12+135176)) = make_half2((half)(acc9.z),(half)(acc9.w)); *((half2*)(data0+alu12+135184)) = make_half2((half)(acc10.z),(half)(acc10.w)); *((half2*)(data0+alu12+135192)) = make_half2((half)(acc11.z),(half)(acc11.w)); *((half2*)(data0+alu12+196608)) = make_half2((half)(acc12.x),(half)(acc12.y)); *((half2*)(data0+alu12+196616)) = make_half2((half)(acc13.x),(half)(acc13.y)); *((half2*)(data0+alu12+196624)) = make_half2((half)(acc14.x),(half)(acc14.y)); *((half2*)(data0+alu12+196632)) = make_half2((half)(acc15.x),(half)(acc15.y)); *((half2*)(data0+alu12+200704)) = make_half2((half)(acc12.z),(half)(acc12.w)); *((half2*)(data0+alu12+200712)) = make_half2((half)(acc13.z),(half)(acc13.w)); *((half2*)(data0+alu12+200720)) = make_half2((half)(acc14.z),(half)(acc14.w)); *((half2*)(data0+alu12+200728)) = make_half2((half)(acc15.z),(half)(acc15.w)); *((half2*)(data0+alu12)) = make_half2((half)(acc0.x),(half)(acc0.y)); }