Skip to content

Commit

Permalink
Misc fix and update
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Dec 12, 2023
1 parent b94bdb2 commit 7624f81
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
20 changes: 20 additions & 0 deletions gemm/.clangd
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
CompileFlags:
Add:
- -std=c++17
- --no-cuda-version-check
- -Iexternal/com_github_nvidia_cutlass/include

---
If:
PathMatch: .*\.cuh?
CompileFlags:
Add:
- -xcu

---
If:
PathMatch: .*/include/(cutlass|cute)/.*
CompileFlags:
Add:
- -xcu
2 changes: 1 addition & 1 deletion gemm/cpu/matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void make_string_impl(std::ostringstream& oss, const T& head, Ts... tail) {
template <typename... Ts>
std::string make_string(Ts... args) {
std::ostringstream oss;
detail::make_string_impl(oss, args...);
::detail::make_string_impl(oss, args...);
return oss.str();
}

Expand Down
18 changes: 17 additions & 1 deletion gemm/cuda_cute/benchmark_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import matmul
from matmul import *


Expand Down Expand Up @@ -54,7 +56,7 @@ def __str__(self):


def run_benchmark(f, size, *, min_iterations=20, max_iterations=100) -> stat:
benchmark_seconds= (0.100 / (4096 ** 3)) * size ** 3
benchmark_seconds = (0.100 / (4096**3)) * size**3
a = np.zeros((size, size), dtype=np.float32)
b = np.zeros((size, size), dtype=np.float32)
c = np.zeros((size, size), dtype=np.float32)
Expand Down Expand Up @@ -141,3 +143,17 @@ def new_benchmark_plot(*dataframes, **kwargs):
chart = chart.configure_axis(labelFontSize=12, titleFontSize=14)
chart = chart.configure_legend(labelLimit=0)
return chart


if __name__ == "__main__":
available = list(filter(lambda name: name.startswith("matmul_") or name.startswith("launch_"), dir(matmul)))

print("Available tests:")
for name in available:
print(" ", name)

selected = sys.argv[1] if len(sys.argv) > 1 else available[0]
f = getattr(matmul, selected)
size = int(sys.argv[2]) if len(sys.argv) > 2 else 1024
print(f"Running {selected} for [{size}x{size}] * [{size}x{size}] ...")
print(run_benchmark(f, size))

0 comments on commit 7624f81

Please sign in to comment.