Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PTX parser rewrite #267

Merged
merged 47 commits into from
Sep 4, 2024
Merged
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
a05bee9
Start rewriting PTX parser
vosen Aug 14, 2024
8d7c88c
Fully parse operands
vosen Aug 15, 2024
dbd37f9
Clean up and improve ident parsing
vosen Aug 15, 2024
ba17906
Pass parser state to instruction callbacks
vosen Aug 15, 2024
0da45ea
Add parsing of st, allow associating type with a non-alternative modi…
vosen Aug 15, 2024
0112880
Parse ld, add, ret
vosen Aug 16, 2024
91dbbb3
Move all types to a separate module
vosen Aug 16, 2024
77de5c7
Parse simplest vector add kernel
vosen Aug 18, 2024
522541d
Support simple module variables
vosen Aug 18, 2024
cb64b04
Add mul
vosen Aug 18, 2024
c08e6a6
Implement setp
vosen Aug 19, 2024
22492ec
Implement not, or, and, bra
vosen Aug 19, 2024
34b0a67
Add types for call instruction
vosen Aug 20, 2024
c21c55d
Parse call instruction
vosen Aug 20, 2024
bc1074e
Add cvt
vosen Aug 20, 2024
47f8314
Add shr, shl
vosen Aug 20, 2024
588d66b
Add cvta
vosen Aug 20, 2024
6cd18bf
Add abs, mad
vosen Aug 21, 2024
798bbf0
Add fma and sub
vosen Aug 21, 2024
fc713f2
Add min, max
vosen Aug 21, 2024
c16bae3
Add rcp, sqrt, rsqrt
vosen Aug 21, 2024
39faaa7
Add atom and atom.cas
vosen Aug 21, 2024
0760c3d
Map remaining instructions
vosen Aug 21, 2024
71e0258
Rename new crates
vosen Aug 21, 2024
1ec1ca0
Attempt #2
vosen Aug 23, 2024
12ef8db
Port first pass
vosen Aug 23, 2024
7ea990e
Work on more passes
vosen Aug 23, 2024
69175d2
Add relaxed type check information to visitors
vosen Aug 24, 2024
4e6dc07
Implement third pass
vosen Aug 24, 2024
107f1eb
Port sreg fix pass
vosen Aug 26, 2024
3e0a15a
Add stateless-to-stateful conversion
vosen Aug 26, 2024
cccd37f
Port ssa conversion
vosen Aug 26, 2024
c088cc2
Port expand_arguments
vosen Aug 26, 2024
144f8bd
Port remaining two passes
vosen Aug 27, 2024
790fe18
Emit most of SPIR-V
vosen Aug 30, 2024
2e5ad8e
Wire new parser into spvtxt tests
vosen Aug 30, 2024
32b6262
Fix PtrAdd
vosen Aug 30, 2024
16fafe5
Parse comments and vector members correctly
vosen Aug 30, 2024
aebf06a
Improve implicit conversion and handling of vectors
vosen Aug 30, 2024
0c93393
Correctly report dst in call instructions
vosen Aug 31, 2024
8d15499
More fixes
vosen Sep 3, 2024
340ad86
Emit correct float add
vosen Sep 3, 2024
7a45b44
Fix more failing tests
vosen Sep 3, 2024
6a7c871
Fix array initializers
vosen Sep 3, 2024
3f31069
Allow ftz and saturated conversions
vosen Sep 3, 2024
aa98ab9
Fix all remaining problems
vosen Sep 3, 2024
061312c
Document wtf is going on with parsing macros
vosen Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Port sreg fix pass
vosen committed Aug 26, 2024
commit 107f1eb17f680dbdfccdebd3828b38b6ec0897aa
183 changes: 183 additions & 0 deletions ptx/src/pass/fix_special_registers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
use super::*;
use std::collections::HashMap;

fn run<'a, 'b, 'input>(
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
typed_statements: Vec<TypedStatement>,
numeric_id_defs: &'a mut NumericIdResolver<'b>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let result = Vec::with_capacity(typed_statements.len());
let mut sreg_sresolver = SpecialRegisterResolver {
ptx_impl_imports,
numeric_id_defs,
result,
};
for statement in typed_statements {
let statement = statement.visit_map(&mut sreg_sresolver)?;
sreg_sresolver.result.push(statement);
}
Ok(sreg_sresolver.result)
}

struct SpecialRegisterResolver<'a, 'b, 'input> {
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
numeric_id_defs: &'a mut NumericIdResolver<'b>,
result: Vec<TypedStatement>,
}

impl<'a, 'b, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
for SpecialRegisterResolver<'a, 'b, 'input>
{
fn visit(
&mut self,
operand: TypedOperand,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<TypedOperand, TranslateError> {
operand.map(|name, vector_index| self.replace_sreg(name, is_dst, vector_index))
}

fn visit_ident(
&mut self,
args: SpirvWord,
_type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
self.replace_sreg(args, is_dst, None)
}
}

impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
fn replace_sreg(
&mut self,
name: SpirvWord,
is_dst: bool,
vector_index: Option<u8>,
) -> Result<SpirvWord, TranslateError> {
if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) {
if is_dst {
return Err(TranslateError::MismatchedType);
}
let input_arguments = match (vector_index, sreg.get_function_input_type()) {
(Some(idx), Some(inp_type)) => {
if inp_type != ast::ScalarType::U8 {
return Err(TranslateError::Unreachable);
}
let constant = self.numeric_id_defs.register_intermediate(Some((
ast::Type::Scalar(inp_type),
ast::StateSpace::Reg,
)));
self.result.push(Statement::Constant(ConstantDefinition {
dst: constant,
typ: inp_type,
value: ast::ImmediateValue::U64(idx as u64),
}));
vec![(
TypedOperand::Reg(constant),
ast::Type::Scalar(inp_type),
ast::StateSpace::Reg,
)]
}
(None, None) => Vec::new(),
_ => return Err(TranslateError::MismatchedType),
};
let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
let return_type = sreg.get_function_return_type();
let fn_result = self.numeric_id_defs.register_intermediate(Some((
ast::Type::Scalar(return_type),
ast::StateSpace::Reg,
)));
let return_arguments = vec![(
fn_result,
ast::Type::Scalar(return_type),
ast::StateSpace::Reg,
)];
let fn_call = register_external_fn_call(
self.numeric_id_defs,
self.ptx_impl_imports,
ocl_fn_name.to_string(),
return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
)?;
let data = ast::CallDetails {
uniform: false,
return_arguments: return_arguments
.iter()
.map(|(_, typ, space)| (typ.clone(), *space))
.collect(),
input_arguments: input_arguments
.iter()
.map(|(_, typ, space)| (typ.clone(), *space))
.collect(),
};
let arguments = ast::CallArgs {
return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
func: fn_call,
input_arguments: input_arguments.iter().map(|(name, _, _)| *name).collect(),
};
self.result
.push(Statement::Instruction(ast::Instruction::Call {
data,
arguments,
}));
Ok(fn_result)
} else {
Ok(name)
}
}
}

fn register_external_fn_call<'a>(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
name: String,
return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Result<SpirvWord, TranslateError> {
match ptx_impl_imports.entry(name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.register_intermediate(None);
let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
let func_decl = ast::MethodDeclaration::<SpirvWord> {
return_arguments,
name: ast::MethodName::Func(fn_id),
input_arguments,
shared_mem: None,
};
let func = Function {
func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
linkage: ast::LinkingDirective::EXTERN,
};
entry.insert(Directive::Method(func));
Ok(fn_id)
}
hash_map::Entry::Occupied(entry) => match entry.get() {
Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
ast::MethodName::Func(fn_id) => Ok(fn_id),
ast::MethodName::Kernel(_) => Err(error_unreachable()),
},
_ => Err(error_unreachable()),
},
}
}

fn fn_arguments_to_variables<'a>(
id_defs: &mut NumericIdResolver,
args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Vec<ast::Variable<SpirvWord>> {
args.map(|(typ, space)| ast::Variable {
align: None,
v_type: typ.clone(),
state_space: space,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
})
.collect::<Vec<_>>()
}
247 changes: 245 additions & 2 deletions ptx/src/pass/mod.rs
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ use std::{
};

mod convert_to_typed;
mod fix_special_registers;
mod normalize_identifiers;
mod normalize_predicates;

@@ -735,6 +736,235 @@ enum Statement<I, P: ast::Operand> {
FunctionPointer(FunctionPointerDetails),
}

impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
fn visit_map<To: ast::Operand<Ident = SpirvWord>, Err>(
self,
visitor: &mut impl ast::VisitorMap<T, To, Err>,
) -> std::result::Result<Statement<ast::Instruction<To>, T>, Err> {
Ok(match self {
Statement::Instruction(i) => {
return ast::visit_map(i, visitor).map(Statement::Instruction)
}
Statement::Label(label) => {
Statement::Label(visitor.visit_ident(label, None, false, false)?)
}
Statement::Variable(var) => {
let name = visitor.visit_ident(
var.name,
Some((&var.v_type, var.state_space)),
true,
false,
)?;
Statement::Variable(ast::Variable {
align: var.align,
v_type: var.v_type,
state_space: var.state_space,
name,
array_init: var.array_init,
})
}
Statement::Conditional(conditional) => {
let predicate = visitor.visit_ident(conditional.predicate, None, false, false)?;
let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?;
let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?;
Statement::Conditional(BrachCondition {
predicate,
if_true,
if_false,
})
}
Statement::LoadVar(LoadVarDetails {
arg,
typ,
member_index,
}) => {
let dst = visitor.visit_ident(
arg.dst,
Some((&typ, ast::StateSpace::Reg)),
true,
false,
)?;
let src = visitor.visit_ident(
arg.src,
Some((&typ, ast::StateSpace::Local)),
false,
false,
)?;
Statement::LoadVar(LoadVarDetails {
arg: ast::LdArgs { dst, src },
typ,
member_index,
})
}
Statement::StoreVar(StoreVarDetails {
arg,
typ,
member_index,
}) => {
let src1 = visitor.visit_ident(
arg.src1,
Some((&typ, ast::StateSpace::Local)),
false,
false,
)?;
let src2 = visitor.visit_ident(
arg.src2,
Some((&typ, ast::StateSpace::Reg)),
false,
false,
)?;
Statement::StoreVar(StoreVarDetails {
arg: ast::StArgs { src1, src2 },
typ,
member_index,
})
}
Statement::Conversion(ImplicitConversion {
src,
dst,
from_type,
to_type,
from_space,
to_space,
kind,
}) => {
let dst = visitor.visit_ident(
dst,
Some((&to_type, ast::StateSpace::Reg)),
true,
false,
)?;
let src = visitor.visit_ident(
src,
Some((&from_type, ast::StateSpace::Reg)),
false,
false,
)?;
Statement::Conversion(ImplicitConversion {
src,
dst,
from_type,
to_type,
from_space,
to_space,
kind,
})
}
Statement::Constant(ConstantDefinition { dst, typ, value }) => {
let dst = visitor.visit_ident(
dst,
Some((&typ.into(), ast::StateSpace::Reg)),
true,
false,
)?;
Statement::Constant(ConstantDefinition { dst, typ, value })
}
Statement::RetValue(data, value) => {
// TODO:
// We should report type here
let value = visitor.visit_ident(value, None, false, false)?;
Statement::RetValue(data, value)
}
Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
offset_src,
}) => {
let dst =
visitor.visit_ident(dst, Some((&underlying_type, state_space)), true, false)?;
let ptr_src = visitor.visit_ident(
ptr_src,
Some((&underlying_type, state_space)),
false,
false,
)?;
Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
offset_src,
})
}
Statement::RepackVector(RepackVectorDetails {
is_extract,
typ,
packed,
unpacked,
relaxed_type_check,
}) => {
let (packed, unpacked) = if is_extract {
let unpacked = unpacked
.into_iter()
.map(|ident| {
visitor.visit_ident(
ident,
Some((&typ.into(), ast::StateSpace::Reg)),
true,
relaxed_type_check,
)
})
.collect::<Result<Vec<_>, _>>()?;
let packed = visitor.visit_ident(
packed,
Some((
&ast::Type::Vector(typ, unpacked.len() as u8),
ast::StateSpace::Reg,
)),
false,
false,
)?;
(packed, unpacked)
} else {
let packed = visitor.visit_ident(
packed,
Some((
&ast::Type::Vector(typ, unpacked.len() as u8),
ast::StateSpace::Reg,
)),
true,
false,
)?;
let unpacked = unpacked
.into_iter()
.map(|ident| {
visitor.visit_ident(
ident,
Some((&typ.into(), ast::StateSpace::Reg)),
false,
relaxed_type_check,
)
})
.collect::<Result<Vec<_>, _>>()?;
(packed, unpacked)
};
Statement::RepackVector(RepackVectorDetails {
is_extract,
typ,
packed,
unpacked,
relaxed_type_check,
})
}
Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => {
let dst = visitor.visit_ident(
dst,
Some((
&ast::Type::Scalar(ast::ScalarType::U64),
ast::StateSpace::Reg,
)),
true,
false,
)?;
let src = visitor.visit_ident(src, None, false, false)?;
Statement::FunctionPointer(FunctionPointerDetails { dst, src })
}
})
}
}

struct BrachCondition {
predicate: SpirvWord,
if_true: SpirvWord,
@@ -743,7 +973,6 @@ struct BrachCondition {
struct LoadVarDetails {
arg: ast::LdArgs<SpirvWord>,
typ: ast::Type,
state_space: ast::StateSpace,
// (index, vector_width)
// HACK ALERT
// For some reason IGC explodes when you try to load from builtin vectors
@@ -798,7 +1027,7 @@ struct RepackVectorDetails {
typ: ast::ScalarType,
packed: SpirvWord,
unpacked: Vec<SpirvWord>,
relaxed_type_check: bool
relaxed_type_check: bool,
}

struct FunctionPointerDetails {
@@ -876,6 +1105,20 @@ enum TypedOperand {
VecMember(SpirvWord, u8),
}

impl TypedOperand {
fn map<Err>(
self,
fn_: impl FnOnce(SpirvWord, Option<u8>) -> Result<SpirvWord, Err>,
) -> Result<Self, Err> {
Ok(match self {
TypedOperand::Reg(reg) => TypedOperand::Reg(fn_(reg, None)?),
TypedOperand::RegOffset(reg, off) => TypedOperand::RegOffset(fn_(reg, None)?, off),
TypedOperand::Imm(imm) => TypedOperand::Imm(imm),
TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx),
})
}
}

impl ast::Operand for TypedOperand {
type Ident = SpirvWord;