diff --git a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c index 9af7841d5af3616087c94dc71ab3b2738a50553a..35c7bf1734fad392af79ff3a9095f65ad6d3aed4 100644 --- a/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c +++ b/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.c @@ -5,6 +5,10 @@ // #include "kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" +#if defined(__ARM_NEON) | defined(_M_ARM64) +#include +#endif // defined(__ARM_NEON) | defined(_M_ARM64) + #include #include #include @@ -178,81 +182,238 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( size_t k0_idx_i = dst_qblock_idx * bl; - for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { - for (size_t segment_idx = 0; segment_idx < 16 / block_length_in_bytes; ++segment_idx) { - for (size_t nr_idx = 0; nr_idx < nr; ++nr_idx) { - const size_t n0_idx = dst_row_idx + nr_idx; - - // Two int4 values are stored in one byte. - // The lower order part of the byte (low) holds the first nibble (K-index + 0). - // The higher order of the byte holds the second nibble (K-index + 16). - size_t k0_idx = k0_idx_i; - size_t k1_idx = k0_idx_i + 16; - - // Clamp the index to avoid out-of-bound reads - const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); - float d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); - - int32_t partial_sum = 0; - - size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; - - for (size_t block_byte_idx = 0; block_byte_idx < block_length_in_bytes; block_byte_idx += 2) { - // Initialize the byte with the zero-point (8) - // e.g. uint8_t byte0 = 8 | 8 << 4 - uint8_t byte0 = 136; - uint8_t byte1 = 136; - uint8_t byte2 = 136; - uint8_t byte3 = 136; - - if (k0_idx < k) { - byte0 = rhs[src_addr_byte0]; - } - - if (k1_idx < k) { - byte1 = rhs[src_addr_byte0 + 8]; - } +#if defined(__ARM_NEON) | defined(_M_ARM64) + const uint8x8_t top_mask = vdup_n_u8(0xF0); + const uint8x8_t bottom_mask = vdup_n_u8(0x0F); + const uint8x8_t zero_point_conversion_mask = vdup_n_u8(0x88); + + if (block_length_in_bytes == 8) { + for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { + for (size_t nr_idx = 0; nr_idx < nr; nr_idx += 4) { + // Clamp the indices to avoid out-of-bound reads + const size_t n0_idx = KAI_MIN(dst_row_idx + nr_idx, n - 1); + const size_t n1_idx = KAI_MIN(n0_idx + 1, n - 1); + const size_t n2_idx = KAI_MIN(n0_idx + 2, n - 1); + const size_t n3_idx = KAI_MIN(n0_idx + 3, n - 1); + + const float d0 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 0]); + const float d1 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 1]); + const float d2 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 2]); + const float d3 = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx + 3]); + + // Take zero-point (-8) into account + int32_t partial_sum0 = -(32 * 8); + int32_t partial_sum1 = -(32 * 8); + int32_t partial_sum2 = -(32 * 8); + int32_t partial_sum3 = -(32 * 8); + + const uint8_t* src_block_base = rhs + ((k0_idx_i / 2) + dst_byte_idx); + + // Load elements as uint64_ts to calculate sums more efficiently + uint64_t ld0_0 = *(const uint64_t*)(src_block_base + n0_idx * rhs_stride); + uint64_t ld0_1 = *(const uint64_t*)(src_block_base + n0_idx * rhs_stride + 8); + + uint64_t ld1_0 = *(const uint64_t*)(src_block_base + n1_idx * rhs_stride); + uint64_t ld1_1 = *(const uint64_t*)(src_block_base + n1_idx * rhs_stride + 8); + + uint64_t ld2_0 = *(const uint64_t*)(src_block_base + n2_idx * rhs_stride); + uint64_t ld2_1 = *(const uint64_t*)(src_block_base + n2_idx * rhs_stride + 8); + + uint64_t ld3_0 = *(const uint64_t*)(src_block_base + n3_idx * rhs_stride); + uint64_t ld3_1 = *(const uint64_t*)(src_block_base + n3_idx * rhs_stride + 8); + + // Copy to vector registers + const uint8x8_t vld0_0 = vcreate_u8(ld0_0); + const uint8x8_t vld0_1 = vcreate_u8(ld0_1); + + const uint8x8_t vld1_0 = vcreate_u8(ld1_0); + const uint8x8_t vld1_1 = vcreate_u8(ld1_1); + + const uint8x8_t vld2_0 = vcreate_u8(ld2_0); + const uint8x8_t vld2_1 = vcreate_u8(ld2_1); + + const uint8x8_t vld3_0 = vcreate_u8(ld3_0); + const uint8x8_t vld3_1 = vcreate_u8(ld3_1); + + // Calculate sums + for (size_t idx = 0; idx < 16; ++idx) { + const int32_t e0_0 = (int32_t)(ld0_0 & 0x0F); + const int32_t e0_1 = (int32_t)(ld0_1 & 0x0F); + partial_sum0 += e0_0 + e0_1; + ld0_0 = ld0_0 >> 4; + ld0_1 = ld0_1 >> 4; + + const int32_t e1_0 = (int32_t)(ld1_0 & 0x0F); + const int32_t e1_1 = (int32_t)(ld1_1 & 0x0F); + partial_sum1 += e1_0 + e1_1; + ld1_0 = ld1_0 >> 4; + ld1_1 = ld1_1 >> 4; + + const int32_t e2_0 = (int32_t)(ld2_0 & 0x0F); + const int32_t e2_1 = (int32_t)(ld2_1 & 0x0F); + partial_sum2 += e2_0 + e2_1; + ld2_0 = ld2_0 >> 4; + ld2_1 = ld2_1 >> 4; + + const int32_t e3_0 = (int32_t)(ld3_0 & 0x0F); + const int32_t e3_1 = (int32_t)(ld3_1 & 0x0F); + partial_sum3 += e3_0 + e3_1; + ld3_0 = ld3_0 >> 4; + ld3_1 = ld3_1 >> 4; + } - if (k0_idx + 1 < k) { - byte2 = byte0; - } + const uint8x8_t vld0_s1s = vand_u8(vld0_0, bottom_mask); + const uint8x8_t vld0_s0s = vshr_n_u8(vld0_0, 4); + const uint8x8_t vld0_s17s = vshl_n_u8(vld0_1, 4); + const uint8x8_t vld0_s16s = vand_u8(vld0_1, top_mask); + + const uint8x8_t vld0_s16s0s_lower = + vorr_u8(vzip1_u8(vld0_s1s, vld0_s0s), vzip1_u8(vld0_s17s, vld0_s16s)); + const uint8x8_t vld0_s16s0s_upper = + vorr_u8(vzip2_u8(vld0_s1s, vld0_s0s), vzip2_u8(vld0_s17s, vld0_s16s)); + + const uint8x8_t vld1_s1s = vand_u8(vld1_0, bottom_mask); + const uint8x8_t vld1_s0s = vshr_n_u8(vld1_0, 4); + const uint8x8_t vld1_s17s = vshl_n_u8(vld1_1, 4); + const uint8x8_t vld1_s16s = vand_u8(vld1_1, top_mask); + + const uint8x8_t vld1_s16s0s_lower = + vorr_u8(vzip1_u8(vld1_s1s, vld1_s0s), vzip1_u8(vld1_s17s, vld1_s16s)); + const uint8x8_t vld1_s16s0s_upper = + vorr_u8(vzip2_u8(vld1_s1s, vld1_s0s), vzip2_u8(vld1_s17s, vld1_s16s)); + + const uint8x8_t vld2_s1s = vand_u8(vld2_0, bottom_mask); + const uint8x8_t vld2_s0s = vshr_n_u8(vld2_0, 4); + const uint8x8_t vld2_s17s = vshl_n_u8(vld2_1, 4); + const uint8x8_t vld2_s16s = vand_u8(vld2_1, top_mask); + + const uint8x8_t vld2_s16s0s_lower = + vorr_u8(vzip1_u8(vld2_s1s, vld2_s0s), vzip1_u8(vld2_s17s, vld2_s16s)); + const uint8x8_t vld2_s16s0s_upper = + vorr_u8(vzip2_u8(vld2_s1s, vld2_s0s), vzip2_u8(vld2_s17s, vld2_s16s)); + + const uint8x8_t vld3_s1s = vand_u8(vld3_0, bottom_mask); + const uint8x8_t vld3_s0s = vshr_n_u8(vld3_0, 4); + const uint8x8_t vld3_s17s = vshl_n_u8(vld3_1, 4); + const uint8x8_t vld3_s16s = vand_u8(vld3_1, top_mask); + + const uint8x8_t vld3_s16s0s_lower = + vorr_u8(vzip1_u8(vld3_s1s, vld3_s0s), vzip1_u8(vld3_s17s, vld3_s16s)); + const uint8x8_t vld3_s16s0s_upper = + vorr_u8(vzip2_u8(vld3_s1s, vld3_s0s), vzip2_u8(vld3_s17s, vld3_s16s)); + + // Convert to unsigned int4 and store repacked values + vst1_u8((uint8_t*)dst_row, veor_u8(vld0_s16s0s_lower, zero_point_conversion_mask)); + vst1_u8((uint8_t*)dst_row + 8, veor_u8(vld1_s16s0s_lower, zero_point_conversion_mask)); + vst1_u8((uint8_t*)dst_row + 16, veor_u8(vld2_s16s0s_lower, zero_point_conversion_mask)); + vst1_u8((uint8_t*)dst_row + 24, veor_u8(vld3_s16s0s_lower, zero_point_conversion_mask)); + + vst1_u8( + (uint8_t*)dst_row + (nr * block_length_in_bytes), + veor_u8(vld0_s16s0s_upper, zero_point_conversion_mask)); + vst1_u8( + (uint8_t*)dst_row + (nr * block_length_in_bytes) + 8, + veor_u8(vld1_s16s0s_upper, zero_point_conversion_mask)); + vst1_u8( + (uint8_t*)dst_row + (nr * block_length_in_bytes) + 16, + veor_u8(vld2_s16s0s_upper, zero_point_conversion_mask)); + vst1_u8( + (uint8_t*)dst_row + (nr * block_length_in_bytes) + 24, + veor_u8(vld3_s16s0s_upper, zero_point_conversion_mask)); + + // Add to row sums + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + sums[nr_idx + 0] += (float)partial_sum0 * d0; + sums[nr_idx + 1] += (float)partial_sum1 * d1; + sums[nr_idx + 2] += (float)partial_sum2 * d2; + sums[nr_idx + 3] += (float)partial_sum3 * d3; + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - if (k1_idx + 1 < k) { - byte3 = byte1; + dst_row += block_length_in_bytes * 4; + } + // Skip to end of qblock + dst_row += nr * block_length_in_bytes; + } + } else { +#endif // defined(__ARM_NEON) | defined(_M_ARM64) + for (size_t dst_byte_idx = 0; dst_byte_idx < num_bytes_per_block_k; dst_byte_idx += 16) { + for (size_t segment_idx = 0; segment_idx < 16 / block_length_in_bytes; ++segment_idx) { + for (size_t nr_idx = 0; nr_idx < nr; ++nr_idx) { + const size_t n0_idx = dst_row_idx + nr_idx; + + // Two int4 values are stored in one byte. + // The lower order part of the byte (low) holds the first nibble (K-index + 0). + // The higher order of the byte holds the second nibble (K-index + 16). + size_t k0_idx = k0_idx_i; + size_t k1_idx = k0_idx_i + 16; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + float d = kai_cast_f32_bf16(((uint16_t*)rhs_packed_scale)[nr_idx]); + + int32_t partial_sum = 0; + + size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + + for (size_t block_byte_idx = 0; block_byte_idx < block_length_in_bytes; + block_byte_idx += 2) { + // Initialize the byte with the zero-point (8) + // e.g. uint8_t byte0 = 8 | 8 << 4 + uint8_t byte0 = 136; + uint8_t byte1 = 136; + uint8_t byte2 = 136; + uint8_t byte3 = 136; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte0 + 8]; + } + + if (k0_idx + 1 < k) { + byte2 = byte0; + } + + if (k1_idx + 1 < k) { + byte3 = byte1; + } + + k0_idx += 2; + k1_idx += 2; + + const uint8_t src_x0_lo = byte0 & 0x0F; + const uint8_t src_x0_hi = byte1 & 0x0F; + const uint8_t src_x1_lo = (byte2 >> 4) & 0x0F; + const uint8_t src_x1_hi = (byte3 >> 4) & 0x0F; + + partial_sum += (int32_t)src_x0_lo; + partial_sum += (int32_t)src_x0_hi; + partial_sum += (int32_t)src_x1_lo; + partial_sum += (int32_t)src_x1_hi; + partial_sum -= 32; // 4 * zero_point (8) + + const uint16_t dst_q = + ((src_x0_lo)) | ((src_x0_hi) << 4) | ((src_x1_lo) << 8) | ((src_x1_hi) << 12); + + *((uint16_t*)dst_row) = dst_q ^ 0x8888; + + dst_row += 2; + src_addr_byte0 += 1; } - - k0_idx += 2; - k1_idx += 2; - - const uint8_t src_x0_lo = byte0 & 0x0F; - const uint8_t src_x0_hi = byte1 & 0x0F; - const uint8_t src_x1_lo = (byte2 >> 4) & 0x0F; - const uint8_t src_x1_hi = (byte3 >> 4) & 0x0F; - - partial_sum += (int32_t)src_x0_lo; - partial_sum += (int32_t)src_x0_hi; - partial_sum += (int32_t)src_x1_lo; - partial_sum += (int32_t)src_x1_hi; - partial_sum -= 32; // 4 * zero_point (8) - - const uint16_t dst_q = - ((src_x0_lo)) | ((src_x0_hi) << 4) | ((src_x1_lo) << 8) | ((src_x1_hi) << 12); - - *((uint16_t*)dst_row) = dst_q ^ 0x8888; - - dst_row += 2; - src_addr_byte0 += 1; + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + sums[nr_idx] += (float)partial_sum * d; + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) } - // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - sums[nr_idx] += (float)partial_sum * d; - // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + k0_idx_i += block_length_in_bytes; } - - k0_idx_i += block_length_in_bytes; + k0_idx_i += 16; } - k0_idx_i += 16; +#if defined(__ARM_NEON) | defined(_M_ARM64) } - // Move the pointer after scales +#endif // defined(__ARM_NEON) | defined(_M_ARM64) + // Move the pointer after scales dst_row += num_bytes_multiplier_rhs * nr; } @@ -269,6 +430,7 @@ void kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( ((float*)dst_row)[i] = bias[src_row_idx]; } } + // Move the pointer after the row sum dst_row += kai_num_bytes_bias * nr; }