Skip to content

Commit

Permalink
Fixed spirv memory operands parser (#737)
Browse files Browse the repository at this point in the history
  • Loading branch information
natgavrilenko authored and hernan-poncedeleon committed Sep 16, 2024
1 parent 1de57bb commit 22d8a02
Show file tree
Hide file tree
Showing 59 changed files with 584 additions and 16,363 deletions.
18 changes: 10 additions & 8 deletions dartagnan/src/main/antlr4/Spirv.g4
Original file line number Diff line number Diff line change
Expand Up @@ -3230,14 +3230,16 @@ loopControl
| Unroll
;

memoryAccess
: AliasScopeINTELMask idRef
| Aligned literalInteger
| MakePointerAvailable idScope
| MakePointerAvailableKHR idScope
| MakePointerVisible idScope
| MakePointerVisibleKHR idScope
| NoAliasINTELMask idRef
memoryAccess : memoryAccessTag (Pipe memoryAccessTag)* literalInteger? idRef*;

memoryAccessTag
: AliasScopeINTELMask
| Aligned
| MakePointerAvailable
| MakePointerAvailableKHR
| MakePointerVisible
| MakePointerVisibleKHR
| NoAliasINTELMask
| NonPrivatePointer
| NonPrivatePointerKHR
| None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
import com.dat3m.dartagnan.program.event.Event;
import com.dat3m.dartagnan.program.event.EventFactory;
import com.dat3m.dartagnan.program.event.Tag;
import org.antlr.v4.runtime.tree.ParseTree;
import org.antlr.v4.runtime.RuleContext;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -187,32 +186,14 @@ private void visitOpAccessChain(String id, String typeId, String baseId,
}

private Set<String> parseMemoryAccessTags(SpirvParser.MemoryAccessContext ctx) {
if (ctx == null || ctx.None() != null) {
return Set.of();
if (ctx != null) {
List<String> operands = ctx.memoryAccessTag().stream().map(RuleContext::getText).toList();
Integer alignment = ctx.literalInteger() != null ? Integer.parseInt(ctx.literalInteger().getText()) : null;
List<String> paramIds = ctx.idRef().stream().map(RuleContext::getText).toList();
List<Expression> paramsValues = ctx.idRef().stream().map(c -> builder.getExpression(c.getText())).toList();
return HelperTags.parseMemoryOperandsTags(operands, alignment, paramIds, paramsValues);
}
if (ctx.Volatile() != null) {
return Set.of(Tag.Spirv.MEM_VOLATILE);
}
if (ctx.Nontemporal() != null) {
return Set.of(Tag.Spirv.MEM_NON_TEMPORAL);
}
if (ctx.NonPrivatePointer() != null || ctx.NonPrivatePointerKHR() != null) {
return Set.of(Tag.Spirv.MEM_NON_PRIVATE);
}
if (ctx.idScope() != null) {
String scopeId = ctx.idScope().getText();
String scopeTag = HelperTags.parseScope(scopeId, builder.getExpression(scopeId));
Set<String> tags = new HashSet<>(Set.of(scopeTag, Tag.Spirv.MEM_NON_PRIVATE));
if (ctx.MakePointerAvailable() != null || ctx.MakePointerAvailableKHR() != null) {
tags.add(Tag.Spirv.MEM_AVAILABLE);
}
if (ctx.MakePointerVisible() != null || ctx.MakePointerVisibleKHR() != null) {
tags.add(Tag.Spirv.MEM_VISIBLE);
}
return tags;
}
throw new ParsingException("Unsupported memory access tag '%s'",
String.join(" ", ctx.children.stream().map(ParseTree::getText).toList()));
return Set.of();
}

private String getScope(String storageClass) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.dat3m.dartagnan.exception.ParsingException;
import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.expression.integers.IntLiteral;
import com.dat3m.dartagnan.program.event.Tag;
import com.google.common.collect.Sets;

import java.util.*;
Expand Down Expand Up @@ -37,6 +38,72 @@ public static Set<String> parseMemorySemanticsTags(String id, Expression expr) {
return tags;
}

public static Set<String> parseMemoryOperandsTags(List<String> operands, Integer alignment,
List<String> paramIds, List<Expression> paramsValues) {
List<String> tagList = parseTagList(operands, alignment);
Set<String> tagSet = new HashSet<>(tagList);
if (tagList.size() != tagSet.size()) {
throwDuplicatesException(operands);
}
int i = 0;
for (String tag : List.of(Tag.Spirv.MEM_AVAILABLE, Tag.Spirv.MEM_VISIBLE)) {
if (tagSet.contains(tag)) {
if (paramIds.size() <= i) {
throwIllegalParametersException(operands);
}
String scopeTag = HelperTags.parseScope(paramIds.get(i), paramsValues.get(i));
tagSet.add(scopeTag);
i++;
}
}
if (i != paramsValues.size()) {
throwIllegalParametersException(operands);
}
// TODO: Implementation: this is a legal combination for OpCopyMemory and OpCopyMemorySized
if (tagSet.contains(Tag.Spirv.MEM_VISIBLE) && tagSet.contains(Tag.Spirv.MEM_AVAILABLE)) {
throw new ParsingException("Unsupported combination of memory operands '%s'",
String.join("|", operands));
}
return tagSet;
}

private static List<String> parseTagList(List<String> operands, Integer alignment) {
boolean isNone = false;
boolean isAligned = false;
List<String> tagList = new LinkedList<>();
for (String tag : operands) {
switch (tag) {
case "None" -> {
if (isNone) {
throwDuplicatesException(operands);
}
isNone = true;
}
case "Aligned" -> {
if (isAligned) {
throwDuplicatesException(operands);
}
isAligned = true;
}
case "Volatile" -> tagList.add(Tag.Spirv.MEM_VOLATILE);
case "Nontemporal" -> tagList.add(Tag.Spirv.MEM_NONTEMPORAL);
case "MakePointerAvailable", "MakePointerAvailableKHR" -> tagList.add(Tag.Spirv.MEM_AVAILABLE);
case "MakePointerVisible", "MakePointerVisibleKHR" -> tagList.add(Tag.Spirv.MEM_VISIBLE);
case "NonPrivatePointer", "NonPrivatePointerKHR" -> tagList.add(Tag.Spirv.MEM_NON_PRIVATE);
case "AliasScopeINTELMask", "NoAliasINTELMask" ->
throw new ParsingException("Unsupported memory operand '%s'", tag);
default -> throw new ParsingException("Unexpected memory operand '%s'", tag);
}
}
if (isNone && (isAligned || !tagList.isEmpty())) {
throw new ParsingException("Memory operand 'None' cannot be combined with other operands");
}
if (isAligned && alignment == null || !isAligned && alignment != null) {
throwIllegalParametersException(operands);
}
return tagList;
}

public static String parseScope(String id, Expression expr) {
int value = getIntValue(id, expr);
if (value >= 0 && value < scopes.size()) {
Expand All @@ -63,6 +130,16 @@ public static String parseStorageClass(String cls) {
};
}

private static void throwDuplicatesException(List<String> operands) {
throw new ParsingException("Duplicated memory operands definition(s) in '%s'",
String.join("|", operands));
}

private static void throwIllegalParametersException(List<String> operands) {
throw new ParsingException("Illegal parameter(s) in memory operands definition '%s'",
String.join("|", operands));
}

private static int getIntValue(String id, Expression expr) {
if (expr instanceof IntLiteral iValue) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ public static final class Spirv {

// Memory access (non-atomic)
public static final String MEM_VOLATILE = "SPV_MEM_VOLATILE";
public static final String MEM_NON_TEMPORAL = "SPV_MEM_NON_TEMPORAL";
public static final String MEM_NONTEMPORAL = "SPV_MEM_NONTEMPORAL";
public static final String MEM_NON_PRIVATE = "SPV_MEM_NON_PRIVATE";
public static final String MEM_AVAILABLE = "SPV_MEM_AVAILABLE";
public static final String MEM_VISIBLE = "SPV_MEM_VISIBLE";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ private String toVulkanTag(String tag) {

// Memory access (non-atomic)
case Tag.Spirv.MEM_VOLATILE,
Tag.Spirv.MEM_NON_TEMPORAL -> null;
Tag.Spirv.MEM_NONTEMPORAL -> null;
case Tag.Spirv.MEM_NON_PRIVATE -> Tag.Vulkan.NON_PRIVATE;
case Tag.Spirv.MEM_AVAILABLE -> Tag.Vulkan.AVAILABLE;
case Tag.Spirv.MEM_VISIBLE -> Tag.Vulkan.VISIBLE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,21 @@ public GrammarSpirvTest(String filename) {

@Parameterized.Parameters(name = "{index}: {0}")
public static Iterable<Object[]> data() throws IOException {
List<Object[]> data = new LinkedList<>();
listFiles(Paths.get(getTestResourcePath("parsers/program/spirv")), data);
listFiles(Paths.get(getTestResourcePath("spirv")), data);
return data;
return listFiles(Paths.get(getTestResourcePath("spirv")));
}

private static void listFiles(Path path, List<Object[]> result) throws IOException {
private static List<Object[]> listFiles(Path path) throws IOException {
List<Object[]> result = new LinkedList<>();
try (DirectoryStream<Path> files = Files.newDirectoryStream(path)) {
for (Path file : files) {
if (Files.isDirectory(file)) {
listFiles(file, result);
result.addAll(listFiles(file));
} else {
result.add(new Object[]{file.toAbsolutePath().toString()});
}
}
}
return result;
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,35 @@ public class ParserSpirvTest {

@Test
public void testParsingProgram() throws IOException {
String path = Paths.get(getTestResourcePath("parsers/program/spirv/valid/fibonacci.spv.dis")).toString();
doTestParsingValidProgram("fibonacci.spv.dis");
doTestParsingValidProgram("mp-memory-operands.spv.dis");
}

@Test
public void testInvalidControlFlow() throws IOException {
String error = "Unexpected operation 'OpLogicalNot'";
doTestParsingInvalidProgram("control-flow/malformed-selection-merge-label.spv.dis", error);
doTestParsingInvalidProgram("control-flow/malformed-selection-merge.spv.dis", error);
doTestParsingInvalidProgram("control-flow/malformed-loop-merge.spv.dis", error);
doTestParsingInvalidProgram("control-flow/malformed-loop-merge-true-label.spv.dis", error);
}

@Test
public void testInvalidMemoryOperands() throws IOException {
doTestParsingInvalidProgram("memory-operands/illegal-parameter-order-1.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/illegal-parameter-order-2.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/missing-alignment.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/missing-scope-1.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/missing-scope-2.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/unnecessary-alignment-1.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/unnecessary-alignment-2.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/unnecessary-alignment-3.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/unnecessary-scope-1.spv.dis", null);
doTestParsingInvalidProgram("memory-operands/unnecessary-scope-2.spv.dis", null);
}

private void doTestParsingValidProgram(String file) throws IOException {
String path = Paths.get(getTestResourcePath("parsers/program/spirv/valid/" + file)).toString();
try (FileInputStream stream = new FileInputStream(path)) {
CharStream charStream = CharStreams.fromStream(stream);
ParserSpirv parser = new ParserSpirv();
Expand All @@ -26,15 +54,7 @@ public void testParsingProgram() throws IOException {
}
}

@Test
public void testParsingInvalidProgram() throws IOException {
doTestParsingInvalidProgram("malformed-selection-merge-label.spv.dis");
doTestParsingInvalidProgram("malformed-selection-merge.spv.dis");
doTestParsingInvalidProgram("malformed-loop-merge.spv.dis");
doTestParsingInvalidProgram("malformed-loop-merge-true-label.spv.dis");
}

private void doTestParsingInvalidProgram(String file) throws IOException {
private void doTestParsingInvalidProgram(String file, String error) throws IOException {
String path = Paths.get(getTestResourcePath("parsers/program/spirv/invalid/" + file)).toString();
try (FileInputStream stream = new FileInputStream(path)) {
CharStream charStream = CharStreams.fromStream(stream);
Expand All @@ -43,7 +63,9 @@ private void doTestParsingInvalidProgram(String file) throws IOException {
parser.parse(charStream);
fail("Should throw exception");
} catch (ParsingException e) {
assertEquals("Unexpected operation 'OpLogicalNot'", e.getMessage());
if (error != null) {
assertEquals(error, e.getMessage());
}
}
}
}
Expand Down
Loading

0 comments on commit 22d8a02

Please sign in to comment.