diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index dfc252d4..f5dfa640 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -219,6 +219,7 @@ test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]); test_ptx!(activemask, [0u32], [1u32]); test_ptx!(membar, [152731u32], [152731u32]); test_ptx!(shared_unify_extern, [7681u64, 7682u64], [15363u64]); +test_ptx!(shared_unify_local, [16752u64, 714u64], [17466u64]); test_ptx!(assertfail); test_ptx!(func_ptr); diff --git a/ptx/src/test/spirv_run/shared_unify_local.ptx b/ptx/src/test/spirv_run/shared_unify_local.ptx new file mode 100644 index 00000000..84f3a50f --- /dev/null +++ b/ptx/src/test/spirv_run/shared_unify_local.ptx @@ -0,0 +1,43 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.extern .shared .b32 shared_ex[]; + +.func (.reg .b64 out) add(.reg .u64 temp2) +{ + .shared .align 4 .u64 shared_mod; + .reg .u64 temp1; + st.shared.u64 [shared_mod], temp2; + ld.shared.u64 temp1, [shared_mod]; + ld.shared.u64 temp2, [shared_ex]; + add.u64 out, temp2, temp1; + ret; +} + +.func (.reg .b64 out) set_shared_temp1(.reg .b64 temp1, .reg .u64 temp2) +{ + st.shared.u64 [shared_ex], temp1; + call (out), add, (temp2); + ret; +} + +.visible .entry shared_unify_local( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp1; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.global.u64 temp1, [in_addr]; + ld.global.u64 temp2, [in_addr+8]; + call (temp2), set_shared_temp1, (temp1, temp2); + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/shared_unify_local.spvtxt b/ptx/src/test/spirv_run/shared_unify_local.spvtxt new file mode 100644 index 00000000..dc00c2f6 --- /dev/null +++ b/ptx/src/test/spirv_run/shared_unify_local.spvtxt @@ -0,0 +1,117 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %64 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %31 "shared_unify_local" %1 %5 + OpExecutionMode %31 ContractionOff + OpDecorate %5 Alignment 4 + %void = OpTypeVoid + %uint = OpTypeInt 32 0 +%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint + %1 = OpVariable %_ptr_Workgroup_uint Workgroup + %ulong = OpTypeInt 64 0 +%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong + %5 = OpVariable %_ptr_Workgroup_ulong Workgroup + %70 = OpTypeFunction %ulong %ulong %_ptr_Workgroup_uint %_ptr_Workgroup_ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %72 = OpTypeFunction %ulong %ulong %ulong %_ptr_Workgroup_uint %_ptr_Workgroup_ulong + %73 = OpTypeFunction %void %ulong %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %ulong_8 = OpConstant %ulong 8 + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %2 = OpFunction %ulong None %70 + %7 = OpFunctionParameter %ulong + %60 = OpFunctionParameter %_ptr_Workgroup_uint + %61 = OpFunctionParameter %_ptr_Workgroup_ulong + %17 = OpLabel + %4 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + OpStore %4 %7 + %8 = OpLoad %ulong %4 + OpStore %61 %8 Aligned 8 + %9 = OpLoad %ulong %61 Aligned 8 + OpStore %6 %9 + %15 = OpBitcast %_ptr_Workgroup_ulong %60 + %10 = OpLoad %ulong %15 Aligned 8 + OpStore %4 %10 + %12 = OpLoad %ulong %4 + %13 = OpLoad %ulong %6 + %16 = OpIAdd %ulong %12 %13 + %11 = OpCopyObject %ulong %16 + OpStore %3 %11 + %14 = OpLoad %ulong %3 + OpReturnValue %14 + OpFunctionEnd + %18 = OpFunction %ulong None %72 + %22 = OpFunctionParameter %ulong + %23 = OpFunctionParameter %ulong + %62 = OpFunctionParameter %_ptr_Workgroup_uint + %63 = OpFunctionParameter %_ptr_Workgroup_ulong + %30 = OpLabel + %20 = OpVariable %_ptr_Function_ulong Function + %21 = OpVariable %_ptr_Function_ulong Function + %19 = OpVariable %_ptr_Function_ulong Function + OpStore %20 %22 + OpStore %21 %23 + %24 = OpLoad %ulong %20 + %28 = OpBitcast %_ptr_Workgroup_ulong %62 + %29 = OpCopyObject %ulong %24 + OpStore %28 %29 Aligned 8 + %26 = OpLoad %ulong %21 + %25 = OpFunctionCall %ulong %2 %26 %62 %63 + OpStore %19 %25 + %27 = OpLoad %ulong %19 + OpReturnValue %27 + OpFunctionEnd + %31 = OpFunction %void None %73 + %38 = OpFunctionParameter %ulong + %39 = OpFunctionParameter %ulong + %58 = OpLabel + %32 = OpVariable %_ptr_Function_ulong Function + %33 = OpVariable %_ptr_Function_ulong Function + %34 = OpVariable %_ptr_Function_ulong Function + %35 = OpVariable %_ptr_Function_ulong Function + %36 = OpVariable %_ptr_Function_ulong Function + %37 = OpVariable %_ptr_Function_ulong Function + OpStore %32 %38 + OpStore %33 %39 + %40 = OpLoad %ulong %32 Aligned 8 + OpStore %34 %40 + %41 = OpLoad %ulong %33 Aligned 8 + OpStore %35 %41 + %43 = OpLoad %ulong %34 + %53 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %43 + %42 = OpLoad %ulong %53 Aligned 8 + OpStore %36 %42 + %45 = OpLoad %ulong %34 + %54 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %45 + %77 = OpBitcast %_ptr_CrossWorkgroup_uchar %54 + %78 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %77 %ulong_8 + %52 = OpBitcast %_ptr_CrossWorkgroup_ulong %78 + %44 = OpLoad %ulong %52 Aligned 8 + OpStore %37 %44 + %47 = OpLoad %ulong %36 + %48 = OpLoad %ulong %37 + %56 = OpCopyObject %ulong %47 + %55 = OpFunctionCall %ulong %18 %56 %48 %1 %5 + %46 = OpCopyObject %ulong %55 + OpStore %37 %46 + %49 = OpLoad %ulong %35 + %50 = OpLoad %ulong %37 + %57 = OpConvertUToPtr %_ptr_Generic_ulong %49 + OpStore %57 %50 Aligned 8 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 165997e1..db1063b6 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -789,6 +789,14 @@ impl<'input> MethodsCallMap<'input> { }) } + fn methods( + &self, + ) -> impl Iterator, &HashSet)> { + self.map + .iter() + .map(|(method, children)| (*method, children)) + } + fn visit_callees( &self, method: ast::MethodName<'input, spirv::Word>, @@ -1102,18 +1110,23 @@ fn resolve_indirect_uses_of_globals_shared<'input>( kernels_methods_call_map: &MethodsCallMap<'input>, ) -> HashMap, BTreeSet> { let mut result = HashMap::new(); - for (method, direct_globals) in methods_use_of_globals_shared.iter() { - let mut indirect_globals = direct_globals.iter().copied().collect::>(); - kernels_methods_call_map.visit_callees(*method, |func| { + for (method, callees) in kernels_methods_call_map.methods() { + let mut indirect_globals = methods_use_of_globals_shared + .get(&method) + .into_iter() + .flatten() + .copied() + .collect::>(); + for &callee in callees { indirect_globals.extend( methods_use_of_globals_shared - .get(&ast::MethodName::Func(func)) + .get(&ast::MethodName::Func(callee)) .into_iter() .flatten() .copied(), ); - }); - result.insert(*method, indirect_globals); + } + result.insert(method, indirect_globals); } result } diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 9732ec9e..24fa88a8 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -54,20 +54,16 @@ impl SpirvModule { } pub(crate) fn load(module: *mut CUmodule, fname: *const i8) -> Result<(), hipError_t> { - let length = (0..) - .position(|i| unsafe { *fname.add(i) == 0 }) - .ok_or(hipError_t::hipErrorInvalidValue)?; - let file_name = CStr::from_bytes_with_nul(unsafe { slice::from_raw_parts(fname as _, length) }) - .map_err(|_| hipError_t::hipErrorInvalidValue)?; - let valid_file_name = file_name + let file_name = unsafe { CStr::from_ptr(fname) } .to_str() .map_err(|_| hipError_t::hipErrorInvalidValue)?; - let mut file = File::open(valid_file_name).map_err(|_| hipError_t::hipErrorFileNotFound)?; + let mut file = File::open(file_name).map_err(|_| hipError_t::hipErrorFileNotFound)?; let mut file_buffer = Vec::new(); file.read_to_end(&mut file_buffer) .map_err(|_| hipError_t::hipErrorUnknown)?; - drop(file); - load_data(module, file_buffer.as_ptr() as _) + let result = load_data(module, file_buffer.as_ptr() as _); + drop(file_buffer); + result } pub(crate) fn load_data( @@ -201,6 +197,8 @@ pub(crate) fn compile_amd<'a>( .arg("-nogpulib") .arg("-mno-wavefrontsize64") .arg("-O3") + .arg("-Xclang") + .arg("-O3") .arg("-Xlinker") .arg("--no-undefined") .arg("-target")