Skip to content

Commit

Permalink
Cleanup input and output visitors
Browse files Browse the repository at this point in the history
  • Loading branch information
natgavrilenko committed Jul 15, 2024
1 parent b5f1a77 commit 9a02956
Show file tree
Hide file tree
Showing 18 changed files with 354 additions and 347 deletions.
12 changes: 5 additions & 7 deletions dartagnan/src/main/antlr4/Spirv.g4
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ assertionList
;

assertion
: ModeHeader_LPar assertionValue ModeHeader_RPar
| ModeHeader_LPar assertion ModeHeader_RPar
| ModeHeader_AssertionNot assertion
| assertion ModeHeader_AssertionAnd assertion
| assertion ModeHeader_AssertionOr assertion
| assertionBasic
: ModeHeader_LPar assertion ModeHeader_RPar # assertionParenthesis
| ModeHeader_AssertionNot assertion # assertionNot
| assertion ModeHeader_AssertionAnd assertion # assertionAnd
| assertion ModeHeader_AssertionOr assertion # assertionOr
| assertionValue assertionCompare assertionValue # assertionBasic
;

assertionBasic : assertionValue assertionCompare assertionValue;
assertionCompare
: ModeHeader_EqualEqual
| ModeHeader_NotEqual
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

Expand All @@ -23,42 +24,66 @@ public class VisitorSpirv extends SpirvBaseVisitor<Program> {
private VisitorOpsConstant specConstantVisitor;
private ProgramBuilder builder;

static String parseOpName(SpirvParser.OpContext ctx) {
ParseTree innerCtx = ctx.getChild(0);
if ("Op".equals(innerCtx.getChild(0).getText())) {
return "Op" + innerCtx.getChild(1).getText();
}
if ("SpecConstantOp".equals(innerCtx.getChild(3).getText())) {
return "Op" + innerCtx.getChild(5).getText();
}
return "Op" + innerCtx.getChild(3).getText();
@Override
public Program visitSpv(SpirvParser.SpvContext ctx) {
this.builder = createBuilder(ctx);
this.initializeVisitors();
this.specConstantVisitor = getSpecConstantVisitor();
ctx.spvHeaders().accept(new VisitorSpirvInput(builder));
visitSpvInstructions(ctx.spvInstructions());
ctx.spvHeaders().accept(new VisitorSpirvOutput(builder));
return builder.build();
}

static boolean isSpecConstantOp(SpirvParser.OpContext ctx) {
ParseTree innerCtx = ctx.getChild(0);
if ("Op".equals(innerCtx.getChild(0).getText())) {
return false;
@Override
public Program visitSpvInstructions(SpirvParser.SpvInstructionsContext ctx) {
this.visitChildren(ctx);
return null;
}

@Override
public Program visitOp(SpirvParser.OpContext ctx) {
String name = parseOpName(ctx);
if (builder.getNextOps() != null) {
if (!builder.getNextOps().contains(name)) {
throw new ParsingException("Unexpected operation '%s'", name);
}
builder.clearNextOps();
}
return "SpecConstantOp".equals(innerCtx.getChild(3).getText());
SpirvBaseVisitor<?> visitor = visitors.get(name);
if (visitor == null) {
throw new ParsingException("Unsupported operation '%s'", name);
}
Object result = ctx.accept(visitor);
if (isSpecConstantOp(ctx)) {
if (result instanceof Register register) {
specConstantVisitor.visitOpSpecConstantOp(register);
} else {
throw new ParsingException(
"Illegal result type for OpSpecConstantOp '%s'", name);
}
}
return null;
}

private static Set<Class<?>> getChildVisitors() {
return Set.of(
VisitorOpsAnnotation.class,
VisitorOpsArithmetic.class,
VisitorOpsAtomic.class,
VisitorOpsBarrier.class,
VisitorOpsBits.class,
VisitorOpsConstant.class,
VisitorOpsControlFlow.class,
VisitorOpsDebug.class,
VisitorOpsExtension.class,
VisitorOpsFunction.class,
VisitorOpsLogical.class,
VisitorOpsMemory.class,
VisitorOpsSetting.class,
VisitorOpsType.class
);
private ProgramBuilder createBuilder(SpirvParser.SpvContext ctx) {
ThreadGrid grid = new ThreadGrid(1, 1, 1, 1);
boolean hasConfig = false;
for (SpirvParser.SpvHeaderContext header : ctx.spvHeaders().spvHeader()) {
SpirvParser.ConfigHeaderContext cfgCtx = header.configHeader();
if (cfgCtx != null) {
if (hasConfig) {
throw new ParsingException("Multiple config headers are not allowed");
}
hasConfig = true;
List<SpirvParser.LiteranHeaderUnsignedIntegerContext> literals = cfgCtx.literanHeaderUnsignedInteger();
int sg = Integer.parseInt(literals.get(0).getText());
int wg = Integer.parseInt(literals.get(1).getText());
int qf = Integer.parseInt(literals.get(2).getText());
grid = new ThreadGrid(sg, wg, qf, 1);
}
}
return new ProgramBuilder(grid);
}

private void initializeVisitors() {
Expand All @@ -85,6 +110,25 @@ private void initializeVisitors() {
}
}

String parseOpName(SpirvParser.OpContext ctx) {
ParseTree innerCtx = ctx.getChild(0);
if ("Op".equals(innerCtx.getChild(0).getText())) {
return "Op" + innerCtx.getChild(1).getText();
}
if ("SpecConstantOp".equals(innerCtx.getChild(3).getText())) {
return "Op" + innerCtx.getChild(5).getText();
}
return "Op" + innerCtx.getChild(3).getText();
}

boolean isSpecConstantOp(SpirvParser.OpContext ctx) {
ParseTree innerCtx = ctx.getChild(0);
if ("Op".equals(innerCtx.getChild(0).getText())) {
return false;
}
return "SpecConstantOp".equals(innerCtx.getChild(3).getText());
}

private VisitorOpsConstant getSpecConstantVisitor() {
return visitors.values().stream()
.filter(VisitorOpsConstant.class::isInstance)
Expand All @@ -94,77 +138,22 @@ private VisitorOpsConstant getSpecConstantVisitor() {
"Missing visitor " + VisitorOpsConstant.class.getSimpleName()));
}

@Override
public Program visitSpv(SpirvParser.SpvContext ctx) {
this.builder = createBuilder(ctx);
this.initializeVisitors();
this.specConstantVisitor = getSpecConstantVisitor();
visitSpvInstructions(ctx.spvInstructions());
visitSpvHeaders(ctx.spvHeaders());
return builder.build();
}

private ProgramBuilder createBuilder(SpirvParser.SpvContext ctx) {
ThreadGrid grid = new ThreadGrid(1, 1, 1, 1);
VisitorSpirvInput visitor = new VisitorSpirvInput();
boolean hasConfig = false;
for (SpirvParser.SpvHeaderContext header : ctx.spvHeaders().spvHeader()) {
if (header.inputHeader() != null && header.inputHeader().initList() != null) {
visitor.visitInitList(header.inputHeader().initList());
}
if (header.configHeader() != null) {
if (hasConfig) {
throw new ParsingException("Multiple config headers are not allowed");
}
hasConfig = true;
int threadAmount = Integer.parseInt(header.configHeader().literanHeaderUnsignedInteger().get(0).getText());
int subGroupAmount = Integer.parseInt(header.configHeader().literanHeaderUnsignedInteger().get(1).getText());
int workGroupAmount = Integer.parseInt(header.configHeader().literanHeaderUnsignedInteger().get(2).getText());
grid = new ThreadGrid(threadAmount, subGroupAmount, workGroupAmount, 1);
}
}
return new ProgramBuilder(grid, visitor.getInputs());
}

@Override
public Program visitSpvHeaders(SpirvParser.SpvHeadersContext ctx) {
VisitorSpirvOutput visitor = new VisitorSpirvOutput(builder);
for (SpirvParser.SpvHeaderContext header : ctx.spvHeader()) {
if (header.outputHeader() != null && header.outputHeader().assertionList() != null) {
visitor.visitAssertionList(header.outputHeader().assertionList());
}
}
return null;
}

@Override
public Program visitSpvInstructions(SpirvParser.SpvInstructionsContext ctx) {
this.visitChildren(ctx);
return null;
}

@Override
public Program visitOp(SpirvParser.OpContext ctx) {
String name = parseOpName(ctx);
if (builder.getNextOps() != null) {
if (!builder.getNextOps().contains(name)) {
throw new ParsingException("Unexpected operation '%s'", name);
}
builder.clearNextOps();
}
SpirvBaseVisitor<?> visitor = visitors.get(name);
if (visitor == null) {
throw new ParsingException("Unsupported operation '%s'", name);
}
Object result = ctx.accept(visitor);
if (isSpecConstantOp(ctx)) {
if (result instanceof Register register) {
specConstantVisitor.visitOpSpecConstantOp(register);
} else {
throw new ParsingException(
"Illegal result type for OpSpecConstantOp '%s'", name);
}
}
return null;
private Set<Class<?>> getChildVisitors() {
return Set.of(
VisitorOpsAnnotation.class,
VisitorOpsArithmetic.class,
VisitorOpsAtomic.class,
VisitorOpsBarrier.class,
VisitorOpsBits.class,
VisitorOpsConstant.class,
VisitorOpsControlFlow.class,
VisitorOpsDebug.class,
VisitorOpsExtension.class,
VisitorOpsFunction.class,
VisitorOpsLogical.class,
VisitorOpsMemory.class,
VisitorOpsSetting.class,
VisitorOpsType.class
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,33 @@
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.parsers.SpirvBaseVisitor;
import com.dat3m.dartagnan.parsers.SpirvParser;

import java.util.HashMap;
import java.util.Map;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ProgramBuilder;

public class VisitorSpirvInput extends SpirvBaseVisitor<Expression> {
private static final TypeFactory types = TypeFactory.getInstance();
private static final ExpressionFactory expressions = ExpressionFactory.getInstance();

private final Map<String, Expression> inputs = new HashMap<>();
private final ProgramBuilder builder;

public Map<String, Expression> getInputs() {
return inputs;
public VisitorSpirvInput(ProgramBuilder builder) {
this.builder = builder;
}

@Override
public Expression visitSpvHeaders(SpirvParser.SpvHeadersContext ctx) {
for (SpirvParser.SpvHeaderContext header : ctx.spvHeader()) {
if (header.inputHeader() != null && header.inputHeader().initList() != null) {
visitInitList(header.inputHeader().initList());
}
}
return null;
}

@Override
public Expression visitInit(SpirvParser.InitContext ctx) {
String varName = ctx.varName().getText();
Expression expr = visit(ctx.initValue());
addInput(varName, expr);
String id = ctx.varName().getText();
Expression value = visit(ctx.initValue());
builder.addInput(id, value);
return null;
}

Expand All @@ -45,11 +53,4 @@ public Expression visitInitCollectionValue(SpirvParser.InitCollectionValueContex
.map(this::visitInitValue)
.toList());
}

private void addInput(String name, Expression value) {
if (inputs.containsKey(name)) {
throw new ParsingException("Duplicated definition '%s'", name);
}
inputs.put(name, value);
}
}
Loading

0 comments on commit 9a02956

Please sign in to comment.