From 0a4c4448df6459a2dd744ab77a539d1749f3b9d3 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 10 Dec 2024 16:03:57 +0000 Subject: [PATCH 1/6] Add MSVC support for BFloat16 Signed-off-by: Viet-Hoa Do --- CMakeLists.txt | 6 +++ kai/common/internal/assembly.h | 45 +++++++++++++++++++++++ test/common/bfloat16.cpp | 16 ++++++++ test/common/bfloat16.hpp | 67 +++++++++++----------------------- test/common/bfloat16_asm.S | 18 +++++++++ test/reference/pack.cpp | 10 ++--- test/tests/bfloat16_test.cpp | 33 +++++++++++++++++ 7 files changed, 145 insertions(+), 50 deletions(-) create mode 100644 kai/common/internal/assembly.h create mode 100644 test/common/bfloat16_asm.S create mode 100644 test/tests/bfloat16_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index cf3ef4d8..c73627b6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -226,6 +226,8 @@ if(KLEIDIAI_BUILD_TESTS) if(MSVC) add_library(kleidiai_test_framework test/common/data_type.cpp + test/common/bfloat16.cpp + test/common/bfloat16_asm.S test/common/float16.cpp test/common/float16.S test/common/cpu_info.cpp @@ -242,6 +244,7 @@ if(KLEIDIAI_BUILD_TESTS) test/common/rect.cpp test/common/round.cpp test/common/bfloat16.cpp + test/common/bfloat16_asm.S test/common/float16.cpp test/common/float16.S test/common/cpu_info.cpp @@ -268,6 +271,7 @@ if(KLEIDIAI_BUILD_TESTS) ) if(MSVC) + set_source_files_properties(test/common/bfloat16_asm.S PROPERTIES LANGUAGE ASM_MARMASM) set_source_files_properties(test/common/float16.S PROPERTIES LANGUAGE ASM_MARMASM) endif() @@ -279,10 +283,12 @@ if(KLEIDIAI_BUILD_TESTS) if(MSVC) add_executable(kleidiai_test + test/tests/bfloat16_test.cpp test/tests/float16_test.cpp ) else() add_executable(kleidiai_test + test/tests/bfloat16_test.cpp test/tests/float16_test.cpp test/tests/matmul_test.cpp test/tests/matmul_clamp_f32_f32_f32p_test.cpp diff --git a/kai/common/internal/assembly.h b/kai/common/internal/assembly.h new file mode 100644 index 00000000..b20e8347 --- /dev/null +++ b/kai/common/internal/assembly.h @@ -0,0 +1,45 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#ifndef KAI_COMMON_INTERNAL_ASSEMBLY_H +#define KAI_COMMON_INTERNAL_ASSEMBLY_H + +// clang-format off + +#ifdef _MSC_VER + +#define KAI_ASM_HEADER AREA |.text|, CODE, READONLY, ALIGN=4 +#define KAI_ASM_LABEL(label) |label| +#define KAI_ASM_FUNCTION(label) |label| +#define KAI_ASM_EXPORT(label) global label +#define KAI_ASM_FOOTER end +#define KAI_ASM_INST(num) dcd num + +#else // _MSC_VER + +#define KAI_ASM_HEADER .text +#define KAI_ASM_LABEL(label) label: + +#ifdef __APPLE__ +#define KAI_ASM_FUNCTION(label) _##label: +#define KAI_ASM_EXPORT(label) \ + .global _##label; \ + .type _##label, %function +#else // __APPLE__ +#define KAI_ASM_FUNCTION(label) label: +#define KAI_ASM_EXPORT(label) \ + .global label; \ + .type label, %function +#endif // __APPLE__ + +#define KAI_ASM_FOOTER +#define KAI_ASM_INST(num) .inst num + +#endif // _MSC_VER + +// clang-format on + +#endif // KAI_COMMON_INTERNAL_ASSEMBLY_H diff --git a/test/common/bfloat16.cpp b/test/common/bfloat16.cpp index d9581b2f..26e9259a 100644 --- a/test/common/bfloat16.cpp +++ b/test/common/bfloat16.cpp @@ -7,9 +7,25 @@ #include "test/common/bfloat16.hpp" #include +#include namespace kai::test { +static_assert(sizeof(BFloat16) == 2); + +static_assert(std::is_trivially_destructible_v); +static_assert(std::is_nothrow_destructible_v); + +static_assert(std::is_trivially_copy_constructible_v); +static_assert(std::is_trivially_copy_assignable_v); +static_assert(std::is_trivially_move_constructible_v); +static_assert(std::is_trivially_move_assignable_v); + +static_assert(std::is_nothrow_copy_constructible_v); +static_assert(std::is_nothrow_copy_assignable_v); +static_assert(std::is_nothrow_move_constructible_v); +static_assert(std::is_nothrow_move_assignable_v); + std::ostream& operator<<(std::ostream& os, BFloat16 value) { return os << static_cast(value); } diff --git a/test/common/bfloat16.hpp b/test/common/bfloat16.hpp index 9291918a..ac416746 100644 --- a/test/common/bfloat16.hpp +++ b/test/common/bfloat16.hpp @@ -12,6 +12,17 @@ #include "test/common/type_traits.hpp" +extern "C" { + +/// Converts single-precision floating-point to brain floating-point. +/// +/// @params[in] value The single-precision floating-point value. +/// +/// @return The brain floating-point value reinterpreted as 16-bit unsigned integer. +uint16_t kai_test_bfloat16_from_float(float value); + +} // extern "C" + namespace kai::test { /// Half-precision brain floating-point. @@ -22,72 +33,39 @@ public: /// Constructor. BFloat16() = default; - /// Destructor. - ~BFloat16() = default; - - /// Copy constructor. - BFloat16(const BFloat16&) = default; - - /// Copy assignment. - BFloat16& operator=(const BFloat16&) = default; - - /// Move constructor. - BFloat16(BFloat16&&) = default; - - /// Move assignment. - BFloat16& operator=(BFloat16&&) = default; - /// Creates a new object from the specified numeric value. - BFloat16(float value) : _data(0) { -#ifdef __ARM_FEATURE_BF16 - __asm__ __volatile__("bfcvt %h[output], %s[input]" : [output] "=w"(_data) : [input] "w"(value)); -#else - const uint32_t* value_i32 = reinterpret_cast(&value); - _data = (*value_i32 >> 16); -#endif + explicit BFloat16(float value) : m_data(kai_test_bfloat16_from_float(value)) { } /// Assigns to the specified numeric value which will be converted to `bfloat16_t`. template , bool> = true> BFloat16& operator=(T value) { const auto value_f32 = static_cast(value); -#ifdef __ARM_FEATURE_BF16 - __asm__ __volatile__("bfcvt %h[output], %s[input]" : [output] "=w"(_data) : [input] "w"(value_f32)); -#else - const uint32_t* value_i32 = reinterpret_cast(&value_f32); - _data = (*value_i32 >> 16); -#endif + m_data = kai_test_bfloat16_from_float(value_f32); return *this; } - /// Converts to floating-point. - operator float() const { + /// Converts to single-precision floating-point. + explicit operator float() const { union { float f32; uint32_t u32; } data; - data.u32 = static_cast(_data) << 16; + data.u32 = static_cast(m_data) << 16; return data.f32; } +private: /// Equality operator. - bool operator==(BFloat16 rhs) const { - return _data == rhs._data; + [[nodiscard]] friend bool operator==(BFloat16 lhs, BFloat16 rhs) { + return lhs.m_data == rhs.m_data; } /// Unequality operator. - bool operator!=(BFloat16 rhs) const { - return _data != rhs._data; - } - - uint16_t data() const { - return _data; - } - - void set_data(uint16_t data) { - _data = data; + [[nodiscard]] friend bool operator!=(BFloat16 lhs, BFloat16 rhs) { + return lhs.m_data != rhs.m_data; } /// Writes the value to the output stream. @@ -98,8 +76,7 @@ public: /// @return The output stream. friend std::ostream& operator<<(std::ostream& os, BFloat16 value); -private: - uint16_t _data; + uint16_t m_data; }; } // namespace kai::test diff --git a/test/common/bfloat16_asm.S b/test/common/bfloat16_asm.S new file mode 100644 index 00000000..28bcd908 --- /dev/null +++ b/test/common/bfloat16_asm.S @@ -0,0 +1,18 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "kai/common/internal/assembly.h" + + KAI_ASM_HEADER + + KAI_ASM_EXPORT(kai_test_bfloat16_from_float) + +KAI_ASM_FUNCTION(kai_test_bfloat16_from_float) + KAI_ASM_INST(0x1e634000) // bfcvt h0, s0 + fmov w0, h0 + ret + + KAI_ASM_FOOTER diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index 69b756c6..dc44ad10 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -26,14 +26,14 @@ namespace kai::test { namespace { -uint16_t convert(const uint8_t* src_ptr_elm, DataType src_dtype, DataType dst_dtype) { +BFloat16 convert(const uint8_t* src_ptr_elm, DataType src_dtype, DataType dst_dtype) { KAI_ASSUME((src_dtype == DataType::FP32 || src_dtype == DataType::FP16) && dst_dtype == DataType::BF16); switch (src_dtype) { case DataType::FP32: - return BFloat16(*reinterpret_cast(src_ptr_elm)).data(); + return BFloat16(*reinterpret_cast(src_ptr_elm)); case DataType::FP16: - return BFloat16(static_cast(*reinterpret_cast(src_ptr_elm))).data(); + return BFloat16(static_cast(*reinterpret_cast(src_ptr_elm))); default: KAI_ERROR("Unsupported Data Type"); } @@ -77,7 +77,7 @@ std::vector pack_block( x_element) * src_esize; - const uint16_t src_value = convert(src_ptr_elm, src_dtype, dst_dtype); + const BFloat16 src_value = convert(src_ptr_elm, src_dtype, dst_dtype); memcpy(dst_ptr, &src_value, dst_esize); } } @@ -149,7 +149,7 @@ std::vector pack_bias_per_row( x_element) * src_esize; - const uint16_t dst_value = convert(src_ptr_elm, src_dtype, dst_dtype); + const BFloat16 dst_value = convert(src_ptr_elm, src_dtype, dst_dtype); memcpy(dst_ptr, &dst_value, dst_esize); } } diff --git a/test/tests/bfloat16_test.cpp b/test/tests/bfloat16_test.cpp new file mode 100644 index 00000000..a8d4bb88 --- /dev/null +++ b/test/tests/bfloat16_test.cpp @@ -0,0 +1,33 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test/common/bfloat16.hpp" + +#include + +#include "test/common/cpu_info.hpp" + +namespace kai::test { + +TEST(BFloat16, SimpleTest) { + if (!cpu_has_bf16()) { + GTEST_SKIP(); + } + + ASSERT_EQ(static_cast(BFloat16()), 0.0F); + ASSERT_EQ(static_cast(BFloat16(1.25F)), 1.25F); + ASSERT_EQ(static_cast(BFloat16(3)), 3.0F); + + ASSERT_FALSE(BFloat16(1.25F) == BFloat16(2.0F)); + ASSERT_TRUE(BFloat16(1.25F) == BFloat16(1.25F)); + ASSERT_FALSE(BFloat16(2.0F) == BFloat16(1.25F)); + + ASSERT_TRUE(BFloat16(1.25F) != BFloat16(2.0F)); + ASSERT_FALSE(BFloat16(1.25F) != BFloat16(1.25F)); + ASSERT_TRUE(BFloat16(2.0F) != BFloat16(1.25F)); +} + +} // namespace kai::test -- GitLab From ee56e126fe7caaee8b5b1914f858f1f1810fd83b Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 10 Dec 2024 17:13:45 +0000 Subject: [PATCH 2/6] Fix documentation Signed-off-by: Viet-Hoa Do --- test/common/bfloat16.hpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/common/bfloat16.hpp b/test/common/bfloat16.hpp index ac416746..671fd4cf 100644 --- a/test/common/bfloat16.hpp +++ b/test/common/bfloat16.hpp @@ -14,11 +14,11 @@ extern "C" { -/// Converts single-precision floating-point to brain floating-point. +/// Converts single-precision floating-point to half-precision brain floating-point. /// /// @params[in] value The single-precision floating-point value. /// -/// @return The brain floating-point value reinterpreted as 16-bit unsigned integer. +/// @return The half-precision brain floating-point value reinterpreted as 16-bit unsigned integer. uint16_t kai_test_bfloat16_from_float(float value); } // extern "C" @@ -26,8 +26,6 @@ uint16_t kai_test_bfloat16_from_float(float value); namespace kai::test { /// Half-precision brain floating-point. -/// -/// This class encapsulates `bfloat16_t` data type provided by `arm_bf16.h`. class BFloat16 { public: /// Constructor. -- GitLab From 0d6dbb125ef7bfb7628d1ba8079e1889271d5cc7 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Tue, 10 Dec 2024 17:39:48 +0000 Subject: [PATCH 3/6] Fix build error and documentation Signed-off-by: Viet-Hoa Do --- BUILD.bazel | 1 + test/common/bfloat16.hpp | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/BUILD.bazel b/BUILD.bazel index c4abdc25..f730ce92 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -30,6 +30,7 @@ config_setting( cc_library( name = "common", hdrs = ["kai/kai_common.h"], + textual_hdrs = ["kai/common/internal/assembly.h"], ) kai_c_library( diff --git a/test/common/bfloat16.hpp b/test/common/bfloat16.hpp index 671fd4cf..da529e03 100644 --- a/test/common/bfloat16.hpp +++ b/test/common/bfloat16.hpp @@ -61,7 +61,7 @@ private: return lhs.m_data == rhs.m_data; } - /// Unequality operator. + /// Inequality operator. [[nodiscard]] friend bool operator!=(BFloat16 lhs, BFloat16 rhs) { return lhs.m_data != rhs.m_data; } -- GitLab From 0b576e99928c9c1d9ea3892bdc08d9594dd57a2d Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 2 Jan 2025 15:05:02 +0000 Subject: [PATCH 4/6] Move assembly header file to test framework Signed-off-by: Viet-Hoa Do --- {kai/common/internal => test/common}/assembly.h | 10 ++++++---- test/common/bfloat16_asm.S | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) rename {kai/common/internal => test/common}/assembly.h (72%) diff --git a/kai/common/internal/assembly.h b/test/common/assembly.h similarity index 72% rename from kai/common/internal/assembly.h rename to test/common/assembly.h index b20e8347..09497093 100644 --- a/kai/common/internal/assembly.h +++ b/test/common/assembly.h @@ -1,11 +1,11 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // -#ifndef KAI_COMMON_INTERNAL_ASSEMBLY_H -#define KAI_COMMON_INTERNAL_ASSEMBLY_H +#ifndef KAI_TEST_COMMON_ASSEMBLY_H +#define KAI_TEST_COMMON_ASSEMBLY_H // clang-format off @@ -13,6 +13,7 @@ #define KAI_ASM_HEADER AREA |.text|, CODE, READONLY, ALIGN=4 #define KAI_ASM_LABEL(label) |label| +#define KAI_ASM_TARGET(label, direction) |label| #define KAI_ASM_FUNCTION(label) |label| #define KAI_ASM_EXPORT(label) global label #define KAI_ASM_FOOTER end @@ -22,6 +23,7 @@ #define KAI_ASM_HEADER .text #define KAI_ASM_LABEL(label) label: +#define KAI_ASM_TARGET(label, direction) label##direction #ifdef __APPLE__ #define KAI_ASM_FUNCTION(label) _##label: @@ -42,4 +44,4 @@ // clang-format on -#endif // KAI_COMMON_INTERNAL_ASSEMBLY_H +#endif // KAI_TEST_COMMON_ASSEMBLY_H diff --git a/test/common/bfloat16_asm.S b/test/common/bfloat16_asm.S index 28bcd908..9f16cda4 100644 --- a/test/common/bfloat16_asm.S +++ b/test/common/bfloat16_asm.S @@ -1,10 +1,10 @@ // -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates // // SPDX-License-Identifier: Apache-2.0 // -#include "kai/common/internal/assembly.h" +#include "test/common/assembly.h" KAI_ASM_HEADER -- GitLab From 98614660b71e0c8f80545c0782f2e446edcb2dc7 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 2 Jan 2025 16:02:43 +0000 Subject: [PATCH 5/6] Update bazel build Signed-off-by: Viet-Hoa Do --- BUILD.bazel | 3 +-- test/BUILD.bazel | 5 ++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index f730ce92..f57d92c8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -30,7 +30,6 @@ config_setting( cc_library( name = "common", hdrs = ["kai/kai_common.h"], - textual_hdrs = ["kai/common/internal/assembly.h"], ) kai_c_library( diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 834520f1..9019364a 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -30,6 +30,9 @@ kai_cxx_library( ), # compare.cpp requires fp16 and bf16 support cpu_uarch = kai_cpu_bf16() + kai_cpu_fp16(), + textual_hdrs = [ + "common/assembly.h", + ], ) kai_cxx_library( -- GitLab From 66eff7bdac1e6ab3a15f7b62182774bd3e669471 Mon Sep 17 00:00:00 2001 From: Viet-Hoa Do Date: Thu, 2 Jan 2025 16:31:11 +0000 Subject: [PATCH 6/6] Remove BUILD.bazel from the commit changes * This file is no longer changed so the copyright year change is not needed anymore. * This change effectively remove BUILD.bazel from the final squashed commit. Signed-off-by: Viet-Hoa Do --- BUILD.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/BUILD.bazel b/BUILD.bazel index f57d92c8..c4abdc25 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # -- GitLab