-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathbuild.rs
293 lines (252 loc) · 8.94 KB
/
build.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
#[cfg(feature = "cuda")]
use std::fs;
#[cfg(any(feature = "vulkan", feature = "mps", feature = "cuda"))]
use std::{env, path::PathBuf, process::Command};
#[cfg(feature = "cuda")]
fn find_cuda_path() -> String {
// Linux
if let Ok(output) = Command::new("which").arg("nvcc").output() {
if let Ok(path) = String::from_utf8(output.stdout) {
if let Some(cuda_path) = path.trim().strip_suffix("/bin/nvcc") {
return cuda_path.to_string();
}
}
}
// Windows
for path in &[
"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA",
"C:/CUDA",
] {
if PathBuf::from(path).exists() {
return path.to_string();
}
}
"/usr/local/cuda".to_string()
}
#[cfg(feature = "vulkan")]
fn compile_vulkan_shaders() -> std::io::Result<()> {
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("Failed to get OUT_DIR"));
let shader_dir = PathBuf::from("shaders/vulkan");
// Create shader directory if it doesn't exist
std::fs::create_dir_all(&shader_dir)?;
// Compile and copy reduction shader
println!("cargo:rerun-if-changed=shaders/vulkan/reduction.comp");
let reduction_out = out_dir.join("reduction.spv");
let reduction_final = shader_dir.join("reduction.spv");
let status = Command::new("glslc")
.args([
"--target-env=vulkan1.0",
"-O",
"-g",
"shaders/vulkan/reduction.comp",
"-o",
reduction_out.to_str().expect("Invalid path"),
])
.status()?;
if !status.success() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to compile reduction shader",
));
}
// Copy reduction shader to final location
std::fs::copy(&reduction_out, &reduction_final)?;
// Compile and copy binary operations shader
println!("cargo:rerun-if-changed=shaders/vulkan/binary_ops.comp");
let binary_ops_out = out_dir.join("binary_ops.spv");
let binary_ops_final = shader_dir.join("binary_ops.spv");
let status = Command::new("glslc")
.args([
"--target-env=vulkan1.0",
"-O",
"-g",
"shaders/vulkan/binary_ops.comp",
"-o",
binary_ops_out.to_str().expect("Invalid path"),
])
.status()?;
if !status.success() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to compile binary operations shader",
));
}
// Copy binary ops shader to final location
std::fs::copy(&binary_ops_out, &binary_ops_final)?;
// Compile and copy matrix multiplication shader
println!("cargo:rerun-if-changed=shaders/vulkan/matmul.comp");
let matmul_out = out_dir.join("matmul.spv");
let matmul_final = shader_dir.join("matmul.spv");
let status = Command::new("glslc")
.args([
"--target-env=vulkan1.0",
"-O",
"-g",
"shaders/vulkan/matmul.comp",
"-o",
matmul_out.to_str().expect("Invalid path"),
])
.status()?;
if !status.success() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to compile matrix multiplication shader",
));
}
// Copy matmul shader to final location
std::fs::copy(&matmul_out, &matmul_final)?;
println!("Successfully compiled and copied Vulkan shaders");
Ok(())
}
#[cfg(all(feature = "mps", target_os = "macos"))]
fn compile_metal_shaders() -> std::io::Result<()> {
// let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let out_dir = PathBuf::from("shaders/metal");
let shader_dir = PathBuf::from("shaders/metal");
if !shader_dir.exists() {
return Ok(()); // Skip if metal shaders directory doesn't exist
}
// Create output directory if it doesn't exist
std::fs::create_dir_all(&out_dir)?;
// Get .metal extension files in the shaders directory
let mut shader_files: Vec<Box<str>> = Vec::new();
shader_dir.read_dir().unwrap().for_each(|entry| {
let entry = entry.unwrap();
let path = entry.path();
if path.is_file() {
if let Some(ext) = path.extension() {
if ext == "metal" {
shader_files.push(path.file_name().unwrap().to_string_lossy().into());
}
}
}
});
// Build .air files
for shader in shader_files.iter() {
let shader_path = shader_dir.join(shader.as_ref());
if !shader_path.exists() {
continue; // Skip if shader file doesn't exist
}
// Compile .metal to .air
let status = Command::new("xcrun")
.args([
"-sdk",
"macosx",
"metal",
"-c",
shader_path.to_str().unwrap(),
"-o",
out_dir
.join(format!("{}.air", shader.replace(".metal", "")))
.to_str()
.unwrap(),
])
.status()?;
if !status.success() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to compile {}", shader),
));
}
}
// Link .air files into metallib
let air_files: Vec<String> = shader_files
.iter()
.map(|f| {
out_dir
.join(format!("{}.air", f.replace(".metal", "")))
.to_str()
.unwrap()
.to_string()
})
.collect();
let status = Command::new("xcrun")
.args([
"-sdk",
"macosx",
"metallib",
"-o",
out_dir.join("shaders.metallib").to_str().unwrap(),
])
.args(&air_files)
.status()?;
if status.success() {
// Delete .air files
for air_file in air_files.iter() {
std::fs::remove_file(air_file)?;
}
Ok(())
} else {
Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to create metallib",
))
}
}
fn main() {
#[cfg(feature = "cuda")]
{
println!("cargo:rerun-if-changed=cuda/");
println!("cargo:rerun-if-changed=cuda-headers/");
println!("cargo:rerun-if-changed=CMakeLists.txt");
let cuda_path = find_cuda_path();
let clangd_path = PathBuf::from(".clangd");
if !clangd_path.exists() {
let clangd_content = format!(
r#"CompileFlags:
Remove:
- "-forward-unknown-to-host-compiler"
- "-rdc=*"
- "-Xcompiler*"
- "--options-file"
- "--generate-code*"
Add:
- "-xcuda"
- "-std=c++14"
- "-I{}/include"
- "-I../../cuda-headers"
- "--cuda-gpu-arch=sm_75"
Compiler: clang
Index:
Background: Build
Diagnostics:
UnusedIncludes: None"#,
cuda_path
);
fs::write(".clangd", clangd_content).expect("Failed to write .clangd file");
}
let dst = cmake::Config::new(".")
.define("CMAKE_BUILD_TYPE", "Release")
.define("CUDA_PATH", cuda_path.clone())
.no_build_target(true)
.build();
// Search paths - include both lib and lib64
println!("cargo:rustc-link-search={}/build/lib", dst.display());
println!("cargo:rustc-link-search={}/build", dst.display());
println!("cargo:rustc-link-search=native={}/lib", dst.display());
println!("cargo:rustc-link-search=native={}/lib64", cuda_path.clone());
println!("cargo:rustc-link-search=native={}/lib", cuda_path.clone());
// CUDA runtime linking - only essential libraries
println!("cargo:rustc-link-lib=cudart");
println!("cargo:rustc-link-lib=cuda");
// Static libraries - if they exist
if PathBuf::from(format!("{}/build/lib/libnn_ops.a", dst.display())).exists() {
println!("cargo:rustc-link-arg=-Wl,--whole-archive");
println!("cargo:rustc-link-lib=static=nn_ops");
println!("cargo:rustc-link-lib=static=tensor_ops");
println!("cargo:rustc-link-arg=-Wl,--no-whole-archive");
}
}
// Compile Vulkan shaders only if the "vulkan" feature is enabled
#[cfg(feature = "vulkan")]
{
println!("cargo:rerun-if-changed=shaders/vulkan/");
compile_vulkan_shaders().expect("Failed to compile Vulkan shaders");
}
// Compile Metal shaders only if the "metal" feature is enabled and on macOS
#[cfg(all(feature = "mps", target_os = "macos"))]
{
println!("cargo:rerun-if-changed=shaders/metal/");
compile_metal_shaders().expect("Failed to compile Metal shaders");
}
}