diff --git a/.bazelrc b/.bazelrc new file mode 100644 index 0000000000000000000000000000000000000000..9565f775d2f967ed91553241502a89e7db9b06c8 --- /dev/null +++ b/.bazelrc @@ -0,0 +1,11 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +# Disable Bzlmod for every Bazel command +common --enable_bzlmod=false + +# Basic build settings +build --jobs 8 diff --git a/.gitignore b/.gitignore index 4b6b113aaa7a7575621625c03872a870e7ecbb30..462475fb73a89ed8c4fe6c816214c1965b134558 100644 --- a/.gitignore +++ b/.gitignore @@ -92,6 +92,7 @@ _deps # Build directory cmake-build-*/ build/ +bazel-* ### Debug files *.dSYM/ diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 8bfb29d11ab717e598fe6b24f1d22b8bcec48de5..409a6fbd9f0b08c4f287547426f8b78417cac98c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -30,12 +30,12 @@ build-clang: - .standard-rules stage: build script: - - cmake -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release -DKLEIDIAI_BUILD_TESTS=ON -S . -B build/ - - cmake --build ./build + - cmake -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release -DKLEIDIAI_BUILD_TESTS=ON -S . -B ${CI_JOB_NAME_SLUG} + - cmake --build ${CI_JOB_NAME_SLUG} artifacts: expire_in: 1 day paths: - - build/kleidiai_test + - ${CI_JOB_NAME_SLUG}/kleidiai_test build-clang-cov: extends: @@ -54,20 +54,55 @@ build-gcc: - .standard-rules stage: build script: - - cmake -G Ninja -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DCMAKE_BUILD_TYPE=Release -DKLEIDIAI_BUILD_TESTS=ON -S . -B build/ - - cmake --build ./build + - cmake -G Ninja -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ -DCMAKE_BUILD_TYPE=Release -DKLEIDIAI_BUILD_TESTS=ON -S . -B ${CI_JOB_NAME_SLUG} + - cmake --build ${CI_JOB_NAME_SLUG} + artifacts: + expire_in: 1 day + paths: + - ${CI_JOB_NAME_SLUG}/kleidiai_test + +build-gcc-bazel: + extends: + - .standard-rules + stage: build + cache: + - key: cache-bazelisk + paths: + - /cache/bazelisk + script: + - bazelisk clean --expunge + - bazelisk build -k --verbose_failures //... + - mkdir -p ${CI_JOB_NAME_SLUG} && cp bazel-bin/test/kleidiai_test ${CI_JOB_NAME_SLUG}/ + artifacts: + expire_in: 1 day + paths: + - ${CI_JOB_NAME_SLUG}/kleidiai_test + +build-clang-bazel: + extends: + - .standard-rules + stage: build + cache: + - key: cache-bazelisk + paths: + - /cache/bazelisk + script: + - bazelisk clean --expunge + # explicitly disable layering_check feature + - CC=clang bazelisk build -k --verbose_failures --compiler=clang --features=no-layering_check //... + - mkdir -p ${CI_JOB_NAME_SLUG} && cp bazel-bin/test/kleidiai_test ${CI_JOB_NAME_SLUG}/ artifacts: expire_in: 1 day paths: - - build/kleidiai_test + - ${CI_JOB_NAME_SLUG}/kleidiai_test clang-tidy-checks: extends: - .standard-rules stage: build script: - - cmake -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release -DKLEIDIAI_BUILD_TESTS=ON -DKLEIDIAI_ENABLE_CLANG_TIDY=ON -S . -B build/ - - cmake --build ./build + - cmake -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release -DKLEIDIAI_BUILD_TESTS=ON -DKLEIDIAI_ENABLE_CLANG_TIDY=ON -S . -B ${CI_JOB_NAME_SLUG} + - cmake --build ${CI_JOB_NAME_SLUG} pre-commit-hooks: variables: @@ -80,23 +115,29 @@ pre-commit-hooks: paths: - $PRE_COMMIT_HOME script: - - pre-commit run --all-files + - PRE_COMMIT_HOME=$PRE_COMMIT_HOME pre-commit run --all-files test-linux-aarch64: extends: - .standard-rules stage: test + parallel: + matrix: + - BUILD_JOB_PROVIDER: [ clang, gcc, clang-bazel, gcc-bazel ] needs: + - build-gcc + - build-gcc-bazel - build-clang + - build-clang-bazel script: - - ./build/kleidiai_test --gtest_output=xml:kleidiai_test_results.xml + - ./build-${BUILD_JOB_PROVIDER}/kleidiai_test --gtest_output=xml:kleidiai_test_results-${BUILD_JOB_PROVIDER}.xml artifacts: when: always expire_in: 1 day paths: - - kleidiai_test_results.xml + - kleidiai_test_results-${BUILD_JOB_PROVIDER}.xml reports: - junit: kleidiai_test_results.xml + junit: kleidiai_test_results-${BUILD_JOB_PROVIDER}.xml test-linux-aarch64-cov: extends: @@ -134,7 +175,7 @@ test-linux-aarch64-cov-fvp: cd '$PWD' mkdir -p artifacts/$CI_PROJECT_DIR GCOV_PREFIX=artifacts ./build/kleidiai_test --gtest_output=xml:artifacts/$CI_PROJECT_DIR/kleidiai_test_results.xml && echo 'FINISHED WITHOUT ERROR' - tar cvf artifacts.tar -C artifacts . + tar cf artifacts.tar -C artifacts . sync echo '==================================================' @@ -191,7 +232,7 @@ test-linux-aarch64-cov-fvp: |& tee output.txt - grep -q "FINISHED WITHOUT ERROR" output.txt - e2cp linux-rootfs.img:"$PWD/artifacts.tar" . - - tar xvf artifacts.tar -C / + - tar xf artifacts.tar -C / - mkdir -p build/coverage - gcovr --gcov-executable="llvm-cov gcov" --exclude-unreachable-branches --exclude=build --exclude=test --exclude-lines-by-pattern=".*KAI_(?:ASSERT|ASSUME|ERROR).*" --exclude-branches-by-pattern=".*KAI_(?:ASSERT|ASSUME).*" --json=build/coverage/linux-aarch64-fvp.json -j --root . build artifacts: diff --git a/BUILD.bazel b/BUILD.bazel new file mode 100644 index 0000000000000000000000000000000000000000..c4abdc25d53561cb9f090a6554aea430ab0d0536 --- /dev/null +++ b/BUILD.bazel @@ -0,0 +1,41 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +load( + "//:kai_defs.bzl", + "kai_c_library", +) + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +exports_files([ + "LICENSES/**", +]) + +config_setting( + name = "linux", + constraint_values = ["@platforms//os:linux"], +) + +config_setting( + name = "windows", + constraint_values = ["@platforms//os:windows"], +) + +cc_library( + name = "common", + hdrs = ["kai/kai_common.h"], +) + +kai_c_library( + name = "kleidiai", + visibility = ["//visibility:public"], + deps = [ + "//kai/ukernels/matmul", + ], +) diff --git a/CMakeLists.txt b/CMakeLists.txt index 07c03a6a85faa7d83bed02ead5430b6506bf9202..80adb858238e89478ba8db2821d3757ded311d1a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -150,6 +150,7 @@ if(KLEIDIAI_BUILD_TESTS) test/common/compare.cpp test/common/matrix_portion.cpp test/common/rect.cpp + test/common/round.cpp test/common/bfloat16.cpp test/common/float16.cpp test/common/cpu_info.cpp @@ -161,7 +162,6 @@ if(KLEIDIAI_BUILD_TESTS) test/reference/pack.cpp test/reference/quantize.cpp test/reference/reduce.cpp - test/reference/round.cpp test/reference/transpose.cpp test/reference/cast.cpp ) diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000000000000000000000000000000000000..0932cdc49fd023eae25d26df36430aec955306ac --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,30 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +workspace(name = "com_arm_kleidiai") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "bazel_skylib", + sha256 = "08c0386f45821ce246bbbf77503c973246ed6ee5c3463e41efc197fa9bc3a7f4", + strip_prefix = "bazel-skylib-288731ef9f7f688932bd50e704a91a45ec185f9b", + urls = ["https://github.com/bazelbuild/bazel-skylib/archive/288731ef9f7f688932bd50e704a91a45ec185f9b.zip"], +) + +http_archive( + name = "com_google_googletest", + sha256 = "1f357c27ca988c3f7c6b4bf68a9395005ac6761f034046e9dde0896e3aba00e4", + strip_prefix = "googletest-1.14.0", + urls = ["https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip"], +) + +http_archive( + name = "com_google_benchmark", + sha256 = "84c49c4c07074f36fbf8b4f182ed7d75191a6fa72756ab4a17848455499f4286", + strip_prefix = "benchmark-v1.8.4", + urls = ["https://github.com/google/benchmark/archive/refs/tags/v1.8.4.zip"], +) diff --git a/kai/ukernels/matmul/BUILD.bazel b/kai/ukernels/matmul/BUILD.bazel new file mode 100644 index 0000000000000000000000000000000000000000..013a96c30514c85ecdbead955b5cf85dcfada887 --- /dev/null +++ b/kai/ukernels/matmul/BUILD.bazel @@ -0,0 +1,168 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +load( + "//:kai_defs.bzl", + "kai_c_library", + "kai_cpu_dotprod", + "kai_cpu_fp16", + "kai_cpu_i8mm", + "kai_cpu_neon", + "kai_cpu_scalar", + "kai_cpu_sme", +) + +package(default_visibility = ["//visibility:public"]) + +kai_c_library( + name = "clamp_f16_f16_f16p_interface", + hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p_interface.h"], + cpu_uarch = kai_cpu_fp16(), +) + +kai_c_library( + name = "clamp_f16_f16_f16p", + srcs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.c"], + hdrs = ["matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla.h"], + cpu_uarch = kai_cpu_fp16(), + deps = [ + ":clamp_f16_f16_f16p_interface", + ], +) + +kai_c_library( + name = "clamp_f32_f32p_f32p", + srcs = ["matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.c"], + hdrs = ["matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"], + cpu_uarch = kai_cpu_sme(), +) + +cc_library( + name = "clamp_f32_qai8dxp_qsi4cxp_interface", + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h"], +) + +kai_c_library( + name = "clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h"], + cpu_uarch = kai_cpu_dotprod(), + deps = [ + ":clamp_f32_qai8dxp_qsi4cxp_interface", + ], +) + +kai_c_library( + name = "clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h"], + cpu_uarch = kai_cpu_dotprod(), + deps = [ + ":clamp_f32_qai8dxp_qsi4cxp_interface", + ], +) + +kai_c_library( + name = "clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm.h"], + cpu_uarch = kai_cpu_i8mm(), + deps = [ + ":clamp_f32_qai8dxp_qsi4cxp_interface", + ], +) + +kai_c_library( + name = "clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h"], + cpu_uarch = kai_cpu_i8mm(), + deps = [ + ":clamp_f32_qai8dxp_qsi4cxp_interface", + ], +) + +kai_c_library( + name = "clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm.h"], + cpu_uarch = kai_cpu_i8mm(), + deps = [ + ":clamp_f32_qai8dxp_qsi4cxp_interface", + ], +) + +kai_c_library( + name = "clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm", + srcs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.c"], + hdrs = ["matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h"], + cpu_uarch = kai_cpu_i8mm(), + deps = [ + ":clamp_f32_qai8dxp_qsi4cxp_interface", + ], +) + +kai_c_library( + name = "lhs_quant_pack_qai8dxp_f32", + srcs = ["pack/kai_lhs_quant_pack_qai8dxp_f32.c"], + hdrs = ["pack/kai_lhs_quant_pack_qai8dxp_f32.h"], + cpu_uarch = kai_cpu_scalar(), +) + +kai_c_library( + name = "lhs_pack_f32p2vlx1_f32_sme", + srcs = ["pack/kai_lhs_pack_f32p2vlx1_f32_sme.c"], + hdrs = ["pack/kai_lhs_pack_f32p2vlx1_f32_sme.h"], + cpu_uarch = kai_cpu_sme(), +) + +kai_c_library( + name = "rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", + srcs = ["pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.c"], + hdrs = ["pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon.h"], + cpu_uarch = kai_cpu_neon(), +) + +kai_c_library( + name = "rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", + srcs = ["pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.c"], + hdrs = ["pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h"], + cpu_uarch = kai_cpu_sme(), +) + +kai_c_library( + name = "rhs_pack_nxk_qsi4cxp_qsu4cxs1s0", + srcs = ["pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c"], + hdrs = ["pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h"], + cpu_uarch = kai_cpu_scalar(), +) + +kai_c_library( + name = "rhs_pack_kxn_qsi4cxp_qsu4cxs1s0", + srcs = ["pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.c"], + hdrs = ["pack/kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0.h"], + cpu_uarch = kai_cpu_scalar(), +) + +kai_c_library( + name = "matmul", + deps = [ + ":clamp_f16_f16_f16p", + ":clamp_f32_f32p_f32p", + ":clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod", + ":clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod", + ":clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm", + ":clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm", + ":clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm", + ":clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm", + ":lhs_pack_f32p2vlx1_f32_sme", + ":lhs_quant_pack_qai8dxp_f32", + ":rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon", + ":rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme", + ":rhs_pack_kxn_qsi4cxp_qsu4cxs1s0", + ":rhs_pack_nxk_qsi4cxp_qsu4cxs1s0", + ], +) diff --git a/kai_defs.bzl b/kai_defs.bzl new file mode 100644 index 0000000000000000000000000000000000000000..a3afc5a9b0982b0cbf62e772748f2b518c5b4737 --- /dev/null +++ b/kai_defs.bzl @@ -0,0 +1,116 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +"""Build definitions for KleidiAI""" + +# Extra warnings for GCC/CLANG C/C++ +def kai_gcc_warn_copts(): + return [ + "-Wall", + "-Wdisabled-optimization", + "-Werror", + "-Wextra", + "-Wformat-security", + "-Wformat=2", + "-Winit-self", + "-Wno-ignored-attributes", + "-Wno-misleading-indentation", + "-Wno-overlength-strings", + "-Wstrict-overflow=2", + "-Wswitch-default", + "-Wno-vla", + ] + +def kai_gcc_warn_cxxopts(): + return kai_gcc_warn_copts() + [ + "-Wctor-dtor-privacy", + "-Weffc++", + "-Woverloaded-virtual", + "-Wsign-promo", + ] + +# GCC/CLANG compiler options +def kai_gcc_std_copts(): + return ["-std=c99"] + kai_gcc_warn_copts() + +# GCC/CLANG compiler options +def kai_gcc_std_cxxopts(): + return ["-std=c++17"] + kai_gcc_warn_cxxopts() + +def kai_cpu_select(cpu_uarch): + if len(cpu_uarch) == 0: + return "armv8-a" + else: + return "armv8.2-a" + cpu_uarch + +def kai_cpu_i8mm(): + return "+i8mm" + +def kai_cpu_dotprod(): + return "+dotprod" + +def kai_cpu_bf16(): + return "+bf16" + +def kai_cpu_fp16(): + return "+fp16" + +def kai_cpu_neon(): + return "+simd" + +def kai_cpu_sme(): + return "+sve+sve2" + +def kai_cpu_sme2(): + return "+sve+sve2" + +def kai_cpu_scalar(): + return "" + +# MSVC compiler options +def kai_msvc_std_copts(): + return ["/Wall"] + +def kai_msvc_std_cxxopts(): + return ["/Wall"] + +def kai_copts(ua_variant): + return select({ + "//:windows": kai_msvc_std_copts(), + # Assume default to use GCC/CLANG compilers. This is a fallback case to make it + # easier for KleidiAI library users + "//conditions:default": kai_gcc_std_copts() + ["-march=" + kai_cpu_select(ua_variant)], + }) + +def kai_cxxopts(ua_variant): + return select({ + "//:windows": kai_msvc_std_cxxopts(), + # Assume default to use GCC/CLANG compilers. This is a fallback case to make it + # easier for KleidiAI library users + "//conditions:default": kai_gcc_std_cxxopts() + ["-march=" + kai_cpu_select(ua_variant)], + }) + +def kai_c_library(name, **kwargs): + native.cc_library( + name = name, + srcs = kwargs.get("srcs", []), + hdrs = kwargs.get("hdrs", []), + deps = ["//:common"] + kwargs.get("deps", []), + visibility = kwargs.get("visibility", None), + copts = kwargs.get("copts", []) + kai_copts(kwargs.get("cpu_uarch", kai_cpu_scalar())), + linkstatic = kwargs.get("linkstatic", True), + ) + +def kai_cxx_library(name, **kwargs): + native.cc_library( + name = name, + srcs = kwargs.get("srcs", []), + hdrs = kwargs.get("hdrs", []), + deps = ["//:common"] + kwargs.get("deps", []), + visibility = kwargs.get("visibility", None), + copts = kwargs.get("copts", []) + kai_cxxopts(kwargs.get("cpu_uarch", kai_cpu_scalar())), + linkstatic = kwargs.get("linkstatic", True), + ) diff --git a/test/BUILD.bazel b/test/BUILD.bazel new file mode 100644 index 0000000000000000000000000000000000000000..85d14acdb37f7eacbabb0464bd5d274a440f04fb --- /dev/null +++ b/test/BUILD.bazel @@ -0,0 +1,70 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +load( + "//:kai_defs.bzl", + "kai_cpu_bf16", + "kai_cpu_fp16", + "kai_cpu_sme", + "kai_cxx_library", + "kai_cxxopts", +) + +package(default_testonly = 1) + +kai_cxx_library( + name = "common", + srcs = glob( + ["common/*.cpp"], + exclude = ["common/sme.cpp"], + ), + hdrs = glob( + ["common/*.hpp"], + exclude = ["common/sme.hpp"], + ), + # compare.cpp requires fp16 and bf16 support + cpu_uarch = kai_cpu_bf16() + kai_cpu_fp16(), +) + +kai_cxx_library( + name = "common_sme", + srcs = ["common/sme.cpp"], + hdrs = ["common/sme.hpp"], + # compare.cpp requires fp16 and bf16 support + cpu_uarch = kai_cpu_sme(), + deps = [ + ":common", + ], +) + +kai_cxx_library( + name = "reference", + srcs = glob(["reference/*.cpp"]), + hdrs = glob(["reference/*.hpp"]), + cpu_uarch = kai_cpu_bf16() + kai_cpu_fp16(), + deps = [ + ":common", + ], +) + +cc_test( + name = "kleidiai_test", + srcs = [ + "tests/matmul_test.cpp", + ], + copts = kai_cxxopts(kai_cpu_bf16() + kai_cpu_fp16()), + includes = [], + linkstatic = True, + visibility = ["//visibility:public"], + deps = [ + ":common", + ":common_sme", + ":reference", + "//:common", + "//:kleidiai", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/test/common/data_format.cpp b/test/common/data_format.cpp index 296994f5c27e52ae3c7a8b9717c9a952b1e1b48a..b4da18ae95cc27fad3d32f2c445daaa5eace11f3 100644 --- a/test/common/data_format.cpp +++ b/test/common/data_format.cpp @@ -11,7 +11,7 @@ #include "kai/kai_common.h" #include "test/common/data_type.hpp" -#include "test/reference/round.hpp" +#include "test/common/round.hpp" namespace kai::test { diff --git a/test/common/matrix_portion.cpp b/test/common/matrix_portion.cpp index 6b975ec7c5a6da0fb78c0fa530609ad15e451de1..5a29446bfb53e725417abe271a2985f709f71cef 100644 --- a/test/common/matrix_portion.cpp +++ b/test/common/matrix_portion.cpp @@ -11,7 +11,7 @@ #include "kai/kai_common.h" #include "test/common/rect.hpp" -#include "test/reference/round.hpp" +#include "test/common/round.hpp" namespace kai::test { diff --git a/test/reference/round.cpp b/test/common/round.cpp similarity index 95% rename from test/reference/round.cpp rename to test/common/round.cpp index 52a2c5575e48508af4a8410250a8e5fdac028ba5..bcb1cb372bbad9ae9334c67fd5a5ee0d6aa56abc 100644 --- a/test/reference/round.cpp +++ b/test/common/round.cpp @@ -4,7 +4,7 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "test/reference/round.hpp" +#include "test/common/round.hpp" #include #include diff --git a/test/reference/round.hpp b/test/common/round.hpp similarity index 100% rename from test/reference/round.hpp rename to test/common/round.hpp diff --git a/test/reference/pack.cpp b/test/reference/pack.cpp index fb0f81b19b563dbb26c700aef274b4d6568ba4cc..721157795eea4f6b40345f1f206446647eddaf4c 100644 --- a/test/reference/pack.cpp +++ b/test/reference/pack.cpp @@ -17,8 +17,8 @@ #include "kai/kai_common.h" #include "test/common/data_format.hpp" #include "test/common/data_type.hpp" +#include "test/common/round.hpp" #include "test/reference/quantize.hpp" -#include "test/reference/round.hpp" namespace kai::test { diff --git a/test/reference/quantize.cpp b/test/reference/quantize.cpp index 213328d5065cf4b1c24f20f94fb7b794bae908d1..144e7a4fab8fa96c49aeeb1f1d8d2631382c1f06 100644 --- a/test/reference/quantize.cpp +++ b/test/reference/quantize.cpp @@ -18,8 +18,8 @@ #include "test/common/int4.hpp" #include "test/common/memory.hpp" #include "test/common/numeric_limits.hpp" +#include "test/common/round.hpp" #include "test/common/type_traits.hpp" -#include "test/reference/round.hpp" namespace kai::test {