diff --git a/Tensile/Source/client/main.cpp b/Tensile/Source/client/main.cpp index 792db82b2..89393d9a5 100644 --- a/Tensile/Source/client/main.cpp +++ b/Tensile/Source/client/main.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -503,11 +503,13 @@ size_t calculate_flush_count(size_t arg_fl size_t cached_size = 0; for(auto const& problem : problemFactory.problems()) - cached_size = std::max( - cached_size, - problem.a().sizes()[0] * problem.a().sizes()[1] * problem.a().elementBytes() - + problem.b().sizes()[0] * problem.b().sizes()[1] * problem.b().elementBytes() - + problem.c().sizes()[0] * problem.c().sizes()[1] * problem.c().elementBytes()); + { + size_t aSize = problem.a().elementBytes() * problem.a().totalLogicalElements(); + size_t bSize = problem.b().elementBytes() * problem.b().totalLogicalElements(); + size_t cSize = problem.c().elementBytes() * problem.c().totalLogicalElements(); + + cached_size = std::max(cached_size, aSize + bSize + cSize); + } if(arg_flush_count != default_arg_flush_count && arg_flush_memory_size != default_arg_flush_memory_size)