Skip to content

Commit

Permalink
fix memory allocation fail with FlushMemorySize + StridedBatched/Batc…
Browse files Browse the repository at this point in the history
…hed cases (#1881)

- multiply batch count size when calculating array size
  • Loading branch information
nakajee authored Feb 9, 2024
1 parent 8fdd9cd commit 1489d85
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions Tensile/Source/client/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* MIT License
*
* Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -503,11 +503,13 @@ size_t calculate_flush_count(size_t arg_fl
size_t cached_size = 0;

for(auto const& problem : problemFactory.problems())
cached_size = std::max(
cached_size,
problem.a().sizes()[0] * problem.a().sizes()[1] * problem.a().elementBytes()
+ problem.b().sizes()[0] * problem.b().sizes()[1] * problem.b().elementBytes()
+ problem.c().sizes()[0] * problem.c().sizes()[1] * problem.c().elementBytes());
{
size_t aSize = problem.a().elementBytes() * problem.a().totalLogicalElements();
size_t bSize = problem.b().elementBytes() * problem.b().totalLogicalElements();
size_t cSize = problem.c().elementBytes() * problem.c().totalLogicalElements();

cached_size = std::max(cached_size, aSize + bSize + cSize);
}

if(arg_flush_count != default_arg_flush_count
&& arg_flush_memory_size != default_arg_flush_memory_size)
Expand Down

0 comments on commit 1489d85

Please sign in to comment.