-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkai_defs.bzl
208 lines (166 loc) · 6.08 KB
/
kai_defs.bzl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#
# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
#
"""Build definitions for KleidiAI"""
load(
"@bazel_tools//tools/build_defs/repo:utils.bzl",
"update_attrs",
"workspace_and_buildfile",
)
# Extra warnings for GCC/CLANG C/C++
def kai_gcc_warn_copts():
return [
"-Wall",
"-Wdisabled-optimization",
"-Wextra",
"-Wformat-security",
"-Wformat=2",
"-Winit-self",
"-Wstrict-overflow=2",
"-Wswitch-default",
"-Wno-vla",
"-Wcast-qual",
]
# GCC/CLANG C only warning options
def kai_gcc_warn_conlyopts():
return [
"-Wmissing-prototypes",
"-Wstrict-prototypes",
"-Wpedantic",
]
# GCC/CLANG C++ only warning options
def kai_gcc_warn_cxxopts():
return [
"-Wctor-dtor-privacy",
"-Weffc++",
"-Woverloaded-virtual",
"-Wsign-promo",
"-Wmissing-declarations",
]
# GCC/CLANG compiler options
def kai_gcc_std_copts():
return ["-std=c99"] + kai_gcc_warn_copts() + kai_gcc_warn_conlyopts()
# GCC/CLANG compiler options
def kai_gcc_std_cxxopts():
return ["-std=c++17"] + kai_gcc_warn_copts() + kai_gcc_warn_cxxopts()
def kai_cpu_select(cpu_uarch):
if len(cpu_uarch) == 0:
return "armv8-a"
else:
return "armv8.2-a+" + "+".join(cpu_uarch)
def kai_cpu_i8mm():
return ["i8mm"]
def kai_cpu_dotprod():
return ["dotprod"]
def kai_cpu_bf16():
return ["bf16"]
def kai_cpu_fp16():
return ["fp16"]
def kai_cpu_neon():
return []
def kai_cpu_sve():
return ["sve"]
def kai_cpu_sve2():
return ["sve2"]
def kai_cpu_sme():
return ["sme"]
def kai_cpu_sme2():
return ["sme2"]
# MSVC compiler options
def kai_msvc_std_copts():
return ["/Wall"]
def kai_msvc_std_cxxopts():
return ["/Wall"]
def kai_copts(ua_variant):
return select({
"//:windows": kai_msvc_std_copts(),
# Assume default to use GCC/CLANG compilers. This is a fallback case to make it
# easier for KleidiAI library users
"//conditions:default": kai_gcc_std_copts() + ["-march=" + kai_cpu_select(ua_variant)],
})
def kai_cxxopts(ua_variant):
return select({
"//:windows": kai_msvc_std_cxxopts(),
# Assume default to use GCC/CLANG compilers. This is a fallback case to make it
# easier for KleidiAI library users
"//conditions:default": kai_gcc_std_cxxopts() + ["-march=" + kai_cpu_select(ua_variant)],
})
def _kai_list_check(predicate, sub_list, super_list):
""" Allow to check of any or all elements of first list are in second one
Args:
predicate (function): predicate to check. For example 'all' or 'any'
sub_list (list): first list
super_list (list): second list
"""
return predicate([item in super_list for item in sub_list])
def _kai_c_cxx_common(name, copts_def_func, **kwargs):
"""Common C/C++ native cc_library wrapper with custom parameters and defaults
Args:
name (string): name of target library
copts_def_func (function): function to get C or C++ respective defaults
**kwargs (dict): other arguments like srcs, hdrs, deps
"""
# Convert CPU uarch to list of features
cpu_uarch = kwargs.get("cpu_uarch", kai_cpu_neon())
extra_copts = []
# Indicate if SME flags should be replaced since a toolchain may not support it
replace_sme_flags = _kai_list_check(any, kai_cpu_sme() + kai_cpu_sme2(), cpu_uarch)
if replace_sme_flags:
if _kai_list_check(all, kai_cpu_sme(), cpu_uarch):
for uarch in kai_cpu_sme():
cpu_uarch.remove(uarch)
if _kai_list_check(all, kai_cpu_sme2(), cpu_uarch):
for uarch in kai_cpu_sme2():
cpu_uarch.remove(uarch)
# Replace SME/SME2 with SVE+SVE2, but disable compiler vectorization
cpu_uarch.extend(kai_cpu_sve())
cpu_uarch.extend(kai_cpu_sve2())
extra_copts.append("-fno-tree-vectorize")
kwargs["copts"] = kwargs.get("copts", []) + copts_def_func(cpu_uarch) + extra_copts
kwargs["deps"] = ["//:common"] + kwargs.get("deps", [])
kwargs["linkstatic"] = kwargs.get("linkstatic", True)
# Remove custom cpu_uarch paramter before passing it to cc_library
if "cpu_uarch" in kwargs:
kwargs.pop("cpu_uarch")
native.cc_library(
name = name,
**kwargs
)
def kai_c_library(name, **kwargs):
"""C native cc_library wrapper with custom parameters and defaults
Args:
name (string): name of target library
**kwargs (dict): other arguments like srcs, hdrs, deps
"""
_kai_c_cxx_common(name, kai_copts, **kwargs)
def kai_cxx_library(name, **kwargs):
"""C++ native cc_library wrapper with custom parameters and defaults
Args:
name (string): name of target library
**kwargs (dict): other arguments like srcs, hdrs, deps
"""
_kai_c_cxx_common(name, kai_cxxopts, **kwargs)
def _kai_local_archive_impl(ctx):
"""Implementation of the kai_local_archive rule."""
ctx.extract(
ctx.attr.archive,
stripPrefix = ctx.attr.strip_prefix,
)
workspace_and_buildfile(ctx)
return update_attrs(ctx.attr, _kai_local_archive_attrs.keys(), {})
_kai_local_archive_attrs = {
"archive": attr.label(mandatory = True, allow_single_file = True, doc = "Path to local archive relative to workspace"),
"strip_prefix": attr.string(doc = "Strip prefix from archive internal content"),
"build_file": attr.label(allow_single_file = True, doc = "Name of BUILD file for extracted repository"),
"build_file_content": attr.string(doc = "Content of BUILD file for extracted repository"),
"workspace_file": attr.label(doc = "Name of WORKSPACE file for extracted repository"),
"workspace_file_content": attr.string(doc = "Content of WORKSPACE file for extracted repository"),
}
kai_local_archive = repository_rule(
implementation = _kai_local_archive_impl,
attrs = _kai_local_archive_attrs,
local = True,
doc = "Rule to use repository from compressed local archive",
)