diff --git a/test/tests/matmul_test.cpp b/test/tests/matmul_test.cpp index c828f56dc89f2ccf2a9605dd6faf5e85e5df85eb..a09f87cc15156ecc6bfd3d82151519a2a55abc70 100644 --- a/test/tests/matmul_test.cpp +++ b/test/tests/matmul_test.cpp @@ -342,36 +342,10 @@ protected: method.dst_format.data_type(), // info.m, info.n, info.k, false, false); - float clamp_min = 0.0F; - float clamp_max = 0.0F; - constexpr float clamp_ratio = 0.8F; - - switch (method.dst_format.data_type()) { - case DataType::FP32: { - const auto [min_value, max_value] = - find_clamp_range(ref_dst.data(), info.m * info.n, clamp_ratio); - ref_dst = clamp(ref_dst.data(), info.m * info.n, min_value, max_value); - - clamp_min = min_value; - clamp_max = max_value; - - break; - } - - case DataType::FP16: { - const auto [min_value, max_value] = - find_clamp_range(ref_dst.data(), info.m * info.n, clamp_ratio); - ref_dst = clamp(ref_dst.data(), info.m * info.n, min_value, max_value); - - clamp_min = static_cast(min_value); - clamp_max = static_cast(max_value); - - break; - } - - default: - KAI_ERROR("Unsupported data type!"); - } + static constexpr float clamp_ratio = 0.8F; + const auto [clamp_min, clamp_max] = + find_clamp_range(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_ratio); + ref_dst = clamp(method.dst_format.data_type(), ref_dst.data(), info.m * info.n, clamp_min, clamp_max); auto& data = _data[data_id] = {}; data.lhs = std::move(lhs); @@ -705,18 +679,19 @@ INSTANTIATE_TEST_SUITE_P( testing::Combine( testing::ValuesIn(get_vecmul_methods()), testing::Values( - MatMulShape{1, 16, 16}, // - MatMulShape{1, 1, 20}, // - MatMulShape{1, 16, 32}, // - MatMulShape{1, 32, 17}, // - MatMulShape{1, 33, 23}, // - MatMulShape{1, 93, 56} // + MatMulShape{1, 16, 16}, // + MatMulShape{1, 1, 20}, // + MatMulShape{1, 16, 32}, // + MatMulShape{1, 32, 17}, // + MatMulShape{1, 33, 23}, // + MatMulShape{1, 1500, 20}, // + MatMulShape{1, 93, 56} // ), testing::Values( MatrixPortion(0, 0, 1, 1), // Full row. MatrixPortion(0, 0, 1, 0.5), // First half - MatrixPortion(0, .25, 1, 0.5), // mid row-section. - MatrixPortion(0, 0.75, 1, 1) // right row section + MatrixPortion(0, .4, 1, 0.3), // mid row-section. + MatrixPortion(0, 0.75, 1, .25) // right row section )), testing::PrintToStringParamName());