Skip to content

Commit

Permalink
Fix shared munging pass and add fix cuModuleLoadData
Browse files Browse the repository at this point in the history
  • Loading branch information
vosen committed Sep 29, 2021
1 parent 0172dc5 commit 816365e
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 15 deletions.
1 change: 1 addition & 0 deletions ptx/src/test/spirv_run/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
test_ptx!(activemask, [0u32], [1u32]);
test_ptx!(membar, [152731u32], [152731u32]);
test_ptx!(shared_unify_extern, [7681u64, 7682u64], [15363u64]);
test_ptx!(shared_unify_local, [16752u64, 714u64], [17466u64]);

test_ptx!(assertfail);
test_ptx!(func_ptr);
Expand Down
43 changes: 43 additions & 0 deletions ptx/src/test/spirv_run/shared_unify_local.ptx
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
.version 6.5
.target sm_30
.address_size 64

.extern .shared .b32 shared_ex[];

.func (.reg .b64 out) add(.reg .u64 temp2)
{
.shared .align 4 .u64 shared_mod;
.reg .u64 temp1;
st.shared.u64 [shared_mod], temp2;
ld.shared.u64 temp1, [shared_mod];
ld.shared.u64 temp2, [shared_ex];
add.u64 out, temp2, temp1;
ret;
}

.func (.reg .b64 out) set_shared_temp1(.reg .b64 temp1, .reg .u64 temp2)
{
st.shared.u64 [shared_ex], temp1;
call (out), add, (temp2);
ret;
}

.visible .entry shared_unify_local(
.param .u64 input,
.param .u64 output
)
{
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .u64 temp1;
.reg .u64 temp2;

ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];

ld.global.u64 temp1, [in_addr];
ld.global.u64 temp2, [in_addr+8];
call (temp2), set_shared_temp1, (temp1, temp2);
st.u64 [out_addr], temp2;
ret;
}
117 changes: 117 additions & 0 deletions ptx/src/test/spirv_run/shared_unify_local.spvtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
OpCapability GenericPointer
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
OpCapability Int8
OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
OpCapability DenormFlushToZero
OpExtension "SPV_KHR_float_controls"
OpExtension "SPV_KHR_no_integer_wrap_decoration"
%64 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %31 "shared_unify_local" %1 %5
OpExecutionMode %31 ContractionOff
OpDecorate %5 Alignment 4
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
%1 = OpVariable %_ptr_Workgroup_uint Workgroup
%ulong = OpTypeInt 64 0
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
%5 = OpVariable %_ptr_Workgroup_ulong Workgroup
%70 = OpTypeFunction %ulong %ulong %_ptr_Workgroup_uint %_ptr_Workgroup_ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%72 = OpTypeFunction %ulong %ulong %ulong %_ptr_Workgroup_uint %_ptr_Workgroup_ulong
%73 = OpTypeFunction %void %ulong %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
%ulong_8 = OpConstant %ulong 8
%uchar = OpTypeInt 8 0
%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%2 = OpFunction %ulong None %70
%7 = OpFunctionParameter %ulong
%60 = OpFunctionParameter %_ptr_Workgroup_uint
%61 = OpFunctionParameter %_ptr_Workgroup_ulong
%17 = OpLabel
%4 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
OpStore %4 %7
%8 = OpLoad %ulong %4
OpStore %61 %8 Aligned 8
%9 = OpLoad %ulong %61 Aligned 8
OpStore %6 %9
%15 = OpBitcast %_ptr_Workgroup_ulong %60
%10 = OpLoad %ulong %15 Aligned 8
OpStore %4 %10
%12 = OpLoad %ulong %4
%13 = OpLoad %ulong %6
%16 = OpIAdd %ulong %12 %13
%11 = OpCopyObject %ulong %16
OpStore %3 %11
%14 = OpLoad %ulong %3
OpReturnValue %14
OpFunctionEnd
%18 = OpFunction %ulong None %72
%22 = OpFunctionParameter %ulong
%23 = OpFunctionParameter %ulong
%62 = OpFunctionParameter %_ptr_Workgroup_uint
%63 = OpFunctionParameter %_ptr_Workgroup_ulong
%30 = OpLabel
%20 = OpVariable %_ptr_Function_ulong Function
%21 = OpVariable %_ptr_Function_ulong Function
%19 = OpVariable %_ptr_Function_ulong Function
OpStore %20 %22
OpStore %21 %23
%24 = OpLoad %ulong %20
%28 = OpBitcast %_ptr_Workgroup_ulong %62
%29 = OpCopyObject %ulong %24
OpStore %28 %29 Aligned 8
%26 = OpLoad %ulong %21
%25 = OpFunctionCall %ulong %2 %26 %62 %63
OpStore %19 %25
%27 = OpLoad %ulong %19
OpReturnValue %27
OpFunctionEnd
%31 = OpFunction %void None %73
%38 = OpFunctionParameter %ulong
%39 = OpFunctionParameter %ulong
%58 = OpLabel
%32 = OpVariable %_ptr_Function_ulong Function
%33 = OpVariable %_ptr_Function_ulong Function
%34 = OpVariable %_ptr_Function_ulong Function
%35 = OpVariable %_ptr_Function_ulong Function
%36 = OpVariable %_ptr_Function_ulong Function
%37 = OpVariable %_ptr_Function_ulong Function
OpStore %32 %38
OpStore %33 %39
%40 = OpLoad %ulong %32 Aligned 8
OpStore %34 %40
%41 = OpLoad %ulong %33 Aligned 8
OpStore %35 %41
%43 = OpLoad %ulong %34
%53 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %43
%42 = OpLoad %ulong %53 Aligned 8
OpStore %36 %42
%45 = OpLoad %ulong %34
%54 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %45
%77 = OpBitcast %_ptr_CrossWorkgroup_uchar %54
%78 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %77 %ulong_8
%52 = OpBitcast %_ptr_CrossWorkgroup_ulong %78
%44 = OpLoad %ulong %52 Aligned 8
OpStore %37 %44
%47 = OpLoad %ulong %36
%48 = OpLoad %ulong %37
%56 = OpCopyObject %ulong %47
%55 = OpFunctionCall %ulong %18 %56 %48 %1 %5
%46 = OpCopyObject %ulong %55
OpStore %37 %46
%49 = OpLoad %ulong %35
%50 = OpLoad %ulong %37
%57 = OpConvertUToPtr %_ptr_Generic_ulong %49
OpStore %57 %50 Aligned 8
OpReturn
OpFunctionEnd
25 changes: 19 additions & 6 deletions ptx/src/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,14 @@ impl<'input> MethodsCallMap<'input> {
})
}

fn methods(
&self,
) -> impl Iterator<Item = (ast::MethodName<'input, spirv::Word>, &HashSet<spirv::Word>)> {
self.map
.iter()
.map(|(method, children)| (*method, children))
}

fn visit_callees(
&self,
method: ast::MethodName<'input, spirv::Word>,
Expand Down Expand Up @@ -1102,18 +1110,23 @@ fn resolve_indirect_uses_of_globals_shared<'input>(
kernels_methods_call_map: &MethodsCallMap<'input>,
) -> HashMap<ast::MethodName<'input, spirv::Word>, BTreeSet<spirv::Word>> {
let mut result = HashMap::new();
for (method, direct_globals) in methods_use_of_globals_shared.iter() {
let mut indirect_globals = direct_globals.iter().copied().collect::<BTreeSet<_>>();
kernels_methods_call_map.visit_callees(*method, |func| {
for (method, callees) in kernels_methods_call_map.methods() {
let mut indirect_globals = methods_use_of_globals_shared
.get(&method)
.into_iter()
.flatten()
.copied()
.collect::<BTreeSet<_>>();
for &callee in callees {
indirect_globals.extend(
methods_use_of_globals_shared
.get(&ast::MethodName::Func(func))
.get(&ast::MethodName::Func(callee))
.into_iter()
.flatten()
.copied(),
);
});
result.insert(*method, indirect_globals);
}
result.insert(method, indirect_globals);
}
result
}
Expand Down
16 changes: 7 additions & 9 deletions zluda/src/impl/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,16 @@ impl SpirvModule {
}

pub(crate) fn load(module: *mut CUmodule, fname: *const i8) -> Result<(), hipError_t> {
let length = (0..)
.position(|i| unsafe { *fname.add(i) == 0 })
.ok_or(hipError_t::hipErrorInvalidValue)?;
let file_name = CStr::from_bytes_with_nul(unsafe { slice::from_raw_parts(fname as _, length) })
.map_err(|_| hipError_t::hipErrorInvalidValue)?;
let valid_file_name = file_name
let file_name = unsafe { CStr::from_ptr(fname) }
.to_str()
.map_err(|_| hipError_t::hipErrorInvalidValue)?;
let mut file = File::open(valid_file_name).map_err(|_| hipError_t::hipErrorFileNotFound)?;
let mut file = File::open(file_name).map_err(|_| hipError_t::hipErrorFileNotFound)?;
let mut file_buffer = Vec::new();
file.read_to_end(&mut file_buffer)
.map_err(|_| hipError_t::hipErrorUnknown)?;
drop(file);
load_data(module, file_buffer.as_ptr() as _)
let result = load_data(module, file_buffer.as_ptr() as _);
drop(file_buffer);
result
}

pub(crate) fn load_data(
Expand Down Expand Up @@ -201,6 +197,8 @@ pub(crate) fn compile_amd<'a>(
.arg("-nogpulib")
.arg("-mno-wavefrontsize64")
.arg("-O3")
.arg("-Xclang")
.arg("-O3")
.arg("-Xlinker")
.arg("--no-undefined")
.arg("-target")
Expand Down

0 comments on commit 816365e

Please sign in to comment.