Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PTX parser rewrite #267

Merged
merged 47 commits into from
Sep 4, 2024
Merged
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
a05bee9
Start rewriting PTX parser
vosen Aug 14, 2024
8d7c88c
Fully parse operands
vosen Aug 15, 2024
dbd37f9
Clean up and improve ident parsing
vosen Aug 15, 2024
ba17906
Pass parser state to instruction callbacks
vosen Aug 15, 2024
0da45ea
Add parsing of st, allow associating type with a non-alternative modi…
vosen Aug 15, 2024
0112880
Parse ld, add, ret
vosen Aug 16, 2024
91dbbb3
Move all types to a separate module
vosen Aug 16, 2024
77de5c7
Parse simplest vector add kernel
vosen Aug 18, 2024
522541d
Support simple module variables
vosen Aug 18, 2024
cb64b04
Add mul
vosen Aug 18, 2024
c08e6a6
Implement setp
vosen Aug 19, 2024
22492ec
Implement not, or, and, bra
vosen Aug 19, 2024
34b0a67
Add types for call instruction
vosen Aug 20, 2024
c21c55d
Parse call instruction
vosen Aug 20, 2024
bc1074e
Add cvt
vosen Aug 20, 2024
47f8314
Add shr, shl
vosen Aug 20, 2024
588d66b
Add cvta
vosen Aug 20, 2024
6cd18bf
Add abs, mad
vosen Aug 21, 2024
798bbf0
Add fma and sub
vosen Aug 21, 2024
fc713f2
Add min, max
vosen Aug 21, 2024
c16bae3
Add rcp, sqrt, rsqrt
vosen Aug 21, 2024
39faaa7
Add atom and atom.cas
vosen Aug 21, 2024
0760c3d
Map remaining instructions
vosen Aug 21, 2024
71e0258
Rename new crates
vosen Aug 21, 2024
1ec1ca0
Attempt #2
vosen Aug 23, 2024
12ef8db
Port first pass
vosen Aug 23, 2024
7ea990e
Work on more passes
vosen Aug 23, 2024
69175d2
Add relaxed type check information to visitors
vosen Aug 24, 2024
4e6dc07
Implement third pass
vosen Aug 24, 2024
107f1eb
Port sreg fix pass
vosen Aug 26, 2024
3e0a15a
Add stateless-to-stateful conversion
vosen Aug 26, 2024
cccd37f
Port ssa conversion
vosen Aug 26, 2024
c088cc2
Port expand_arguments
vosen Aug 26, 2024
144f8bd
Port remaining two passes
vosen Aug 27, 2024
790fe18
Emit most of SPIR-V
vosen Aug 30, 2024
2e5ad8e
Wire new parser into spvtxt tests
vosen Aug 30, 2024
32b6262
Fix PtrAdd
vosen Aug 30, 2024
16fafe5
Parse comments and vector members correctly
vosen Aug 30, 2024
aebf06a
Improve implicit conversion and handling of vectors
vosen Aug 30, 2024
0c93393
Correctly report dst in call instructions
vosen Aug 31, 2024
8d15499
More fixes
vosen Sep 3, 2024
340ad86
Emit correct float add
vosen Sep 3, 2024
7a45b44
Fix more failing tests
vosen Sep 3, 2024
6a7c871
Fix array initializers
vosen Sep 3, 2024
3f31069
Allow ftz and saturated conversions
vosen Sep 3, 2024
aa98ab9
Fix all remaining problems
vosen Sep 3, 2024
061312c
Document wtf is going on with parsing macros
vosen Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add stateless-to-stateful conversion
vosen committed Aug 26, 2024
commit 3e0a15ac845679b9ecd4f12c8bc84cf16b77081c
535 changes: 535 additions & 0 deletions ptx/src/pass/convert_to_stateful_memory_access.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,535 @@
use super::*;
use ptx_parser as ast;
use std::{
collections::{BTreeSet, HashSet},
iter,
rc::Rc,
};

/*
Our goal here is to transform
.visible .entry foobar(.param .u64 input) {
.reg .b64 in_addr;
.reg .b64 in_addr2;
ld.param.u64 in_addr, [input];
cvta.to.global.u64 in_addr2, in_addr;
}
into:
.visible .entry foobar(.param .u8 input[]) {
.reg .u8 in_addr[];
.reg .u8 in_addr2[];
ld.param.u8[] in_addr, [input];
mov.u8[] in_addr2, in_addr;
}
or:
.visible .entry foobar(.reg .u8 input[]) {
.reg .u8 in_addr[];
.reg .u8 in_addr2[];
mov.u8[] in_addr, input;
mov.u8[] in_addr2, in_addr;
}
or:
.visible .entry foobar(.param ptr<u8, global> input) {
.reg ptr<u8, global> in_addr;
.reg ptr<u8, global> in_addr2;
ld.param.ptr<u8, global> in_addr, [input];
mov.ptr<u8, global> in_addr2, in_addr;
}
*/
// TODO: detect more patterns (mov, call via reg, call via param)
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
// TODO: propagate out of calls and into calls
pub(super) fn run<'a, 'input>(
func_args: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
func_body: Vec<TypedStatement>,
id_defs: &mut NumericIdResolver<'a>,
) -> Result<
(
Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
Vec<TypedStatement>,
),
TranslateError,
> {
let mut method_decl = func_args.borrow_mut();
if !matches!(method_decl.name, ast::MethodName::Kernel(..)) {
drop(method_decl);
return Ok((func_args, func_body));
}
if Rc::strong_count(&func_args) != 1 {
return Err(error_unreachable());
}
let func_args_64bit = (*method_decl)
.input_arguments
.iter()
.filter_map(|arg| match arg.v_type {
ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name),
_ => None,
})
.collect::<HashSet<_>>();
let mut stateful_markers = Vec::new();
let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Cvta {
data:
ast::CvtaDetails {
state_space: ast::StateSpace::Global,
direction: ast::CvtaDirection::GenericToExplicit,
},
arguments,
}) => {
if let (TypedOperand::Reg(dst), Some(src)) =
(arguments.dst, arguments.src.underlying_register())
{
if is_64_bit_integer(id_defs, src) && is_64_bit_integer(id_defs, dst) {
stateful_markers.push((dst, src));
}
}
}
Statement::Instruction(ast::Instruction::Ld {
data:
ast::LdDetails {
state_space: ast::StateSpace::Param,
typ: ast::Type::Scalar(ast::ScalarType::U64),
..
},
arguments,
})
| Statement::Instruction(ast::Instruction::Ld {
data:
ast::LdDetails {
state_space: ast::StateSpace::Param,
typ: ast::Type::Scalar(ast::ScalarType::S64),
..
},
arguments,
})
| Statement::Instruction(ast::Instruction::Ld {
data:
ast::LdDetails {
state_space: ast::StateSpace::Param,
typ: ast::Type::Scalar(ast::ScalarType::B64),
..
},
arguments,
}) => {
if let (TypedOperand::Reg(dst), Some(src)) =
(arguments.dst, arguments.src.underlying_register())
{
if func_args_64bit.contains(&src) {
multi_hash_map_append(&mut stateful_init_reg, dst, src);
}
}
}
_ => {}
}
}
if stateful_markers.len() == 0 {
drop(method_decl);
return Ok((func_args, func_body));
}
let mut func_args_ptr = HashSet::new();
let mut regs_ptr_current = HashSet::new();
for (dst, src) in stateful_markers {
if let Some(func_args) = stateful_init_reg.get(&src) {
for a in func_args {
func_args_ptr.insert(*a);
regs_ptr_current.insert(src);
regs_ptr_current.insert(dst);
}
}
}
// BTreeSet here to have a stable order of iteration,
// unfortunately our tests rely on it
let mut regs_ptr_seen = BTreeSet::new();
while regs_ptr_current.len() > 0 {
let mut regs_ptr_new = HashSet::new();
for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Add {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::U64,
saturate: false,
}),
arguments,
})
| Statement::Instruction(ast::Instruction::Add {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::S64,
saturate: false,
}),
arguments,
}) => {
// TODO: don't mark result of double pointer sub or double
// pointer add as ptr result
if let (TypedOperand::Reg(dst), Some(src1)) =
(arguments.dst, arguments.src1.underlying_register())
{
if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
regs_ptr_new.insert(dst);
}
} else if let (TypedOperand::Reg(dst), Some(src2)) =
(arguments.dst, arguments.src2.underlying_register())
{
if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
regs_ptr_new.insert(dst);
}
}
}

Statement::Instruction(ast::Instruction::Sub {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::U64,
saturate: false,
}),
arguments,
})
| Statement::Instruction(ast::Instruction::Sub {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::S64,
saturate: false,
}),
arguments,
}) => {
// TODO: don't mark result of double pointer sub or double
// pointer add as ptr result
if let (TypedOperand::Reg(dst), Some(src1)) =
(arguments.dst, arguments.src1.underlying_register())
{
if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
regs_ptr_new.insert(dst);
}
} else if let (TypedOperand::Reg(dst), Some(src2)) =
(arguments.dst, arguments.src2.underlying_register())
{
if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
regs_ptr_new.insert(dst);
}
}
}
_ => {}
}
}
for id in regs_ptr_current {
regs_ptr_seen.insert(id);
}
regs_ptr_current = regs_ptr_new;
}
drop(regs_ptr_current);
let mut remapped_ids = HashMap::new();
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
for reg in regs_ptr_seen {
let new_id = id_defs.register_variable(
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::StateSpace::Reg,
);
result.push(Statement::Variable(ast::Variable {
align: None,
name: new_id,
array_init: Vec::new(),
v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
state_space: ast::StateSpace::Reg,
}));
remapped_ids.insert(reg, new_id);
}
for arg in (*method_decl).input_arguments.iter_mut() {
if !func_args_ptr.contains(&arg.name) {
continue;
}
let new_id = id_defs.register_variable(
ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
ast::StateSpace::Param,
);
let old_name = arg.name;
arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
arg.name = new_id;
remapped_ids.insert(old_name, new_id);
}
for statement in func_body {
match statement {
l @ Statement::Label(_) => result.push(l),
c @ Statement::Conditional(_) => result.push(c),
c @ Statement::Constant(..) => result.push(c),
Statement::Variable(var) => {
if !remapped_ids.contains_key(&var.name) {
result.push(Statement::Variable(var));
}
}
Statement::Instruction(ast::Instruction::Add {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::U64,
saturate: false,
}),
arguments,
})
| Statement::Instruction(ast::Instruction::Add {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::S64,
saturate: false,
}),
arguments,
}) if is_add_ptr_direct(&remapped_ids, &arguments) => {
let (ptr, offset) = match arguments.src1.underlying_register() {
Some(src1) if remapped_ids.contains_key(&src1) => {
(remapped_ids.get(&src1).unwrap(), arguments.src2)
}
Some(src2) if remapped_ids.contains_key(&src2) => {
(remapped_ids.get(&src2).unwrap(), arguments.src1)
}
_ => return Err(error_unreachable()),
};
let dst = arguments.dst.unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: offset,
}))
}
Statement::Instruction(ast::Instruction::Sub {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::U64,
saturate: false,
}),
arguments,
})
| Statement::Instruction(ast::Instruction::Sub {
data:
ast::ArithDetails::Integer(ast::ArithInteger {
type_: ast::ScalarType::S64,
saturate: false,
}),
arguments,
}) if is_sub_ptr_direct(&remapped_ids, &arguments) => {
let (ptr, offset) = match arguments.src1.underlying_register() {
Some(ref src1) => (remapped_ids.get(src1).unwrap(), arguments.src2),
_ => return Err(error_unreachable()),
};
let offset_neg = id_defs.register_intermediate(Some((
ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
)));
result.push(Statement::Instruction(ast::Instruction::Neg {
data: ast::TypeFtz {
type_: ast::ScalarType::S64,
flush_to_zero: None,
},
arguments: ast::NegArgs {
src: offset,
dst: TypedOperand::Reg(offset_neg),
},
}));
let dst = arguments.dst.unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: TypedOperand::Reg(offset_neg),
}))
}
inst @ Statement::Instruction(_) => {
let mut post_statements = Vec::new();
let new_statement = inst.visit_map(&mut FnVisitor::new(
|operand, type_space, is_dst, relaxed_conversion| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&mut result,
&mut post_statements,
operand,
type_space,
is_dst,
relaxed_conversion,
)
},
))?;
result.push(new_statement);
result.extend(post_statements);
}
repack @ Statement::RepackVector(_) => {
let mut post_statements = Vec::new();
let new_statement = repack.visit_map(&mut FnVisitor::new(
|operand, type_space, is_dst, relaxed_conversion| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
&mut result,
&mut post_statements,
operand,
type_space,
is_dst,
relaxed_conversion,
)
},
))?;
result.push(new_statement);
result.extend(post_statements);
}
_ => return Err(error_unreachable()),
}
}
drop(method_decl);
Ok((func_args, result))
}

fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool {
match id_defs.get_typed(id) {
Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _))
| Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _))
| Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true,
_ => false,
}
}

fn multi_hash_map_append<
K: Eq + std::hash::Hash,
V,
Collection: std::iter::Extend<V> + std::default::Default,
>(
m: &mut HashMap<K, Collection>,
key: K,
value: V,
) {
match m.entry(key) {
hash_map::Entry::Occupied(mut entry) => {
entry.get_mut().extend(iter::once(value));
}
hash_map::Entry::Vacant(entry) => {
entry.insert(Default::default()).extend(iter::once(value));
}
}
}

fn is_add_ptr_direct(
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
arg: &ast::AddArgs<TypedOperand>,
) -> bool {
match arg.dst {
TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
return false
}
TypedOperand::Reg(dst) => {
if !remapped_ids.contains_key(&dst) {
return false;
}
if let Some(ref src1_reg) = arg.src1.underlying_register() {
if remapped_ids.contains_key(src1_reg) {
// don't trigger optimization when adding two pointers
if let Some(ref src2_reg) = arg.src2.underlying_register() {
return !remapped_ids.contains_key(src2_reg);
}
}
}
if let Some(ref src2_reg) = arg.src2.underlying_register() {
remapped_ids.contains_key(src2_reg)
} else {
false
}
}
}
}

fn is_sub_ptr_direct(
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
arg: &ast::SubArgs<TypedOperand>,
) -> bool {
match arg.dst {
TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
return false
}
TypedOperand::Reg(dst) => {
if !remapped_ids.contains_key(&dst) {
return false;
}
match arg.src1.underlying_register() {
Some(ref src1_reg) => {
if remapped_ids.contains_key(src1_reg) {
// don't trigger optimization when subtracting two pointers
arg.src2
.underlying_register()
.map_or(true, |ref src2_reg| !remapped_ids.contains_key(src2_reg))
} else {
false
}
}
None => false,
}
}
}
}

fn convert_to_stateful_memory_access_postprocess(
id_defs: &mut NumericIdResolver,
remapped_ids: &HashMap<SpirvWord, SpirvWord>,
result: &mut Vec<TypedStatement>,
post_statements: &mut Vec<TypedStatement>,
operand: TypedOperand,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_conversion: bool,
) -> Result<TypedOperand, TranslateError> {
operand.map(|operand, _| {
Ok(match remapped_ids.get(&operand) {
Some(new_id) => {
let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
// TODO: readd if required
if let Some(..) = type_space {
if relaxed_conversion {
return Ok(*new_id);
}
}
let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
let converting_id = id_defs
.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
let kind = if state_is_compatible(new_operand_space, ast::StateSpace::Reg) {
ConversionKind::Default
} else {
ConversionKind::PtrToPtr
};
if is_dst {
post_statements.push(Statement::Conversion(ImplicitConversion {
src: converting_id,
dst: *new_id,
from_type: old_operand_type,
from_space: old_operand_space,
to_type: new_operand_type,
to_space: new_operand_space,
kind,
}));
converting_id
} else {
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
from_type: new_operand_type,
from_space: new_operand_space,
to_type: old_operand_type,
to_space: old_operand_space,
kind,
}));
converting_id
}
}
None => operand,
})
})
}

fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
this == other
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
}
2 changes: 1 addition & 1 deletion ptx/src/pass/fix_special_registers.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::*;
use std::collections::HashMap;

fn run<'a, 'b, 'input>(
pub(super) fn run<'a, 'b, 'input>(
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
typed_statements: Vec<TypedStatement>,
numeric_id_defs: &'a mut NumericIdResolver<'b>,
90 changes: 85 additions & 5 deletions ptx/src/pass/mod.rs
Original file line number Diff line number Diff line change
@@ -5,9 +5,11 @@ use std::{
cell::RefCell,
collections::{hash_map, HashMap},
ffi::CString,
marker::PhantomData,
rc::Rc,
};

mod convert_to_stateful_memory_access;
mod convert_to_typed;
mod fix_special_registers;
mod normalize_identifiers;
@@ -169,12 +171,12 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
todo!()
/*
let typed_statements =
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
let (func_decl, typed_statements) =
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?;
todo!()
/*
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
@@ -1035,7 +1037,7 @@ struct FunctionPointerDetails {
src: SpirvWord,
}

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
struct SpirvWord(spirv::Word);

impl From<spirv::Word> for SpirvWord {
@@ -1117,6 +1119,20 @@ impl TypedOperand {
TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx),
})
}

fn underlying_register(&self) -> Option<SpirvWord> {
match self {
Self::Reg(r) | Self::RegOffset(r, _) | Self::VecMember(r, _) => Some(*r),
Self::Imm(_) => None,
}
}

fn unwrap_reg(&self) -> Result<SpirvWord, TranslateError> {
match self {
TypedOperand::Reg(reg) => Ok(*reg),
_ => Err(error_unreachable()),
}
}
}

impl ast::Operand for TypedOperand {
@@ -1126,3 +1142,67 @@ impl ast::Operand for TypedOperand {
TypedOperand::Reg(ident)
}
}

impl<Fn> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
for FnVisitor<TypedOperand, TypedOperand, TranslateError, Fn>
where
Fn: FnMut(
TypedOperand,
Option<(&ast::Type, ast::StateSpace)>,
bool,
bool,
) -> Result<TypedOperand, TranslateError>,
{
fn visit(
&mut self,
args: TypedOperand,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<TypedOperand, TranslateError> {
(self.fn_)(args, type_space, is_dst, relaxed_type_check)
}

fn visit_ident(
&mut self,
args: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
match (self.fn_)(
TypedOperand::Reg(args),
type_space,
is_dst,
relaxed_type_check,
)? {
TypedOperand::Reg(reg) => Ok(reg),
_ => Err(TranslateError::Unreachable),
}
}
}

struct FnVisitor<
T,
U,
Err,
Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result<U, Err>,
> {
fn_: Fn,
_marker: PhantomData<fn(T) -> Result<U, Err>>,
}

impl<
T,
U,
Err,
Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result<U, Err>,
> FnVisitor<T, U, Err, Fn>
{
fn new(fn_: Fn) -> Self {
Self {
fn_,
_marker: PhantomData,
}
}
}
20 changes: 5 additions & 15 deletions ptx/src/translate.rs
Original file line number Diff line number Diff line change
@@ -1608,17 +1608,13 @@ fn extract_globals<'input, 'b>(
for statement in sorted_statements {
match statement {
Statement::Variable(
var
@
ast::Variable {
var @ ast::Variable {
state_space: ast::StateSpace::Shared,
..
},
)
| Statement::Variable(
var
@
ast::Variable {
var @ ast::Variable {
state_space: ast::StateSpace::Global,
..
},
@@ -1660,9 +1656,7 @@ fn extract_globals<'input, 'b>(
)?);
}
Statement::Instruction(ast::Instruction::Atom(
details
@
ast::AtomDetails {
details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Inc,
@@ -1691,9 +1685,7 @@ fn extract_globals<'input, 'b>(
)?);
}
Statement::Instruction(ast::Instruction::Atom(
details
@
ast::AtomDetails {
details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Dec,
@@ -1722,9 +1714,7 @@ fn extract_globals<'input, 'b>(
)?);
}
Statement::Instruction(ast::Instruction::Atom(
details
@
ast::AtomDetails {
details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Float {
op: ast::AtomFloatOp::Add,
1 change: 1 addition & 0 deletions ptx_parser/src/ast.rs
Original file line number Diff line number Diff line change
@@ -760,6 +760,7 @@ pub enum Type {
Vector(ScalarType, u8),
// .param.b32 foo[4];
Array(ScalarType, Vec<u32>),
Pointer(ScalarType, StateSpace)
}

impl Type {