From cf90baafdbceb8f188fc1301410bb6ca5c70d75c Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Wed, 2 Aug 2023 15:33:40 -0400 Subject: [PATCH 01/32] WIP: experimental recovergence tests * Partial framework implemented --- .../reconvergence/reconvergence.spec.ts | 692 ++++++++++++++++++ 1 file changed, 692 insertions(+) create mode 100644 src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts new file mode 100644 index 000000000000..5416c7851447 --- /dev/null +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -0,0 +1,692 @@ +export const description = `Experimental reconvergence tests based on the Vulkan reconvergence tests at: +https://github.com/KhronosGroup/VK-GL-CTS/blob/main/external/vulkancts/modules/vulkan/reconvergence/vktReconvergenceTests.cpp`; + +import { makeTestGroup } from '../../../../common/framework/test_group.js'; +import { GPUTest } from '../../../gpu_test.js'; +import { assert, unreachable } from '../../../../common/util/util.js'; +import { PRNG } from '../../../util/prng.js'; + +export const g = makeTestGroup(GPUTest); + +// Returns a bitmask where bits [0,size) are 1s. +function getMask(size: number): bigint { + return (1n << BigInt(size)) - 1n; +} + +// Returns a bitmask where submask is repeated every size bits for total bits. +function getReplicatedMask(submask: bigint, size: number, total: number = 128): bigint { + const reps = total / size; + var mask: bigint = submask; + for (var i = 1; i < reps; i++) { + mask |= (mask << BigInt(size)); + } + return mask; +} + +function any(value: bigint): boolean { + return value !== 0n; +} + +function all(value: bigint, stride: number): boolean { + return value === ((1n << BigInt(stride) - 1n)); +} + +enum Style { + Workgroup, + Subgroup, + Maximal, +}; + +enum OpType { + OpBallot, + + OpStore, + + OpIfMask, + OpElseMask, + OpEndIf, + + OpIfLoopCount, + OpElseLoopCount, + + OpIfLid, + OpElseLid, + + OpBreak, + OpContinue, + + OpForUniform, + OpEndForUniform, + + OpReturn, + + OpMAX, +} + +enum IfType { + IfMask, + IfUniform, + IfLoopCount, + IfLid, +}; + +class Op { + op : OpType; + value : number; + caseValue : number; + + constructor(op : OpType, value: number = 0, caseValue: number = 0) { + this.op = op; + this.value = value; + this.caseValue = caseValue; + } +}; + +class Program { + prng: PRNG; + ops : Op[]; + style: Style; + minCount: number; + maxNesting: number; + nesting: number; + loopNesting: number; + loopNestingThisFunction: number; + numMasks: number; + masks: number[]; + curFunc: number; + functions: string[]; + indents: number[]; + + constructor(style : Style = Style.Workgroup, seed: number = 1) { + this.prng = new PRNG(seed); + this.ops = []; + this.style = style; + this.minCount = 5; // 30; + this.maxNesting = 5; // this.getRandomUint(70) + 30; // [30,100) + this.nesting = 0; + this.loopNesting = 0; + this.loopNestingThisFunction = 0; + this.numMasks = 10; + this.masks = []; + this.masks.push(0xffffffff); + this.masks.push(0xffffffff); + this.masks.push(0xffffffff); + this.masks.push(0xffffffff); + for (var i = 1; i < this.numMasks; i++) { + this.masks.push(this.getRandomUint(0xffffffff)); + this.masks.push(this.getRandomUint(0xffffffff)); + this.masks.push(this.getRandomUint(0xffffffff)); + this.masks.push(this.getRandomUint(0xffffffff)); + } + this.curFunc = 0; + this.functions = []; + this.functions.push(``); + this.indents = []; + this.indents.push(2); + } + + getRandomFloat(): number { + return this.prng.random(); + } + + getRandomUint(max: number): number { + return this.prng.randomU32() % max; + } + + pickOp(count : number) { + for (var i = 0; i < count; i++) { + //optBallot(); + if (this.nesting < this.maxNesting) { + const r = this.getRandomUint(12); + switch (r) { + case 0: { + if (this.loopNesting > 0) { + this.genIf(IfType.IfLoopCount); + break; + } + this.genIf(IfType.IfLid); + break; + } + case 1: { + this.genIf(IfType.IfLid); + break; + } + case 2: { + this.genIf(IfType.IfMask); + break; + } + case 3: { + this.genIf(IfType.IfUniform); + break; + } + case 4: { + if (this.loopNesting <= 3) { + const r2 = this.getRandomUint(3); + switch (r2) { + case 0: this.genForUniform(); break; + case 2: + default: { + break; + } + } + } + break; + } + case 5: { + this.genBreak(); + break; + } + case 6: { + this.genContinue(); + break; + } + default: { + break; + } + } + } + } + } + + genIf(type: IfType) { + let maskIdx = this.getRandomUint(this.numMasks); + if (type == IfType.IfUniform) + maskIdx = 0; + + const lid = this.getRandomUint(128); + if (type == IfType.IfLid) { + this.ops.push(new Op(OpType.OpIfLid, lid)); + } else if (type == IfType.IfLoopCount) { + this.ops.push(new Op(OpType.OpIfLoopCount, 0)); + } else { + this.ops.push(new Op(OpType.OpIfMask, maskIdx)); + } + + this.nesting++; + + let beforeSize = this.ops.length; + this.pickOp(2); + let afterSize = this.ops.length; + + const randElse = this.getRandomFloat(); + if (randElse < 0.5) { + if (type == IfType.IfLid) { + this.ops.push(new Op(OpType.OpElseLid, lid)); + } else if (type == IfType.IfLoopCount) { + this.ops.push(new Op(OpType.OpElseLoopCount, 0)); + } else { + this.ops.push(new Op(OpType.OpElseMask, maskIdx)); + } + + // Sometimes make the else identical to the if. + if (randElse < 0.1 && beforeSize != afterSize) { + for (var i = beforeSize; i < afterSize; i++) { + const op = this.ops[i]; + this.ops.push(new Op(op.op, op.value, op.caseValue)); + } + } else { + this.pickOp(2); + } + } + this.ops.push(new Op(OpType.OpEndIf, 0)); + + this.nesting--; + } + + genForUniform() { + const n = this.getRandomUint(5) + 1; // [1, 5] + this.ops.push(new Op(OpType.OpForUniform, n)); + const header = this.ops.length - 1; + this.nesting++; + this.loopNesting++; + this.loopNestingThisFunction++; + this.pickOp(2); + this.ops.push(new Op(OpType.OpEndForUniform, header)); + this.loopNestingThisFunction--; + this.loopNesting--; + this.nesting--; + } + + genBreak() { + if (this.loopNestingThisFunction > 0) + { + // Sometimes put the break in a divergent if + if (this.getRandomFloat() < 0.1) { + const r = this.getRandomUint(this.numMasks-1) + 1; + this.ops.push(new Op(OpType.OpIfMask, r)); + this.ops.push(new Op(OpType.OpBreak, 0)); + this.ops.push(new Op(OpType.OpElseMask, r)); + this.ops.push(new Op(OpType.OpBreak, 0)); + this.ops.push(new Op(OpType.OpEndIf, 0)); + } else { + this.ops.push(new Op(OpType.OpBreak, 0)); + } + } + } + + genContinue() { + } + + genCode(): string { + for (var i = 0; i < this.ops.length; i++) { + const op = this.ops[i]; + this.genIndent() + this.addCode(`// ops[${i}] = ${op.op}\n`); + switch (op.op) { + //case OpType.OpBallot: { + // break; + //} + //case OpType.OpStore: { + // break; + //} + default: { + this.genIndent(); + this.addCode(`/* missing op ${op.op} */\n`); + break; + } + case OpType.OpIfMask: { + this.genIndent(); + if (op.value == 0) { + const idx = this.getRandomUint(4); + this.addCode(`if inputs[${idx}] == ${idx} {\n`); + } else { + const idx = op.value; + const x = this.masks[4*idx]; + const y = this.masks[4*idx+1]; + const z = this.masks[4*idx+2]; + const w = this.masks[4*idx+3]; + this.addCode(`if testBit(vec4u(${x},${y},${z},${w}), subgroup_id) {\n`); + } + this.increaseIndent(); + break; + } + case OpType.OpIfLid: { + this.genIndent(); + this.addCode(`if subgroup_id < inputs[${op.value}] {\n`); + this.increaseIndent(); + break; + } + case OpType.OpIfLoopCount: { + this.genIndent(); + this.addCode(`if subgroup_id == i${this.loopNesting-1} {\n`); + this.increaseIndent(); + break; + } + case OpType.OpElseMask: + case OpType.OpElseLid: + case OpType.OpElseLoopCount: { + this.decreaseIndent(); + this.genIndent(); + this.addCode(`} else {\n`); + this.increaseIndent(); + break; + } + case OpType.OpEndIf: { + this.decreaseIndent(); + this.genIndent(); + this.addCode(`}\n`); + break; + } + case OpType.OpForUniform: { + this.genIndent(); + const iter = `i${this.loopNesting}`; + this.addCode(`for (var ${iter} = 0u; ${iter} < inputs[${op.value}]; ${iter}++) {\n`); + this.increaseIndent(); + this.loopNesting++; + break; + } + case OpType.OpEndForUniform: { + this.loopNesting--; + this.decreaseIndent(); + this.genIndent(); + this.addCode(`}\n`); + break; + } + case OpType.OpBreak: { + this.genIndent(); + this.addCode(`break;\n`); + break; + } + } + } + + let code = ``; + for (var i = 0; i < this.functions.length; i++) { + code += ` +fn f${i}() { +${this.functions[i]} +} +`; + } + return code; + } + + genIndent() { + this.functions[this.curFunc] += ' '.repeat(this.indents[this.curFunc]); + } + increaseIndent() { + this.indents[this.curFunc] += 2; + } + decreaseIndent() { + this.indents[this.curFunc] -= 2; + } + addCode(code: string) { + this.functions[this.curFunc] += code; + } + + simulate(countOnly: boolean, size: number, stride: number = 128): number { + class State { + activeMask: bigint; + continueMask: bigint; + header: number; + isLoop: boolean; + tripCount: number; + isCall: boolean; + isSwitch: boolean; + + constructor() { + this.activeMask = 0n; + this.continueMask = 0n; + this.header = 0; + this.isLoop = false; + this.tripCount = 0; + this.isCall = false; + this.isSwitch = false; + } + + copy(other: State) { + this.activeMask = other.activeMask; + this.continueMask = other.continueMask; + this.header = other.header; + this.isLoop = other.isLoop; + this.tripCount = other.tripCount; + this.isCall = other.isCall; + this.isSwitch = other.isSwitch; + } + }; + var stack = new Array(); + stack.push(new State()); + stack[0].activeMask = (1n << 128n) - 1n; + //for (var i = 0; i < 10; i++) { + // stack[i] = new State(); + //} + //stack[0].activeMask = (1n << 128n) - 1n; + + var nesting = 0; + var loopNesting = 0; + var locs = new Array(stride); + locs.fill(0); + + var i = 0; + while (i < this.ops.length) { + const op = this.ops[i]; + console.log(`ops[${i}] = ${op.op}, nesting = ${nesting}`); + console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); + //for (var j = 0; j <= nesting; j++) { + // console.log(` mask[${j}] = ${stack[j].activeMask.toString(16)}`); + //} + switch (op.op) { + case OpType.OpIfMask: { + nesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.copy(stack[nesting-1]); + cur.header = i; + // O is always uniform true. + if (op.value != 0) { + cur.activeMask &= this.getValueMask(op.value); + } + break; + } + case OpType.OpElseMask: { + // 0 is always uniform true so the else will never be taken. + const cur = stack[nesting]; + if (op.value == 0) { + cur.activeMask = 0n; + } else { + const prev = stack[nesting-1]; + cur.activeMask = prev.activeMask & ~this.getValueMask(op.value); + } + break; + } + case OpType.OpIfLid: { + nesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.copy(stack[nesting-1]); + cur.header = i; + // All invocations with subgroup invocation id less than op.value are active. + cur.activeMask &= getReplicatedMask(getMask(op.value), size, stride); + break; + } + case OpType.OpElseLid: { + const prev = stack[nesting-1]; + // All invocations with a subgroup invocation id greater or equal to op.value are active. + stack[nesting].activeMask = prev.activeMask; + stack[nesting].activeMask &= ~getReplicatedMask(getMask(op.value), size, stride); + break; + } + case OpType.OpIfLoopCount: { + let n = nesting; + while (!stack[n].isLoop) { + n--; + } + + nesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.copy(stack[nesting-1]); + cur.header = i; + cur.isLoop = 0; + cur.isSwitch = 0; + cur.activeMask &= getReplicatedMask(BigInt(1 << stack[n].tripCount), size, stride); + break; + } + case OpType.OpElseLoopCount: { + let n = nesting; + while (!stack[n].isLoop) { + n--; + } + + stack[nesting].activeMask = stack[nesting-1].activeMask; + stack[nesting].activeMask &= ~getReplicatedMask(BigInt(1 << stack[n].tripCount), size, stride); + break; + } + case OpType.OpEndIf: { + nesting--; + stack.pop(); + break; + } + case OpType.OpForUniform: { + nesting++; + loopNesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.header = i; + cur.isLoop = true; + cur.activeMask = stack[nesting-1].activeMask; + break; + } + case OpType.OpEndForUniform: { + const cur = stack[nesting]; + cur.tripCount++; + cur.activeMask |= stack[nesting].continueMask; + cur.continueMask = 0n; + // Loop if there are any invocations left with iterations to perform. + if (cur.tripCount < this.ops[cur.header].value && + any(cur.activeMask)) { + i = cur.header + 1; + continue; + } else { + loopNesting--; + nesting--; + stack.pop(); + } + break; + } + case OpType.OpBreak: { + var n = nesting; + var mask: bigint = stack[nesting].activeMask; + while (true) { + stack[n].activeMask &= ~mask; + if (stack[n].isLoop || stack[n].isSwitch) { + break; + } + + n--; + } + break; + } + default: { + unreachable(`Unhandled op ${op.op}`); + } + } + i++; + } + + assert(stack.length == 1); + + var maxLoc = 0; + for (var j = 0; j < stride; j++) { + maxLoc = Math.max(maxLoc, locs[j]); + } + return maxLoc; + } + + // Returns an active mask for the mask at the given index. + getValueMask(idx: number): bigint { + const x = this.masks[4*idx]; + const y = this.masks[4*idx+1]; + const z = this.masks[4*idx+2]; + const w = this.masks[4*idx+3]; + var mask: bigint = 0n; + mask |= BigInt(x); + mask |= BigInt(y) << 32n; + mask |= BigInt(z) << 64n; + mask |= BigInt(w) << 96n; + return mask; + } +}; + +function generateProgram(program: Program): string { + while (program.ops.length < program.minCount) { + program.pickOp(1); + } + + return program.genCode(); +}; + +function generateSeeds(numCases: number): number[] { + var prng: PRNG = new PRNG(1); + var output: number[] = new Array(numCases); + for (var i = 0; i < numCases; i++) { + output[i] = prng.randomU32(); + } + return output; +} + +g.test('reconvergence') + .desc(`Test reconvergence`) + .params(u => + u + .combine('style', [Style.Workgroup, Style.Subgroup, Style.Maximal] as const) + .combine('seed', generateSeeds(5)) + .filter(u => { + if (u.style == Style.Workgroup) { + return true; + } + return false; + }) + .beginSubcases() + ) + .fn(t => { + const invocations = 128; // t.device.limits.maxSubgroupSize; + + let wgsl = ` +//enable chromium_experimental_subgroups; + +const stride = ${invocations}; + +@group(0) @binding(0) +var inputs : array; +@group(0) @binding(1) +var ballots : array; + +var subgroup_id : u32; + +@compute @workgroup_size(${invocations},1,1) +fn main( + //@builtin(local_invocation_index) id : u32, +) { + _ = inputs[0]; + _ = ballots[0]; + subgroup_id = 0; // id; + + f0(); +} + +fn testBit(mask : vec4u, id : u32) -> bool { + let xbit = extractBits(mask.x, id, 1); + let ybit = extractBits(mask.y, id - 32, 1); + let zbit = extractBits(mask.z, id - 64, 1); + let wbit = extractBits(mask.w, id - 96, 1); + let lt32 = id < 32; + let lt64 = id < 64; + let lt96 = id < 96; + let sela = select(wbit, xbit, lt96); + let selb = select(zbit, ybit, lt64); + return select(selb, sela, lt32) == 1; +} +`; + + let program : Program = new Program(t.params.style, t.params.seed); + wgsl += generateProgram(program); + console.log(wgsl); + + const num = program.simulate(true, 16, invocations); + + const pipeline = t.device.createComputePipeline({ + layout: 'auto', + compute: { + module: t.device.createShaderModule({ + code: wgsl, + }), + entryPoint: 'main', + }, + }); + + //// Helper to create a `size`-byte buffer with binding number `binding`. + //function createBuffer(size: number, binding: number) { + // const buffer = t.device.createBuffer({ + // size, + // usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, + // }); + // t.trackForCleanup(buffer); + + // bindGroupEntries.push({ + // binding, + // resource: { + // buffer, + // }, + // }); + + // return buffer; + //} + + //const bindGroupEntries: GPUBindGroupEntry[] = []; + //const inputBuffer = createBuffer(16, 0); + //const ballotBuffer = createBuffer(16, 1); + + //const bindGroup = t.device.createBindGroup({ + // layout: pipeline.getBindGroupLayout(0), + // entries: bindGroupEntries, + //}); + + //const encoder = t.device.createCommandEncoder(); + //const pass = encoder.beginComputePass(); + //pass.setPipeline(pipeline); + //pass.setBindGroup(0, bindGroup); + //pass.dispatchWorkgroups(1,1,1); + //pass.end(); + //t.queue.submit([encoder.finish()]); + }); From a25b6dc54d1cd6e47ba4c0ea196b966fe3d716d9 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Wed, 9 Aug 2023 12:12:20 -0400 Subject: [PATCH 02/32] Refactoring and implementation * Added more ops * Moved infrastructure into a utility file --- .../reconvergence/reconvergence.spec.ts | 623 +-------------- .../shader/execution/reconvergence/util.ts | 717 ++++++++++++++++++ 2 files changed, 721 insertions(+), 619 deletions(-) create mode 100644 src/webgpu/shader/execution/reconvergence/util.ts diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 5416c7851447..9ea786b91cc5 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -3,588 +3,10 @@ https://github.com/KhronosGroup/VK-GL-CTS/blob/main/external/vulkancts/modules/v import { makeTestGroup } from '../../../../common/framework/test_group.js'; import { GPUTest } from '../../../gpu_test.js'; -import { assert, unreachable } from '../../../../common/util/util.js'; -import { PRNG } from '../../../util/prng.js'; +import { Style, Program, generateSeeds } from './util.js' export const g = makeTestGroup(GPUTest); -// Returns a bitmask where bits [0,size) are 1s. -function getMask(size: number): bigint { - return (1n << BigInt(size)) - 1n; -} - -// Returns a bitmask where submask is repeated every size bits for total bits. -function getReplicatedMask(submask: bigint, size: number, total: number = 128): bigint { - const reps = total / size; - var mask: bigint = submask; - for (var i = 1; i < reps; i++) { - mask |= (mask << BigInt(size)); - } - return mask; -} - -function any(value: bigint): boolean { - return value !== 0n; -} - -function all(value: bigint, stride: number): boolean { - return value === ((1n << BigInt(stride) - 1n)); -} - -enum Style { - Workgroup, - Subgroup, - Maximal, -}; - -enum OpType { - OpBallot, - - OpStore, - - OpIfMask, - OpElseMask, - OpEndIf, - - OpIfLoopCount, - OpElseLoopCount, - - OpIfLid, - OpElseLid, - - OpBreak, - OpContinue, - - OpForUniform, - OpEndForUniform, - - OpReturn, - - OpMAX, -} - -enum IfType { - IfMask, - IfUniform, - IfLoopCount, - IfLid, -}; - -class Op { - op : OpType; - value : number; - caseValue : number; - - constructor(op : OpType, value: number = 0, caseValue: number = 0) { - this.op = op; - this.value = value; - this.caseValue = caseValue; - } -}; - -class Program { - prng: PRNG; - ops : Op[]; - style: Style; - minCount: number; - maxNesting: number; - nesting: number; - loopNesting: number; - loopNestingThisFunction: number; - numMasks: number; - masks: number[]; - curFunc: number; - functions: string[]; - indents: number[]; - - constructor(style : Style = Style.Workgroup, seed: number = 1) { - this.prng = new PRNG(seed); - this.ops = []; - this.style = style; - this.minCount = 5; // 30; - this.maxNesting = 5; // this.getRandomUint(70) + 30; // [30,100) - this.nesting = 0; - this.loopNesting = 0; - this.loopNestingThisFunction = 0; - this.numMasks = 10; - this.masks = []; - this.masks.push(0xffffffff); - this.masks.push(0xffffffff); - this.masks.push(0xffffffff); - this.masks.push(0xffffffff); - for (var i = 1; i < this.numMasks; i++) { - this.masks.push(this.getRandomUint(0xffffffff)); - this.masks.push(this.getRandomUint(0xffffffff)); - this.masks.push(this.getRandomUint(0xffffffff)); - this.masks.push(this.getRandomUint(0xffffffff)); - } - this.curFunc = 0; - this.functions = []; - this.functions.push(``); - this.indents = []; - this.indents.push(2); - } - - getRandomFloat(): number { - return this.prng.random(); - } - - getRandomUint(max: number): number { - return this.prng.randomU32() % max; - } - - pickOp(count : number) { - for (var i = 0; i < count; i++) { - //optBallot(); - if (this.nesting < this.maxNesting) { - const r = this.getRandomUint(12); - switch (r) { - case 0: { - if (this.loopNesting > 0) { - this.genIf(IfType.IfLoopCount); - break; - } - this.genIf(IfType.IfLid); - break; - } - case 1: { - this.genIf(IfType.IfLid); - break; - } - case 2: { - this.genIf(IfType.IfMask); - break; - } - case 3: { - this.genIf(IfType.IfUniform); - break; - } - case 4: { - if (this.loopNesting <= 3) { - const r2 = this.getRandomUint(3); - switch (r2) { - case 0: this.genForUniform(); break; - case 2: - default: { - break; - } - } - } - break; - } - case 5: { - this.genBreak(); - break; - } - case 6: { - this.genContinue(); - break; - } - default: { - break; - } - } - } - } - } - - genIf(type: IfType) { - let maskIdx = this.getRandomUint(this.numMasks); - if (type == IfType.IfUniform) - maskIdx = 0; - - const lid = this.getRandomUint(128); - if (type == IfType.IfLid) { - this.ops.push(new Op(OpType.OpIfLid, lid)); - } else if (type == IfType.IfLoopCount) { - this.ops.push(new Op(OpType.OpIfLoopCount, 0)); - } else { - this.ops.push(new Op(OpType.OpIfMask, maskIdx)); - } - - this.nesting++; - - let beforeSize = this.ops.length; - this.pickOp(2); - let afterSize = this.ops.length; - - const randElse = this.getRandomFloat(); - if (randElse < 0.5) { - if (type == IfType.IfLid) { - this.ops.push(new Op(OpType.OpElseLid, lid)); - } else if (type == IfType.IfLoopCount) { - this.ops.push(new Op(OpType.OpElseLoopCount, 0)); - } else { - this.ops.push(new Op(OpType.OpElseMask, maskIdx)); - } - - // Sometimes make the else identical to the if. - if (randElse < 0.1 && beforeSize != afterSize) { - for (var i = beforeSize; i < afterSize; i++) { - const op = this.ops[i]; - this.ops.push(new Op(op.op, op.value, op.caseValue)); - } - } else { - this.pickOp(2); - } - } - this.ops.push(new Op(OpType.OpEndIf, 0)); - - this.nesting--; - } - - genForUniform() { - const n = this.getRandomUint(5) + 1; // [1, 5] - this.ops.push(new Op(OpType.OpForUniform, n)); - const header = this.ops.length - 1; - this.nesting++; - this.loopNesting++; - this.loopNestingThisFunction++; - this.pickOp(2); - this.ops.push(new Op(OpType.OpEndForUniform, header)); - this.loopNestingThisFunction--; - this.loopNesting--; - this.nesting--; - } - - genBreak() { - if (this.loopNestingThisFunction > 0) - { - // Sometimes put the break in a divergent if - if (this.getRandomFloat() < 0.1) { - const r = this.getRandomUint(this.numMasks-1) + 1; - this.ops.push(new Op(OpType.OpIfMask, r)); - this.ops.push(new Op(OpType.OpBreak, 0)); - this.ops.push(new Op(OpType.OpElseMask, r)); - this.ops.push(new Op(OpType.OpBreak, 0)); - this.ops.push(new Op(OpType.OpEndIf, 0)); - } else { - this.ops.push(new Op(OpType.OpBreak, 0)); - } - } - } - - genContinue() { - } - - genCode(): string { - for (var i = 0; i < this.ops.length; i++) { - const op = this.ops[i]; - this.genIndent() - this.addCode(`// ops[${i}] = ${op.op}\n`); - switch (op.op) { - //case OpType.OpBallot: { - // break; - //} - //case OpType.OpStore: { - // break; - //} - default: { - this.genIndent(); - this.addCode(`/* missing op ${op.op} */\n`); - break; - } - case OpType.OpIfMask: { - this.genIndent(); - if (op.value == 0) { - const idx = this.getRandomUint(4); - this.addCode(`if inputs[${idx}] == ${idx} {\n`); - } else { - const idx = op.value; - const x = this.masks[4*idx]; - const y = this.masks[4*idx+1]; - const z = this.masks[4*idx+2]; - const w = this.masks[4*idx+3]; - this.addCode(`if testBit(vec4u(${x},${y},${z},${w}), subgroup_id) {\n`); - } - this.increaseIndent(); - break; - } - case OpType.OpIfLid: { - this.genIndent(); - this.addCode(`if subgroup_id < inputs[${op.value}] {\n`); - this.increaseIndent(); - break; - } - case OpType.OpIfLoopCount: { - this.genIndent(); - this.addCode(`if subgroup_id == i${this.loopNesting-1} {\n`); - this.increaseIndent(); - break; - } - case OpType.OpElseMask: - case OpType.OpElseLid: - case OpType.OpElseLoopCount: { - this.decreaseIndent(); - this.genIndent(); - this.addCode(`} else {\n`); - this.increaseIndent(); - break; - } - case OpType.OpEndIf: { - this.decreaseIndent(); - this.genIndent(); - this.addCode(`}\n`); - break; - } - case OpType.OpForUniform: { - this.genIndent(); - const iter = `i${this.loopNesting}`; - this.addCode(`for (var ${iter} = 0u; ${iter} < inputs[${op.value}]; ${iter}++) {\n`); - this.increaseIndent(); - this.loopNesting++; - break; - } - case OpType.OpEndForUniform: { - this.loopNesting--; - this.decreaseIndent(); - this.genIndent(); - this.addCode(`}\n`); - break; - } - case OpType.OpBreak: { - this.genIndent(); - this.addCode(`break;\n`); - break; - } - } - } - - let code = ``; - for (var i = 0; i < this.functions.length; i++) { - code += ` -fn f${i}() { -${this.functions[i]} -} -`; - } - return code; - } - - genIndent() { - this.functions[this.curFunc] += ' '.repeat(this.indents[this.curFunc]); - } - increaseIndent() { - this.indents[this.curFunc] += 2; - } - decreaseIndent() { - this.indents[this.curFunc] -= 2; - } - addCode(code: string) { - this.functions[this.curFunc] += code; - } - - simulate(countOnly: boolean, size: number, stride: number = 128): number { - class State { - activeMask: bigint; - continueMask: bigint; - header: number; - isLoop: boolean; - tripCount: number; - isCall: boolean; - isSwitch: boolean; - - constructor() { - this.activeMask = 0n; - this.continueMask = 0n; - this.header = 0; - this.isLoop = false; - this.tripCount = 0; - this.isCall = false; - this.isSwitch = false; - } - - copy(other: State) { - this.activeMask = other.activeMask; - this.continueMask = other.continueMask; - this.header = other.header; - this.isLoop = other.isLoop; - this.tripCount = other.tripCount; - this.isCall = other.isCall; - this.isSwitch = other.isSwitch; - } - }; - var stack = new Array(); - stack.push(new State()); - stack[0].activeMask = (1n << 128n) - 1n; - //for (var i = 0; i < 10; i++) { - // stack[i] = new State(); - //} - //stack[0].activeMask = (1n << 128n) - 1n; - - var nesting = 0; - var loopNesting = 0; - var locs = new Array(stride); - locs.fill(0); - - var i = 0; - while (i < this.ops.length) { - const op = this.ops[i]; - console.log(`ops[${i}] = ${op.op}, nesting = ${nesting}`); - console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); - //for (var j = 0; j <= nesting; j++) { - // console.log(` mask[${j}] = ${stack[j].activeMask.toString(16)}`); - //} - switch (op.op) { - case OpType.OpIfMask: { - nesting++; - stack.push(new State()); - const cur = stack[nesting]; - cur.copy(stack[nesting-1]); - cur.header = i; - // O is always uniform true. - if (op.value != 0) { - cur.activeMask &= this.getValueMask(op.value); - } - break; - } - case OpType.OpElseMask: { - // 0 is always uniform true so the else will never be taken. - const cur = stack[nesting]; - if (op.value == 0) { - cur.activeMask = 0n; - } else { - const prev = stack[nesting-1]; - cur.activeMask = prev.activeMask & ~this.getValueMask(op.value); - } - break; - } - case OpType.OpIfLid: { - nesting++; - stack.push(new State()); - const cur = stack[nesting]; - cur.copy(stack[nesting-1]); - cur.header = i; - // All invocations with subgroup invocation id less than op.value are active. - cur.activeMask &= getReplicatedMask(getMask(op.value), size, stride); - break; - } - case OpType.OpElseLid: { - const prev = stack[nesting-1]; - // All invocations with a subgroup invocation id greater or equal to op.value are active. - stack[nesting].activeMask = prev.activeMask; - stack[nesting].activeMask &= ~getReplicatedMask(getMask(op.value), size, stride); - break; - } - case OpType.OpIfLoopCount: { - let n = nesting; - while (!stack[n].isLoop) { - n--; - } - - nesting++; - stack.push(new State()); - const cur = stack[nesting]; - cur.copy(stack[nesting-1]); - cur.header = i; - cur.isLoop = 0; - cur.isSwitch = 0; - cur.activeMask &= getReplicatedMask(BigInt(1 << stack[n].tripCount), size, stride); - break; - } - case OpType.OpElseLoopCount: { - let n = nesting; - while (!stack[n].isLoop) { - n--; - } - - stack[nesting].activeMask = stack[nesting-1].activeMask; - stack[nesting].activeMask &= ~getReplicatedMask(BigInt(1 << stack[n].tripCount), size, stride); - break; - } - case OpType.OpEndIf: { - nesting--; - stack.pop(); - break; - } - case OpType.OpForUniform: { - nesting++; - loopNesting++; - stack.push(new State()); - const cur = stack[nesting]; - cur.header = i; - cur.isLoop = true; - cur.activeMask = stack[nesting-1].activeMask; - break; - } - case OpType.OpEndForUniform: { - const cur = stack[nesting]; - cur.tripCount++; - cur.activeMask |= stack[nesting].continueMask; - cur.continueMask = 0n; - // Loop if there are any invocations left with iterations to perform. - if (cur.tripCount < this.ops[cur.header].value && - any(cur.activeMask)) { - i = cur.header + 1; - continue; - } else { - loopNesting--; - nesting--; - stack.pop(); - } - break; - } - case OpType.OpBreak: { - var n = nesting; - var mask: bigint = stack[nesting].activeMask; - while (true) { - stack[n].activeMask &= ~mask; - if (stack[n].isLoop || stack[n].isSwitch) { - break; - } - - n--; - } - break; - } - default: { - unreachable(`Unhandled op ${op.op}`); - } - } - i++; - } - - assert(stack.length == 1); - - var maxLoc = 0; - for (var j = 0; j < stride; j++) { - maxLoc = Math.max(maxLoc, locs[j]); - } - return maxLoc; - } - - // Returns an active mask for the mask at the given index. - getValueMask(idx: number): bigint { - const x = this.masks[4*idx]; - const y = this.masks[4*idx+1]; - const z = this.masks[4*idx+2]; - const w = this.masks[4*idx+3]; - var mask: bigint = 0n; - mask |= BigInt(x); - mask |= BigInt(y) << 32n; - mask |= BigInt(z) << 64n; - mask |= BigInt(w) << 96n; - return mask; - } -}; - -function generateProgram(program: Program): string { - while (program.ops.length < program.minCount) { - program.pickOp(1); - } - - return program.genCode(); -}; - -function generateSeeds(numCases: number): number[] { - var prng: PRNG = new PRNG(1); - var output: number[] = new Array(numCases); - for (var i = 0; i < numCases; i++) { - output[i] = prng.randomU32(); - } - return output; -} - g.test('reconvergence') .desc(`Test reconvergence`) .params(u => @@ -602,48 +24,11 @@ g.test('reconvergence') .fn(t => { const invocations = 128; // t.device.limits.maxSubgroupSize; - let wgsl = ` -//enable chromium_experimental_subgroups; - -const stride = ${invocations}; - -@group(0) @binding(0) -var inputs : array; -@group(0) @binding(1) -var ballots : array; - -var subgroup_id : u32; - -@compute @workgroup_size(${invocations},1,1) -fn main( - //@builtin(local_invocation_index) id : u32, -) { - _ = inputs[0]; - _ = ballots[0]; - subgroup_id = 0; // id; - - f0(); -} - -fn testBit(mask : vec4u, id : u32) -> bool { - let xbit = extractBits(mask.x, id, 1); - let ybit = extractBits(mask.y, id - 32, 1); - let zbit = extractBits(mask.z, id - 64, 1); - let wbit = extractBits(mask.w, id - 96, 1); - let lt32 = id < 32; - let lt64 = id < 64; - let lt96 = id < 96; - let sela = select(wbit, xbit, lt96); - let selb = select(zbit, ybit, lt64); - return select(selb, sela, lt32) == 1; -} -`; - - let program : Program = new Program(t.params.style, t.params.seed); - wgsl += generateProgram(program); + let program: Program = new Program(t.params.style, t.params.seed, invocations); + let wgsl = program.generate(); console.log(wgsl); - const num = program.simulate(true, 16, invocations); + const num = program.simulate(true, 16); const pipeline = t.device.createComputePipeline({ layout: 'auto', diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts new file mode 100644 index 000000000000..5661da62a931 --- /dev/null +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -0,0 +1,717 @@ +import { assert, unreachable } from '../../../../common/util/util.js'; +import { PRNG } from '../../../util/prng.js'; + +/** @returns A bitmask where bits [0,size) are 1s. */ +function getMask(size: number): bigint { + return (1n << BigInt(size)) - 1n; +} + +/** @returns A bitmask where submask is repeated every size bits for total bits. */ +function getReplicatedMask(submask: bigint, size: number, total: number = 128): bigint { + const reps = total / size; + var mask: bigint = submask; + for (var i = 1; i < reps; i++) { + mask |= (mask << BigInt(size)); + } + return mask; +} + +/** @returns true if any bit in value is 1. */ +function any(value: bigint): boolean { + return value !== 0n; +} + +/** @returns true if all bits in value from [0, size) are 1. */ +function all(value: bigint, size: number): boolean { + return value === ((1n << BigInt(size) - 1n)); +} + + +export enum Style { + Workgroup, + Subgroup, + Maximal, +}; + +enum OpType { + OpBallot, + + OpStore, + + OpIfMask, + OpElseMask, + OpEndIf, + + OpIfLoopCount, + OpElseLoopCount, + + OpIfLid, + OpElseLid, + + OpBreak, + OpContinue, + + OpForUniform, + OpEndForUniform, + + OpReturn, + + OpMAX, +} + +enum IfType { + IfMask, + IfUniform, + IfLoopCount, + IfLid, +}; + +class Op { + type : OpType; + value : number; + caseValue : number; + + constructor(type : OpType, value: number = 0, caseValue: number = 0) { + this.type = type; + this.value = value; + this.caseValue = caseValue; + } +}; + +export class Program { + private invocations: number; + private readonly prng: PRNG; + private ops : Op[]; + private readonly style: Style; + private readonly minCount: number; + private readonly maxNesting: number; + private nesting: number; + private loopNesting: number; + private loopNestingThisFunction: number; + private numMasks: number; + private readonly masks: number[]; + private curFunc: number; + private functions: string[]; + private indents: number[]; + private storeBase: number; + + /** + * constructor + * + * @param style Enum indicating the type of reconvergence being tested + * @param seed Value used to seed the PRNG + */ + constructor(style : Style = Style.Workgroup, seed: number = 1, invocations: number = 128) { + this.invocations = invocations; + this.prng = new PRNG(seed); + this.ops = []; + this.style = style; + this.minCount = 5; // 30; + this.maxNesting = 5; // this.getRandomUint(70) + 30; // [30,100) + this.nesting = 0; + this.loopNesting = 0; + this.loopNestingThisFunction = 0; + this.numMasks = 10; + this.masks = []; + this.masks.push(0xffffffff); + this.masks.push(0xffffffff); + this.masks.push(0xffffffff); + this.masks.push(0xffffffff); + for (var i = 1; i < this.numMasks; i++) { + this.masks.push(this.getRandomUint(0xffffffff)); + this.masks.push(this.getRandomUint(0xffffffff)); + this.masks.push(this.getRandomUint(0xffffffff)); + this.masks.push(this.getRandomUint(0xffffffff)); + } + this.curFunc = 0; + this.functions = []; + this.functions.push(``); + this.indents = []; + this.indents.push(2); + this.storeBase = 0x10000; + } + + /** @returns A random float between 0 and 1 */ + private getRandomFloat(): number { + return this.prng.random(); + } + + /** @returns A random 32-bit integer between 0 and max. */ + private getRandomUint(max: number): number { + return this.prng.randomU32() % max; + } + + private pickOp(count : number) { + for (var i = 0; i < count; i++) { + //this.genBallot(); + if (this.nesting < this.maxNesting) { + const r = this.getRandomUint(12); + switch (r) { + case 0: { + if (this.loopNesting > 0) { + this.genIf(IfType.IfLoopCount); + break; + } + this.genIf(IfType.IfLid); + break; + } + case 1: { + this.genIf(IfType.IfLid); + break; + } + case 2: { + this.genIf(IfType.IfMask); + break; + } + case 3: { + this.genIf(IfType.IfUniform); + break; + } + case 4: { + if (this.loopNesting <= 3) { + const r2 = this.getRandomUint(3); + switch (r2) { + case 0: this.genForUniform(); break; + case 2: + default: { + break; + } + } + } + break; + } + case 5: { + this.genBreak(); + break; + } + case 6: { + this.genContinue(); + break; + } + default: { + break; + } + } + } + } + } + + private genBallot() { + // Optionally insert ballots, stores, and noise. + // Ballots and stores are used to determine correctness. + if (this.getRandomFloat() < 0.2) { + const cur_length = this.ops.length; + if (cur_length < 2 || + !(this.ops[cur_length - 1].type == OpType.OpBallot || + (this.ops[cur_length-1].type == OpType.OpStore && this.ops[cur_length - 2].type == OpType.OpBallot))) { + // Perform a store with each ballot so the results can be correlated. + if (this.style != Style.Maximal) + this.ops.push(new Op(OpType.OpStore, cur_length + this.storeBase)); + this.ops.push(new Op(OpType.OpBallot, 0)); + } + } + + if (this.getRandomFloat() < 0.1) { + const cur_length = this.ops.length; + if (cur_length < 2 || + !(this.ops[cur_length - 1].type == OpType.OpStore || + (this.ops[cur_length - 1].type == OpType.OpBallot && this.ops[cur_length - 2].type == OpType.OpStore))) { + // Subgroup and workgroup styles do a store with every ballot. + // Don't bloat the code by adding more. + if (this.style == Style.Maximal) + this.ops.push(new Op(OpType.OpStore, cur_length + this.storeBase)); + } + } + + //deUint32 r = this.getRandomUint(10000); + //if (r < 3) { + // ops.push_back({OP_NOISE, 0}); + //} else if (r < 10) { + // ops.push_back({OP_NOISE, 1}); + //} + } + + private genIf(type: IfType) { + let maskIdx = this.getRandomUint(this.numMasks); + if (type == IfType.IfUniform) + maskIdx = 0; + + const lid = this.getRandomUint(128); + if (type == IfType.IfLid) { + this.ops.push(new Op(OpType.OpIfLid, lid)); + } else if (type == IfType.IfLoopCount) { + this.ops.push(new Op(OpType.OpIfLoopCount, 0)); + } else { + this.ops.push(new Op(OpType.OpIfMask, maskIdx)); + } + + this.nesting++; + + let beforeSize = this.ops.length; + this.pickOp(2); + let afterSize = this.ops.length; + + const randElse = this.getRandomFloat(); + if (randElse < 0.5) { + if (type == IfType.IfLid) { + this.ops.push(new Op(OpType.OpElseLid, lid)); + } else if (type == IfType.IfLoopCount) { + this.ops.push(new Op(OpType.OpElseLoopCount, 0)); + } else { + this.ops.push(new Op(OpType.OpElseMask, maskIdx)); + } + + // Sometimes make the else identical to the if. + if (randElse < 0.1 && beforeSize != afterSize) { + for (var i = beforeSize; i < afterSize; i++) { + const op = this.ops[i]; + this.ops.push(new Op(op.type, op.value, op.caseValue)); + } + } else { + this.pickOp(2); + } + } + this.ops.push(new Op(OpType.OpEndIf, 0)); + + this.nesting--; + } + + private genForUniform() { + const n = this.getRandomUint(5) + 1; // [1, 5] + this.ops.push(new Op(OpType.OpForUniform, n)); + const header = this.ops.length - 1; + this.nesting++; + this.loopNesting++; + this.loopNestingThisFunction++; + this.pickOp(2); + this.ops.push(new Op(OpType.OpEndForUniform, header)); + this.loopNestingThisFunction--; + this.loopNesting--; + this.nesting--; + } + + private genBreak() { + if (this.loopNestingThisFunction > 0) + { + // Sometimes put the break in a divergent if + if (this.getRandomFloat() < 0.1) { + const r = this.getRandomUint(this.numMasks-1) + 1; + this.ops.push(new Op(OpType.OpIfMask, r)); + this.ops.push(new Op(OpType.OpBreak, 0)); + this.ops.push(new Op(OpType.OpElseMask, r)); + this.ops.push(new Op(OpType.OpBreak, 0)); + this.ops.push(new Op(OpType.OpEndIf, 0)); + } else { + this.ops.push(new Op(OpType.OpBreak, 0)); + } + } + } + + private genContinue() { + // TODO: need to avoid infinite loops + if (this.loopNestingThisFunction > 0) + { + // Sometimes put the continue in a divergent if + if (this.getRandomFloat() < 0.1) { + const r = this.getRandomUint(this.numMasks-1) + 1; + this.ops.push(new Op(OpType.OpIfMask, r)); + this.ops.push(new Op(OpType.OpContinue, 0)); + this.ops.push(new Op(OpType.OpElseMask, r)); + this.ops.push(new Op(OpType.OpBreak, 0)); + this.ops.push(new Op(OpType.OpEndIf, 0)); + } else { + this.ops.push(new Op(OpType.OpContinue, 0)); + } + } + } + + private genCode(): string { + for (var i = 0; i < this.ops.length; i++) { + const op = this.ops[i]; + this.genIndent() + this.addCode(`// ops[${i}] = ${op.type}\n`); + switch (op.type) { + case OpType.OpBallot: { + this.genIndent(); + this.addCode(`ballots[stride * output_loc + local_id] = subgroupBallot();\n`); + this.genIndent(); + this.addCode(`output_loc++;\n`); + break; + } + case OpType.OpStore: { + this.genIndent(); + this.addCode(`locations[local_id]++;\n`); + this.genIndent(); + this.addCode(`ballots[stride * output_loc + local_id] = vec4u(${op.value},0,0,0);\n`); + this.genIndent(); + this.addCode(`output_loc++;\n`); + break; + } + default: { + this.genIndent(); + this.addCode(`/* missing op ${op.type} */\n`); + break; + } + case OpType.OpIfMask: { + this.genIndent(); + if (op.value == 0) { + const idx = this.getRandomUint(4); + this.addCode(`if inputs[${idx}] == ${idx} {\n`); + } else { + const idx = op.value; + const x = this.masks[4*idx]; + const y = this.masks[4*idx+1]; + const z = this.masks[4*idx+2]; + const w = this.masks[4*idx+3]; + this.addCode(`if testBit(vec4u(${x},${y},${z},${w}), subgroup_id) {\n`); + } + this.increaseIndent(); + break; + } + case OpType.OpIfLid: { + this.genIndent(); + this.addCode(`if subgroup_id < inputs[${op.value}] {\n`); + this.increaseIndent(); + break; + } + case OpType.OpIfLoopCount: { + this.genIndent(); + this.addCode(`if subgroup_id == i${this.loopNesting-1} {\n`); + this.increaseIndent(); + break; + } + case OpType.OpElseMask: + case OpType.OpElseLid: + case OpType.OpElseLoopCount: { + this.decreaseIndent(); + this.genIndent(); + this.addCode(`} else {\n`); + this.increaseIndent(); + break; + } + case OpType.OpEndIf: { + this.decreaseIndent(); + this.genIndent(); + this.addCode(`}\n`); + break; + } + case OpType.OpForUniform: { + this.genIndent(); + const iter = `i${this.loopNesting}`; + this.addCode(`for (var ${iter} = 0u; ${iter} < inputs[${op.value}]; ${iter}++) {\n`); + this.increaseIndent(); + this.loopNesting++; + break; + } + case OpType.OpEndForUniform: { + this.loopNesting--; + this.decreaseIndent(); + this.genIndent(); + this.addCode(`}\n`); + break; + } + case OpType.OpBreak: { + this.genIndent(); + this.addCode(`break;\n`); + break; + } + case OpType.OpBreak: { + this.genIndent(); + this.addCode(`continue;\n`); + break; + } + } + } + + let code: string = ` +//enable chromium_experimental_subgroups; + +const stride = ${this.invocations}; + +@group(0) @binding(0) +var inputs : array; +@group(0) @binding(1) +var ballots : array; +@group(0) @binding(2) +var locations : array; + +var subgroup_id : u32; +var local_id : u32; +var output_loc : u32 = 0; + +@compute @workgroup_size(${this.invocations},1,1) +fn main( + @builtin(local_invocation_index) lid : u32, + //@builtin(subgroup_invocation_id) sid : u32, +) { + _ = inputs[0]; + _ = ballots[0]; + subgroup_id = 0; // sid; + local_id = lid; + + f0(); +} + +fn testBit(mask : vec4u, id : u32) -> bool { + let xbit = extractBits(mask.x, id, 1); + let ybit = extractBits(mask.y, id - 32, 1); + let zbit = extractBits(mask.z, id - 64, 1); + let wbit = extractBits(mask.w, id - 96, 1); + let lt32 = id < 32; + let lt64 = id < 64; + let lt96 = id < 96; + let sela = select(wbit, xbit, lt96); + let selb = select(zbit, ybit, lt64); + return select(selb, sela, lt32) == 1; +} +`; + + for (var i = 0; i < this.functions.length; i++) { + code += ` +fn f${i}() { +${this.functions[i]} +} +`; + } + return code; + } + + private genIndent() { + this.functions[this.curFunc] += ' '.repeat(this.indents[this.curFunc]); + } + private increaseIndent() { + this.indents[this.curFunc] += 2; + } + private decreaseIndent() { + this.indents[this.curFunc] -= 2; + } + private addCode(code: string) { + this.functions[this.curFunc] += code; + } + + public simulate(countOnly: boolean, size: number): number { + class State { + activeMask: bigint; + continueMask: bigint; + header: number; + isLoop: boolean; + tripCount: number; + isCall: boolean; + isSwitch: boolean; + + constructor() { + this.activeMask = 0n; + this.continueMask = 0n; + this.header = 0; + this.isLoop = false; + this.tripCount = 0; + this.isCall = false; + this.isSwitch = false; + } + + copy(other: State) { + this.activeMask = other.activeMask; + this.continueMask = other.continueMask; + this.header = other.header; + this.isLoop = other.isLoop; + this.tripCount = other.tripCount; + this.isCall = other.isCall; + this.isSwitch = other.isSwitch; + } + }; + var stack = new Array(); + stack.push(new State()); + stack[0].activeMask = (1n << 128n) - 1n; + + var nesting = 0; + var loopNesting = 0; + var locs = new Array(this.invocations); + locs.fill(0); + + var i = 0; + while (i < this.ops.length) { + const op = this.ops[i]; + console.log(`ops[${i}] = ${op.type}, nesting = ${nesting}`); + console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); + //for (var j = 0; j <= nesting; j++) { + // console.log(` mask[${j}] = ${stack[j].activeMask.toString(16)}`); + //} + switch (op.type) { + case OpType.OpBallot: { + break; + } + case OpType.OpStore: { + break; + } + case OpType.OpIfMask: { + nesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.copy(stack[nesting-1]); + cur.header = i; + // O is always uniform true. + if (op.value != 0) { + cur.activeMask &= this.getValueMask(op.value); + } + break; + } + case OpType.OpElseMask: { + // 0 is always uniform true so the else will never be taken. + const cur = stack[nesting]; + if (op.value == 0) { + cur.activeMask = 0n; + } else { + const prev = stack[nesting-1]; + cur.activeMask = prev.activeMask & ~this.getValueMask(op.value); + } + break; + } + case OpType.OpIfLid: { + nesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.copy(stack[nesting-1]); + cur.header = i; + // All invocations with subgroup invocation id less than op.value are active. + cur.activeMask &= getReplicatedMask(getMask(op.value), size, this.invocations); + break; + } + case OpType.OpElseLid: { + const prev = stack[nesting-1]; + // All invocations with a subgroup invocation id greater or equal to op.value are active. + stack[nesting].activeMask = prev.activeMask; + stack[nesting].activeMask &= ~getReplicatedMask(getMask(op.value), size, this.invocations); + break; + } + case OpType.OpIfLoopCount: { + let n = nesting; + while (!stack[n].isLoop) { + n--; + } + + nesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.copy(stack[nesting-1]); + cur.header = i; + cur.isLoop = 0; + cur.isSwitch = 0; + cur.activeMask &= getReplicatedMask(BigInt(1 << stack[n].tripCount), size, this.invocations); + break; + } + case OpType.OpElseLoopCount: { + let n = nesting; + while (!stack[n].isLoop) { + n--; + } + + stack[nesting].activeMask = stack[nesting-1].activeMask; + stack[nesting].activeMask &= ~getReplicatedMask(BigInt(1 << stack[n].tripCount), size, this.invocations); + break; + } + case OpType.OpEndIf: { + nesting--; + stack.pop(); + break; + } + case OpType.OpForUniform: { + nesting++; + loopNesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.header = i; + cur.isLoop = true; + cur.activeMask = stack[nesting-1].activeMask; + break; + } + case OpType.OpEndForUniform: { + const cur = stack[nesting]; + cur.tripCount++; + cur.activeMask |= stack[nesting].continueMask; + cur.continueMask = 0n; + // Loop if there are any invocations left with iterations to perform. + if (cur.tripCount < this.ops[cur.header].value && + any(cur.activeMask)) { + i = cur.header + 1; + continue; + } else { + loopNesting--; + nesting--; + stack.pop(); + } + break; + } + case OpType.OpBreak: { + var n = nesting; + var mask: bigint = stack[nesting].activeMask; + while (true) { + stack[n].activeMask &= ~mask; + if (stack[n].isLoop || stack[n].isSwitch) { + break; + } + + n--; + } + break; + } + case OpType.OpContinue: { + var n = nesting; + var mask: bigint = stack[nesting].activeMask; + while (true) { + stack[n].activeMask &= ~mask; + if (stack[n].isLoop) { + stack[n].continueMask |= mask; + break; + } + n--; + } + break; + } + default: { + unreachable(`Unhandled op ${op.type}`); + } + } + i++; + } + + assert(stack.length == 1); + + var maxLoc = 0; + for (var j = 0; j < this.invocations; j++) { + maxLoc = Math.max(maxLoc, locs[j]); + } + return maxLoc; + } + + // Returns an active mask for the mask at the given index. + private getValueMask(idx: number): bigint { + const x = this.masks[4*idx]; + const y = this.masks[4*idx+1]; + const z = this.masks[4*idx+2]; + const w = this.masks[4*idx+3]; + var mask: bigint = 0n; + mask |= BigInt(x); + mask |= BigInt(y) << 32n; + mask |= BigInt(z) << 64n; + mask |= BigInt(w) << 96n; + return mask; + } + + /** @returns a randomized program */ + public generate(): string { + while (this.ops.length < this.minCount) { + this.pickOp(1); + } + + return this.genCode(); + } +}; + +export function generateSeeds(numCases: number): number[] { + var prng: PRNG = new PRNG(1); + var output: number[] = new Array(numCases); + for (var i = 0; i < numCases; i++) { + output[i] = prng.randomU32(); + } + return output; +} From ad268fe4ebed7fb461b4ac192ae93b07e20d0848 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Thu, 10 Aug 2023 14:17:46 -0400 Subject: [PATCH 03/32] More implementation * Refactor executing the program into a helper function * Add some predefined testcases in a separate group * Add comments and improve enum names * Add ballot, store, and return ops --- .../reconvergence/reconvergence.spec.ts | 164 ++++-- .../shader/execution/reconvergence/util.ts | 468 ++++++++++++++---- 2 files changed, 473 insertions(+), 159 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 9ea786b91cc5..895e5e4b6b5f 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -3,12 +3,120 @@ https://github.com/KhronosGroup/VK-GL-CTS/blob/main/external/vulkancts/modules/v import { makeTestGroup } from '../../../../common/framework/test_group.js'; import { GPUTest } from '../../../gpu_test.js'; +import { iterRange, unreachable } from '../../../../common/util/util.js'; import { Style, Program, generateSeeds } from './util.js' export const g = makeTestGroup(GPUTest); -g.test('reconvergence') - .desc(`Test reconvergence`) +function testProgram(t: GPUTest, program: Program) { + let wgsl = program.genCode(); + console.log(wgsl); + + let num = program.simulate(true, 16); + console.log(`Max locations = ${num}`); + + num = program.simulate(true, 32); + console.log(`Max locations = ${num}`); + + const pipeline = t.device.createComputePipeline({ + layout: 'auto', + compute: { + module: t.device.createShaderModule({ + code: wgsl, + }), + entryPoint: 'main', + }, + }); + + const inputBuffer = t.makeBufferWithContents( + new Uint32Array([...iterRange(128, x => x)]), + GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST + ); + t.trackForCleanup(inputBuffer); + + const ballotBuffer = t.device.createBuffer({ + size: 128 * 4, // TODO: FIXME + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, + }); + t.trackForCleanup(ballotBuffer); + + const locationBuffer = t.makeBufferWithContents( + new Uint32Array([...iterRange(program.invocations, x => 0)]), + GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC + ); + t.trackForCleanup(locationBuffer); + + const bindGroup = t.device.createBindGroup({ + layout: pipeline.getBindGroupLayout(0), + entries: [ + { + binding: 0, + resource: { + buffer: inputBuffer + }, + }, + { + binding: 1, + resource: { + buffer: ballotBuffer + }, + }, + { + binding: 2, + resource: { + buffer: locationBuffer + }, + }, + ], + }); + + const encoder = t.device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(pipeline); + pass.setBindGroup(0, bindGroup); + pass.dispatchWorkgroups(1,1,1); + pass.end(); + t.queue.submit([encoder.finish()]); +} + +g.test('predefined_reconvergence') + .desc(`Test reconvergence using some predefined programs`) + .params(u => + u + .combine('test', [...iterRange(3, x => x)] as const) + .beginSubcases() + ) + .fn(t => { + const invocations = 128; // t.device.limits.maxSubgroupSize; + + let program: Program; + switch (t.params.test) { + case 0: { + program = new Program(Style.Workgroup, 1, invocations); + program.predefinedProgram1(); + break; + } + case 1: { + program = new Program(Style.Subgroup, 1, invocations); + program.predefinedProgram1(); + break; + } + case 2: { + program = new Program(Style.Subgroup, 1, invocations); + program.predefinedProgram2(); + break; + } + default: { + program = new Program(); + unreachable('Unhandled testcase'); + } + } + + testProgram(t, program); + }); + +g.test('random_reconvergence') + .desc(`Test reconvergence using randomly generated programs`) .params(u => u .combine('style', [Style.Workgroup, Style.Subgroup, Style.Maximal] as const) @@ -25,53 +133,7 @@ g.test('reconvergence') const invocations = 128; // t.device.limits.maxSubgroupSize; let program: Program = new Program(t.params.style, t.params.seed, invocations); - let wgsl = program.generate(); - console.log(wgsl); - - const num = program.simulate(true, 16); - - const pipeline = t.device.createComputePipeline({ - layout: 'auto', - compute: { - module: t.device.createShaderModule({ - code: wgsl, - }), - entryPoint: 'main', - }, - }); - - //// Helper to create a `size`-byte buffer with binding number `binding`. - //function createBuffer(size: number, binding: number) { - // const buffer = t.device.createBuffer({ - // size, - // usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, - // }); - // t.trackForCleanup(buffer); - - // bindGroupEntries.push({ - // binding, - // resource: { - // buffer, - // }, - // }); - - // return buffer; - //} - - //const bindGroupEntries: GPUBindGroupEntry[] = []; - //const inputBuffer = createBuffer(16, 0); - //const ballotBuffer = createBuffer(16, 1); - - //const bindGroup = t.device.createBindGroup({ - // layout: pipeline.getBindGroupLayout(0), - // entries: bindGroupEntries, - //}); - - //const encoder = t.device.createCommandEncoder(); - //const pass = encoder.beginComputePass(); - //pass.setPipeline(pipeline); - //pass.setBindGroup(0, bindGroup); - //pass.dispatchWorkgroups(1,1,1); - //pass.end(); - //t.queue.submit([encoder.finish()]); + program.generate(); + + testProgram(t, program); }); diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 5661da62a931..2b00a1c800a2 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -16,6 +16,11 @@ function getReplicatedMask(submask: bigint, size: number, total: number = 128): return mask; } +/** @returns true if bit |bit| is set to 1. */ +function testBit(mask: bigint, bit: number): boolean { + return ((mask >> BigInt(bit)) & 0x1n) == 1n; +} + /** @returns true if any bit in value is 1. */ function any(value: bigint): boolean { return value !== 0n; @@ -26,70 +31,93 @@ function all(value: bigint, size: number): boolean { return value === ((1n << BigInt(size) - 1n)); } - export enum Style { + // Workgroup uniform control flow Workgroup, + + // Subgroup uniform control flow Subgroup, + + // Maximal uniformity Maximal, }; enum OpType { - OpBallot, + // Store a ballot. + // During simulation, uniform is set to false if the + // ballot is not fully uniform for the given style. + Ballot, - OpStore, + // Store a literal. + Store, - OpIfMask, - OpElseMask, - OpEndIf, + // if (testBit(mask, subgroup_id)) + // Special case if value == 0: if (inputs[idx] == idx) + IfMask, + ElseMask, + EndIf, - OpIfLoopCount, - OpElseLoopCount, + // Conditional based on loop iteration + // if (subgroup_id == iN) + IfLoopCount, + ElseLoopCount, - OpIfLid, - OpElseLid, + // if (subgroup_id < inputs[value]) + IfLid, + ElseLid, - OpBreak, - OpContinue, + // Break/continue + Break, + Continue, - OpForUniform, - OpEndForUniform, + // for (var i = 0u; i < inputs[value]; i++) + ForUniform, + EndForUniform, - OpReturn, + // Function return + Return, - OpMAX, + MAX, } enum IfType { - IfMask, - IfUniform, - IfLoopCount, - IfLid, + Mask, + Uniform, + LoopCount, + Lid, }; +/** + * Operation in a Program. + * + * Includes the type of operations, an operation specific value and whether or + * not the operation is uniform. + */ class Op { type : OpType; value : number; - caseValue : number; + uniform : boolean; - constructor(type : OpType, value: number = 0, caseValue: number = 0) { + constructor(type : OpType, value: number = 0, uniform: boolean = true) { this.type = type; this.value = value; - this.caseValue = caseValue; + this.uniform = uniform; } }; export class Program { - private invocations: number; + public invocations: number; private readonly prng: PRNG; private ops : Op[]; - private readonly style: Style; + public readonly style: Style; private readonly minCount: number; private readonly maxNesting: number; private nesting: number; private loopNesting: number; private loopNestingThisFunction: number; + private callNesting: number; private numMasks: number; - private readonly masks: number[]; + private masks: number[]; private curFunc: number; private functions: string[]; private indents: number[]; @@ -111,6 +139,7 @@ export class Program { this.nesting = 0; this.loopNesting = 0; this.loopNestingThisFunction = 0; + this.callNesting = 0; this.numMasks = 10; this.masks = []; this.masks.push(0xffffffff); @@ -143,28 +172,28 @@ export class Program { private pickOp(count : number) { for (var i = 0; i < count; i++) { - //this.genBallot(); + this.genBallot(); if (this.nesting < this.maxNesting) { const r = this.getRandomUint(12); switch (r) { case 0: { if (this.loopNesting > 0) { - this.genIf(IfType.IfLoopCount); + this.genIf(IfType.LoopCount); break; } - this.genIf(IfType.IfLid); + this.genIf(IfType.Lid); break; } case 1: { - this.genIf(IfType.IfLid); + this.genIf(IfType.Lid); break; } case 2: { - this.genIf(IfType.IfMask); + this.genIf(IfType.Mask); break; } case 3: { - this.genIf(IfType.IfUniform); + this.genIf(IfType.Uniform); break; } case 4: { @@ -188,6 +217,12 @@ export class Program { this.genContinue(); break; } + case 7: { + // Calls and returns. + // TODO: calls + this.genReturn(); + break; + } default: { break; } @@ -202,24 +237,24 @@ export class Program { if (this.getRandomFloat() < 0.2) { const cur_length = this.ops.length; if (cur_length < 2 || - !(this.ops[cur_length - 1].type == OpType.OpBallot || - (this.ops[cur_length-1].type == OpType.OpStore && this.ops[cur_length - 2].type == OpType.OpBallot))) { + !(this.ops[cur_length - 1].type == OpType.Ballot || + (this.ops[cur_length-1].type == OpType.Store && this.ops[cur_length - 2].type == OpType.Ballot))) { // Perform a store with each ballot so the results can be correlated. if (this.style != Style.Maximal) - this.ops.push(new Op(OpType.OpStore, cur_length + this.storeBase)); - this.ops.push(new Op(OpType.OpBallot, 0)); + this.ops.push(new Op(OpType.Store, cur_length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); } } if (this.getRandomFloat() < 0.1) { const cur_length = this.ops.length; if (cur_length < 2 || - !(this.ops[cur_length - 1].type == OpType.OpStore || - (this.ops[cur_length - 1].type == OpType.OpBallot && this.ops[cur_length - 2].type == OpType.OpStore))) { + !(this.ops[cur_length - 1].type == OpType.Store || + (this.ops[cur_length - 1].type == OpType.Ballot && this.ops[cur_length - 2].type == OpType.Store))) { // Subgroup and workgroup styles do a store with every ballot. // Don't bloat the code by adding more. if (this.style == Style.Maximal) - this.ops.push(new Op(OpType.OpStore, cur_length + this.storeBase)); + this.ops.push(new Op(OpType.Store, cur_length + this.storeBase)); } } @@ -233,16 +268,16 @@ export class Program { private genIf(type: IfType) { let maskIdx = this.getRandomUint(this.numMasks); - if (type == IfType.IfUniform) + if (type == IfType.Uniform) maskIdx = 0; const lid = this.getRandomUint(128); - if (type == IfType.IfLid) { - this.ops.push(new Op(OpType.OpIfLid, lid)); - } else if (type == IfType.IfLoopCount) { - this.ops.push(new Op(OpType.OpIfLoopCount, 0)); + if (type == IfType.Lid) { + this.ops.push(new Op(OpType.IfLid, lid)); + } else if (type == IfType.LoopCount) { + this.ops.push(new Op(OpType.IfLoopCount, 0)); } else { - this.ops.push(new Op(OpType.OpIfMask, maskIdx)); + this.ops.push(new Op(OpType.IfMask, maskIdx)); } this.nesting++; @@ -253,38 +288,38 @@ export class Program { const randElse = this.getRandomFloat(); if (randElse < 0.5) { - if (type == IfType.IfLid) { - this.ops.push(new Op(OpType.OpElseLid, lid)); - } else if (type == IfType.IfLoopCount) { - this.ops.push(new Op(OpType.OpElseLoopCount, 0)); + if (type == IfType.Lid) { + this.ops.push(new Op(OpType.ElseLid, lid)); + } else if (type == IfType.LoopCount) { + this.ops.push(new Op(OpType.ElseLoopCount, 0)); } else { - this.ops.push(new Op(OpType.OpElseMask, maskIdx)); + this.ops.push(new Op(OpType.ElseMask, maskIdx)); } // Sometimes make the else identical to the if. if (randElse < 0.1 && beforeSize != afterSize) { for (var i = beforeSize; i < afterSize; i++) { const op = this.ops[i]; - this.ops.push(new Op(op.type, op.value, op.caseValue)); + this.ops.push(new Op(op.type, op.value, op.uniform)); } } else { this.pickOp(2); } } - this.ops.push(new Op(OpType.OpEndIf, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); this.nesting--; } private genForUniform() { const n = this.getRandomUint(5) + 1; // [1, 5] - this.ops.push(new Op(OpType.OpForUniform, n)); + this.ops.push(new Op(OpType.ForUniform, n)); const header = this.ops.length - 1; this.nesting++; this.loopNesting++; this.loopNestingThisFunction++; this.pickOp(2); - this.ops.push(new Op(OpType.OpEndForUniform, header)); + this.ops.push(new Op(OpType.EndForUniform, header)); this.loopNestingThisFunction--; this.loopNesting--; this.nesting--; @@ -296,13 +331,13 @@ export class Program { // Sometimes put the break in a divergent if if (this.getRandomFloat() < 0.1) { const r = this.getRandomUint(this.numMasks-1) + 1; - this.ops.push(new Op(OpType.OpIfMask, r)); - this.ops.push(new Op(OpType.OpBreak, 0)); - this.ops.push(new Op(OpType.OpElseMask, r)); - this.ops.push(new Op(OpType.OpBreak, 0)); - this.ops.push(new Op(OpType.OpEndIf, 0)); + this.ops.push(new Op(OpType.IfMask, r)); + this.ops.push(new Op(OpType.Break, 0)); + this.ops.push(new Op(OpType.ElseMask, r)); + this.ops.push(new Op(OpType.Break, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); } else { - this.ops.push(new Op(OpType.OpBreak, 0)); + this.ops.push(new Op(OpType.Break, 0)); } } } @@ -314,31 +349,51 @@ export class Program { // Sometimes put the continue in a divergent if if (this.getRandomFloat() < 0.1) { const r = this.getRandomUint(this.numMasks-1) + 1; - this.ops.push(new Op(OpType.OpIfMask, r)); - this.ops.push(new Op(OpType.OpContinue, 0)); - this.ops.push(new Op(OpType.OpElseMask, r)); - this.ops.push(new Op(OpType.OpBreak, 0)); - this.ops.push(new Op(OpType.OpEndIf, 0)); + this.ops.push(new Op(OpType.IfMask, r)); + this.ops.push(new Op(OpType.Continue, 0)); + this.ops.push(new Op(OpType.ElseMask, r)); + this.ops.push(new Op(OpType.Break, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); } else { - this.ops.push(new Op(OpType.OpContinue, 0)); + this.ops.push(new Op(OpType.Continue, 0)); } } } - private genCode(): string { + private genReturn() { + const r = this.getRandomFloat(); + if (this.nesting > 0 && + (r < 0.05 || + (this.callNesting > 0 && this.loopNestingThisFunction > 0 && r < 0.2) || + (this.callNesting > 0 && this.loopNestingThisFunction > 1 && r < 0.5))) { + this.genBallot(); + if (this.getRandomFloat() < 0.1) { + this.ops.push(new Op(OpType.IfMask, 0)); + this.ops.push(new Op(OpType.Return, 0)); + this.ops.push(new Op(OpType.ElseMask, 0)); + this.ops.push(new Op(OpType.Return, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); + } else { + this.ops.push(new Op(OpType.Return, 0)); + } + } + } + + public genCode(): string { for (var i = 0; i < this.ops.length; i++) { const op = this.ops[i]; this.genIndent() this.addCode(`// ops[${i}] = ${op.type}\n`); switch (op.type) { - case OpType.OpBallot: { + case OpType.Ballot: { this.genIndent(); - this.addCode(`ballots[stride * output_loc + local_id] = subgroupBallot();\n`); + // TODO: FIXME + this.addCode(`\/\/ ballots[stride * output_loc + local_id] = subgroupBallot();\n`); this.genIndent(); this.addCode(`output_loc++;\n`); break; } - case OpType.OpStore: { + case OpType.Store: { this.genIndent(); this.addCode(`locations[local_id]++;\n`); this.genIndent(); @@ -352,7 +407,7 @@ export class Program { this.addCode(`/* missing op ${op.type} */\n`); break; } - case OpType.OpIfMask: { + case OpType.IfMask: { this.genIndent(); if (op.value == 0) { const idx = this.getRandomUint(4); @@ -368,34 +423,34 @@ export class Program { this.increaseIndent(); break; } - case OpType.OpIfLid: { + case OpType.IfLid: { this.genIndent(); this.addCode(`if subgroup_id < inputs[${op.value}] {\n`); this.increaseIndent(); break; } - case OpType.OpIfLoopCount: { + case OpType.IfLoopCount: { this.genIndent(); this.addCode(`if subgroup_id == i${this.loopNesting-1} {\n`); this.increaseIndent(); break; } - case OpType.OpElseMask: - case OpType.OpElseLid: - case OpType.OpElseLoopCount: { + case OpType.ElseMask: + case OpType.ElseLid: + case OpType.ElseLoopCount: { this.decreaseIndent(); this.genIndent(); this.addCode(`} else {\n`); this.increaseIndent(); break; } - case OpType.OpEndIf: { + case OpType.EndIf: { this.decreaseIndent(); this.genIndent(); this.addCode(`}\n`); break; } - case OpType.OpForUniform: { + case OpType.ForUniform: { this.genIndent(); const iter = `i${this.loopNesting}`; this.addCode(`for (var ${iter} = 0u; ${iter} < inputs[${op.value}]; ${iter}++) {\n`); @@ -403,23 +458,28 @@ export class Program { this.loopNesting++; break; } - case OpType.OpEndForUniform: { + case OpType.EndForUniform: { this.loopNesting--; this.decreaseIndent(); this.genIndent(); this.addCode(`}\n`); break; } - case OpType.OpBreak: { + case OpType.Break: { this.genIndent(); this.addCode(`break;\n`); break; } - case OpType.OpBreak: { + case OpType.Continue: { this.genIndent(); this.addCode(`continue;\n`); break; } + case OpType.Return: { + this.genIndent(); + this.addCode(`return;\n`); + break; + } } } @@ -446,6 +506,7 @@ fn main( ) { _ = inputs[0]; _ = ballots[0]; + _ = locations[0]; subgroup_id = 0; // sid; local_id = lid; @@ -476,20 +537,36 @@ ${this.functions[i]} return code; } + /** + * Adds indentation to the code for the current function. + */ private genIndent() { this.functions[this.curFunc] += ' '.repeat(this.indents[this.curFunc]); } + + /** + * Increase the amount of indenting for the current function. + */ private increaseIndent() { this.indents[this.curFunc] += 2; } + + /** + * Decrease the amount of indenting for the current function. + */ private decreaseIndent() { this.indents[this.curFunc] -= 2; } + + /** + * Adds 'code' to the current function + */ private addCode(code: string) { this.functions[this.curFunc] += code; } - public simulate(countOnly: boolean, size: number): number { + // TODO: Reconvergence guarantees are not as strong as this simulation. + public simulate(countOnly: boolean, subgroupSize: number): number { class State { activeMask: bigint; continueMask: bigint; @@ -537,13 +614,51 @@ ${this.functions[i]} // console.log(` mask[${j}] = ${stack[j].activeMask.toString(16)}`); //} switch (op.type) { - case OpType.OpBallot: { + case OpType.Ballot: { + const curMask = stack[nesting].activeMask; + // Flag if this ballot is not workgroup uniform. + if (this.style == Style.Workgroup && any(curMask) && !all(curMask, this.invocations)) { + op.uniform = false; + } + + // Flag if this ballot is not subgroup uniform. + if (this.style == Style.Subgroup) { + for (var id = 0; id < this.invocations; id += subgroupSize) { + const subgroupMask = (curMask >> BigInt(id)) & getMask(subgroupSize); + if (subgroupMask != 0n && !all(subgroupMask, subgroupSize)) { + op.uniform = false; + } + } + } + + for (var id = 0; id < this.invocations; id++) { + if (testBit(curMask, id)) { + if (countOnly) { + locs[id]++; + } else { + // if (op.caseValue == 1) { + // // Emit a magic value to indicate that we shouldn't validate this ballot + // ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(0x12345678, subgroupSize, id); + // } else { + // ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id); + // } + } + } + } break; } - case OpType.OpStore: { + case OpType.Store: { + for (var id = 0; id < 128; id++) { + if (testBit(stack[nesting].activeMask, id)) { + if (countOnly) + locs[id]++; + //else + // ref[(outLoc[id]++)*invocationStride + id] = ops[i].value; + } + } break; } - case OpType.OpIfMask: { + case OpType.IfMask: { nesting++; stack.push(new State()); const cur = stack[nesting]; @@ -555,7 +670,7 @@ ${this.functions[i]} } break; } - case OpType.OpElseMask: { + case OpType.ElseMask: { // 0 is always uniform true so the else will never be taken. const cur = stack[nesting]; if (op.value == 0) { @@ -566,24 +681,25 @@ ${this.functions[i]} } break; } - case OpType.OpIfLid: { + case OpType.IfLid: { nesting++; stack.push(new State()); const cur = stack[nesting]; cur.copy(stack[nesting-1]); cur.header = i; // All invocations with subgroup invocation id less than op.value are active. - cur.activeMask &= getReplicatedMask(getMask(op.value), size, this.invocations); + cur.activeMask &= getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); break; } - case OpType.OpElseLid: { + case OpType.ElseLid: { const prev = stack[nesting-1]; // All invocations with a subgroup invocation id greater or equal to op.value are active. stack[nesting].activeMask = prev.activeMask; - stack[nesting].activeMask &= ~getReplicatedMask(getMask(op.value), size, this.invocations); + stack[nesting].activeMask &= ~getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); break; } - case OpType.OpIfLoopCount: { + case OpType.IfLoopCount: { + // Branch based on the subgroup invocation id == loop iteration. let n = nesting; while (!stack[n].isLoop) { n--; @@ -596,27 +712,32 @@ ${this.functions[i]} cur.header = i; cur.isLoop = 0; cur.isSwitch = 0; - cur.activeMask &= getReplicatedMask(BigInt(1 << stack[n].tripCount), size, this.invocations); + cur.activeMask &= getReplicatedMask(BigInt(1 << stack[n].tripCount), subgroupSize, this.invocations); break; } - case OpType.OpElseLoopCount: { + case OpType.ElseLoopCount: { + // Execute the else of the loop count conditional. It includes all + // invocations whose subgroup invocation id does not match the + // current iteration count. let n = nesting; while (!stack[n].isLoop) { n--; } stack[nesting].activeMask = stack[nesting-1].activeMask; - stack[nesting].activeMask &= ~getReplicatedMask(BigInt(1 << stack[n].tripCount), size, this.invocations); - break; + stack[nesting].activeMask &= ~getReplicatedMask(BigInt(1 << stack[n].tripCount), subgroupSize, this.invocations); + break; } - case OpType.OpEndIf: { + case OpType.EndIf: { + // End the current if. nesting--; stack.pop(); break; } - case OpType.OpForUniform: { + case OpType.ForUniform: { + // New uniform for loop. nesting++; - loopNesting++; + loopNesting++; stack.push(new State()); const cur = stack[nesting]; cur.header = i; @@ -624,12 +745,12 @@ ${this.functions[i]} cur.activeMask = stack[nesting-1].activeMask; break; } - case OpType.OpEndForUniform: { + case OpType.EndForUniform: { + // Determine which invocations have another iteration of the loop to execute. const cur = stack[nesting]; cur.tripCount++; cur.activeMask |= stack[nesting].continueMask; cur.continueMask = 0n; - // Loop if there are any invocations left with iterations to perform. if (cur.tripCount < this.ops[cur.header].value && any(cur.activeMask)) { i = cur.header + 1; @@ -641,9 +762,10 @@ ${this.functions[i]} } break; } - case OpType.OpBreak: { - var n = nesting; - var mask: bigint = stack[nesting].activeMask; + case OpType.Break: { + // Remove this active mask from all stack entries for the current loop/switch. + let n = nesting; + let mask: bigint = stack[nesting].activeMask; while (true) { stack[n].activeMask &= ~mask; if (stack[n].isLoop || stack[n].isSwitch) { @@ -654,9 +776,11 @@ ${this.functions[i]} } break; } - case OpType.OpContinue: { - var n = nesting; - var mask: bigint = stack[nesting].activeMask; + case OpType.Continue: { + // Remove this active mask from stack entries in this loop. + // Add this mask to the loop's continue mask for the next iteration. + let n = nesting; + let mask: bigint = stack[nesting].activeMask; while (true) { stack[n].activeMask &= ~mask; if (stack[n].isLoop) { @@ -667,6 +791,17 @@ ${this.functions[i]} } break; } + case OpType.Return: { + // Remove this active mask from all stack entries for this function. + let mask: bigint = stack[nesting].activeMask; + for (var n = nesting; n >= 0; n--) { + stack[n].activeMask &= ~mask; + if (stack[n].isCall) { + break; + } + } + break; + } default: { unreachable(`Unhandled op ${op.type}`); } @@ -698,12 +833,129 @@ ${this.functions[i]} } /** @returns a randomized program */ - public generate(): string { + public generate() { while (this.ops.length < this.minCount) { this.pickOp(1); } + } + + /** + * Equivalent to: + * + * ballot(); // fully uniform + * if (inputs[1] == 1) { + * ballot(); // fullly uniform + * for (var i = 0; i < 3; i++) { + * ballot(); // Simulation expects fully uniform, WGSL does not. + * if (testBit(vec4u(0xaaaaaaaa,0xaaaaaaa,0xaaaaaaaa,0xaaaaaaaa), subgroup_id)) { + * ballot(); // non-uniform + * continue; + * } + * ballot(); // non-uniform + * } + * ballot(); // fully uniform + * } + * ballot(); // fully uniform + */ + public predefinedProgram1() { + // Set the mask for index 1 + this.masks[4*1 + 0] = 0xaaaaaaaa + this.masks[4*1 + 1] = 0xaaaaaaaa + this.masks[4*1 + 2] = 0xaaaaaaaa + this.masks[4*1 + 3] = 0xaaaaaaaa + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.IfMask, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.ForUniform, 3)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.IfMask, 1)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.Continue, 0)); + + this.ops.push(new Op(OpType.EndIf, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndForUniform, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + } + + /** + * Equivalent to: + * + * ballot(); // uniform + * if (subgroup_id < 16) { + * ballot(); // 0xffff + * if (testbit(vec4u(0x00ff00ff,00ff00ff,00ff00ff,00ff00ff), subgroup_id)) { + * ballot(); // 0xff + * if (inputs[1] == 1) { + * ballot(); // 0xff + * } + * ballot(); // 0xff + * } else { + * ballot(); // 0xF..0000 + * return; + * } + * ballot; // 0xffff + * + * In this program, subgroups larger than 16 invocations diverge at the first if. + * Subgroups larger than 8 diverge at the second if. + * No divergence at the third if. + * The else of the first if returns, so the final ballot is only uniform for subgroups <= 16. + */ + public predefinedProgram2() { + // Set the mask for index 1 + this.masks[4*1 + 0] = 0x00ff00ff + this.masks[4*1 + 1] = 0x00ff00ff + this.masks[4*1 + 2] = 0x00ff00ff + this.masks[4*1 + 3] = 0x00ff00ff + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.IfLid, 16)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.IfMask, 1)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.IfMask, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.ElseLid, 16)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.Return, 16)); + + this.ops.push(new Op(OpType.EndIf, 16)); - return this.genCode(); + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); } }; From 59fd1b1d17edf190db03eeaf55b617b8bb46b9b9 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Mon, 14 Aug 2023 15:24:49 -0400 Subject: [PATCH 04/32] More implementation * Very basic result checking --- .../reconvergence/reconvergence.spec.ts | 127 +++++++++++++++++- .../shader/execution/reconvergence/util.ts | 84 +++++++++--- 2 files changed, 187 insertions(+), 24 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 895e5e4b6b5f..9c9f45432053 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -3,20 +3,93 @@ https://github.com/KhronosGroup/VK-GL-CTS/blob/main/external/vulkancts/modules/v import { makeTestGroup } from '../../../../common/framework/test_group.js'; import { GPUTest } from '../../../gpu_test.js'; -import { iterRange, unreachable } from '../../../../common/util/util.js'; +import { + assert, + iterRange, + TypedArrayBufferViewConstructor, + unreachable +} from '../../../../common/util/util.js'; import { Style, Program, generateSeeds } from './util.js' export const g = makeTestGroup(GPUTest); +/** + * @returns The population count of input. + */ +function popcount(input: number): number { + let n = input; + n = n - ((n >> 1) & 0x55555555) + n = (n & 0x33333333) + ((n >> 2) & 0x33333333) + return ((n + (n >> 4) & 0xF0F0F0F) * 0x1010101) >> 24 +} + +class SizeReference { + private x: number; + constructor(n: number = 0) { + this.x = n; + } + set value(n : number) { + this.x = n; + } + get value(): number { + return this.x; + } +}; + +/** + * Checks that subgroup size reported by the shader is consistent. + * + * @param data GPUBuffer that stores the builtin value and uniform ballot count. + * @param min The device reported minimum subgroup size + * @param max The device reported maximum subgroup size + * + * @returns an error if either the builtin value or ballot count is outside [min, max], + * not a a power of 2, or they do not match. + */ +function checkSubgroupSizeConsistency(data: Uint32Array, min: number, max: number, sizeRef: SizeReference): Error | undefined { + const builtin: number = data[0]; + const ballot: number = data[1]; + sizeRef.value = builtin; + if (popcount(builtin) != 1) + return new Error(`Subgroup size builtin value (${builtin}) is not a power of two`); + if (builtin < min) + return new Error(`Subgroup size builtin value (${builtin}) is less than device minimum ${min}`); + if (max < builtin) + return new Error(`Subgroup size builtin value (${builtin}) is greater than device maximum ${max}`); + + if (popcount(ballot) != 1) + return new Error(`Subgroup size ballot value (${builtin}) is not a power of two`); + if (ballot < min) + return new Error(`Subgroup size ballot value (${ballot}) is less than device minimum ${min}`); + if (max < ballot) + return new Error(`Subgroup size ballot value (${ballot}) is greater than device maximum ${max}`); + + if (builtin != ballot) { + return new Error(`Subgroup size mismatch: + - builtin value = ${builtin} + - ballot = ${ballot} +`); + } + return undefined; +} + function testProgram(t: GPUTest, program: Program) { let wgsl = program.genCode(); console.log(wgsl); - let num = program.simulate(true, 16); - console.log(`Max locations = ${num}`); + // TODO: query the device + const minSubgroupSize = 4; + const maxSubgroupSize = 128; - num = program.simulate(true, 32); - console.log(`Max locations = ${num}`); + let numLocs = 0; + const locMap = new Map(); + for (var size = minSubgroupSize; size <= maxSubgroupSize; size *= 2) { + let num = program.simulate(true, size); + locMap.set(size, num); + numLocs = Math.max(num, numLocs); + } + // Add 1 to ensure there are no extraneous writes. + numLocs++; const pipeline = t.device.createComputePipeline({ layout: 'auto', @@ -35,7 +108,8 @@ function testProgram(t: GPUTest, program: Program) { t.trackForCleanup(inputBuffer); const ballotBuffer = t.device.createBuffer({ - size: 128 * 4, // TODO: FIXME + // Each location stores 16 bytes per invocation. + size: numLocs * program.invocations * 4 * 4, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, }); t.trackForCleanup(ballotBuffer); @@ -46,6 +120,12 @@ function testProgram(t: GPUTest, program: Program) { ); t.trackForCleanup(locationBuffer); + const sizeBuffer = t.makeBufferWithContents( + new Uint32Array([...iterRange(2, x => 0)]), + GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC + ); + t.trackForCleanup(sizeBuffer); + const bindGroup = t.device.createBindGroup({ layout: pipeline.getBindGroupLayout(0), entries: [ @@ -67,6 +147,12 @@ function testProgram(t: GPUTest, program: Program) { buffer: locationBuffer }, }, + { + binding: 3, + resource: { + buffer: sizeBuffer + }, + }, ], }); @@ -77,6 +163,29 @@ function testProgram(t: GPUTest, program: Program) { pass.dispatchWorkgroups(1,1,1); pass.end(); t.queue.submit([encoder.finish()]); + + const actualSize = new SizeReference(0); + + t.expectGPUBufferValuesPassCheck( + sizeBuffer, + a => checkSubgroupSizeConsistency(a, minSubgroupSize, maxSubgroupSize, actualSize), + { + srcByteOffset: 0, + type: Uint32Array, + typedLength: 2, + method: 'copy', + mode: 'fail', + } + ); + + console.log(`Actual subgroup size = ${actualSize.value}`); + for (var i = minSubgroupSize; i <= maxSubgroupSize; i *= 2) { + console.log(` Simulated locs for size ${i} = ${locMap.get(i)}`); + } + program.sizeRefData(locMap.get(actualSize.value)); + console.log(`RefData length = ${program.refData.length}`); + //let num = program.simulate(false, actualSize.value); + //assert(num === locMap.get(actualSize.value)); } g.test('predefined_reconvergence') @@ -86,6 +195,9 @@ g.test('predefined_reconvergence') .combine('test', [...iterRange(3, x => x)] as const) .beginSubcases() ) + //.beforeAllSubcases(t => { + // t.selectDeviceOrSkipTestCase({ requiredFeatures: ['chromium-experimental-subgroups'] }); + //}) .fn(t => { const invocations = 128; // t.device.limits.maxSubgroupSize; @@ -129,6 +241,9 @@ g.test('random_reconvergence') }) .beginSubcases() ) + //.beforeAllSubcases(t => { + // t.selectDeviceOrSkipTestCase({requiredFeatures: ['chromium-experimental-subgroups']}); + //}) .fn(t => { const invocations = 128; // t.device.limits.maxSubgroupSize; diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 2b00a1c800a2..4e9f2c7323ec 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -16,6 +16,25 @@ function getReplicatedMask(submask: bigint, size: number, total: number = 128): return mask; } +/** + * Produce the subgroup mask for local invocation |id| within |fullMask| + * + * @param fullMask The active mask for the full workgroup + * @param size The subgroup size + * @param id The local invocation index + * + * @returns A Uint32Array with 4 elements containing the subgroup mask. + */ +function getSubgroupMask(fullMask: bigint, size: number, id: number): Uint32Array { + const arr: Uint32Array = new Uint32Array(4); + let mask: bigint = fullMask >> BigInt((id / size) * size); + arr[0] = Number(BigInt.asUintN(32, mask)); + arr[1] = Number(BigInt.asUintN(32, mask >> 32n)); + arr[2] = Number(BigInt.asUintN(32, mask >> 64n)); + arr[3] = Number(BigInt.asUintN(32, mask >> 96n)); + return arr; +} + /** @returns true if bit |bit| is set to 1. */ function testBit(mask: bigint, bit: number): boolean { return ((mask >> BigInt(bit)) & 0x1n) == 1n; @@ -122,6 +141,7 @@ export class Program { private functions: string[]; private indents: number[]; private storeBase: number; + public refData: Uint32Array; /** * constructor @@ -158,6 +178,7 @@ export class Program { this.indents = []; this.indents.push(2); this.storeBase = 0x10000; + this.refData = new Uint32Array(); } /** @returns A random float between 0 and 1 */ @@ -387,8 +408,7 @@ export class Program { switch (op.type) { case OpType.Ballot: { this.genIndent(); - // TODO: FIXME - this.addCode(`\/\/ ballots[stride * output_loc + local_id] = subgroupBallot();\n`); + this.addCode(`ballots[stride * output_loc + local_id] = subgroupBallot();\n`); this.genIndent(); this.addCode(`output_loc++;\n`); break; @@ -484,7 +504,7 @@ export class Program { } let code: string = ` -//enable chromium_experimental_subgroups; +enable chromium_experimental_subgroups; const stride = ${this.invocations}; @@ -494,6 +514,8 @@ var inputs : array; var ballots : array; @group(0) @binding(2) var locations : array; +@group(0) @binding(3) +var size : array; var subgroup_id : u32; var local_id : u32; @@ -502,14 +524,24 @@ var output_loc : u32 = 0; @compute @workgroup_size(${this.invocations},1,1) fn main( @builtin(local_invocation_index) lid : u32, - //@builtin(subgroup_invocation_id) sid : u32, + @builtin(subgroup_invocation_id) sid : u32, + @builtin(subgroup_size) sg_size : u32, ) { _ = inputs[0]; _ = ballots[0]; _ = locations[0]; - subgroup_id = 0; // sid; + subgroup_id = sid; local_id = lid; + // Store the subgroup size from the built-in value and ballot to check for + // consistency. + let b = subgroupBallot(); + if lid == 0 { + size[0] = sg_size; + let count = countOneBits(b); + size[1] = count.x + count.y + count.z + count.w; + } + f0(); } @@ -565,6 +597,11 @@ ${this.functions[i]} this.functions[this.curFunc] += code; } + public sizeRefData(locs: number) { + this.refData = new Uint32Array(locs * 4 * this.invocations); + this.refData.fill(0); + } + // TODO: Reconvergence guarantees are not as strong as this simulation. public simulate(countOnly: boolean, subgroupSize: number): number { class State { @@ -605,6 +642,7 @@ ${this.functions[i]} var locs = new Array(this.invocations); locs.fill(0); + console.log(`Simulating subgroup size = ${subgroupSize}`); var i = 0; while (i < this.ops.length) { const op = this.ops[i]; @@ -633,16 +671,22 @@ ${this.functions[i]} for (var id = 0; id < this.invocations; id++) { if (testBit(curMask, id)) { - if (countOnly) { - locs[id]++; - } else { - // if (op.caseValue == 1) { - // // Emit a magic value to indicate that we shouldn't validate this ballot - // ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(0x12345678, subgroupSize, id); - // } else { - // ref[(outLoc[id]++)*invocationStride + id] = bitsetToU64(stateStack[nesting].activeMask, subgroupSize, id); - // } + if (!countOnly) { + if (!op.uniform) { + // Emit a magic value to indicate that we shouldn't validate this ballot + this.refData[4 * locs[id] * this.invocations + id + 0] = 0x12345678 + this.refData[4 * locs[id] * this.invocations + id + 1] = 0x12345678 + this.refData[4 * locs[id] * this.invocations + id + 2] = 0x12345678 + this.refData[4 * locs[id] * this.invocations + id + 3] = 0x12345678 + } else { + let mask = getSubgroupMask(curMask, subgroupSize, id); + this.refData[4 * locs[id] * this.invocations + id + 0] = mask[0]; + this.refData[4 * locs[id] * this.invocations + id + 1] = mask[1]; + this.refData[4 * locs[id] * this.invocations + id + 2] = mask[2]; + this.refData[4 * locs[id] * this.invocations + id + 3] = mask[3]; + } } + locs[id]++; } } break; @@ -650,10 +694,13 @@ ${this.functions[i]} case OpType.Store: { for (var id = 0; id < 128; id++) { if (testBit(stack[nesting].activeMask, id)) { - if (countOnly) - locs[id]++; - //else - // ref[(outLoc[id]++)*invocationStride + id] = ops[i].value; + if (!countOnly) { + this.refData[4 * locs[id]++ * this.invocations + id + 0] = op.value; + this.refData[4 * locs[id]++ * this.invocations + id + 1] = 0; + this.refData[4 * locs[id]++ * this.invocations + id + 2] = 0; + this.refData[4 * locs[id]++ * this.invocations + id + 3] = 0; + } + locs[id]++; } } break; @@ -815,6 +862,7 @@ ${this.functions[i]} for (var j = 0; j < this.invocations; j++) { maxLoc = Math.max(maxLoc, locs[j]); } + console.log(`Max location = ${maxLoc}\n`); return maxLoc; } From 1136df0fa09be4773528d691cc78169025778d25 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Mon, 14 Aug 2023 17:46:37 -0400 Subject: [PATCH 05/32] switch readback style --- .../reconvergence/reconvergence.spec.ts | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 9c9f45432053..bbf331607f52 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -23,19 +23,6 @@ function popcount(input: number): number { return ((n + (n >> 4) & 0xF0F0F0F) * 0x1010101) >> 24 } -class SizeReference { - private x: number; - constructor(n: number = 0) { - this.x = n; - } - set value(n : number) { - this.x = n; - } - get value(): number { - return this.x; - } -}; - /** * Checks that subgroup size reported by the shader is consistent. * @@ -46,10 +33,9 @@ class SizeReference { * @returns an error if either the builtin value or ballot count is outside [min, max], * not a a power of 2, or they do not match. */ -function checkSubgroupSizeConsistency(data: Uint32Array, min: number, max: number, sizeRef: SizeReference): Error | undefined { +function checkSubgroupSizeConsistency(data: Uint32Array, min: number, max: number): Error | undefined { const builtin: number = data[0]; const ballot: number = data[1]; - sizeRef.value = builtin; if (popcount(builtin) != 1) return new Error(`Subgroup size builtin value (${builtin}) is not a power of two`); if (builtin < min) @@ -73,7 +59,7 @@ function checkSubgroupSizeConsistency(data: Uint32Array, min: number, max: numbe return undefined; } -function testProgram(t: GPUTest, program: Program) { +async function testProgram(t: GPUTest, program: Program) { let wgsl = program.genCode(); console.log(wgsl); @@ -164,28 +150,42 @@ function testProgram(t: GPUTest, program: Program) { pass.end(); t.queue.submit([encoder.finish()]); - const actualSize = new SizeReference(0); + console.log(`READBACK NOW`); - t.expectGPUBufferValuesPassCheck( + const sizeReadback = await t.readGPUBufferRangeTyped( sizeBuffer, - a => checkSubgroupSizeConsistency(a, minSubgroupSize, maxSubgroupSize, actualSize), { srcByteOffset: 0, type: Uint32Array, typedLength: 2, method: 'copy', - mode: 'fail', } ); - - console.log(`Actual subgroup size = ${actualSize.value}`); - for (var i = minSubgroupSize; i <= maxSubgroupSize; i *= 2) { - console.log(` Simulated locs for size ${i} = ${locMap.get(i)}`); - } - program.sizeRefData(locMap.get(actualSize.value)); - console.log(`RefData length = ${program.refData.length}`); - //let num = program.simulate(false, actualSize.value); - //assert(num === locMap.get(actualSize.value)); + console.log(`POST READBACK`); + const sizeData: Uint32Array = sizeReadback.data; + const actualSize = sizeData[0]; + console.log(`Actual subgroup size = ${actualSize}`); + //t.expectOK(checkSubgroupSizeConsistency(sizeData, minSubgroupSize, maxSubgroupSize)); + + //t.expectGPUBufferValuesPassCheck( + // sizeBuffer, + // a => checkSubgroupSizeConsistency(a, minSubgroupSize, maxSubgroupSize, actualSize), + // { + // srcByteOffset: 0, + // type: Uint32Array, + // typedLength: 2, + // method: 'copy', + // mode: 'fail', + // } + //); + + //for (var i = minSubgroupSize; i <= maxSubgroupSize; i *= 2) { + // console.log(` Simulated locs for size ${i} = ${locMap.get(i)}`); + //} + //program.sizeRefData(locMap.get(actualSize)); + //console.log(`RefData length = ${program.refData.length}`); + //let num = program.simulate(false, actualSize); + //assert(num === locMap.get(actualSize)); } g.test('predefined_reconvergence') From e6568f79412e00e42b746ca69165adbc1f76ff14 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Mon, 14 Aug 2023 18:07:15 -0400 Subject: [PATCH 06/32] fix sync issue --- .../shader/execution/reconvergence/reconvergence.spec.ts | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index bbf331607f52..75f4aeb17984 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -150,8 +150,6 @@ async function testProgram(t: GPUTest, program: Program) { pass.end(); t.queue.submit([encoder.finish()]); - console.log(`READBACK NOW`); - const sizeReadback = await t.readGPUBufferRangeTyped( sizeBuffer, { @@ -161,11 +159,10 @@ async function testProgram(t: GPUTest, program: Program) { method: 'copy', } ); - console.log(`POST READBACK`); const sizeData: Uint32Array = sizeReadback.data; const actualSize = sizeData[0]; console.log(`Actual subgroup size = ${actualSize}`); - //t.expectOK(checkSubgroupSizeConsistency(sizeData, minSubgroupSize, maxSubgroupSize)); + t.expectOK(checkSubgroupSizeConsistency(sizeData, minSubgroupSize, maxSubgroupSize)); //t.expectGPUBufferValuesPassCheck( // sizeBuffer, @@ -244,11 +241,11 @@ g.test('random_reconvergence') //.beforeAllSubcases(t => { // t.selectDeviceOrSkipTestCase({requiredFeatures: ['chromium-experimental-subgroups']}); //}) - .fn(t => { + .fn(async t => { const invocations = 128; // t.device.limits.maxSubgroupSize; let program: Program = new Program(t.params.style, t.params.seed, invocations); program.generate(); - testProgram(t, program); + await testProgram(t, program); }); From 354e66c6bc887821f43d5d4aa1326f6083371450 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Thu, 17 Aug 2023 09:41:43 -0400 Subject: [PATCH 07/32] More implementation * Switch from var to let in most places * Add buffer checking * Fix some bugs: * all * simulation of IfId --- .../reconvergence/reconvergence.spec.ts | 147 +++++-- .../shader/execution/reconvergence/util.ts | 399 ++++++++++++++---- 2 files changed, 435 insertions(+), 111 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 75f4aeb17984..d19dda2cfcd3 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -9,7 +9,12 @@ import { TypedArrayBufferViewConstructor, unreachable } from '../../../../common/util/util.js'; -import { Style, Program, generateSeeds } from './util.js' +import { + hex, + Style, + Program, + generateSeeds +} from './util.js' export const g = makeTestGroup(GPUTest); @@ -59,6 +64,20 @@ function checkSubgroupSizeConsistency(data: Uint32Array, min: number, max: numbe return undefined; } +/** + * Checks the mapping of subgroup_invocation_id to local_invocation_index + */ +function checkIds(data: Uint32Array, subgroupSize: number): Error | undefined { + for (let i = 0; i < data.length; i++) { + if (data[i] !== (i % subgroupSize)) { + return Error(`subgroup_invocation_id does not map as assumed to local_invocation_index: +location_invocation_index = ${i} +subgroup_invocation_id = ${data[i]}`); + } + } + return undefined; +} + async function testProgram(t: GPUTest, program: Program) { let wgsl = program.genCode(); console.log(wgsl); @@ -69,7 +88,7 @@ async function testProgram(t: GPUTest, program: Program) { let numLocs = 0; const locMap = new Map(); - for (var size = minSubgroupSize; size <= maxSubgroupSize; size *= 2) { + for (let size = minSubgroupSize; size <= maxSubgroupSize; size *= 2) { let num = program.simulate(true, size); locMap.set(size, num); numLocs = Math.max(num, numLocs); @@ -87,31 +106,42 @@ async function testProgram(t: GPUTest, program: Program) { }, }); + // Inputs have a value equal to their index. const inputBuffer = t.makeBufferWithContents( new Uint32Array([...iterRange(128, x => x)]), GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST ); t.trackForCleanup(inputBuffer); - const ballotBuffer = t.device.createBuffer({ - // Each location stores 16 bytes per invocation. - size: numLocs * program.invocations * 4 * 4, - usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC, - }); + // Each location stores 4 uint32s per invocation. + const ballotLength = numLocs * program.invocations * 4; + const ballotBuffer = t.makeBufferWithContents( + new Uint32Array([...iterRange(ballotLength, x => 0)]), + GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC + ); t.trackForCleanup(ballotBuffer); + const locationLength = program.invocations; const locationBuffer = t.makeBufferWithContents( - new Uint32Array([...iterRange(program.invocations, x => 0)]), + new Uint32Array([...iterRange(locationLength, x => 0)]), GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC ); t.trackForCleanup(locationBuffer); + const sizeLength = 2; const sizeBuffer = t.makeBufferWithContents( - new Uint32Array([...iterRange(2, x => 0)]), + new Uint32Array([...iterRange(sizeLength, x => 0)]), GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC ); t.trackForCleanup(sizeBuffer); + const idLength = program.invocations; + const idBuffer = t.makeBufferWithContents( + new Uint32Array([...iterRange(idLength, x => 0)]), + GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC + ); + t.trackForCleanup(idBuffer); + const bindGroup = t.device.createBindGroup({ layout: pipeline.getBindGroupLayout(0), entries: [ @@ -139,6 +169,12 @@ async function testProgram(t: GPUTest, program: Program) { buffer: sizeBuffer }, }, + { + binding: 4, + resource: { + buffer: idBuffer + }, + }, ], }); @@ -150,52 +186,86 @@ async function testProgram(t: GPUTest, program: Program) { pass.end(); t.queue.submit([encoder.finish()]); + // The simulaton assumes subgroup_invocation_id maps directly to local_invocation_index. + // That is: + // SID: 0, 1, 2, ..., SGSize-1, 0, ..., SGSize-1, ... + // LID: 0, 1, 2, ..., 128 + // Generate a warning if this is not true of the device. + // This mapping is not guaranteed by APIs (Vulkan particularly), but seems reliable + // (for linear workgroups at least). const sizeReadback = await t.readGPUBufferRangeTyped( sizeBuffer, { srcByteOffset: 0, type: Uint32Array, - typedLength: 2, + typedLength: sizeLength, method: 'copy', } ); const sizeData: Uint32Array = sizeReadback.data; const actualSize = sizeData[0]; - console.log(`Actual subgroup size = ${actualSize}`); t.expectOK(checkSubgroupSizeConsistency(sizeData, minSubgroupSize, maxSubgroupSize)); - //t.expectGPUBufferValuesPassCheck( - // sizeBuffer, - // a => checkSubgroupSizeConsistency(a, minSubgroupSize, maxSubgroupSize, actualSize), - // { - // srcByteOffset: 0, - // type: Uint32Array, - // typedLength: 2, - // method: 'copy', - // mode: 'fail', - // } - //); - - //for (var i = minSubgroupSize; i <= maxSubgroupSize; i *= 2) { - // console.log(` Simulated locs for size ${i} = ${locMap.get(i)}`); - //} - //program.sizeRefData(locMap.get(actualSize)); - //console.log(`RefData length = ${program.refData.length}`); - //let num = program.simulate(false, actualSize); - //assert(num === locMap.get(actualSize)); + program.sizeRefData(locMap.get(actualSize)); + let num = program.simulate(false, actualSize); + + const idReadback = await t.readGPUBufferRangeTyped( + idBuffer, + { + srcByteOffset: 0, + type: Uint32Array, + typedLength: idLength, + method: 'copy', + } + ); + const idData = idReadback.data; + t.expectOK(checkIds(idData, actualSize), { mode: 'warn' }); + + const locationReadback = await t.readGPUBufferRangeTyped( + locationBuffer, + { + srcByteOffset: 0, + type: Uint32Array, + typedLength: locationLength, + method: 'copy', + } + ); + const locationData = locationReadback.data; + + const ballotReadback = await t.readGPUBufferRangeTyped( + ballotBuffer, + { + srcByteOffset: 0, + type: Uint32Array, + typedLength: ballotLength, + method: 'copy', + } + ); + const ballotData = ballotReadback.data; + + console.log(`Ballots`); + for (let id = 0; id < program.invocations; id++) { + console.log(` id[${id}]:`); + for (let loc = 0; loc < numLocs; loc++) { + const idx = 4 * (program.invocations * loc + id); + console.log(` loc[${loc}] = (${hex(ballotData[idx+3])},${hex(ballotData[idx+2])},${hex(ballotData[idx+1])},${hex(ballotData[idx])}), (${ballotData[idx+3]},${ballotData[idx+2]},${ballotData[idx+1]},${ballotData[idx]})`); + } + } + + t.expectOK(program.checkResults(ballotData, locationData, actualSize, num)); } g.test('predefined_reconvergence') .desc(`Test reconvergence using some predefined programs`) .params(u => u - .combine('test', [...iterRange(3, x => x)] as const) + .combine('test', [...iterRange(4, x => x)] as const) .beginSubcases() ) //.beforeAllSubcases(t => { // t.selectDeviceOrSkipTestCase({ requiredFeatures: ['chromium-experimental-subgroups'] }); //}) - .fn(t => { + .fn(async t => { const invocations = 128; // t.device.limits.maxSubgroupSize; let program: Program; @@ -215,13 +285,18 @@ g.test('predefined_reconvergence') program.predefinedProgram2(); break; } + case 3: { + program = new Program(Style.Maximal, 1, invocations); + program.predefinedProgram3(); + break; + } default: { program = new Program(); unreachable('Unhandled testcase'); } } - testProgram(t, program); + await testProgram(t, program); }); g.test('random_reconvergence') @@ -230,12 +305,6 @@ g.test('random_reconvergence') u .combine('style', [Style.Workgroup, Style.Subgroup, Style.Maximal] as const) .combine('seed', generateSeeds(5)) - .filter(u => { - if (u.style == Style.Workgroup) { - return true; - } - return false; - }) .beginSubcases() ) //.beforeAllSubcases(t => { diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 4e9f2c7323ec..40db77553653 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -1,6 +1,10 @@ import { assert, unreachable } from '../../../../common/util/util.js'; import { PRNG } from '../../../util/prng.js'; +export function hex(n: number): string { + return n.toString(16); +} + /** @returns A bitmask where bits [0,size) are 1s. */ function getMask(size: number): bigint { return (1n << BigInt(size)) - 1n; @@ -8,9 +12,9 @@ function getMask(size: number): bigint { /** @returns A bitmask where submask is repeated every size bits for total bits. */ function getReplicatedMask(submask: bigint, size: number, total: number = 128): bigint { - const reps = total / size; - var mask: bigint = submask; - for (var i = 1; i < reps; i++) { + const reps = Math.floor(total / size); + let mask: bigint = submask; + for (let i = 1; i < reps; i++) { mask |= (mask << BigInt(size)); } return mask; @@ -25,9 +29,11 @@ function getReplicatedMask(submask: bigint, size: number, total: number = 128): * * @returns A Uint32Array with 4 elements containing the subgroup mask. */ -function getSubgroupMask(fullMask: bigint, size: number, id: number): Uint32Array { +function getSubgroupMask(fullMask: bigint, size: number, id: number = 0): Uint32Array { const arr: Uint32Array = new Uint32Array(4); - let mask: bigint = fullMask >> BigInt((id / size) * size); + const subgroup_id: number = Math.floor(id / size); + const shift: number = subgroup_id * size; + let mask: bigint = (fullMask >> BigInt(shift)) & getMask(size); arr[0] = Number(BigInt.asUintN(32, mask)); arr[1] = Number(BigInt.asUintN(32, mask >> 32n)); arr[2] = Number(BigInt.asUintN(32, mask >> 64n)); @@ -47,7 +53,7 @@ function any(value: bigint): boolean { /** @returns true if all bits in value from [0, size) are 1. */ function all(value: bigint, size: number): boolean { - return value === ((1n << BigInt(size) - 1n)); + return value === ((1n << BigInt(size)) - 1n); } export enum Style { @@ -82,8 +88,8 @@ enum OpType { ElseLoopCount, // if (subgroup_id < inputs[value]) - IfLid, - ElseLid, + IfId, + ElseId, // Break/continue Break, @@ -99,6 +105,29 @@ enum OpType { MAX, } +function serializeOpType(op: OpType): string { + switch (op) { + case OpType.Ballot: return 'Ballot'; + case OpType.Store: return 'Store'; + case OpType.IfMask: return 'IfMask'; + case OpType.ElseMask: return 'ElseMask'; + case OpType.EndIf: return 'EndIf'; + case OpType.IfLoopCount: return 'IfLoopCount'; + case OpType.ElseLoopCount: return 'ElseLoopCount'; + case OpType.IfId: return 'IfId'; + case OpType.ElseId: return 'ElseId'; + case OpType.Break: return 'Break'; + case OpType.Continue: return 'Continue'; + case OpType.ForUniform: return 'ForUniform'; + case OpType.EndForUniform: return 'EndForUniform'; + case OpType.Return: return 'Return'; + default: + unreachable('Unhandled op'); + break; + } + return ''; +} + enum IfType { Mask, Uniform, @@ -140,7 +169,7 @@ export class Program { private curFunc: number; private functions: string[]; private indents: number[]; - private storeBase: number; + private readonly storeBase: number; public refData: Uint32Array; /** @@ -154,8 +183,8 @@ export class Program { this.prng = new PRNG(seed); this.ops = []; this.style = style; - this.minCount = 5; // 30; - this.maxNesting = 5; // this.getRandomUint(70) + 30; // [30,100) + this.minCount = 30; + this.maxNesting = this.getRandomUint(40) + 20; //this.getRandomUint(70) + 30; // [30,100) this.nesting = 0; this.loopNesting = 0; this.loopNestingThisFunction = 0; @@ -166,7 +195,7 @@ export class Program { this.masks.push(0xffffffff); this.masks.push(0xffffffff); this.masks.push(0xffffffff); - for (var i = 1; i < this.numMasks; i++) { + for (let i = 1; i < this.numMasks; i++) { this.masks.push(this.getRandomUint(0xffffffff)); this.masks.push(this.getRandomUint(0xffffffff)); this.masks.push(this.getRandomUint(0xffffffff)); @@ -192,7 +221,7 @@ export class Program { } private pickOp(count : number) { - for (var i = 0; i < count; i++) { + for (let i = 0; i < count; i++) { this.genBallot(); if (this.nesting < this.maxNesting) { const r = this.getRandomUint(12); @@ -261,7 +290,7 @@ export class Program { !(this.ops[cur_length - 1].type == OpType.Ballot || (this.ops[cur_length-1].type == OpType.Store && this.ops[cur_length - 2].type == OpType.Ballot))) { // Perform a store with each ballot so the results can be correlated. - if (this.style != Style.Maximal) + //if (this.style != Style.Maximal) this.ops.push(new Op(OpType.Store, cur_length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); } @@ -294,7 +323,7 @@ export class Program { const lid = this.getRandomUint(128); if (type == IfType.Lid) { - this.ops.push(new Op(OpType.IfLid, lid)); + this.ops.push(new Op(OpType.IfId, lid)); } else if (type == IfType.LoopCount) { this.ops.push(new Op(OpType.IfLoopCount, 0)); } else { @@ -310,7 +339,7 @@ export class Program { const randElse = this.getRandomFloat(); if (randElse < 0.5) { if (type == IfType.Lid) { - this.ops.push(new Op(OpType.ElseLid, lid)); + this.ops.push(new Op(OpType.ElseId, lid)); } else if (type == IfType.LoopCount) { this.ops.push(new Op(OpType.ElseLoopCount, 0)); } else { @@ -319,7 +348,7 @@ export class Program { // Sometimes make the else identical to the if. if (randElse < 0.1 && beforeSize != afterSize) { - for (var i = beforeSize; i < afterSize; i++) { + for (let i = beforeSize; i < afterSize; i++) { const op = this.ops[i]; this.ops.push(new Op(op.type, op.value, op.uniform)); } @@ -401,10 +430,10 @@ export class Program { } public genCode(): string { - for (var i = 0; i < this.ops.length; i++) { + for (let i = 0; i < this.ops.length; i++) { const op = this.ops[i]; this.genIndent() - this.addCode(`// ops[${i}] = ${op.type}\n`); + this.addCode(`// ops[${i}] = ${serializeOpType(op.type)}\n`); switch (op.type) { case OpType.Ballot: { this.genIndent(); @@ -417,7 +446,7 @@ export class Program { this.genIndent(); this.addCode(`locations[local_id]++;\n`); this.genIndent(); - this.addCode(`ballots[stride * output_loc + local_id] = vec4u(${op.value},0,0,0);\n`); + this.addCode(`ballots[stride * output_loc + local_id] = vec4u(${op.value});\n`); this.genIndent(); this.addCode(`output_loc++;\n`); break; @@ -438,12 +467,12 @@ export class Program { const y = this.masks[4*idx+1]; const z = this.masks[4*idx+2]; const w = this.masks[4*idx+3]; - this.addCode(`if testBit(vec4u(${x},${y},${z},${w}), subgroup_id) {\n`); + this.addCode(`if testBit(vec4u(0x${hex(x)},0x${hex(y)},0x${hex(z)},0x${hex(w)}), subgroup_id) {\n`); } this.increaseIndent(); break; } - case OpType.IfLid: { + case OpType.IfId: { this.genIndent(); this.addCode(`if subgroup_id < inputs[${op.value}] {\n`); this.increaseIndent(); @@ -456,7 +485,7 @@ export class Program { break; } case OpType.ElseMask: - case OpType.ElseLid: + case OpType.ElseId: case OpType.ElseLoopCount: { this.decreaseIndent(); this.genIndent(); @@ -516,12 +545,14 @@ var ballots : array; var locations : array; @group(0) @binding(3) var size : array; +@group(0) @binding(4) +var ids : array; var subgroup_id : u32; var local_id : u32; var output_loc : u32 = 0; -@compute @workgroup_size(${this.invocations},1,1) +@compute @workgroup_size(stride,1,1) fn main( @builtin(local_invocation_index) lid : u32, @builtin(subgroup_invocation_id) sid : u32, @@ -532,6 +563,7 @@ fn main( _ = locations[0]; subgroup_id = sid; local_id = lid; + ids[lid] = sid; // Store the subgroup size from the built-in value and ballot to check for // consistency. @@ -559,7 +591,7 @@ fn testBit(mask : vec4u, id : u32) -> bool { } `; - for (var i = 0; i < this.functions.length; i++) { + for (let i = 0; i < this.functions.length; i++) { code += ` fn f${i}() { ${this.functions[i]} @@ -597,6 +629,11 @@ ${this.functions[i]} this.functions[this.curFunc] += code; } + /** + * Sizes the simulation buffer. + * + * The total size is # of invocations * |locs| * 4 (uint4 is written). + */ public sizeRefData(locs: number) { this.refData = new Uint32Array(locs * 4 * this.invocations); this.refData.fill(0); @@ -605,12 +642,19 @@ ${this.functions[i]} // TODO: Reconvergence guarantees are not as strong as this simulation. public simulate(countOnly: boolean, subgroupSize: number): number { class State { + // Active invocations activeMask: bigint; + // Invocations that rejoin at the head of a loop continueMask: bigint; + // Header index header: number; + // This state is a loop isLoop: boolean; + // Current trip count tripCount: number; + // This state is a call isCall: boolean; + // This state is a switch isSwitch: boolean; constructor() { @@ -633,24 +677,32 @@ ${this.functions[i]} this.isSwitch = other.isSwitch; } }; - var stack = new Array(); + for (let idx = 0; idx < this.ops.length; idx++) { + this.ops[idx].uniform = true; + } + + let stack = new Array(); stack.push(new State()); stack[0].activeMask = (1n << 128n) - 1n; - var nesting = 0; - var loopNesting = 0; - var locs = new Array(this.invocations); + let nesting = 0; + let loopNesting = 0; + let locs = new Array(this.invocations); locs.fill(0); - console.log(`Simulating subgroup size = ${subgroupSize}`); - var i = 0; + if (!countOnly) { + console.log(`Simulating subgroup size = ${subgroupSize}`); + } + let i = 0; while (i < this.ops.length) { const op = this.ops[i]; - console.log(`ops[${i}] = ${op.type}, nesting = ${nesting}`); - console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); - //for (var j = 0; j <= nesting; j++) { - // console.log(` mask[${j}] = ${stack[j].activeMask.toString(16)}`); - //} + if (!countOnly) { + console.log(`ops[${i}] = ${serializeOpType(op.type)}, nesting = ${nesting}`); + console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); + //for (let j = 0; j <= nesting; j++) { + // console.log(` mask[${j}] = ${stack[j].activeMask.toString(16)}`); + //} + } switch (op.type) { case OpType.Ballot: { const curMask = stack[nesting].activeMask; @@ -661,7 +713,7 @@ ${this.functions[i]} // Flag if this ballot is not subgroup uniform. if (this.style == Style.Subgroup) { - for (var id = 0; id < this.invocations; id += subgroupSize) { + for (let id = 0; id < this.invocations; id += subgroupSize) { const subgroupMask = (curMask >> BigInt(id)) & getMask(subgroupSize); if (subgroupMask != 0n && !all(subgroupMask, subgroupSize)) { op.uniform = false; @@ -669,21 +721,22 @@ ${this.functions[i]} } } - for (var id = 0; id < this.invocations; id++) { + for (let id = 0; id < this.invocations; id++) { if (testBit(curMask, id)) { if (!countOnly) { - if (!op.uniform) { - // Emit a magic value to indicate that we shouldn't validate this ballot - this.refData[4 * locs[id] * this.invocations + id + 0] = 0x12345678 - this.refData[4 * locs[id] * this.invocations + id + 1] = 0x12345678 - this.refData[4 * locs[id] * this.invocations + id + 2] = 0x12345678 - this.refData[4 * locs[id] * this.invocations + id + 3] = 0x12345678 - } else { + const idx = this.baseIndex(id, locs[id]); + if (op.uniform) { let mask = getSubgroupMask(curMask, subgroupSize, id); - this.refData[4 * locs[id] * this.invocations + id + 0] = mask[0]; - this.refData[4 * locs[id] * this.invocations + id + 1] = mask[1]; - this.refData[4 * locs[id] * this.invocations + id + 2] = mask[2]; - this.refData[4 * locs[id] * this.invocations + id + 3] = mask[3]; + this.refData[idx + 0] = mask[0]; + this.refData[idx + 1] = mask[1]; + this.refData[idx + 2] = mask[2]; + this.refData[idx + 3] = mask[3]; + } else { + // Emit a magic value to indicate that we shouldn't validate this ballot + this.refData[idx + 0] = 0x12345678 + this.refData[idx + 1] = 0x12345678 + this.refData[idx + 2] = 0x12345678 + this.refData[idx + 3] = 0x12345678 } } locs[id]++; @@ -692,13 +745,14 @@ ${this.functions[i]} break; } case OpType.Store: { - for (var id = 0; id < 128; id++) { + for (let id = 0; id < this.invocations; id++) { if (testBit(stack[nesting].activeMask, id)) { if (!countOnly) { - this.refData[4 * locs[id]++ * this.invocations + id + 0] = op.value; - this.refData[4 * locs[id]++ * this.invocations + id + 1] = 0; - this.refData[4 * locs[id]++ * this.invocations + id + 2] = 0; - this.refData[4 * locs[id]++ * this.invocations + id + 3] = 0; + const idx = this.baseIndex(id, locs[id]); + this.refData[idx + 0] = op.value; + this.refData[idx + 1] = op.value; + this.refData[idx + 2] = op.value; + this.refData[idx + 3] = op.value; } locs[id]++; } @@ -711,9 +765,13 @@ ${this.functions[i]} const cur = stack[nesting]; cur.copy(stack[nesting-1]); cur.header = i; + cur.isLoop = 0; + cur.isSwitch = 0; // O is always uniform true. if (op.value != 0) { - cur.activeMask &= this.getValueMask(op.value); + let subMask = this.getValueMask(op.value); + subMask &= getMask(subgroupSize); + cur.activeMask &= getReplicatedMask(subMask, subgroupSize, this.invocations); } break; } @@ -724,21 +782,26 @@ ${this.functions[i]} cur.activeMask = 0n; } else { const prev = stack[nesting-1]; - cur.activeMask = prev.activeMask & ~this.getValueMask(op.value); + let subMask = this.getValueMask(op.value); + subMask &= getMask(subgroupSize); + cur.activeMask = prev.activeMask; + cur.activeMask &= ~getReplicatedMask(subMask, subgroupSize, this.invocations); } break; } - case OpType.IfLid: { + case OpType.IfId: { nesting++; stack.push(new State()); const cur = stack[nesting]; cur.copy(stack[nesting-1]); cur.header = i; + cur.isLoop = 0; + cur.isSwitch = 0; // All invocations with subgroup invocation id less than op.value are active. cur.activeMask &= getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); break; } - case OpType.ElseLid: { + case OpType.ElseId: { const prev = stack[nesting-1]; // All invocations with a subgroup invocation id greater or equal to op.value are active. stack[nesting].activeMask = prev.activeMask; @@ -841,7 +904,7 @@ ${this.functions[i]} case OpType.Return: { // Remove this active mask from all stack entries for this function. let mask: bigint = stack[nesting].activeMask; - for (var n = nesting; n >= 0; n--) { + for (let n = nesting; n >= 0; n--) { stack[n].activeMask &= ~mask; if (stack[n].isCall) { break; @@ -858,11 +921,13 @@ ${this.functions[i]} assert(stack.length == 1); - var maxLoc = 0; - for (var j = 0; j < this.invocations; j++) { - maxLoc = Math.max(maxLoc, locs[j]); + let maxLoc = 0; + for (let id = 0; id < this.invocations; id++) { + maxLoc = Math.max(maxLoc, locs[id]); + } + if (!countOnly) { + console.log(`Max location = ${maxLoc}\n`); } - console.log(`Max location = ${maxLoc}\n`); return maxLoc; } @@ -872,7 +937,7 @@ ${this.functions[i]} const y = this.masks[4*idx+1]; const z = this.masks[4*idx+2]; const w = this.masks[4*idx+3]; - var mask: bigint = 0n; + let mask: bigint = 0n; mask |= BigInt(x); mask |= BigInt(y) << 32n; mask |= BigInt(z) << 64n; @@ -882,9 +947,143 @@ ${this.functions[i]} /** @returns a randomized program */ public generate() { - while (this.ops.length < this.minCount) { - this.pickOp(1); + do { + this.ops = []; + while (this.ops.length < this.minCount) { + this.pickOp(1); + } + + // If this is an uniform control flow case, make sure a uniform ballot is + // generated. A subgroup size of 64 is used for testing purposes here. + if (this.style != Style.Maximal) { + this.simulate(true, 64); + } + } while (this.style != Style.Maximal && !this.isUCF()); + } + + private isUCF(): boolean { + let ucf: boolean = false; + for (let i = 0; i < this.ops.length; i++) { + const op = this.ops[i]; + if (op.type === OpType.Ballot && op.uniform) { + ucf = true; + } + } + return ucf; + } + + /** + * Calculates the base index for values in the result arrays. + * + * @param id The invocation id + * @param loc The location + * + * @returns The base index in a Uint32Array + */ + private baseIndex(id: number, loc: number): number { + return 4 * (this.invocations * loc + id); + } + + /** + * Determines if an instance of results match. + * + * @param res The result data + * @param resIdx The base result index + * @param ref The reference data + * @param refIdx The base reference index + * + * @returns true if 4 successive values match in both arrays + */ + private matchResult(res: Uint32Array, resIdx: number, ref: Uint32Array, refIdx: number): boolean { + return res[resIdx + 0] === ref[refIdx + 0] && + res[resIdx + 1] === ref[refIdx + 1] && + res[resIdx + 2] === ref[refIdx + 2] && + res[resIdx + 3] === ref[refIdx + 3]; + } + + /** + * Validates the results of the program. + * + * @param ballots The output data array + * @param locations The location data array + * @param subgroupSize Subgroup size that was executed on device + * @param numLocs The maximum locations used in simulation + * @returns an error if the results do meet expectatations + */ + public checkResults(ballots: Uint32Array, locations: Uint32Array, + subgroupSize: number, numLocs: number): Error | undefined { + //console.log(`Verifying numLocs = ${numLocs}`); + if (this.style == Style.Workgroup || this.style === Style.Subgroup) { + if (!this.isUCF()) { + return Error(`Expected some uniform condition for this test`); + } + // Subgroup and Workgroup tests always have an associated store + // preceeding them in the buffer. + const maskArray = getSubgroupMask(getMask(subgroupSize), subgroupSize); + const zeroArray = new Uint32Array([0,0,0,0]); + for (let id = 0; id < this.invocations; id++) { + let refLoc = 1; + let resLoc = 0; + while (refLoc < numLocs) { + while (refLoc < numLocs && + !this.matchResult(this.refData, this.baseIndex(id, refLoc), maskArray, 0)) { + refLoc++; + } + if (refLoc < numLocs) { + // Fully converged simulation + + // Search for the corresponding store in the result data. + const storeRefLoc = refLoc - 1; + while (resLoc < numLocs && + !this.matchResult(ballots, this.baseIndex(id, resLoc), + this.refData, this.baseIndex(id, storeRefLoc))) { + resLoc++; + } + + if (resLoc >= numLocs) { + return Error(`Failure for invocation ${id}: could not find associated store for reference location ${storeRefLoc}`); + } else { + // Found a matching store, now check the ballot. + const resIdx = this.baseIndex(id, resLoc + 1); + const refIdx = this.baseIndex(id, refLoc); + if (!this.matchResult(ballots, resIdx, this.refData, refIdx)) { + return Error(`Failure for invocation ${id} at location ${resLoc} +- expected: (0x${hex(this.refData[refIdx+3])},0x${hex(this.refData[refIdx+2])},0x${hex(this.refData[refIdx+1])},0x${hex(this.refData[refIdx])}) +- got: (0x${hex(ballots[resIdx+3])},0x${hex(ballots[resIdx+2])},0x${hex(ballots[resIdx+1])},0x${hex(ballots[resIdx])})`); + } + resLoc++; + } + refLoc++; + } + } + // Check there were no extra writes. + const idx = this.baseIndex(id, numLocs); + if (!this.matchResult(ballots, idx, zeroArray, 0)) { + return Error(`Unexpected write at end of buffer (location = ${numLocs}) for invocation ${id} +- got: (${ballots[idx]}, ${ballots[idx + 1]}, ${ballots[idx + 2]}, ${ballots[idx + 3]})`); + } + } + } else if (this.style == Style.Maximal) { + // Expect exact matches. + for (let i = 0; i < this.refData.length; i += 4) { + const idx_uvec4 = Math.floor(i / 4); + const id = Math.floor(idx_uvec4 % this.invocations); + const loc = Math.floor(idx_uvec4 / this.invocations); + if (!this.matchResult(ballots, i, this.refData, i)) { + return Error(`Failure for invocation ${id} at location ${loc}: +- expected: (0x${hex(this.refData[i+3])},0x${hex(this.refData[i+2])},0x${hex(this.refData[i+1])},0x${hex(this.refData[i])}) +- got: (0x${hex(ballots[i+3])},0x${hex(ballots[i+2])},0x${hex(ballots[i+1])},0x${hex(ballots[i])})`); + } + } + for (let i = this.refData.length; i < ballots.length; i++) { + if (ballots[i] !== 0) { + return Error(`Unexpected write at end of buffer (index = ${i}): +- got: (${ballots[i]})`); + } + } } + + return undefined; } /** @@ -974,7 +1173,7 @@ ${this.functions[i]} this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(OpType.IfLid, 16)); + this.ops.push(new Op(OpType.IfId, 16)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); @@ -994,23 +1193,79 @@ ${this.functions[i]} this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(OpType.ElseLid, 16)); + this.ops.push(new Op(OpType.ElseId, 16)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); this.ops.push(new Op(OpType.Return, 16)); - this.ops.push(new Op(OpType.EndIf, 16)); + this.ops.push(new Op(OpType.EndIf, 0)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); } + + /** + * Equivalent to: + * + * if subgroup_id < inputs[107] { + * if subgroup_id < inputs[112] { + * ballot(); + * if testBit(vec4u(0xd2f269c6,0xffe83b3f,0xa279f695,0x58899224), subgroup_id) { + * ballot(); + * } else { + * ballot() + * } + * ballot(); + * } else { + * ballot(); + * } + * } + * + * The first two if statements are uniform for subgroup sizes 64 or less. + * The third if statement is non-uniform for all subgroup sizes. + * It is tempting for compilers to collapse the third if/else into a single + * basic block which can lead to unexpected convergence of the ballots. + */ + public predefinedProgram3() { + // Set the mask for index 1 + this.masks[4*1 + 0] = 0xd2f269c6; + this.masks[4*1 + 1] = 0xffe83b3f; + this.masks[4*1 + 2] = 0xa279f695; + this.masks[4*1 + 3] = 0x58899224; + + this.ops.push(new Op(OpType.IfId, 107)); + + this.ops.push(new Op(OpType.IfId, 112)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.IfMask, 1)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.ElseMask, 1)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.ElseId, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); + + this.ops.push(new Op(OpType.EndIf, 0)); + } }; export function generateSeeds(numCases: number): number[] { - var prng: PRNG = new PRNG(1); - var output: number[] = new Array(numCases); - for (var i = 0; i < numCases; i++) { + let prng: PRNG = new PRNG(1); + let output: number[] = new Array(numCases); + for (let i = 0; i < numCases; i++) { output[i] = prng.randomU32(); } return output; From 986e2a891b1d6a8dd96c0612e5857f07d473de35 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Fri, 18 Aug 2023 14:58:36 -0400 Subject: [PATCH 08/32] Implementation * Add infinite for loops and elects * remove tabs * small optimizations to simulation runtime * add a predefined case with infinite for loop --- .../reconvergence/reconvergence.spec.ts | 28 +- .../shader/execution/reconvergence/util.ts | 383 ++++++++++++++---- 2 files changed, 335 insertions(+), 76 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index d19dda2cfcd3..e419d529a517 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -89,6 +89,7 @@ async function testProgram(t: GPUTest, program: Program) { let numLocs = 0; const locMap = new Map(); for (let size = minSubgroupSize; size <= maxSubgroupSize; size *= 2) { + console.log(`${new Date()}: simulating subgroup size = ${size}`); let num = program.simulate(true, size); locMap.set(size, num); numLocs = Math.max(num, numLocs); @@ -178,6 +179,7 @@ async function testProgram(t: GPUTest, program: Program) { ], }); + console.log(`${new Date()}: running pipeline`); const encoder = t.device.createCommandEncoder(); const pass = encoder.beginComputePass(); pass.setPipeline(pipeline); @@ -190,6 +192,7 @@ async function testProgram(t: GPUTest, program: Program) { // That is: // SID: 0, 1, 2, ..., SGSize-1, 0, ..., SGSize-1, ... // LID: 0, 1, 2, ..., 128 + // // Generate a warning if this is not true of the device. // This mapping is not guaranteed by APIs (Vulkan particularly), but seems reliable // (for linear workgroups at least). @@ -202,11 +205,13 @@ async function testProgram(t: GPUTest, program: Program) { method: 'copy', } ); + console.log(`${new Date()}: done pipeline`); const sizeData: Uint32Array = sizeReadback.data; const actualSize = sizeData[0]; t.expectOK(checkSubgroupSizeConsistency(sizeData, minSubgroupSize, maxSubgroupSize)); program.sizeRefData(locMap.get(actualSize)); + console.log(`${new Date()}: Full simulation size = ${actualSize}`); let num = program.simulate(false, actualSize); const idReadback = await t.readGPUBufferRangeTyped( @@ -232,6 +237,7 @@ async function testProgram(t: GPUTest, program: Program) { ); const locationData = locationReadback.data; + console.log(`${new Date()}: Reading ballot buffer ${ballotLength * 4} bytes`); const ballotReadback = await t.readGPUBufferRangeTyped( ballotBuffer, { @@ -243,10 +249,12 @@ async function testProgram(t: GPUTest, program: Program) { ); const ballotData = ballotReadback.data; + console.log(`${Date()}: Finished buffer readbacks`); console.log(`Ballots`); - for (let id = 0; id < program.invocations; id++) { + //for (let id = 0; id < program.invocations; id++) { + for (let id = 0; id < actualSize; id++) { console.log(` id[${id}]:`); - for (let loc = 0; loc < numLocs; loc++) { + for (let loc = 0; loc < num; loc++) { const idx = 4 * (program.invocations * loc + id); console.log(` loc[${loc}] = (${hex(ballotData[idx+3])},${hex(ballotData[idx+2])},${hex(ballotData[idx+1])},${hex(ballotData[idx])}), (${ballotData[idx+3]},${ballotData[idx+2]},${ballotData[idx+1]},${ballotData[idx]})`); } @@ -259,7 +267,7 @@ g.test('predefined_reconvergence') .desc(`Test reconvergence using some predefined programs`) .params(u => u - .combine('test', [...iterRange(4, x => x)] as const) + .combine('test', [...iterRange(5, x => x)] as const) .beginSubcases() ) //.beforeAllSubcases(t => { @@ -290,6 +298,11 @@ g.test('predefined_reconvergence') program.predefinedProgram3(); break; } + case 4: { + program = new Program(Style.Workgroup, 1, invocations); + program.predefinedProgram4(); + break; + } default: { program = new Program(); unreachable('Unhandled testcase'); @@ -305,6 +318,15 @@ g.test('random_reconvergence') u .combine('style', [Style.Workgroup, Style.Subgroup, Style.Maximal] as const) .combine('seed', generateSeeds(5)) + .filter(u => { + if (u.style == Style.Maximal) { + return false; + } + if (u.style == Style.Subgroup) { + return false; + } + return true; + }) .beginSubcases() ) //.beforeAllSubcases(t => { diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 40db77553653..640dc98124d0 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -20,6 +20,18 @@ function getReplicatedMask(submask: bigint, size: number, total: number = 128): return mask; } +/** @returns a mask with only the least significant 1 in |value| set for each subgroup. */ +function getElectMask(value: bigint, size: number, total: number = 128): bigint { + let mask = value; + let count = 0; + while (!(mask & 1n)) { + mask >>= 1n; + count++; + } + mask = value & (1n << BigInt(count)); + return getReplicatedMask(mask, size, total); +} + /** * Produce the subgroup mask for local invocation |id| within |fullMask| * @@ -99,28 +111,41 @@ enum OpType { ForUniform, EndForUniform, + // Equivalent to: + // for (var i = 0u; ; i++, ballot) + // Always includes an "elect"-based break in the loop. + ForInf, + EndForInf, + // Function return Return, + // Emulated elect for breaks from infinite loops. + Elect, + MAX, } function serializeOpType(op: OpType): string { + // prettier-ignore switch (op) { - case OpType.Ballot: return 'Ballot'; - case OpType.Store: return 'Store'; - case OpType.IfMask: return 'IfMask'; - case OpType.ElseMask: return 'ElseMask'; - case OpType.EndIf: return 'EndIf'; - case OpType.IfLoopCount: return 'IfLoopCount'; + case OpType.Ballot: return 'Ballot'; + case OpType.Store: return 'Store'; + case OpType.IfMask: return 'IfMask'; + case OpType.ElseMask: return 'ElseMask'; + case OpType.EndIf: return 'EndIf'; + case OpType.IfLoopCount: return 'IfLoopCount'; case OpType.ElseLoopCount: return 'ElseLoopCount'; - case OpType.IfId: return 'IfId'; - case OpType.ElseId: return 'ElseId'; + case OpType.IfId: return 'IfId'; + case OpType.ElseId: return 'ElseId'; case OpType.Break: return 'Break'; case OpType.Continue: return 'Continue'; case OpType.ForUniform: return 'ForUniform'; case OpType.EndForUniform: return 'EndForUniform'; + case OpType.ForInf: return 'ForInf'; + case OpType.EndForInf: return 'EndForInf'; case OpType.Return: return 'Return'; + case OpType.Elect: return 'Elect'; default: unreachable('Unhandled op'); break; @@ -171,6 +196,8 @@ export class Program { private indents: number[]; private readonly storeBase: number; public refData: Uint32Array; + private isLoopInf: Map; + private doneInfLoopBreak: Map; /** * constructor @@ -208,6 +235,8 @@ export class Program { this.indents.push(2); this.storeBase = 0x10000; this.refData = new Uint32Array(); + this.isLoopInf = new Map(); + this.doneInfLoopBreak = new Map(); } /** @returns A random float between 0 and 1 */ @@ -247,10 +276,12 @@ export class Program { break; } case 4: { - if (this.loopNesting <= 3) { + // Avoid very deep loop nests to limit memory and runtime. + if (this.loopNesting <= 2) { const r2 = this.getRandomUint(3); switch (r2) { case 0: this.genForUniform(); break; + case 1: this.genForInf(); break; case 2: default: { break; @@ -278,41 +309,42 @@ export class Program { } } } + this.genBallot(); } } private genBallot() { - // Optionally insert ballots, stores, and noise. + // Optionally insert ballots, stores, and noise. // Ballots and stores are used to determine correctness. - if (this.getRandomFloat() < 0.2) { + if (this.getRandomFloat() < 0.2) { const cur_length = this.ops.length; - if (cur_length < 2 || - !(this.ops[cur_length - 1].type == OpType.Ballot || - (this.ops[cur_length-1].type == OpType.Store && this.ops[cur_length - 2].type == OpType.Ballot))) { + if (cur_length < 2 || + !(this.ops[cur_length - 1].type == OpType.Ballot || + (this.ops[cur_length-1].type == OpType.Store && this.ops[cur_length - 2].type == OpType.Ballot))) { // Perform a store with each ballot so the results can be correlated. - //if (this.style != Style.Maximal) - this.ops.push(new Op(OpType.Store, cur_length + this.storeBase)); - this.ops.push(new Op(OpType.Ballot, 0)); - } - } + //if (this.style != Style.Maximal) + this.ops.push(new Op(OpType.Store, cur_length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + } + } - if (this.getRandomFloat() < 0.1) { + if (this.getRandomFloat() < 0.1) { const cur_length = this.ops.length; - if (cur_length < 2 || - !(this.ops[cur_length - 1].type == OpType.Store || - (this.ops[cur_length - 1].type == OpType.Ballot && this.ops[cur_length - 2].type == OpType.Store))) { - // Subgroup and workgroup styles do a store with every ballot. + if (cur_length < 2 || + !(this.ops[cur_length - 1].type == OpType.Store || + (this.ops[cur_length - 1].type == OpType.Ballot && this.ops[cur_length - 2].type == OpType.Store))) { + // Subgroup and workgroup styles do a store with every ballot. // Don't bloat the code by adding more. - if (this.style == Style.Maximal) - this.ops.push(new Op(OpType.Store, cur_length + this.storeBase)); - } - } - - //deUint32 r = this.getRandomUint(10000); - //if (r < 3) { - // ops.push_back({OP_NOISE, 0}); + if (this.style == Style.Maximal) + this.ops.push(new Op(OpType.Store, cur_length + this.storeBase)); + } + } + + //deUint32 r = this.getRandomUint(10000); + //if (r < 3) { + // ops.push_back({OP_NOISE, 0}); //} else if (r < 10) { - // ops.push_back({OP_NOISE, 1}); + // ops.push_back({OP_NOISE, 1}); //} } @@ -363,8 +395,8 @@ export class Program { private genForUniform() { const n = this.getRandomUint(5) + 1; // [1, 5] + const header = this.ops.length; this.ops.push(new Op(OpType.ForUniform, n)); - const header = this.ops.length - 1; this.nesting++; this.loopNesting++; this.loopNestingThisFunction++; @@ -375,39 +407,85 @@ export class Program { this.nesting--; } + private genForInf() { + const header = this.ops.length; + this.ops.push(new Op(OpType.ForInf, 0)); + this.nesting++; + this.loopNesting++; + this.loopNestingThisFunction++; + this.isLoopInf.set(this.loopNesting, true); + this.doneInfLoopBreak.set(this.loopNesting, false); + + this.pickOp(2); + + this.genElect(true); + this.doneInfLoopBreak.set(this.loopNesting, true); + + this.pickOp(2); + + this.ops.push(new Op(OpType.EndForInf, header)); + this.isLoopInf.set(this.loopNesting, false); + this.doneInfLoopBreak.set(this.loopNesting, false); + this.loopNestingThisFunction--; + this.loopNesting--; + this.nesting--; + } + + private genElect(forceBreak: boolean) { + this.ops.push(new Op(OpType.Elect, 0)); + this.nesting++; + + if (forceBreak) { + this.genBallot(); + this.genBallot(); + if (this.getRandomFloat() < 0.1) { + this.pickOp(1); + } + + // Sometimes use a return if we're in a call. + if (this.callNesting > 0 && this.getRandomFloat() < 0.3) { + this.ops.push(new Op(OpType.Return, 0)); + } else { + this.genBreak(); + } + } else { + this.pickOp(2); + } + + this.ops.push(new Op(OpType.EndIf, 0)); + this.nesting--; + } + private genBreak() { - if (this.loopNestingThisFunction > 0) - { - // Sometimes put the break in a divergent if - if (this.getRandomFloat() < 0.1) { + if (this.loopNestingThisFunction > 0) { + // Sometimes put the break in a divergent if + if (this.getRandomFloat() < 0.1) { const r = this.getRandomUint(this.numMasks-1) + 1; this.ops.push(new Op(OpType.IfMask, r)); this.ops.push(new Op(OpType.Break, 0)); this.ops.push(new Op(OpType.ElseMask, r)); this.ops.push(new Op(OpType.Break, 0)); this.ops.push(new Op(OpType.EndIf, 0)); - } else { - this.ops.push(new Op(OpType.Break, 0)); + } else { + this.ops.push(new Op(OpType.Break, 0)); } - } + } } private genContinue() { - // TODO: need to avoid infinite loops - if (this.loopNestingThisFunction > 0) - { - // Sometimes put the continue in a divergent if - if (this.getRandomFloat() < 0.1) { + if (this.loopNestingThisFunction > 0 && !this.isLoopInf.get(this.loopNesting)) { + // Sometimes put the continue in a divergent if + if (this.getRandomFloat() < 0.1) { const r = this.getRandomUint(this.numMasks-1) + 1; this.ops.push(new Op(OpType.IfMask, r)); this.ops.push(new Op(OpType.Continue, 0)); this.ops.push(new Op(OpType.ElseMask, r)); this.ops.push(new Op(OpType.Break, 0)); this.ops.push(new Op(OpType.EndIf, 0)); - } else { - this.ops.push(new Op(OpType.Continue, 0)); + } else { + this.ops.push(new Op(OpType.Continue, 0)); } - } + } } private genReturn() { @@ -429,6 +507,7 @@ export class Program { } } + /** @returns The WGSL code for the program */ public genCode(): string { for (let i = 0; i < this.ops.length; i++) { const op = this.ops[i]; @@ -507,13 +586,27 @@ export class Program { this.loopNesting++; break; } - case OpType.EndForUniform: { + case OpType.EndForUniform: + case OpType.EndForInf: { this.loopNesting--; this.decreaseIndent(); this.genIndent(); this.addCode(`}\n`); break; } + case OpType.ForInf: { + this.genIndent(); + const iter = `i${this.loopNesting}`; + this.addCode(`for (var ${iter} = 0u; true; ${iter} = infLoopIncrement(${iter})) {\n`); + this.loopNesting++; + this.increaseIndent(); + // Safety mechanism for hardware runs. + this.genIndent(); + this.addCode(`// Safety valve\n`); + this.genIndent(); + this.addCode(`if ${iter} >= 128u { break; }\n\n`); + break; + } case OpType.Break: { this.genIndent(); this.addCode(`break;\n`); @@ -529,6 +622,12 @@ export class Program { this.addCode(`return;\n`); break; } + case OpType.Elect: { + this.genIndent(); + this.addCode(`if subgroupElect() {\n`); + this.increaseIndent(); + break; + } } } @@ -577,6 +676,24 @@ fn main( f0(); } +fn infLoopIncrement(iter : u32) -> u32 { + ballots[stride * output_loc + local_id] = subgroupBallot(); + output_loc++; + return iter + 1; +} + +fn subgroupElect() -> bool { + let b = subgroupBallot(); + let lsb = firstTrailingBit(b); + let x_m1 = lsb.x != 0xffffffffu; + let y_m1 = lsb.y != 0xffffffffu; + let z_m1 = lsb.z != 0xffffffffu; + let w_or_z = select(lsb.w + 96, lsb.z + 64, z_m1); + let wz_or_y = select(w_or_z, lsb.y + 32, y_m1); + let val = select(wz_or_y, lsb.x, x_m1); + return val == subgroup_id; +} + fn testBit(mask : vec4u, id : u32) -> bool { let xbit = extractBits(mask.x, id, 1); let ybit = extractBits(mask.y, id - 32, 1); @@ -632,7 +749,7 @@ ${this.functions[i]} /** * Sizes the simulation buffer. * - * The total size is # of invocations * |locs| * 4 (uint4 is written). + * The total size is (# of invocations) * |locs| * 4 (uint4 is written). */ public sizeRefData(locs: number) { this.refData = new Uint32Array(locs * 4 * this.invocations); @@ -640,6 +757,14 @@ ${this.functions[i]} } // TODO: Reconvergence guarantees are not as strong as this simulation. + /** + * Simulate the program for the given subgroup size + * + * @param countOnly If true, the reference output is not generated just max locations + * @param subgroupSize The subgroup size to simulate + * + * BigInt is not the fastest value to manipulate. Care should be taken to optimize it's use. + */ public simulate(countOnly: boolean, subgroupSize: number): number { class State { // Active invocations @@ -706,7 +831,7 @@ ${this.functions[i]} switch (op.type) { case OpType.Ballot: { const curMask = stack[nesting].activeMask; - // Flag if this ballot is not workgroup uniform. + // Flag if this ballot is not workgroup uniform. if (this.style == Style.Workgroup && any(curMask) && !all(curMask, this.invocations)) { op.uniform = false; } @@ -721,22 +846,26 @@ ${this.functions[i]} } } + if (!any(curMask)) { + break; + } + + let mask = new Uint32Array(); for (let id = 0; id < this.invocations; id++) { + if (id % subgroupSize === 0) { + mask = getSubgroupMask(curMask, subgroupSize, id); + } if (testBit(curMask, id)) { if (!countOnly) { const idx = this.baseIndex(id, locs[id]); if (op.uniform) { - let mask = getSubgroupMask(curMask, subgroupSize, id); this.refData[idx + 0] = mask[0]; this.refData[idx + 1] = mask[1]; this.refData[idx + 2] = mask[2]; this.refData[idx + 3] = mask[3]; } else { // Emit a magic value to indicate that we shouldn't validate this ballot - this.refData[idx + 0] = 0x12345678 - this.refData[idx + 1] = 0x12345678 - this.refData[idx + 2] = 0x12345678 - this.refData[idx + 3] = 0x12345678 + this.refData.fill(0x12345678, idx, idx + 4); } } locs[id]++; @@ -745,14 +874,15 @@ ${this.functions[i]} break; } case OpType.Store: { + if (!any(stack[nesting].activeMask)) { + break; + } + for (let id = 0; id < this.invocations; id++) { if (testBit(stack[nesting].activeMask, id)) { if (!countOnly) { const idx = this.baseIndex(id, locs[id]); - this.refData[idx + 0] = op.value; - this.refData[idx + 1] = op.value; - this.refData[idx + 2] = op.value; - this.refData[idx + 3] = op.value; + this.refData.fill(op.value, idx, idx + 4); } locs[id]++; } @@ -768,7 +898,7 @@ ${this.functions[i]} cur.isLoop = 0; cur.isSwitch = 0; // O is always uniform true. - if (op.value != 0) { + if (op.value != 0 && any(cur.activeMask)) { let subMask = this.getValueMask(op.value); subMask &= getMask(subgroupSize); cur.activeMask &= getReplicatedMask(subMask, subgroupSize, this.invocations); @@ -780,7 +910,7 @@ ${this.functions[i]} const cur = stack[nesting]; if (op.value == 0) { cur.activeMask = 0n; - } else { + } else if (any(cur.activeMask)) { const prev = stack[nesting-1]; let subMask = this.getValueMask(op.value); subMask &= getMask(subgroupSize); @@ -797,15 +927,19 @@ ${this.functions[i]} cur.header = i; cur.isLoop = 0; cur.isSwitch = 0; - // All invocations with subgroup invocation id less than op.value are active. - cur.activeMask &= getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); + if (any(cur.activeMask)) { + // All invocations with subgroup invocation id less than op.value are active. + cur.activeMask &= getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); + } break; } case OpType.ElseId: { const prev = stack[nesting-1]; // All invocations with a subgroup invocation id greater or equal to op.value are active. stack[nesting].activeMask = prev.activeMask; - stack[nesting].activeMask &= ~getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); + if (any(prev.activeMask)) { + stack[nesting].activeMask &= ~getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); + } break; } case OpType.IfLoopCount: { @@ -822,7 +956,9 @@ ${this.functions[i]} cur.header = i; cur.isLoop = 0; cur.isSwitch = 0; - cur.activeMask &= getReplicatedMask(BigInt(1 << stack[n].tripCount), subgroupSize, this.invocations); + if (any(cur.activeMask)) { + cur.activeMask &= getReplicatedMask(BigInt(1 << stack[n].tripCount), subgroupSize, this.invocations); + } break; } case OpType.ElseLoopCount: { @@ -835,7 +971,9 @@ ${this.functions[i]} } stack[nesting].activeMask = stack[nesting-1].activeMask; - stack[nesting].activeMask &= ~getReplicatedMask(BigInt(1 << stack[n].tripCount), subgroupSize, this.invocations); + if (any(stack[nesting].activeMask)) { + stack[nesting].activeMask &= ~getReplicatedMask(BigInt(1 << stack[n].tripCount), subgroupSize, this.invocations); + } break; } case OpType.EndIf: { @@ -859,7 +997,7 @@ ${this.functions[i]} // Determine which invocations have another iteration of the loop to execute. const cur = stack[nesting]; cur.tripCount++; - cur.activeMask |= stack[nesting].continueMask; + cur.activeMask |= cur.continueMask; cur.continueMask = 0n; if (cur.tripCount < this.ops[cur.header].value && any(cur.activeMask)) { @@ -872,10 +1010,55 @@ ${this.functions[i]} } break; } + case OpType.ForInf: { + nesting++; + loopNesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.header = i; + cur.isLoop = true; + cur.activeMask = stack[nesting-1].activeMask; + break; + } + case OpType.EndForInf: { + const cur = stack[nesting]; + cur.tripCount++; + cur.activeMask |= cur.continueMask; + cur.continueMask = 0n; + if (any(cur.activeMask)) { + let maskArray = new Uint32Array(); + for (let id = 0; id < this.invocations; id++) { + if (id % subgroupSize === 0) { + maskArray = getSubgroupMask(cur.activeMask, subgroupSize, id); + } + if (testBit(cur.activeMask, id)) { + if (!countOnly) { + const idx = this.baseIndex(id, locs[id]); + this.refData[idx + 0] = maskArray[0]; + this.refData[idx + 1] = maskArray[1]; + this.refData[idx + 2] = maskArray[2]; + this.refData[idx + 3] = maskArray[3]; + } + locs[id]++; + } + } + i = cur.header + 1; + continue; + } else { + loopNesting--; + nesting--; + stack.pop(); + } + break; + } case OpType.Break: { // Remove this active mask from all stack entries for the current loop/switch. let n = nesting; let mask: bigint = stack[nesting].activeMask; + if (!any(mask)) { + break; + } + while (true) { stack[n].activeMask &= ~mask; if (stack[n].isLoop || stack[n].isSwitch) { @@ -891,6 +1074,10 @@ ${this.functions[i]} // Add this mask to the loop's continue mask for the next iteration. let n = nesting; let mask: bigint = stack[nesting].activeMask; + if (!any(mask)) { + break; + } + while (true) { stack[n].activeMask &= ~mask; if (stack[n].isLoop) { @@ -904,6 +1091,10 @@ ${this.functions[i]} case OpType.Return: { // Remove this active mask from all stack entries for this function. let mask: bigint = stack[nesting].activeMask; + if (!any(mask)) { + break; + } + for (let n = nesting; n >= 0; n--) { stack[n].activeMask &= ~mask; if (stack[n].isCall) { @@ -912,8 +1103,21 @@ ${this.functions[i]} } break; } + case OpType.Elect: { + nesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.copy(stack[nesting-1]); + cur.header = i; + cur.isLoop = 0; + cur.isSwitch = 0; + if (any(cur.activeMask)) { + cur.activeMask = getElectMask(cur.activeMask, subgroupSize, this.invocations); + } + break; + } default: { - unreachable(`Unhandled op ${op.type}`); + unreachable(`Unhandled op ${serializeOpType(op.type)}`); } } i++; @@ -931,7 +1135,9 @@ ${this.functions[i]} return maxLoc; } - // Returns an active mask for the mask at the given index. + /** + * @returns a mask formed from |masks[idx]| + */ private getValueMask(idx: number): bigint { const x = this.masks[4*idx]; const y = this.masks[4*idx+1]; @@ -952,6 +1158,7 @@ ${this.functions[i]} while (this.ops.length < this.minCount) { this.pickOp(1); } + //break; // If this is an uniform control flow case, make sure a uniform ballot is // generated. A subgroup size of 64 is used for testing purposes here. @@ -961,6 +1168,7 @@ ${this.functions[i]} } while (this.style != Style.Maximal && !this.isUCF()); } + /** @returns true if the program has uniform control flow for some ballot */ private isUCF(): boolean { let ucf: boolean = false; for (let i = 0; i < this.ops.length; i++) { @@ -1033,7 +1241,7 @@ ${this.functions[i]} // Fully converged simulation // Search for the corresponding store in the result data. - const storeRefLoc = refLoc - 1; + let storeRefLoc = refLoc - 1; while (resLoc < numLocs && !this.matchResult(ballots, this.baseIndex(id, resLoc), this.refData, this.baseIndex(id, storeRefLoc))) { @@ -1041,7 +1249,8 @@ ${this.functions[i]} } if (resLoc >= numLocs) { - return Error(`Failure for invocation ${id}: could not find associated store for reference location ${storeRefLoc}`); + const refIdx = this.baseIndex(id, storeRefLoc); + return Error(`Failure for invocation ${id}: could not find associated store for reference location ${storeRefLoc}: ${this.refData[refIdx]},${this.refData[refIdx+1]},${this.refData[refIdx+2]},${this.refData[refIdx+3]}`); } else { // Found a matching store, now check the ballot. const resIdx = this.baseIndex(id, resLoc + 1); @@ -1260,6 +1469,34 @@ ${this.functions[i]} this.ops.push(new Op(OpType.EndIf, 0)); } + + /** + * Equivalent to: + * + * for (var i = 0; ; i++, ballot()) { + * ballot(); + * if (subgroupElect()) { + * ballot(); + * break; + * } + * } + * ballot(); + */ + public predefinedProgram4() { + this.ops.push(new Op(OpType.ForInf, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.Elect, 0)); + + this.ops.push(new Op(OpType.Break, 0)); + + this.ops.push(new Op(OpType.EndIf, 0)); + this.ops.push(new Op(OpType.EndForInf, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + } }; export function generateSeeds(numCases: number): number[] { From 6a38736c95f1fac7e35dde46781196494cd33674 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Fri, 18 Aug 2023 15:41:48 -0400 Subject: [PATCH 09/32] Add another for loop variant * Add a variable based for that iterates based on subgroup id * Add a predefined test case to cover it --- .../reconvergence/reconvergence.spec.ts | 19 ++- .../shader/execution/reconvergence/util.ts | 115 ++++++++++++++++-- 2 files changed, 122 insertions(+), 12 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index e419d529a517..f6304a4e48cb 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -267,7 +267,7 @@ g.test('predefined_reconvergence') .desc(`Test reconvergence using some predefined programs`) .params(u => u - .combine('test', [...iterRange(5, x => x)] as const) + .combine('test', [...iterRange(8, x => x)] as const) .beginSubcases() ) //.beforeAllSubcases(t => { @@ -300,7 +300,22 @@ g.test('predefined_reconvergence') } case 4: { program = new Program(Style.Workgroup, 1, invocations); - program.predefinedProgram4(); + program.predefinedProgramForInf(); + break; + } + case 5: { + program = new Program(Style.Subgroup, 1, invocations); + program.predefinedProgramForInf(); + break; + } + case 6: { + program = new Program(Style.Subgroup, 1, invocations); + program.predefinedProgramForVar(); + break; + } + case 7: { + program = new Program(Style.Maximal, 1, invocations); + program.predefinedProgramForVar(); break; } default: { diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 640dc98124d0..221482527ead 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -117,6 +117,11 @@ enum OpType { ForInf, EndForInf, + // Equivalent to: + // for (var i = 0u; i < subgroup_invocation_id + 1; i++) + ForVar, + EndForVar, + // Function return Return, @@ -144,6 +149,8 @@ function serializeOpType(op: OpType): string { case OpType.EndForUniform: return 'EndForUniform'; case OpType.ForInf: return 'ForInf'; case OpType.EndForInf: return 'EndForInf'; + case OpType.ForVar: return 'ForVar'; + case OpType.EndForVar: return 'EndForVar'; case OpType.Return: return 'Return'; case OpType.Elect: return 'Elect'; default: @@ -282,7 +289,7 @@ export class Program { switch (r2) { case 0: this.genForUniform(); break; case 1: this.genForInf(); break; - case 2: + case 2: this.genForVar(); break; default: { break; } @@ -431,6 +438,21 @@ export class Program { this.nesting--; } + private genForVar() { + const header = this.ops.length; + this.ops.push(new Op(OpType.ForVar, 0)); + this.nesting++; + this.loopNesting++; + this.loopNestingThisFunction++; + + this.pickOp(2); + + this.ops.push(new Op(OpType.EndForVar, header)); + this.loopNestingThisFunction--; + this.loopNesting--; + this.nesting--; + } + private genElect(forceBreak: boolean) { this.ops.push(new Op(OpType.Elect, 0)); this.nesting++; @@ -586,14 +608,6 @@ export class Program { this.loopNesting++; break; } - case OpType.EndForUniform: - case OpType.EndForInf: { - this.loopNesting--; - this.decreaseIndent(); - this.genIndent(); - this.addCode(`}\n`); - break; - } case OpType.ForInf: { this.genIndent(); const iter = `i${this.loopNesting}`; @@ -607,6 +621,23 @@ export class Program { this.addCode(`if ${iter} >= 128u { break; }\n\n`); break; } + case OpType.ForVar: { + this.genIndent(); + const iter = `i${this.loopNesting}`; + this.addCode(`for (var ${iter} = 0u; ${iter} < subgroup_id + 1; ${iter}++) {\n`); + this.loopNesting++; + this.increaseIndent(); + break; + } + case OpType.EndForUniform: + case OpType.EndForInf: + case OpType.EndForVar: { + this.loopNesting--; + this.decreaseIndent(); + this.genIndent(); + this.addCode(`}\n`); + break; + } case OpType.Break: { this.genIndent(); this.addCode(`break;\n`); @@ -1051,6 +1082,38 @@ ${this.functions[i]} } break; } + case OpType.ForVar: { + nesting++; + loopNesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.header = i; + cur.isLoop = true; + cur.activeMask = stack[nesting-1].activeMask; + break; + } + case OpType.EndForVar: { + const cur = stack[nesting]; + cur.tripCount++; + cur.activeMask |= cur.continueMask; + cur.continueMask = 0n; + let done = !any(cur.activeMask) || cur.tripCount === subgroupSize; + if (!done) { + let submask = getMask(subgroupSize) & ~getMask(cur.tripCount); + let mask = getReplicatedMask(submask, subgroupSize, this.invocations); + cur.activeMask &= mask; + done = !any(cur.activeMask); + } + + if (done) { + loopNesting--; + nesting--; + stack.pop(); + } else { + i = cur.header + 1; + } + break; + } case OpType.Break: { // Remove this active mask from all stack entries for the current loop/switch. let n = nesting; @@ -1482,7 +1545,7 @@ ${this.functions[i]} * } * ballot(); */ - public predefinedProgram4() { + public predefinedProgramForInf() { this.ops.push(new Op(OpType.ForInf, 0)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); @@ -1497,6 +1560,38 @@ ${this.functions[i]} this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); } + + /** + * Equivalent to: + * + * for (var i = 0; i < subgroup_invocation_id + 1; i++) { + * ballot(); + * } + * ballot(); + * for (var i = 0; i < subgroup_invocation_id + 1; i++) { + * ballot(); + * } + * ballot(); + */ + public predefinedProgramForVar() { + this.ops.push(new Op(OpType.ForVar, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndForVar, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + + this.ops.push(new Op(OpType.ForVar, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndForVar, 0)); + + this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); + this.ops.push(new Op(OpType.Ballot, 0)); + } }; export function generateSeeds(numCases: number): number[] { From 4cd3e8ad61b4a7c009c4f9abbc4b61958b3a5b6d Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Sun, 20 Aug 2023 20:43:44 -0400 Subject: [PATCH 10/32] Add function calls * program generation, code generations, and simulation * Add a predefined program that is run for both workgroup and maxial reconvergence --- .../reconvergence/reconvergence.spec.ts | 12 +- .../shader/execution/reconvergence/util.ts | 150 +++++++++++++++++- 2 files changed, 155 insertions(+), 7 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index f6304a4e48cb..1f1f5b7badb2 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -267,7 +267,7 @@ g.test('predefined_reconvergence') .desc(`Test reconvergence using some predefined programs`) .params(u => u - .combine('test', [...iterRange(8, x => x)] as const) + .combine('test', [...iterRange(10, x => x)] as const) .beginSubcases() ) //.beforeAllSubcases(t => { @@ -318,6 +318,16 @@ g.test('predefined_reconvergence') program.predefinedProgramForVar(); break; } + case 8: { + program = new Program(Style.Workgroup, 1, invocations); + program.predefinedProgramCall(); + break; + } + case 9: { + program = new Program(Style.Maximal, 1, invocations); + program.predefinedProgramCall(); + break; + } default: { program = new Program(); unreachable('Unhandled testcase'); diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 221482527ead..d537fac8e09b 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -128,6 +128,10 @@ enum OpType { // Emulated elect for breaks from infinite loops. Elect, + // Function call + Call, + EndCall, + MAX, } @@ -153,6 +157,8 @@ function serializeOpType(op: OpType): string { case OpType.EndForVar: return 'EndForVar'; case OpType.Return: return 'Return'; case OpType.Elect: return 'Elect'; + case OpType.Call: return 'Call'; + case OpType.EndCall: return 'EndCall'; default: unreachable('Unhandled op'); break; @@ -298,7 +304,11 @@ export class Program { break; } case 5: { - this.genBreak(); + if (this.getRandomFloat() < 0.2 && this.callNesting == 0 && this.nesting < this.maxNesting - 1) { + this.genCall(); + } else { + this.genBreak(); + } break; } case 6: { @@ -510,6 +520,21 @@ export class Program { } } + private genCall() { + this.ops.push(new Op(OpType.Call, 0)); + this.callNesting++; + this.nesting++; + const curLoopNesting = this.loopNestingThisFunction; + this.loopNestingThisFunction = 0; + + this.pickOp(2); + + this.loopNestingThisFunction = curLoopNesting; + this.nesting--; + this.callNesting--; + this.ops.push(new Op(OpType.EndCall, 0)); + } + private genReturn() { const r = this.getRandomFloat(); if (this.nesting > 0 && @@ -659,6 +684,30 @@ export class Program { this.increaseIndent(); break; } + case OpType.Call: { + this.genIndent(); + this.addCode(`f${this.functions.length}(`); + for (let i = 0; i < this.loopNesting; i++) { + this.addCode(`i${i},`); + } + this.addCode(`);\n`); + + this.curFunc = this.functions.length; + this.functions.push(`fn f${this.curFunc}(`); + for (let i = 0; i < this.loopNesting; i++) { + this.addCode(`i${i} : u32,`); + } + this.addCode(`) {\n`); + this.indents.push(2); + break; + } + case OpType.EndCall: { + this.decreaseIndent(); + this.addCode(`}\n`); + // Call nesting is limited to 1 so we always return to f0. + this.curFunc = 0; + break; + } } } @@ -737,14 +786,15 @@ fn testBit(mask : vec4u, id : u32) -> bool { let selb = select(zbit, ybit, lt64); return select(selb, sela, lt32) == 1; } -`; + +fn f0() {`; for (let i = 0; i < this.functions.length; i++) { code += ` -fn f${i}() { -${this.functions[i]} -} -`; +${this.functions[i]}`; + if (i == 0) { + code += `\n}\n`; + } } return code; } @@ -1179,6 +1229,19 @@ ${this.functions[i]} } break; } + case OpType.Call: { + nesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.activeMask = stack[nesting-1].activeMask; + cur.isCall = 1; + break; + } + case OpType.EndCall: { + nesting--; + stack.pop(); + break; + } default: { unreachable(`Unhandled op ${serializeOpType(op.type)}`); } @@ -1592,6 +1655,81 @@ ${this.functions[i]} this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); } + + /** + * Equivalent to: + * + * fn f0() { + * for (var i = 0; i < inputs[3]; i++) { + * f1(i); + * ballot(); + * } + * ballot(); + * if (inputs[3] == 3) { + * f2(); + * ballot(); + * } + * ballot() + * } + * fn f1(i : u32) { + * ballot(); + * if (subgroup_invocation_id == i) { + * ballot(); + * return; + * } + * } + * fn f2() { + * ballot(); + * if (testBit(vec4u(0xaaaaaaaa,0xaaaaaaaa,0xaaaaaaaa,0xaaaaaaaa), local_invocation_index)) { + * ballot(); + * return; + * } + * } + */ + public predefinedProgramCall() { + this.masks[4 + 0] = 0xaaaaaaaa; + this.masks[4 + 1] = 0xaaaaaaaa; + this.masks[4 + 2] = 0xaaaaaaaa; + this.masks[4 + 3] = 0xaaaaaaaa; + + this.ops.push(new Op(OpType.ForUniform, 3)); + + this.ops.push(new Op(OpType.Call, 0)); + // f1 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.IfLoopCount, 0)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.Return, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); + // end f1 + this.ops.push(new Op(OpType.EndCall, 0)); + + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndForUniform, 0)); + + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.IfMask, 0)); + + this.ops.push(new Op(OpType.Call, 0)); + // f2 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.IfMask, 1)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.Return, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); + // end f2 + this.ops.push(new Op(OpType.EndCall, 0)); + + this.ops.push(new Op(OpType.EndIf, 0)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + } }; export function generateSeeds(numCases: number): number[] { From c436200abb9a67fb6d0bbbff4d0da70761a7ac30 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Mon, 21 Aug 2023 14:41:05 -0400 Subject: [PATCH 11/32] Add uniform loop --- .../reconvergence/reconvergence.spec.ts | 8 +- .../shader/execution/reconvergence/util.ts | 161 ++++++++++++++---- 2 files changed, 134 insertions(+), 35 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 1f1f5b7badb2..87d4823b0182 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -12,6 +12,7 @@ import { import { hex, Style, + OpType, Program, generateSeeds } from './util.js' @@ -267,7 +268,7 @@ g.test('predefined_reconvergence') .desc(`Test reconvergence using some predefined programs`) .params(u => u - .combine('test', [...iterRange(10, x => x)] as const) + .combine('test', [...iterRange(11, x => x)] as const) .beginSubcases() ) //.beforeAllSubcases(t => { @@ -328,6 +329,11 @@ g.test('predefined_reconvergence') program.predefinedProgramCall(); break; } + case 10: { + program = new Program(Style.Workgroup, 1, invocations); + program.predefinedProgram1(OpType.LoopUniform, OpType.EndLoopUniform); + break; + } default: { program = new Program(); unreachable('Unhandled testcase'); diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index d537fac8e09b..20092d1a1482 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -79,7 +79,7 @@ export enum Style { Maximal, }; -enum OpType { +export enum OpType { // Store a ballot. // During simulation, uniform is set to false if the // ballot is not fully uniform for the given style. @@ -122,6 +122,10 @@ enum OpType { ForVar, EndForVar, + // var i = 0; loop { ... continuing { i++; break if i >= inputs[value]; } } + LoopUniform, + EndLoopUniform, + // Function return Return, @@ -138,27 +142,29 @@ enum OpType { function serializeOpType(op: OpType): string { // prettier-ignore switch (op) { - case OpType.Ballot: return 'Ballot'; - case OpType.Store: return 'Store'; - case OpType.IfMask: return 'IfMask'; - case OpType.ElseMask: return 'ElseMask'; - case OpType.EndIf: return 'EndIf'; - case OpType.IfLoopCount: return 'IfLoopCount'; - case OpType.ElseLoopCount: return 'ElseLoopCount'; - case OpType.IfId: return 'IfId'; - case OpType.ElseId: return 'ElseId'; - case OpType.Break: return 'Break'; - case OpType.Continue: return 'Continue'; - case OpType.ForUniform: return 'ForUniform'; - case OpType.EndForUniform: return 'EndForUniform'; - case OpType.ForInf: return 'ForInf'; - case OpType.EndForInf: return 'EndForInf'; - case OpType.ForVar: return 'ForVar'; - case OpType.EndForVar: return 'EndForVar'; - case OpType.Return: return 'Return'; - case OpType.Elect: return 'Elect'; - case OpType.Call: return 'Call'; - case OpType.EndCall: return 'EndCall'; + case OpType.Ballot: return 'Ballot'; + case OpType.Store: return 'Store'; + case OpType.IfMask: return 'IfMask'; + case OpType.ElseMask: return 'ElseMask'; + case OpType.EndIf: return 'EndIf'; + case OpType.IfLoopCount: return 'IfLoopCount'; + case OpType.ElseLoopCount: return 'ElseLoopCount'; + case OpType.IfId: return 'IfId'; + case OpType.ElseId: return 'ElseId'; + case OpType.Break: return 'Break'; + case OpType.Continue: return 'Continue'; + case OpType.ForUniform: return 'ForUniform'; + case OpType.EndForUniform: return 'EndForUniform'; + case OpType.ForInf: return 'ForInf'; + case OpType.EndForInf: return 'EndForInf'; + case OpType.ForVar: return 'ForVar'; + case OpType.EndForVar: return 'EndForVar'; + case OpType.LoopUniform: return 'LoopUniform'; + case OpType.EndLoopUniform: return 'EndLoopUniform'; + case OpType.Return: return 'Return'; + case OpType.Elect: return 'Elect'; + case OpType.Call: return 'Call'; + case OpType.EndCall: return 'EndCall'; default: unreachable('Unhandled op'); break; @@ -290,7 +296,7 @@ export class Program { } case 4: { // Avoid very deep loop nests to limit memory and runtime. - if (this.loopNesting <= 2) { + if (this.loopNesting <= 3) { const r2 = this.getRandomUint(3); switch (r2) { case 0: this.genForUniform(); break; @@ -304,11 +310,7 @@ export class Program { break; } case 5: { - if (this.getRandomFloat() < 0.2 && this.callNesting == 0 && this.nesting < this.maxNesting - 1) { - this.genCall(); - } else { - this.genBreak(); - } + this.genBreak(); break; } case 6: { @@ -317,8 +319,26 @@ export class Program { } case 7: { // Calls and returns. - // TODO: calls - this.genReturn(); + if (this.getRandomFloat() < 0.2 && this.callNesting == 0 && this.nesting < this.maxNesting - 1) { + this.genCall(); + } else { + this.genBreak(); + } + break; + break; + } + case 8: { + if (this.loopNesting <= 3) { + const r2 = this.getRandomUint(3); + switch (r2) { + case 0: this.genLoopUniform(); break; + case 1: + case 2: + default: { + break; + } + } + } break; } default: { @@ -463,6 +483,22 @@ export class Program { this.nesting--; } + private genLoopUniform() { + const n = this.getRandomUint(5) + 1; + const header = this.ops.length; + this.ops.push(new Op(OpType.LoopUniform, n)); + this.nesting++; + this.loopNesting++; + this.loopNestingThisFunction++; + + this.pickOp(2); + + this.ops.push(new Op(OpType.EndLoopUniform, header)); + this.loopNestingThisFunction--; + this.loopNesting--; + this.nesting--; + } + private genElect(forceBreak: boolean) { this.ops.push(new Op(OpType.Elect, 0)); this.nesting++; @@ -663,6 +699,35 @@ export class Program { this.addCode(`}\n`); break; } + case OpType.LoopUniform: { + this.genIndent(); + const iter = `i${this.loopNesting}`; + this.addCode(`var ${iter} = 0u;\n`); + this.genIndent(); + this.addCode(`loop {\n`); + this.loopNesting++; + this.increaseIndent(); + break; + } + case OpType.EndLoopUniform: { + this.loopNesting--; + const iter = `i${this.loopNesting}`; + this.genIndent(); + this.addCode(`continuing {\n`); + this.increaseIndent(); + this.genIndent(); + this.addCode(`${iter}++;\n`); + this.genIndent(); + const limit = this.ops[op.value].value; + this.addCode(`break if ${iter} >= inputs[${limit}];\n`); + this.decreaseIndent(); + this.genIndent(); + this.addCode(`}\n`); + this.decreaseIndent(); + this.genIndent(); + this.addCode(`}\n`); + break; + } case OpType.Break: { this.genIndent(); this.addCode(`break;\n`); @@ -1164,6 +1229,32 @@ ${this.functions[i]}`; } break; } + case OpType.LoopUniform: { + nesting++; + loopNesting++; + stack.push(new State()); + const cur = stack[nesting]; + cur.header = i; + cur.isLoop = true; + cur.activeMask = stack[nesting-1].activeMask; + break; + } + case OpType.EndLoopUniform: { + const cur = stack[nesting]; + cur.tripCount++; + cur.activeMask |= cur.continueMask; + cur.continueMask = 0n; + if (cur.tripCount < this.ops[cur.header].value && + any(cur.activeMask)) { + i = cur.header + 1; + continue; + } else { + loopNesting--; + nesting--; + stack.pop(); + } + break; + } case OpType.Break: { // Remove this active mask from all stack entries for the current loop/switch. let n = nesting; @@ -1439,7 +1530,8 @@ ${this.functions[i]}`; * } * ballot(); // fully uniform */ - public predefinedProgram1() { + public predefinedProgram1(beginLoop: OpType = OpType.ForUniform, + endLoop: OpType = OpType.EndForUniform) { // Set the mask for index 1 this.masks[4*1 + 0] = 0xaaaaaaaa this.masks[4*1 + 1] = 0xaaaaaaaa @@ -1452,7 +1544,8 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(OpType.ForUniform, 3)); + const header = this.ops.length; + this.ops.push(new Op(beginLoop, 3)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); @@ -1466,7 +1559,7 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(OpType.EndForUniform, 0)); + this.ops.push(new Op(endLoop, header)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); @@ -1708,7 +1801,7 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(OpType.EndForUniform, 0)); + this.ops.push(new Op(OpType.EndForUniform, 3)); this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); From e31e16d350c25d93d70c0eb24503aac9bbb483fe Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Mon, 21 Aug 2023 21:12:03 -0400 Subject: [PATCH 12/32] Fixes and refactoring * Fixed a bug in the shader code for testbit * couldn't select bits [96,127] * Fix a bug in ForVar simulation * didn't loop correctly * Refactored some simulation code to reduce duplication --- .../reconvergence/reconvergence.spec.ts | 68 ++++------ .../shader/execution/reconvergence/util.ts | 124 ++++++++---------- 2 files changed, 77 insertions(+), 115 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 87d4823b0182..1a72e5dba401 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -81,7 +81,7 @@ subgroup_invocation_id = ${data[i]}`); async function testProgram(t: GPUTest, program: Program) { let wgsl = program.genCode(); - console.log(wgsl); + //console.log(wgsl); // TODO: query the device const minSubgroupSize = 4; @@ -251,15 +251,15 @@ async function testProgram(t: GPUTest, program: Program) { const ballotData = ballotReadback.data; console.log(`${Date()}: Finished buffer readbacks`); - console.log(`Ballots`); - //for (let id = 0; id < program.invocations; id++) { - for (let id = 0; id < actualSize; id++) { - console.log(` id[${id}]:`); - for (let loc = 0; loc < num; loc++) { - const idx = 4 * (program.invocations * loc + id); - console.log(` loc[${loc}] = (${hex(ballotData[idx+3])},${hex(ballotData[idx+2])},${hex(ballotData[idx+1])},${hex(ballotData[idx])}), (${ballotData[idx+3]},${ballotData[idx+2]},${ballotData[idx+1]},${ballotData[idx]})`); - } - } + //console.log(`Ballots`); + ////for (let id = 0; id < program.invocations; id++) { + //for (let id = 0; id < actualSize; id++) { + // console.log(` id[${id}]:`); + // for (let loc = 0; loc < num; loc++) { + // const idx = 4 * (program.invocations * loc + id); + // console.log(` loc[${loc}] = (${hex(ballotData[idx+3])},${hex(ballotData[idx+2])},${hex(ballotData[idx+1])},${hex(ballotData[idx])}), (${ballotData[idx+3]},${ballotData[idx+2]},${ballotData[idx+1]},${ballotData[idx]})`); + // } + //} t.expectOK(program.checkResults(ballotData, locationData, actualSize, num)); } @@ -268,7 +268,8 @@ g.test('predefined_reconvergence') .desc(`Test reconvergence using some predefined programs`) .params(u => u - .combine('test', [...iterRange(11, x => x)] as const) + .combine('style', [Style.Workgroup, Style.Subgroup, Style.Maximal] as const) + .combine('test', [...iterRange(7, x => x)] as const) .beginSubcases() ) //.beforeAllSubcases(t => { @@ -276,61 +277,42 @@ g.test('predefined_reconvergence') //}) .fn(async t => { const invocations = 128; // t.device.limits.maxSubgroupSize; + const style = t.params.style; let program: Program; switch (t.params.test) { case 0: { - program = new Program(Style.Workgroup, 1, invocations); + program = new Program(style, 1, invocations); program.predefinedProgram1(); break; } case 1: { - program = new Program(Style.Subgroup, 1, invocations); - program.predefinedProgram1(); - break; - } - case 2: { - program = new Program(Style.Subgroup, 1, invocations); + program = new Program(style, 1, invocations); program.predefinedProgram2(); break; } - case 3: { - program = new Program(Style.Maximal, 1, invocations); + case 2: { + program = new Program(style, 1, invocations); program.predefinedProgram3(); break; } - case 4: { - program = new Program(Style.Workgroup, 1, invocations); - program.predefinedProgramForInf(); - break; - } - case 5: { - program = new Program(Style.Subgroup, 1, invocations); + case 3: { + program = new Program(style, 1, invocations); program.predefinedProgramForInf(); break; } - case 6: { - program = new Program(Style.Subgroup, 1, invocations); - program.predefinedProgramForVar(); - break; - } - case 7: { - program = new Program(Style.Maximal, 1, invocations); + case 4: { + program = new Program(style, 1, invocations); program.predefinedProgramForVar(); break; } - case 8: { - program = new Program(Style.Workgroup, 1, invocations); - program.predefinedProgramCall(); - break; - } - case 9: { - program = new Program(Style.Maximal, 1, invocations); + case 5: { + program = new Program(style, 1, invocations); program.predefinedProgramCall(); break; } - case 10: { - program = new Program(Style.Workgroup, 1, invocations); + case 6: { + program = new Program(style, 1, invocations); program.predefinedProgram1(OpType.LoopUniform, OpType.EndLoopUniform); break; } diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 20092d1a1482..3f51d5db6f2b 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -70,13 +70,13 @@ function all(value: bigint, size: number): boolean { export enum Style { // Workgroup uniform control flow - Workgroup, + Workgroup = 0, // Subgroup uniform control flow - Subgroup, + Subgroup = 1, // Maximal uniformity - Maximal, + Maximal = 2, }; export enum OpType { @@ -322,7 +322,7 @@ export class Program { if (this.getRandomFloat() < 0.2 && this.callNesting == 0 && this.nesting < this.maxNesting - 1) { this.genCall(); } else { - this.genBreak(); + this.genReturn(); } break; break; @@ -432,20 +432,18 @@ export class Program { private genForUniform() { const n = this.getRandomUint(5) + 1; // [1, 5] - const header = this.ops.length; this.ops.push(new Op(OpType.ForUniform, n)); this.nesting++; this.loopNesting++; this.loopNestingThisFunction++; this.pickOp(2); - this.ops.push(new Op(OpType.EndForUniform, header)); + this.ops.push(new Op(OpType.EndForUniform, n)); this.loopNestingThisFunction--; this.loopNesting--; this.nesting--; } private genForInf() { - const header = this.ops.length; this.ops.push(new Op(OpType.ForInf, 0)); this.nesting++; this.loopNesting++; @@ -460,7 +458,7 @@ export class Program { this.pickOp(2); - this.ops.push(new Op(OpType.EndForInf, header)); + this.ops.push(new Op(OpType.EndForInf, 0)); this.isLoopInf.set(this.loopNesting, false); this.doneInfLoopBreak.set(this.loopNesting, false); this.loopNestingThisFunction--; @@ -469,7 +467,6 @@ export class Program { } private genForVar() { - const header = this.ops.length; this.ops.push(new Op(OpType.ForVar, 0)); this.nesting++; this.loopNesting++; @@ -477,7 +474,7 @@ export class Program { this.pickOp(2); - this.ops.push(new Op(OpType.EndForVar, header)); + this.ops.push(new Op(OpType.EndForVar, 0)); this.loopNestingThisFunction--; this.loopNesting--; this.nesting--; @@ -485,7 +482,6 @@ export class Program { private genLoopUniform() { const n = this.getRandomUint(5) + 1; - const header = this.ops.length; this.ops.push(new Op(OpType.LoopUniform, n)); this.nesting++; this.loopNesting++; @@ -493,7 +489,7 @@ export class Program { this.pickOp(2); - this.ops.push(new Op(OpType.EndLoopUniform, header)); + this.ops.push(new Op(OpType.EndLoopUniform, n)); this.loopNestingThisFunction--; this.loopNesting--; this.nesting--; @@ -718,8 +714,7 @@ export class Program { this.genIndent(); this.addCode(`${iter}++;\n`); this.genIndent(); - const limit = this.ops[op.value].value; - this.addCode(`break if ${iter} >= inputs[${limit}];\n`); + this.addCode(`break if ${iter} >= inputs[${op.value}];\n`); this.decreaseIndent(); this.genIndent(); this.addCode(`}\n`); @@ -844,12 +839,11 @@ fn testBit(mask : vec4u, id : u32) -> bool { let ybit = extractBits(mask.y, id - 32, 1); let zbit = extractBits(mask.z, id - 64, 1); let wbit = extractBits(mask.w, id - 96, 1); - let lt32 = id < 32; - let lt64 = id < 64; - let lt96 = id < 96; - let sela = select(wbit, xbit, lt96); - let selb = select(zbit, ybit, lt64); - return select(selb, sela, lt32) == 1; + let lower32 = (id & 63) < 32; + let lower64 = id < 64; + let xybit = select(ybit, xbit, lower32); + let zwbit = select(wbit, zbit, lower32); + return select(zwbit, xybit, lower64) == 1; } fn f0() {`; @@ -968,8 +962,9 @@ ${this.functions[i]}`; while (i < this.ops.length) { const op = this.ops[i]; if (!countOnly) { - console.log(`ops[${i}] = ${serializeOpType(op.type)}, nesting = ${nesting}`); - console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); + //console.log(`ops[${i}] = ${serializeOpType(op.type)}, nesting = ${nesting}, loopNesting = ${loopNesting}`); + //console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); + //console.log(` isLoop = ${stack[nesting].isLoop}`); //for (let j = 0; j <= nesting; j++) { // console.log(` mask[${j}] = ${stack[j].activeMask.toString(16)}`); //} @@ -1041,8 +1036,9 @@ ${this.functions[i]}`; const cur = stack[nesting]; cur.copy(stack[nesting-1]); cur.header = i; - cur.isLoop = 0; - cur.isSwitch = 0; + cur.isLoop = false; + cur.isSwitch = false; + cur.isCall = false; // O is always uniform true. if (op.value != 0 && any(cur.activeMask)) { let subMask = this.getValueMask(op.value); @@ -1071,8 +1067,9 @@ ${this.functions[i]}`; const cur = stack[nesting]; cur.copy(stack[nesting-1]); cur.header = i; - cur.isLoop = 0; - cur.isSwitch = 0; + cur.isLoop = false; + cur.isSwitch = false; + cur.isCall = false; if (any(cur.activeMask)) { // All invocations with subgroup invocation id less than op.value are active. cur.activeMask &= getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); @@ -1094,14 +1091,18 @@ ${this.functions[i]}`; while (!stack[n].isLoop) { n--; } + if (n < 0) { + unreachable(`Failed to find loop for IfLoopCount`); + } nesting++; stack.push(new State()); const cur = stack[nesting]; cur.copy(stack[nesting-1]); cur.header = i; - cur.isLoop = 0; - cur.isSwitch = 0; + cur.isLoop = false; + cur.isSwitch = false; + cur.isCall = false; if (any(cur.activeMask)) { cur.activeMask &= getReplicatedMask(BigInt(1 << stack[n].tripCount), subgroupSize, this.invocations); } @@ -1128,8 +1129,10 @@ ${this.functions[i]}`; stack.pop(); break; } - case OpType.ForUniform: { - // New uniform for loop. + case OpType.ForUniform: + case OpType.ForInf: + case OpType.ForVar: + case OpType.LoopUniform: { nesting++; loopNesting++; stack.push(new State()); @@ -1156,16 +1159,6 @@ ${this.functions[i]}`; } break; } - case OpType.ForInf: { - nesting++; - loopNesting++; - stack.push(new State()); - const cur = stack[nesting]; - cur.header = i; - cur.isLoop = true; - cur.activeMask = stack[nesting-1].activeMask; - break; - } case OpType.EndForInf: { const cur = stack[nesting]; cur.tripCount++; @@ -1197,16 +1190,6 @@ ${this.functions[i]}`; } break; } - case OpType.ForVar: { - nesting++; - loopNesting++; - stack.push(new State()); - const cur = stack[nesting]; - cur.header = i; - cur.isLoop = true; - cur.activeMask = stack[nesting-1].activeMask; - break; - } case OpType.EndForVar: { const cur = stack[nesting]; cur.tripCount++; @@ -1226,19 +1209,10 @@ ${this.functions[i]}`; stack.pop(); } else { i = cur.header + 1; + continue; } break; } - case OpType.LoopUniform: { - nesting++; - loopNesting++; - stack.push(new State()); - const cur = stack[nesting]; - cur.header = i; - cur.isLoop = true; - cur.activeMask = stack[nesting-1].activeMask; - break; - } case OpType.EndLoopUniform: { const cur = stack[nesting]; cur.tripCount++; @@ -1257,39 +1231,43 @@ ${this.functions[i]}`; } case OpType.Break: { // Remove this active mask from all stack entries for the current loop/switch. - let n = nesting; let mask: bigint = stack[nesting].activeMask; if (!any(mask)) { break; } - while (true) { + let n = nesting; + for (; n >= 0; n--) { stack[n].activeMask &= ~mask; if (stack[n].isLoop || stack[n].isSwitch) { break; } - - n--; + } + if (n < 0) { + unreachable(`Failed to find loop/switch for break`); } break; } case OpType.Continue: { // Remove this active mask from stack entries in this loop. // Add this mask to the loop's continue mask for the next iteration. - let n = nesting; let mask: bigint = stack[nesting].activeMask; if (!any(mask)) { break; } - while (true) { + let n = nesting; + for (; n >= 0; n--) { stack[n].activeMask &= ~mask; if (stack[n].isLoop) { stack[n].continueMask |= mask; break; } - n--; } + if (n < 0) { + unreachable(`Failed to loop for continue`); + } + break; } case OpType.Return: { @@ -1299,7 +1277,8 @@ ${this.functions[i]}`; break; } - for (let n = nesting; n >= 0; n--) { + let n = nesting; + for (; n >= 0; n--) { stack[n].activeMask &= ~mask; if (stack[n].isCall) { break; @@ -1313,8 +1292,9 @@ ${this.functions[i]}`; const cur = stack[nesting]; cur.copy(stack[nesting-1]); cur.header = i; - cur.isLoop = 0; - cur.isSwitch = 0; + cur.isLoop = false; + cur.isSwitch = false; + cur.isCall = false; if (any(cur.activeMask)) { cur.activeMask = getElectMask(cur.activeMask, subgroupSize, this.invocations); } @@ -1380,6 +1360,7 @@ ${this.functions[i]}`; // If this is an uniform control flow case, make sure a uniform ballot is // generated. A subgroup size of 64 is used for testing purposes here. if (this.style != Style.Maximal) { + console.log(`${new Date()}: simulating for UCF`); this.simulate(true, 64); } } while (this.style != Style.Maximal && !this.isUCF()); @@ -1544,7 +1525,6 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); - const header = this.ops.length; this.ops.push(new Op(beginLoop, 3)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); @@ -1559,7 +1539,7 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(endLoop, header)); + this.ops.push(new Op(endLoop, 3)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); From 6b0a8070f5af9fe9ffe16c64bdc4074d65b7c72d Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Tue, 22 Aug 2023 17:07:22 -0400 Subject: [PATCH 13/32] Impl and fixes * added infinite loop * added better safeguards for program length * optimized simulation runtime * fixed a couple bugs --- .../reconvergence/reconvergence.spec.ts | 11 +- .../shader/execution/reconvergence/util.ts | 404 +++++++++++++----- 2 files changed, 295 insertions(+), 120 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 1a72e5dba401..f6d3f56ac357 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -82,6 +82,7 @@ subgroup_invocation_id = ${data[i]}`); async function testProgram(t: GPUTest, program: Program) { let wgsl = program.genCode(); //console.log(wgsl); + //return; // TODO: query the device const minSubgroupSize = 4; @@ -98,6 +99,7 @@ async function testProgram(t: GPUTest, program: Program) { // Add 1 to ensure there are no extraneous writes. numLocs++; + console.log(`${new Date()}: creating pipeline`); const pipeline = t.device.createComputePipeline({ layout: 'auto', compute: { @@ -269,7 +271,7 @@ g.test('predefined_reconvergence') .params(u => u .combine('style', [Style.Workgroup, Style.Subgroup, Style.Maximal] as const) - .combine('test', [...iterRange(7, x => x)] as const) + .combine('test', [...iterRange(8, x => x)] as const) .beginSubcases() ) //.beforeAllSubcases(t => { @@ -298,7 +300,7 @@ g.test('predefined_reconvergence') } case 3: { program = new Program(style, 1, invocations); - program.predefinedProgramForInf(); + program.predefinedProgramInf(); break; } case 4: { @@ -316,6 +318,11 @@ g.test('predefined_reconvergence') program.predefinedProgram1(OpType.LoopUniform, OpType.EndLoopUniform); break; } + case 7: { + program = new Program(style, 1, invocations); + program.predefinedProgramInf(OpType.LoopInf, OpType.EndLoopInf); + break; + } default: { program = new Program(); unreachable('Unhandled testcase'); diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 3f51d5db6f2b..d0293325be43 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -13,7 +13,7 @@ function getMask(size: number): bigint { /** @returns A bitmask where submask is repeated every size bits for total bits. */ function getReplicatedMask(submask: bigint, size: number, total: number = 128): bigint { const reps = Math.floor(total / size); - let mask: bigint = submask; + let mask: bigint = submask & ((1n << BigInt(size)) - 1n); for (let i = 1; i < reps; i++) { mask |= (mask << BigInt(size)); } @@ -126,6 +126,10 @@ export enum OpType { LoopUniform, EndLoopUniform, + // var i = 0; loop { /* break */ ... continuing { ballot(); i++; } } + LoopInf, + EndLoopInf, + // Function return Return, @@ -161,6 +165,8 @@ function serializeOpType(op: OpType): string { case OpType.EndForVar: return 'EndForVar'; case OpType.LoopUniform: return 'LoopUniform'; case OpType.EndLoopUniform: return 'EndLoopUniform'; + case OpType.LoopInf: return 'LoopInf'; + case OpType.EndLoopInf: return 'EndLoopInf'; case OpType.Return: return 'Return'; case OpType.Elect: return 'Elect'; case OpType.Call: return 'Call'; @@ -198,25 +204,28 @@ class Op { }; export class Program { - public invocations: number; + public readonly invocations: number; private readonly prng: PRNG; private ops : Op[]; public readonly style: Style; private readonly minCount: number; + private readonly maxCount: number; private readonly maxNesting: number; + private readonly maxLoopNesting: number; private nesting: number; private loopNesting: number; private loopNestingThisFunction: number; private callNesting: number; - private numMasks: number; + private readonly numMasks: number; private masks: number[]; private curFunc: number; private functions: string[]; private indents: number[]; private readonly storeBase: number; - public refData: Uint32Array; + private refData: Uint32Array; private isLoopInf: Map; private doneInfLoopBreak: Map; + private maxProgramNesting; /** * constructor @@ -230,7 +239,9 @@ export class Program { this.ops = []; this.style = style; this.minCount = 30; + this.maxCount = 50000; // TODO: what is a reasonable limit? this.maxNesting = this.getRandomUint(40) + 20; //this.getRandomUint(70) + 30; // [30,100) + this.maxLoopNesting = 4; this.nesting = 0; this.loopNesting = 0; this.loopNestingThisFunction = 0; @@ -256,6 +267,7 @@ export class Program { this.refData = new Uint32Array(); this.isLoopInf = new Map(); this.doneInfLoopBreak = new Map(); + this.maxProgramNesting = 10; // default stack allocation } /** @returns A random float between 0 and 1 */ @@ -270,6 +282,10 @@ export class Program { private pickOp(count : number) { for (let i = 0; i < count; i++) { + if (this.ops.length >= this.maxCount) { + return; + } + this.genBallot(); if (this.nesting < this.maxNesting) { const r = this.getRandomUint(12); @@ -296,7 +312,7 @@ export class Program { } case 4: { // Avoid very deep loop nests to limit memory and runtime. - if (this.loopNesting <= 3) { + if (this.loopNesting < this.maxLoopNesting) { const r2 = this.getRandomUint(3); switch (r2) { case 0: this.genForUniform(); break; @@ -319,21 +335,21 @@ export class Program { } case 7: { // Calls and returns. - if (this.getRandomFloat() < 0.2 && this.callNesting == 0 && this.nesting < this.maxNesting - 1) { + if (this.getRandomFloat() < 0.2 && + this.callNesting == 0 && + this.nesting < this.maxNesting - 1) { this.genCall(); } else { this.genReturn(); } break; - break; } case 8: { - if (this.loopNesting <= 3) { - const r2 = this.getRandomUint(3); + if (this.loopNesting < this.maxLoopNesting) { + const r2 = this.getRandomUint(2); switch (r2) { case 0: this.genLoopUniform(); break; - case 1: - case 2: + case 1: this.genLoopInf(); break; default: { break; } @@ -400,6 +416,7 @@ export class Program { } this.nesting++; + this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); let beforeSize = this.ops.length; this.pickOp(2); @@ -415,8 +432,10 @@ export class Program { this.ops.push(new Op(OpType.ElseMask, maskIdx)); } - // Sometimes make the else identical to the if. - if (randElse < 0.1 && beforeSize != afterSize) { + // Sometimes make the else identical to the if, but don't just completely + // blow up the instruction count. + if (randElse < 0.1 && beforeSize != afterSize && + (beforeSize + 2 * (afterSize - beforeSize)) < this.maxCount) { for (let i = beforeSize; i < afterSize; i++) { const op = this.ops[i]; this.ops.push(new Op(op.type, op.value, op.uniform)); @@ -434,6 +453,7 @@ export class Program { const n = this.getRandomUint(5) + 1; // [1, 5] this.ops.push(new Op(OpType.ForUniform, n)); this.nesting++; + this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); this.loopNesting++; this.loopNestingThisFunction++; this.pickOp(2); @@ -446,6 +466,7 @@ export class Program { private genForInf() { this.ops.push(new Op(OpType.ForInf, 0)); this.nesting++; + this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); this.loopNesting++; this.loopNestingThisFunction++; this.isLoopInf.set(this.loopNesting, true); @@ -469,6 +490,7 @@ export class Program { private genForVar() { this.ops.push(new Op(OpType.ForVar, 0)); this.nesting++; + this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); this.loopNesting++; this.loopNestingThisFunction++; @@ -484,6 +506,7 @@ export class Program { const n = this.getRandomUint(5) + 1; this.ops.push(new Op(OpType.LoopUniform, n)); this.nesting++; + this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); this.loopNesting++; this.loopNestingThisFunction++; @@ -495,9 +518,37 @@ export class Program { this.nesting--; } + private genLoopInf() { + const header = this.ops.length; + this.ops.push(new Op(OpType.LoopInf, 0)); + + this.nesting++; + this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); + this.loopNesting++; + this.loopNestingThisFunction++; + this.isLoopInf.set(this.loopNesting, true); + this.doneInfLoopBreak.set(this.loopNesting, false); + + this.pickOp(2); + + this.genElect(true); + this.doneInfLoopBreak.set(this.loopNesting, true); + + this.pickOp(2); + + this.ops.push(new Op(OpType.EndLoopInf, header)); + + this.isLoopInf.set(this.loopNesting, false); + this.doneInfLoopBreak.set(this.loopNesting, false); + this.loopNestingThisFunction--; + this.loopNesting--; + this.nesting--; + } + private genElect(forceBreak: boolean) { this.ops.push(new Op(OpType.Elect, 0)); this.nesting++; + this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); if (forceBreak) { this.genBallot(); @@ -530,6 +581,7 @@ export class Program { this.ops.push(new Op(OpType.ElseMask, r)); this.ops.push(new Op(OpType.Break, 0)); this.ops.push(new Op(OpType.EndIf, 0)); + this.maxProgramNesting = Math.max(this.nesting + 1, this.maxProgramNesting); } else { this.ops.push(new Op(OpType.Break, 0)); } @@ -546,6 +598,7 @@ export class Program { this.ops.push(new Op(OpType.ElseMask, r)); this.ops.push(new Op(OpType.Break, 0)); this.ops.push(new Op(OpType.EndIf, 0)); + this.maxProgramNesting = Math.max(this.nesting + 1, this.maxProgramNesting); } else { this.ops.push(new Op(OpType.Continue, 0)); } @@ -556,6 +609,7 @@ export class Program { this.ops.push(new Op(OpType.Call, 0)); this.callNesting++; this.nesting++; + this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); const curLoopNesting = this.loopNestingThisFunction; this.loopNestingThisFunction = 0; @@ -580,6 +634,7 @@ export class Program { this.ops.push(new Op(OpType.ElseMask, 0)); this.ops.push(new Op(OpType.Return, 0)); this.ops.push(new Op(OpType.EndIf, 0)); + this.maxProgramNesting = Math.max(this.nesting + 1, this.maxProgramNesting); } else { this.ops.push(new Op(OpType.Return, 0)); } @@ -590,55 +645,45 @@ export class Program { public genCode(): string { for (let i = 0; i < this.ops.length; i++) { const op = this.ops[i]; - this.genIndent() - this.addCode(`// ops[${i}] = ${serializeOpType(op.type)}\n`); + this.addCode(`// ops[${i}] = ${serializeOpType(op.type)}`); switch (op.type) { case OpType.Ballot: { - this.genIndent(); - this.addCode(`ballots[stride * output_loc + local_id] = subgroupBallot();\n`); - this.genIndent(); - this.addCode(`output_loc++;\n`); + this.addCode(`ballots[stride * output_loc + local_id] = subgroupBallot();`); + this.addCode(`output_loc++;`); break; } case OpType.Store: { - this.genIndent(); - this.addCode(`locations[local_id]++;\n`); - this.genIndent(); - this.addCode(`ballots[stride * output_loc + local_id] = vec4u(${op.value});\n`); - this.genIndent(); - this.addCode(`output_loc++;\n`); + this.addCode(`locations[local_id]++;`); + this.addCode(`ballots[stride * output_loc + local_id] = vec4u(${op.value});`); + this.addCode(`output_loc++;`); break; } default: { - this.genIndent(); - this.addCode(`/* missing op ${op.type} */\n`); + this.addCode(`/* missing op ${op.type} */`); break; } case OpType.IfMask: { - this.genIndent(); if (op.value == 0) { const idx = this.getRandomUint(4); - this.addCode(`if inputs[${idx}] == ${idx} {\n`); + this.addCode(`if inputs[${idx}] == ${idx} {`); } else { const idx = op.value; const x = this.masks[4*idx]; const y = this.masks[4*idx+1]; const z = this.masks[4*idx+2]; const w = this.masks[4*idx+3]; - this.addCode(`if testBit(vec4u(0x${hex(x)},0x${hex(y)},0x${hex(z)},0x${hex(w)}), subgroup_id) {\n`); + this.addCode(`if testBit(vec4u(0x${hex(x)},0x${hex(y)},0x${hex(z)},0x${hex(w)}), subgroup_id) {`); } this.increaseIndent(); break; } case OpType.IfId: { - this.genIndent(); - this.addCode(`if subgroup_id < inputs[${op.value}] {\n`); + this.addCode(`if subgroup_id < inputs[${op.value}] {`); this.increaseIndent(); break; } case OpType.IfLoopCount: { - this.genIndent(); - this.addCode(`if subgroup_id == i${this.loopNesting-1} {\n`); + this.addCode(`if subgroup_id == i${this.loopNesting-1} {`); this.increaseIndent(); break; } @@ -646,42 +691,36 @@ export class Program { case OpType.ElseId: case OpType.ElseLoopCount: { this.decreaseIndent(); - this.genIndent(); - this.addCode(`} else {\n`); + this.addCode(`} else {`); this.increaseIndent(); break; } case OpType.EndIf: { this.decreaseIndent(); - this.genIndent(); - this.addCode(`}\n`); + this.addCode(`}`); break; } case OpType.ForUniform: { - this.genIndent(); const iter = `i${this.loopNesting}`; - this.addCode(`for (var ${iter} = 0u; ${iter} < inputs[${op.value}]; ${iter}++) {\n`); + this.addCode(`for (var ${iter} = 0u; ${iter} < inputs[${op.value}]; ${iter}++) {`); this.increaseIndent(); this.loopNesting++; break; } case OpType.ForInf: { - this.genIndent(); const iter = `i${this.loopNesting}`; - this.addCode(`for (var ${iter} = 0u; true; ${iter} = infLoopIncrement(${iter})) {\n`); + this.addCode(`for (var ${iter} = 0u; true; ${iter} = infLoopIncrement(${iter})) {`); this.loopNesting++; this.increaseIndent(); // Safety mechanism for hardware runs. - this.genIndent(); - this.addCode(`// Safety valve\n`); - this.genIndent(); - this.addCode(`if ${iter} >= 128u { break; }\n\n`); + // Intention extra newline. + this.addCode(`// Safety valve`); + this.addCode(`if ${iter} >= 128u { break; }\n`); break; } case OpType.ForVar: { - this.genIndent(); const iter = `i${this.loopNesting}`; - this.addCode(`for (var ${iter} = 0u; ${iter} < subgroup_id + 1; ${iter}++) {\n`); + this.addCode(`for (var ${iter} = 0u; ${iter} < subgroup_id + 1; ${iter}++) {`); this.loopNesting++; this.increaseIndent(); break; @@ -691,16 +730,13 @@ export class Program { case OpType.EndForVar: { this.loopNesting--; this.decreaseIndent(); - this.genIndent(); - this.addCode(`}\n`); + this.addCode(`}`); break; } case OpType.LoopUniform: { - this.genIndent(); const iter = `i${this.loopNesting}`; - this.addCode(`var ${iter} = 0u;\n`); - this.genIndent(); - this.addCode(`loop {\n`); + this.addCode(`${iter} = 0u;`); + this.addCode(`loop {`); this.loopNesting++; this.increaseIndent(); break; @@ -708,62 +744,86 @@ export class Program { case OpType.EndLoopUniform: { this.loopNesting--; const iter = `i${this.loopNesting}`; - this.genIndent(); - this.addCode(`continuing {\n`); + this.addCode(`continuing {`); + this.increaseIndent(); + this.addCode(`${iter}++;`); + this.addCode(`break if ${iter} >= inputs[${op.value}];`); + this.decreaseIndent(); + this.addCode(`}`); + this.decreaseIndent(); + this.addCode(`}`); + break; + } + case OpType.LoopInf: { + const iter = `i${this.loopNesting}`; + this.addCode(`${iter} = 0u;`); + this.addCode(`loop {`); + this.loopNesting++; + this.increaseIndent(); + break; + } + case OpType.EndLoopInf: { + this.loopNesting--; + const iter = `i${this.loopNesting}`; + this.addCode(`continuing {`); this.increaseIndent(); - this.genIndent(); - this.addCode(`${iter}++;\n`); - this.genIndent(); - this.addCode(`break if ${iter} >= inputs[${op.value}];\n`); + this.addCode(`${iter}++;`); + this.addCode(`ballots[stride * output_loc + local_id] = subgroupBallot();`); + this.addCode(`output_loc++;`); + // Safety mechanism for hardware runs. + // Intentional extra newlines. + this.addCode(``); + this.addCode(`// Safety mechanism`); + this.addCode(`break if ${iter} >= 128;`); this.decreaseIndent(); - this.genIndent(); - this.addCode(`}\n`); + this.addCode(`}`); this.decreaseIndent(); - this.genIndent(); - this.addCode(`}\n`); + this.addCode(`}`); break; } case OpType.Break: { - this.genIndent(); - this.addCode(`break;\n`); + this.addCode(`break;`); break; } case OpType.Continue: { - this.genIndent(); - this.addCode(`continue;\n`); + this.addCode(`continue;`); break; } case OpType.Return: { - this.genIndent(); - this.addCode(`return;\n`); + this.addCode(`return;`); break; } case OpType.Elect: { - this.genIndent(); - this.addCode(`if subgroupElect() {\n`); + this.addCode(`if subgroupElect() {`); this.increaseIndent(); break; } case OpType.Call: { - this.genIndent(); - this.addCode(`f${this.functions.length}(`); + let call = `f${this.functions.length}(`; for (let i = 0; i < this.loopNesting; i++) { - this.addCode(`i${i},`); + call += `i${i},`; } - this.addCode(`);\n`); + call += `);`; + this.addCode(call); this.curFunc = this.functions.length; - this.functions.push(`fn f${this.curFunc}(`); + this.functions.push(``); + this.indents.push(0); + let decl = `fn f${this.curFunc}(` for (let i = 0; i < this.loopNesting; i++) { - this.addCode(`i${i} : u32,`); + decl += `i${i} : u32,`; + } + decl += `) {`; + this.addCode(decl); + this.increaseIndent(); + for (let i = this.loopNesting; i < this.maxLoopNesting; i++) { + this.addCode(`var i${i} = 0u;`); } - this.addCode(`) {\n`); - this.indents.push(2); break; } case OpType.EndCall: { this.decreaseIndent(); - this.addCode(`}\n`); + this.addCode(`}`); // Call nesting is limited to 1 so we always return to f0. this.curFunc = 0; break; @@ -846,7 +906,11 @@ fn testBit(mask : vec4u, id : u32) -> bool { return select(zwbit, xybit, lower64) == 1; } -fn f0() {`; +fn f0() { + var i0 = 0u; + var i1 = 0u; + var i2 = 0u; + var i3 = 0u;`; for (let i = 0; i < this.functions.length; i++) { code += ` @@ -880,10 +944,11 @@ ${this.functions[i]}`; } /** - * Adds 'code' to the current function + * Adds the line 'code' to the current function. */ private addCode(code: string) { - this.functions[this.curFunc] += code; + this.genIndent(); + this.functions[this.curFunc] += code + `\n`; } /** @@ -946,8 +1011,13 @@ ${this.functions[i]}`; this.ops[idx].uniform = true; } - let stack = new Array(); - stack.push(new State()); + //let stack = new Array(); + // Allocate the stack based on the maximum nesting in the program. + let stack: State[] = new Array(this.maxProgramNesting + 1); + for (let i = 0; i < stack.length; i++) { + stack[i] = new State(); + } + //stack.push(new State()); stack[0].activeMask = (1n << 128n) - 1n; let nesting = 0; @@ -957,10 +1027,14 @@ ${this.functions[i]}`; if (!countOnly) { console.log(`Simulating subgroup size = ${subgroupSize}`); + console.log(` Max program nesting = ${this.maxProgramNesting}`); } let i = 0; while (i < this.ops.length) { const op = this.ops[i]; + if (nesting >= stack.length) { + unreachable(`Max stack nesting surpassed (${stack.length} vs ${this.nesting}) at ops[${i}] = ${serializeOpType(op.type)}`); + } if (!countOnly) { //console.log(`ops[${i}] = ${serializeOpType(op.type)}, nesting = ${nesting}, loopNesting = ${loopNesting}`); //console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); @@ -969,6 +1043,32 @@ ${this.functions[i]}`; // console.log(` mask[${j}] = ${stack[j].activeMask.toString(16)}`); //} } + + // Early outs if no invocations are active. + switch (op.type) { + case OpType.Ballot: + case OpType.Store: + case OpType.Return: + case OpType.Continue: + case OpType.Break: { + if (!any(stack[nesting].activeMask)) { + i++; + continue; + } + break; + } + case OpType.ElseMask: + case OpType.ElseId: + case OpType.ElseLoopCount: { + if (!any(stack[nesting-1].activeMask)) { + stack[nesting].activeMask = 0n; + i++; + continue; + } + } + default: + break; + } switch (op.type) { case OpType.Ballot: { const curMask = stack[nesting].activeMask; @@ -978,7 +1078,7 @@ ${this.functions[i]}`; } // Flag if this ballot is not subgroup uniform. - if (this.style == Style.Subgroup) { + if (this.style == Style.Subgroup && any(curMask)) { for (let id = 0; id < this.invocations; id += subgroupSize) { const subgroupMask = (curMask >> BigInt(id)) & getMask(subgroupSize); if (subgroupMask != 0n && !all(subgroupMask, subgroupSize)) { @@ -1015,12 +1115,13 @@ ${this.functions[i]}`; break; } case OpType.Store: { - if (!any(stack[nesting].activeMask)) { + const cur = stack[nesting]; + if (!any(cur.activeMask)) { break; } for (let id = 0; id < this.invocations; id++) { - if (testBit(stack[nesting].activeMask, id)) { + if (testBit(cur.activeMask, id)) { if (!countOnly) { const idx = this.baseIndex(id, locs[id]); this.refData.fill(op.value, idx, idx + 4); @@ -1032,13 +1133,14 @@ ${this.functions[i]}`; } case OpType.IfMask: { nesting++; - stack.push(new State()); + //stack.push(new State()); const cur = stack[nesting]; - cur.copy(stack[nesting-1]); + cur.activeMask = stack[nesting-1].activeMask; cur.header = i; cur.isLoop = false; cur.isSwitch = false; cur.isCall = false; + cur.continueMask = 0n; // O is always uniform true. if (op.value != 0 && any(cur.activeMask)) { let subMask = this.getValueMask(op.value); @@ -1050,10 +1152,10 @@ ${this.functions[i]}`; case OpType.ElseMask: { // 0 is always uniform true so the else will never be taken. const cur = stack[nesting]; + const prev = stack[nesting-1]; if (op.value == 0) { cur.activeMask = 0n; - } else if (any(cur.activeMask)) { - const prev = stack[nesting-1]; + } else if (any(prev.activeMask)) { let subMask = this.getValueMask(op.value); subMask &= getMask(subgroupSize); cur.activeMask = prev.activeMask; @@ -1063,16 +1165,18 @@ ${this.functions[i]}`; } case OpType.IfId: { nesting++; - stack.push(new State()); + //stack.push(new State()); const cur = stack[nesting]; - cur.copy(stack[nesting-1]); + cur.activeMask = stack[nesting-1].activeMask; cur.header = i; cur.isLoop = false; cur.isSwitch = false; cur.isCall = false; + cur.continueMask = 0n; if (any(cur.activeMask)) { // All invocations with subgroup invocation id less than op.value are active. - cur.activeMask &= getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); + const mask = getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); + cur.activeMask &= mask; } break; } @@ -1081,7 +1185,8 @@ ${this.functions[i]}`; // All invocations with a subgroup invocation id greater or equal to op.value are active. stack[nesting].activeMask = prev.activeMask; if (any(prev.activeMask)) { - stack[nesting].activeMask &= ~getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); + const mask = getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); + stack[nesting].activeMask &= ~mask; } break; } @@ -1096,13 +1201,14 @@ ${this.functions[i]}`; } nesting++; - stack.push(new State()); + //stack.push(new State()); const cur = stack[nesting]; - cur.copy(stack[nesting-1]); + cur.activeMask = stack[nesting-1].activeMask; cur.header = i; cur.isLoop = false; cur.isSwitch = false; cur.isCall = false; + cur.continueMask = 0n; if (any(cur.activeMask)) { cur.activeMask &= getReplicatedMask(BigInt(1 << stack[n].tripCount), subgroupSize, this.invocations); } @@ -1116,6 +1222,9 @@ ${this.functions[i]}`; while (!stack[n].isLoop) { n--; } + if (n < 0) { + unreachable(`Failed to find loop for ElseLoopCount`); + } stack[nesting].activeMask = stack[nesting-1].activeMask; if (any(stack[nesting].activeMask)) { @@ -1126,20 +1235,25 @@ ${this.functions[i]}`; case OpType.EndIf: { // End the current if. nesting--; - stack.pop(); + //stack.pop(); break; } case OpType.ForUniform: case OpType.ForInf: case OpType.ForVar: - case OpType.LoopUniform: { + case OpType.LoopUniform: + case OpType.LoopInf: { nesting++; loopNesting++; - stack.push(new State()); + assert(nesting < stack.length); + //stack.push(new State()); const cur = stack[nesting]; cur.header = i; cur.isLoop = true; cur.activeMask = stack[nesting-1].activeMask; + cur.isSwitch = false; + cur.isCall = false; + cur.continueMask = 0n; break; } case OpType.EndForUniform: { @@ -1155,7 +1269,7 @@ ${this.functions[i]}`; } else { loopNesting--; nesting--; - stack.pop(); + //stack.pop(); } break; } @@ -1186,7 +1300,7 @@ ${this.functions[i]}`; } else { loopNesting--; nesting--; - stack.pop(); + //stack.pop(); } break; } @@ -1206,7 +1320,7 @@ ${this.functions[i]}`; if (done) { loopNesting--; nesting--; - stack.pop(); + //stack.pop(); } else { i = cur.header + 1; continue; @@ -1225,7 +1339,37 @@ ${this.functions[i]}`; } else { loopNesting--; nesting--; - stack.pop(); + //stack.pop(); + } + break; + } + case OpType.EndLoopInf: { + const cur = stack[nesting]; + cur.tripCount++; + cur.activeMask |= cur.continueMask; + if (any(cur.activeMask)) { + let maskArray = new Uint32Array(); + for (let id = 0; id < this.invocations; id++) { + if (id % subgroupSize === 0) { + maskArray = getSubgroupMask(cur.activeMask, subgroupSize, id); + } + if (testBit(cur.activeMask, id)) { + if (!countOnly) { + const idx = this.baseIndex(id, locs[id]); + this.refData[idx + 0] = maskArray[0]; + this.refData[idx + 1] = maskArray[1]; + this.refData[idx + 2] = maskArray[2]; + this.refData[idx + 3] = maskArray[3]; + } + locs[id]++; + } + } + i = cur.header + 1; + continue; + } else { + loopNesting--; + nesting--; + //stack.pop(); } break; } @@ -1265,7 +1409,7 @@ ${this.functions[i]}`; } } if (n < 0) { - unreachable(`Failed to loop for continue`); + unreachable(`Failed to find loop for continue`); } break; @@ -1284,17 +1428,22 @@ ${this.functions[i]}`; break; } } + if (n < 0) { + unreachable(`Failed to find call for return`); + } + break; } case OpType.Elect: { nesting++; - stack.push(new State()); + //stack.push(new State()); const cur = stack[nesting]; - cur.copy(stack[nesting-1]); + cur.activeMask = stack[nesting-1].activeMask; cur.header = i; cur.isLoop = false; cur.isSwitch = false; cur.isCall = false; + cur.continueMask = 0n; if (any(cur.activeMask)) { cur.activeMask = getElectMask(cur.activeMask, subgroupSize, this.invocations); } @@ -1302,15 +1451,18 @@ ${this.functions[i]}`; } case OpType.Call: { nesting++; - stack.push(new State()); + //stack.push(new State()); const cur = stack[nesting]; cur.activeMask = stack[nesting-1].activeMask; - cur.isCall = 1; + cur.isCall = true; + cur.isLoop = false; + cur.isSwitch = false; + cur.continueMask = 0n; break; } case OpType.EndCall: { nesting--; - stack.pop(); + //stack.pop(); break; } default: { @@ -1320,7 +1472,8 @@ ${this.functions[i]}`; i++; } - assert(stack.length == 1); + assert(nesting == 0); + //assert(stack.length == 1); let maxLoc = 0; for (let id = 0; id < this.invocations; id++) { @@ -1355,7 +1508,7 @@ ${this.functions[i]}`; while (this.ops.length < this.minCount) { this.pickOp(1); } - //break; + break; // If this is an uniform control flow case, make sure a uniform ballot is // generated. A subgroup size of 64 is used for testing purposes here. @@ -1510,6 +1663,13 @@ ${this.functions[i]}`; * ballot(); // fully uniform * } * ballot(); // fully uniform + * + * @param beginLoop The loop type + * @param endLoop The end loop type + * + * |beginLoop| and |endLoop| must be paired. Currently supported pairs: + * * ForUniform and EndForUniform + * * LoopUniform and EndLoopUniform */ public predefinedProgram1(beginLoop: OpType = OpType.ForUniform, endLoop: OpType = OpType.EndForUniform) { @@ -1660,7 +1820,7 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(OpType.ElseId, 0)); + this.ops.push(new Op(OpType.ElseId, 112)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); @@ -1680,9 +1840,17 @@ ${this.functions[i]}`; * } * } * ballot(); + * + * @param beginType The loop type + * @param endType The end loop type + * + * |beginType| and |endType| must be paired. Currently supported pairs: + * * ForInf and EndForInf + * * LoopInf and EndLoopInf */ - public predefinedProgramForInf() { - this.ops.push(new Op(OpType.ForInf, 0)); + public predefinedProgramInf(beginType: OpType = OpType.ForInf, + endType: OpType = OpType.EndForInf) { + this.ops.push(new Op(beginType, 0)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); @@ -1691,7 +1859,7 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Break, 0)); this.ops.push(new Op(OpType.EndIf, 0)); - this.ops.push(new Op(OpType.EndForInf, 0)); + this.ops.push(new Op(endType, 0)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); From c97f8e9998f60173aada7853a73a7d3fabab4cd7 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Wed, 23 Aug 2023 15:16:00 -0400 Subject: [PATCH 14/32] Improve performance and fixes * remove the locations buffer * cap the maximum number of locations per invocation such that the buffer size is guaranteed < 256MB * fix simulation of return in main function * remove checking of last buffer value in UCF tests --- .../reconvergence/reconvergence.spec.ts | 53 ++++--- .../shader/execution/reconvergence/util.ts | 143 +++++++----------- 2 files changed, 87 insertions(+), 109 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index f6d3f56ac357..f8c79cad11d9 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -96,8 +96,10 @@ async function testProgram(t: GPUTest, program: Program) { locMap.set(size, num); numLocs = Math.max(num, numLocs); } + numLocs = Math.min(program.maxLocations, numLocs); // Add 1 to ensure there are no extraneous writes. numLocs++; + console.log(`${new Date()}: Maximum locations = ${numLocs}`); console.log(`${new Date()}: creating pipeline`); const pipeline = t.device.createComputePipeline({ @@ -125,12 +127,12 @@ async function testProgram(t: GPUTest, program: Program) { ); t.trackForCleanup(ballotBuffer); - const locationLength = program.invocations; - const locationBuffer = t.makeBufferWithContents( - new Uint32Array([...iterRange(locationLength, x => 0)]), - GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC - ); - t.trackForCleanup(locationBuffer); + //const locationLength = program.invocations; + //const locationBuffer = t.makeBufferWithContents( + // new Uint32Array([...iterRange(locationLength, x => 0)]), + // GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC + //); + //t.trackForCleanup(locationBuffer); const sizeLength = 2; const sizeBuffer = t.makeBufferWithContents( @@ -161,12 +163,12 @@ async function testProgram(t: GPUTest, program: Program) { buffer: ballotBuffer }, }, - { - binding: 2, - resource: { - buffer: locationBuffer - }, - }, + //{ + // binding: 2, + // resource: { + // buffer: locationBuffer + // }, + //}, { binding: 3, resource: { @@ -216,6 +218,9 @@ async function testProgram(t: GPUTest, program: Program) { program.sizeRefData(locMap.get(actualSize)); console.log(`${new Date()}: Full simulation size = ${actualSize}`); let num = program.simulate(false, actualSize); + console.log(`${new Date()}: locations = ${num}`); + num = Math.min(program.maxLocations, num); + console.log(`${new Date()}: locations = ${num}`); const idReadback = await t.readGPUBufferRangeTyped( idBuffer, @@ -229,16 +234,16 @@ async function testProgram(t: GPUTest, program: Program) { const idData = idReadback.data; t.expectOK(checkIds(idData, actualSize), { mode: 'warn' }); - const locationReadback = await t.readGPUBufferRangeTyped( - locationBuffer, - { - srcByteOffset: 0, - type: Uint32Array, - typedLength: locationLength, - method: 'copy', - } - ); - const locationData = locationReadback.data; + //const locationReadback = await t.readGPUBufferRangeTyped( + // locationBuffer, + // { + // srcByteOffset: 0, + // type: Uint32Array, + // typedLength: locationLength, + // method: 'copy', + // } + //); + //const locationData = locationReadback.data; console.log(`${new Date()}: Reading ballot buffer ${ballotLength * 4} bytes`); const ballotReadback = await t.readGPUBufferRangeTyped( @@ -263,7 +268,7 @@ async function testProgram(t: GPUTest, program: Program) { // } //} - t.expectOK(program.checkResults(ballotData, locationData, actualSize, num)); + t.expectOK(program.checkResults(ballotData, /*locationData,*/ actualSize, num)); } g.test('predefined_reconvergence') @@ -337,7 +342,7 @@ g.test('random_reconvergence') .params(u => u .combine('style', [Style.Workgroup, Style.Subgroup, Style.Maximal] as const) - .combine('seed', generateSeeds(5)) + .combine('seed', generateSeeds(50)) .filter(u => { if (u.style == Style.Maximal) { return false; diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index d0293325be43..e9b1610a2cc3 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -226,6 +226,7 @@ export class Program { private isLoopInf: Map; private doneInfLoopBreak: Map; private maxProgramNesting; + public readonly maxLocations: number; /** * constructor @@ -240,8 +241,11 @@ export class Program { this.style = style; this.minCount = 30; this.maxCount = 50000; // TODO: what is a reasonable limit? + // TODO: https://crbug.com/tint/2011 + // Tint is double counting depth this.maxNesting = this.getRandomUint(40) + 20; //this.getRandomUint(70) + 30; // [30,100) - this.maxLoopNesting = 4; + // Loops significantly affect runtime and memory performance + this.maxLoopNesting = 3; //4; this.nesting = 0; this.loopNesting = 0; this.loopNestingThisFunction = 0; @@ -268,6 +272,7 @@ export class Program { this.isLoopInf = new Map(); this.doneInfLoopBreak = new Map(); this.maxProgramNesting = 10; // default stack allocation + this.maxLocations = 130000; // keep the buffer under 256MiB } /** @returns A random float between 0 and 1 */ @@ -559,7 +564,7 @@ export class Program { // Sometimes use a return if we're in a call. if (this.callNesting > 0 && this.getRandomFloat() < 0.3) { - this.ops.push(new Op(OpType.Return, 0)); + this.ops.push(new Op(OpType.Return, this.callNesting)); } else { this.genBreak(); } @@ -630,13 +635,13 @@ export class Program { this.genBallot(); if (this.getRandomFloat() < 0.1) { this.ops.push(new Op(OpType.IfMask, 0)); - this.ops.push(new Op(OpType.Return, 0)); + this.ops.push(new Op(OpType.Return, this.callNesting)); this.ops.push(new Op(OpType.ElseMask, 0)); - this.ops.push(new Op(OpType.Return, 0)); + this.ops.push(new Op(OpType.Return, this.callNesting)); this.ops.push(new Op(OpType.EndIf, 0)); this.maxProgramNesting = Math.max(this.nesting + 1, this.maxProgramNesting); } else { - this.ops.push(new Op(OpType.Return, 0)); + this.ops.push(new Op(OpType.Return, this.callNesting)); } } } @@ -653,7 +658,7 @@ export class Program { break; } case OpType.Store: { - this.addCode(`locations[local_id]++;`); + //this.addCode(`locations[local_id]++;`); this.addCode(`ballots[stride * output_loc + local_id] = vec4u(${op.value});`); this.addCode(`output_loc++;`); break; @@ -840,8 +845,8 @@ const stride = ${this.invocations}; var inputs : array; @group(0) @binding(1) var ballots : array; -@group(0) @binding(2) -var locations : array; +//@group(0) @binding(2) +//var locations : array; @group(0) @binding(3) var size : array; @group(0) @binding(4) @@ -859,7 +864,7 @@ fn main( ) { _ = inputs[0]; _ = ballots[0]; - _ = locations[0]; + //_ = locations[0]; subgroup_id = sid; local_id = lid; ids[lid] = sid; @@ -955,9 +960,11 @@ ${this.functions[i]}`; * Sizes the simulation buffer. * * The total size is (# of invocations) * |locs| * 4 (uint4 is written). + * |locs| is capped at this.maxLocations. */ public sizeRefData(locs: number) { - this.refData = new Uint32Array(locs * 4 * this.invocations); + const num = Math.min(this.maxLocations, locs); + this.refData = new Uint32Array(num * 4 * this.invocations); this.refData.fill(0); } @@ -997,27 +1004,25 @@ ${this.functions[i]}`; this.isSwitch = false; } - copy(other: State) { - this.activeMask = other.activeMask; - this.continueMask = other.continueMask; - this.header = other.header; - this.isLoop = other.isLoop; - this.tripCount = other.tripCount; - this.isCall = other.isCall; - this.isSwitch = other.isSwitch; + reset() { + this.activeMask = 0n; + this.continueMask = 0n; + this.header = 0; + this.isLoop = false; + this.tripCount = 0; + this.isCall = false; + this.isSwitch = false; } }; for (let idx = 0; idx < this.ops.length; idx++) { this.ops[idx].uniform = true; } - //let stack = new Array(); // Allocate the stack based on the maximum nesting in the program. let stack: State[] = new Array(this.maxProgramNesting + 1); for (let i = 0; i < stack.length; i++) { stack[i] = new State(); } - //stack.push(new State()); stack[0].activeMask = (1n << 128n) - 1n; let nesting = 0; @@ -1025,10 +1030,6 @@ ${this.functions[i]}`; let locs = new Array(this.invocations); locs.fill(0); - if (!countOnly) { - console.log(`Simulating subgroup size = ${subgroupSize}`); - console.log(` Max program nesting = ${this.maxProgramNesting}`); - } let i = 0; while (i < this.ops.length) { const op = this.ops[i]; @@ -1036,12 +1037,11 @@ ${this.functions[i]}`; unreachable(`Max stack nesting surpassed (${stack.length} vs ${this.nesting}) at ops[${i}] = ${serializeOpType(op.type)}`); } if (!countOnly) { - //console.log(`ops[${i}] = ${serializeOpType(op.type)}, nesting = ${nesting}, loopNesting = ${loopNesting}`); + //console.log(`ops[${i}] = ${serializeOpType(op.type)}, nesting = ${nesting}, loopNesting = ${loopNesting}, value = ${op.value}`); //console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); - //console.log(` isLoop = ${stack[nesting].isLoop}`); - //for (let j = 0; j <= nesting; j++) { - // console.log(` mask[${j}] = ${stack[j].activeMask.toString(16)}`); - //} + // //for (let j = 0; j <= nesting; j++) { + // // console.log(` mask[${j}] = ${stack[j].activeMask.toString(16)}`); + // //} } // Early outs if no invocations are active. @@ -1133,14 +1133,10 @@ ${this.functions[i]}`; } case OpType.IfMask: { nesting++; - //stack.push(new State()); const cur = stack[nesting]; + cur.reset(); cur.activeMask = stack[nesting-1].activeMask; cur.header = i; - cur.isLoop = false; - cur.isSwitch = false; - cur.isCall = false; - cur.continueMask = 0n; // O is always uniform true. if (op.value != 0 && any(cur.activeMask)) { let subMask = this.getValueMask(op.value); @@ -1165,14 +1161,10 @@ ${this.functions[i]}`; } case OpType.IfId: { nesting++; - //stack.push(new State()); const cur = stack[nesting]; + cur.reset(); cur.activeMask = stack[nesting-1].activeMask; cur.header = i; - cur.isLoop = false; - cur.isSwitch = false; - cur.isCall = false; - cur.continueMask = 0n; if (any(cur.activeMask)) { // All invocations with subgroup invocation id less than op.value are active. const mask = getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); @@ -1201,14 +1193,10 @@ ${this.functions[i]}`; } nesting++; - //stack.push(new State()); const cur = stack[nesting]; + cur.reset(); cur.activeMask = stack[nesting-1].activeMask; cur.header = i; - cur.isLoop = false; - cur.isSwitch = false; - cur.isCall = false; - cur.continueMask = 0n; if (any(cur.activeMask)) { cur.activeMask &= getReplicatedMask(BigInt(1 << stack[n].tripCount), subgroupSize, this.invocations); } @@ -1235,7 +1223,6 @@ ${this.functions[i]}`; case OpType.EndIf: { // End the current if. nesting--; - //stack.pop(); break; } case OpType.ForUniform: @@ -1245,15 +1232,11 @@ ${this.functions[i]}`; case OpType.LoopInf: { nesting++; loopNesting++; - assert(nesting < stack.length); - //stack.push(new State()); const cur = stack[nesting]; + cur.reset(); cur.header = i; cur.isLoop = true; cur.activeMask = stack[nesting-1].activeMask; - cur.isSwitch = false; - cur.isCall = false; - cur.continueMask = 0n; break; } case OpType.EndForUniform: { @@ -1269,7 +1252,6 @@ ${this.functions[i]}`; } else { loopNesting--; nesting--; - //stack.pop(); } break; } @@ -1300,7 +1282,6 @@ ${this.functions[i]}`; } else { loopNesting--; nesting--; - //stack.pop(); } break; } @@ -1320,7 +1301,6 @@ ${this.functions[i]}`; if (done) { loopNesting--; nesting--; - //stack.pop(); } else { i = cur.header + 1; continue; @@ -1339,7 +1319,6 @@ ${this.functions[i]}`; } else { loopNesting--; nesting--; - //stack.pop(); } break; } @@ -1369,7 +1348,6 @@ ${this.functions[i]}`; } else { loopNesting--; nesting--; - //stack.pop(); } break; } @@ -1428,7 +1406,9 @@ ${this.functions[i]}`; break; } } - if (n < 0) { + // op.value for Return is the call nesting. + // If the value is > 0 we should have encountered the call on the stack. + if (op.value != 0 && n < 0) { unreachable(`Failed to find call for return`); } @@ -1436,14 +1416,10 @@ ${this.functions[i]}`; } case OpType.Elect: { nesting++; - //stack.push(new State()); const cur = stack[nesting]; + cur.reset(); cur.activeMask = stack[nesting-1].activeMask; cur.header = i; - cur.isLoop = false; - cur.isSwitch = false; - cur.isCall = false; - cur.continueMask = 0n; if (any(cur.activeMask)) { cur.activeMask = getElectMask(cur.activeMask, subgroupSize, this.invocations); } @@ -1451,18 +1427,14 @@ ${this.functions[i]}`; } case OpType.Call: { nesting++; - //stack.push(new State()); const cur = stack[nesting]; + cur.reset(); cur.activeMask = stack[nesting-1].activeMask; cur.isCall = true; - cur.isLoop = false; - cur.isSwitch = false; - cur.continueMask = 0n; break; } case OpType.EndCall: { nesting--; - //stack.pop(); break; } default: { @@ -1473,15 +1445,15 @@ ${this.functions[i]}`; } assert(nesting == 0); - //assert(stack.length == 1); let maxLoc = 0; for (let id = 0; id < this.invocations; id++) { maxLoc = Math.max(maxLoc, locs[id]); } - if (!countOnly) { - console.log(`Max location = ${maxLoc}\n`); - } + maxLoc = Math.min(this.maxLocations, maxLoc); + //if (!countOnly) { + // console.log(`Max location = ${maxLoc}\n`); + //} return maxLoc; } @@ -1503,12 +1475,16 @@ ${this.functions[i]}`; /** @returns a randomized program */ public generate() { + let i = 0; do { + if (i != 0) { + console.log(`Warning regenerating UCF testcase`); + } this.ops = []; while (this.ops.length < this.minCount) { this.pickOp(1); } - break; + //break; // If this is an uniform control flow case, make sure a uniform ballot is // generated. A subgroup size of 64 is used for testing purposes here. @@ -1516,6 +1492,7 @@ ${this.functions[i]}`; console.log(`${new Date()}: simulating for UCF`); this.simulate(true, 64); } + i++; } while (this.style != Style.Maximal && !this.isUCF()); } @@ -1540,6 +1517,7 @@ ${this.functions[i]}`; * @returns The base index in a Uint32Array */ private baseIndex(id: number, loc: number): number { + const capped_loc = Math.min(this.maxLocations, loc); return 4 * (this.invocations * loc + id); } @@ -1569,8 +1547,9 @@ ${this.functions[i]}`; * @param numLocs The maximum locations used in simulation * @returns an error if the results do meet expectatations */ - public checkResults(ballots: Uint32Array, locations: Uint32Array, + public checkResults(ballots: Uint32Array, /*locations: Uint32Array,*/ subgroupSize: number, numLocs: number): Error | undefined { + let totalLocs = Math.min(numLocs, this.maxLocations); //console.log(`Verifying numLocs = ${numLocs}`); if (this.style == Style.Workgroup || this.style === Style.Subgroup) { if (!this.isUCF()) { @@ -1583,8 +1562,8 @@ ${this.functions[i]}`; for (let id = 0; id < this.invocations; id++) { let refLoc = 1; let resLoc = 0; - while (refLoc < numLocs) { - while (refLoc < numLocs && + while (refLoc < totalLocs) { + while (refLoc < totalLocs && !this.matchResult(this.refData, this.baseIndex(id, refLoc), maskArray, 0)) { refLoc++; } @@ -1593,13 +1572,13 @@ ${this.functions[i]}`; // Search for the corresponding store in the result data. let storeRefLoc = refLoc - 1; - while (resLoc < numLocs && + while (resLoc < totalLocs && !this.matchResult(ballots, this.baseIndex(id, resLoc), this.refData, this.baseIndex(id, storeRefLoc))) { resLoc++; } - if (resLoc >= numLocs) { + if (resLoc >= totalLocs) { const refIdx = this.baseIndex(id, storeRefLoc); return Error(`Failure for invocation ${id}: could not find associated store for reference location ${storeRefLoc}: ${this.refData[refIdx]},${this.refData[refIdx+1]},${this.refData[refIdx+2]},${this.refData[refIdx+3]}`); } else { @@ -1616,12 +1595,6 @@ ${this.functions[i]}`; refLoc++; } } - // Check there were no extra writes. - const idx = this.baseIndex(id, numLocs); - if (!this.matchResult(ballots, idx, zeroArray, 0)) { - return Error(`Unexpected write at end of buffer (location = ${numLocs}) for invocation ${id} -- got: (${ballots[idx]}, ${ballots[idx + 1]}, ${ballots[idx + 2]}, ${ballots[idx + 3]})`); - } } } else if (this.style == Style.Maximal) { // Expect exact matches. @@ -1765,7 +1738,7 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(OpType.Return, 16)); + this.ops.push(new Op(OpType.Return, 0)); this.ops.push(new Op(OpType.EndIf, 0)); @@ -1942,7 +1915,7 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.IfLoopCount, 0)); this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(OpType.Return, 0)); + this.ops.push(new Op(OpType.Return, 1)); this.ops.push(new Op(OpType.EndIf, 0)); // end f1 this.ops.push(new Op(OpType.EndCall, 0)); @@ -1962,7 +1935,7 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.IfMask, 1)); this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(OpType.Return, 0)); + this.ops.push(new Op(OpType.Return, 1)); this.ops.push(new Op(OpType.EndIf, 0)); // end f2 this.ops.push(new Op(OpType.EndCall, 0)); From fe9370cf6e649842f854944fff19cb4905fff1a8 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Thu, 24 Aug 2023 14:42:55 -0400 Subject: [PATCH 15/32] Cleanup * Add debug functions for dumping info * remove some logging * removve some default parameters to ensure consistency * change some limits to improve runtime performance * Add loop reduction factors to improve runtime performance * ForVar, ForInf, and LoopInf will execute half as many when inside 1 loop and a quarter inside 2 * Made all stores unique * identical else blocks previously reused values --- .../reconvergence/reconvergence.spec.ts | 34 ++-- .../shader/execution/reconvergence/util.ts | 161 ++++++++++++++---- 2 files changed, 146 insertions(+), 49 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index f8c79cad11d9..c64437e75555 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -65,6 +65,22 @@ function checkSubgroupSizeConsistency(data: Uint32Array, min: number, max: numbe return undefined; } +function dumpBallots(ballots: Uint32Array, invocations: number, locations: number) { + let dump = `Ballots\n`; + for (let id = 0; id < invocations; id++) { + dump += `id[${id}]\n`; + for (let loc = 0; loc < locations; loc++) { + const idx = 4 * (invocations * loc + id); + const w = ballots[idx+3]; + const z = ballots[idx+2]; + const y = ballots[idx+1]; + const x = ballots[idx+0]; + dump += ` loc[${loc}] = (0x${hex(w)},0x${hex(z)},0x${hex(y)},0x${hex(x)}), (${w},${z},${y},${x})\n`; + } + } + console.log(dump); +} + /** * Checks the mapping of subgroup_invocation_id to local_invocation_index */ @@ -80,8 +96,9 @@ subgroup_invocation_id = ${data[i]}`); } async function testProgram(t: GPUTest, program: Program) { - let wgsl = program.genCode(); + const wgsl = program.genCode(); //console.log(wgsl); + //program.dumpStats(true); //return; // TODO: query the device @@ -91,7 +108,6 @@ async function testProgram(t: GPUTest, program: Program) { let numLocs = 0; const locMap = new Map(); for (let size = minSubgroupSize; size <= maxSubgroupSize; size *= 2) { - console.log(`${new Date()}: simulating subgroup size = ${size}`); let num = program.simulate(true, size); locMap.set(size, num); numLocs = Math.max(num, numLocs); @@ -220,7 +236,6 @@ async function testProgram(t: GPUTest, program: Program) { let num = program.simulate(false, actualSize); console.log(`${new Date()}: locations = ${num}`); num = Math.min(program.maxLocations, num); - console.log(`${new Date()}: locations = ${num}`); const idReadback = await t.readGPUBufferRangeTyped( idBuffer, @@ -245,7 +260,6 @@ async function testProgram(t: GPUTest, program: Program) { //); //const locationData = locationReadback.data; - console.log(`${new Date()}: Reading ballot buffer ${ballotLength * 4} bytes`); const ballotReadback = await t.readGPUBufferRangeTyped( ballotBuffer, { @@ -258,15 +272,7 @@ async function testProgram(t: GPUTest, program: Program) { const ballotData = ballotReadback.data; console.log(`${Date()}: Finished buffer readbacks`); - //console.log(`Ballots`); - ////for (let id = 0; id < program.invocations; id++) { - //for (let id = 0; id < actualSize; id++) { - // console.log(` id[${id}]:`); - // for (let loc = 0; loc < num; loc++) { - // const idx = 4 * (program.invocations * loc + id); - // console.log(` loc[${loc}] = (${hex(ballotData[idx+3])},${hex(ballotData[idx+2])},${hex(ballotData[idx+1])},${hex(ballotData[idx])}), (${ballotData[idx+3]},${ballotData[idx+2]},${ballotData[idx+1]},${ballotData[idx]})`); - // } - //} + //dumpBallots(ballotData, program.invocations, num); t.expectOK(program.checkResults(ballotData, /*locationData,*/ actualSize, num)); } @@ -329,7 +335,7 @@ g.test('predefined_reconvergence') break; } default: { - program = new Program(); + program = new Program(style, 1, invocations); unreachable('Unhandled testcase'); } } diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index e9b1610a2cc3..1ed5025a12e7 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -11,7 +11,7 @@ function getMask(size: number): bigint { } /** @returns A bitmask where submask is repeated every size bits for total bits. */ -function getReplicatedMask(submask: bigint, size: number, total: number = 128): bigint { +function getReplicatedMask(submask: bigint, size: number, total: number): bigint { const reps = Math.floor(total / size); let mask: bigint = submask & ((1n << BigInt(size)) - 1n); for (let i = 1; i < reps; i++) { @@ -21,7 +21,7 @@ function getReplicatedMask(submask: bigint, size: number, total: number = 128): } /** @returns a mask with only the least significant 1 in |value| set for each subgroup. */ -function getElectMask(value: bigint, size: number, total: number = 128): bigint { +function getElectMask(value: bigint, size: number, total: number): bigint { let mask = value; let count = 0; while (!(mask & 1n)) { @@ -225,8 +225,8 @@ export class Program { private refData: Uint32Array; private isLoopInf: Map; private doneInfLoopBreak: Map; - private maxProgramNesting; public readonly maxLocations: number; + private maxProgramNesting; /** * constructor @@ -234,16 +234,19 @@ export class Program { * @param style Enum indicating the type of reconvergence being tested * @param seed Value used to seed the PRNG */ - constructor(style : Style = Style.Workgroup, seed: number = 1, invocations: number = 128) { + constructor(style : Style = Style.Workgroup, seed: number = 1, invocations: number) { this.invocations = invocations; this.prng = new PRNG(seed); this.ops = []; this.style = style; this.minCount = 30; - this.maxCount = 50000; // TODO: what is a reasonable limit? + //this.maxCount = 50000; // TODO: what is a reasonable limit? + this.maxCount = 20000; // TODO: what is a reasonable limit? // TODO: https://crbug.com/tint/2011 // Tint is double counting depth - this.maxNesting = this.getRandomUint(40) + 20; //this.getRandomUint(70) + 30; // [30,100) + //this.maxNesting = this.getRandomUint(70) + 30; // [30,100) + //this.maxNesting = this.getRandomUint(40) + 20; + this.maxNesting = this.getRandomUint(20) + 20; // Loops significantly affect runtime and memory performance this.maxLoopNesting = 3; //4; this.nesting = 0; @@ -380,7 +383,7 @@ export class Program { !(this.ops[cur_length - 1].type == OpType.Ballot || (this.ops[cur_length-1].type == OpType.Store && this.ops[cur_length - 2].type == OpType.Ballot))) { // Perform a store with each ballot so the results can be correlated. - //if (this.style != Style.Maximal) + if (this.style != Style.Maximal) this.ops.push(new Op(OpType.Store, cur_length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); } @@ -411,7 +414,7 @@ export class Program { if (type == IfType.Uniform) maskIdx = 0; - const lid = this.getRandomUint(128); + const lid = this.getRandomUint(this.invocations); if (type == IfType.Lid) { this.ops.push(new Op(OpType.IfId, lid)); } else if (type == IfType.LoopCount) { @@ -444,6 +447,10 @@ export class Program { for (let i = beforeSize; i < afterSize; i++) { const op = this.ops[i]; this.ops.push(new Op(op.type, op.value, op.uniform)); + // Make stores unique. + if (op.type == OpType.Store) { + this.ops[this.ops.length-1].value = this.storeBase + this.ops.length - 1; + } } } else { this.pickOp(2); @@ -479,7 +486,9 @@ export class Program { this.pickOp(2); - this.genElect(true); + // As loop become more deeply nested, execute fewer iterations. + const reduction = this.loopNesting === 1 ? 1 : this.loopNesting === 2 ? 2 : 4; + this.genElect(true, reduction); this.doneInfLoopBreak.set(this.loopNesting, true); this.pickOp(2); @@ -493,7 +502,9 @@ export class Program { } private genForVar() { - this.ops.push(new Op(OpType.ForVar, 0)); + // op.value is the iteration reduction factor. + const reduction = this.loopNesting === 0 ? 1 : this.loopNesting === 1 ? 2 : 4; + this.ops.push(new Op(OpType.ForVar, reduction)); this.nesting++; this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); this.loopNesting++; @@ -501,7 +512,7 @@ export class Program { this.pickOp(2); - this.ops.push(new Op(OpType.EndForVar, 0)); + this.ops.push(new Op(OpType.EndForVar, reduction)); this.loopNestingThisFunction--; this.loopNesting--; this.nesting--; @@ -536,7 +547,8 @@ export class Program { this.pickOp(2); - this.genElect(true); + const reduction = this.loopNesting === 1 ? 1 : this.loopNesting === 2 ? 2 : 4; + this.genElect(true, reduction); this.doneInfLoopBreak.set(this.loopNesting, true); this.pickOp(2); @@ -550,7 +562,7 @@ export class Program { this.nesting--; } - private genElect(forceBreak: boolean) { + private genElect(forceBreak: boolean, reduction: number = 1) { this.ops.push(new Op(OpType.Elect, 0)); this.nesting++; this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); @@ -574,6 +586,12 @@ export class Program { this.ops.push(new Op(OpType.EndIf, 0)); this.nesting--; + // Reduction injects extra breaks to reduce the number of iterations. + for (let i = 1; i < reduction; i++) { + this.ops.push(new Op(OpType.Elect, 0)); + this.ops.push(new Op(OpType.Break, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); + } } private genBreak() { @@ -720,12 +738,12 @@ export class Program { // Safety mechanism for hardware runs. // Intention extra newline. this.addCode(`// Safety valve`); - this.addCode(`if ${iter} >= 128u { break; }\n`); + this.addCode(`if ${iter} >= sgsize { break; }\n`); break; } case OpType.ForVar: { const iter = `i${this.loopNesting}`; - this.addCode(`for (var ${iter} = 0u; ${iter} < subgroup_id + 1; ${iter}++) {`); + this.addCode(`for (var ${iter} = 0u; ${iter} < (subgroup_id / ${op.value}) + 1; ${iter}++) {`); this.loopNesting++; this.increaseIndent(); break; @@ -779,7 +797,7 @@ export class Program { // Intentional extra newlines. this.addCode(``); this.addCode(`// Safety mechanism`); - this.addCode(`break if ${iter} >= 128;`); + this.addCode(`break if ${iter} >= sgsize;`); this.decreaseIndent(); this.addCode(`}`); this.decreaseIndent(); @@ -855,6 +873,7 @@ var ids : array; var subgroup_id : u32; var local_id : u32; var output_loc : u32 = 0; +var sgsize : u32 = 0; @compute @workgroup_size(stride,1,1) fn main( @@ -868,6 +887,7 @@ fn main( subgroup_id = sid; local_id = lid; ids[lid] = sid; + sgsize = sg_size; // Store the subgroup size from the built-in value and ballot to check for // consistency. @@ -956,6 +976,82 @@ ${this.functions[i]}`; this.functions[this.curFunc] += code + `\n`; } + public dumpStats(detailed: boolean = true) { + let stats = `Total instructions: ${this.ops.length}\n`; + let nesting = 0; + let stores = 0; + let totalStores = 0; + let totalLoops = 0; + let loopsAtNesting = new Array(this.maxLoopNesting); + loopsAtNesting.fill(0); + let storesAtNesting = new Array(this.maxLoopNesting + 1); + storesAtNesting.fill(0); + for (let i = 0; i < this.ops.length; i++) { + const op = this.ops[i]; + switch (op.type) { + case OpType.Store: + case OpType.Ballot: { + stores++; + storesAtNesting[nesting]++; + break; + } + case OpType.ForUniform: + case OpType.LoopUniform: + case OpType.ForVar: + case OpType.ForInf: + case OpType.LoopInf: { + totalLoops++; + loopsAtNesting[nesting]++; + if (detailed) { + stats += ' '.repeat(nesting) + `${stores} stores\n`; + } + totalStores += stores; + stores = 0; + + if (detailed) { + let iters = `subgroup size`; + if (op.type === OpType.ForUniform || op.type === OpType.LoopUniform) { + iters = `${op.value}`; + } + stats += ' '.repeat(nesting) + serializeOpType(op.type) + `: ${iters} iterations\n`; + } + nesting++; + break; + } + case OpType.EndForUniform: + case OpType.EndForInf: + case OpType.EndForVar: + case OpType.EndLoopUniform: + case OpType.EndLoopInf: { + if (detailed) { + stats += ' '.repeat(nesting) + `${stores} stores\n`; + } + totalStores += stores; + stores = 0; + + nesting--; + if (detailed) { + stats += ' '.repeat(nesting) + serializeOpType(op.type) + '\n'; + } + break; + } + default: + break; + } + } + totalStores += stores; + stats += `\n`; + stats += `${totalLoops} loops\n`; + for (let i = 0; i < loopsAtNesting.length; i++) { + stats += ` ${loopsAtNesting[i]} at nesting ${i}\n`; + } + stats += `${totalStores} stores\n`; + for (let i = 0; i < storesAtNesting.length; i++) { + stats += ` ${storesAtNesting[i]} at nesting ${i}\n`; + } + console.log(stats); + } + /** * Sizes the simulation buffer. * @@ -968,7 +1064,6 @@ ${this.functions[i]}`; this.refData.fill(0); } - // TODO: Reconvergence guarantees are not as strong as this simulation. /** * Simulate the program for the given subgroup size * @@ -976,8 +1071,11 @@ ${this.functions[i]}`; * @param subgroupSize The subgroup size to simulate * * BigInt is not the fastest value to manipulate. Care should be taken to optimize it's use. + * TODO: would it be better to roll my own 128 bitvector? + * + * TODO: reconvergence guarantees in WGSL are not as strong as this simulation */ - public simulate(countOnly: boolean, subgroupSize: number): number { + public simulate(countOnly: boolean, subgroupSize: number, debug: boolean = false): number { class State { // Active invocations activeMask: bigint; @@ -1023,7 +1121,7 @@ ${this.functions[i]}`; for (let i = 0; i < stack.length; i++) { stack[i] = new State(); } - stack[0].activeMask = (1n << 128n) - 1n; + stack[0].activeMask = (1n << BigInt(this.invocations)) - 1n; let nesting = 0; let loopNesting = 0; @@ -1036,15 +1134,13 @@ ${this.functions[i]}`; if (nesting >= stack.length) { unreachable(`Max stack nesting surpassed (${stack.length} vs ${this.nesting}) at ops[${i}] = ${serializeOpType(op.type)}`); } - if (!countOnly) { - //console.log(`ops[${i}] = ${serializeOpType(op.type)}, nesting = ${nesting}, loopNesting = ${loopNesting}, value = ${op.value}`); - //console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); - // //for (let j = 0; j <= nesting; j++) { - // // console.log(` mask[${j}] = ${stack[j].activeMask.toString(16)}`); - // //} + if (debug) { + console.log(`ops[${i}] = ${serializeOpType(op.type)}, nesting = ${nesting}, loopNesting = ${loopNesting}, value = ${op.value}`); + console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); } // Early outs if no invocations are active. + // Don't skip ops that change nesting. switch (op.type) { case OpType.Ballot: case OpType.Store: @@ -1292,7 +1388,7 @@ ${this.functions[i]}`; cur.continueMask = 0n; let done = !any(cur.activeMask) || cur.tripCount === subgroupSize; if (!done) { - let submask = getMask(subgroupSize) & ~getMask(cur.tripCount); + let submask = getMask(Math.floor(subgroupSize / op.value)) & ~getMask(cur.tripCount); let mask = getReplicatedMask(submask, subgroupSize, this.invocations); cur.activeMask &= mask; done = !any(cur.activeMask); @@ -1451,9 +1547,6 @@ ${this.functions[i]}`; maxLoc = Math.max(maxLoc, locs[id]); } maxLoc = Math.min(this.maxLocations, maxLoc); - //if (!countOnly) { - // console.log(`Max location = ${maxLoc}\n`); - //} return maxLoc; } @@ -1489,7 +1582,6 @@ ${this.functions[i]}`; // If this is an uniform control flow case, make sure a uniform ballot is // generated. A subgroup size of 64 is used for testing purposes here. if (this.style != Style.Maximal) { - console.log(`${new Date()}: simulating for UCF`); this.simulate(true, 64); } i++; @@ -1550,7 +1642,6 @@ ${this.functions[i]}`; public checkResults(ballots: Uint32Array, /*locations: Uint32Array,*/ subgroupSize: number, numLocs: number): Error | undefined { let totalLocs = Math.min(numLocs, this.maxLocations); - //console.log(`Verifying numLocs = ${numLocs}`); if (this.style == Style.Workgroup || this.style === Style.Subgroup) { if (!this.isUCF()) { return Error(`Expected some uniform condition for this test`); @@ -1851,20 +1942,20 @@ ${this.functions[i]}`; * ballot(); */ public predefinedProgramForVar() { - this.ops.push(new Op(OpType.ForVar, 0)); + this.ops.push(new Op(OpType.ForVar, 1)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(OpType.EndForVar, 0)); + this.ops.push(new Op(OpType.EndForVar, 1)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(OpType.ForVar, 0)); + this.ops.push(new Op(OpType.ForVar, 1)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); - this.ops.push(new Op(OpType.EndForVar, 0)); + this.ops.push(new Op(OpType.EndForVar, 1)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); From 2399352437d84104b0501f3d6462139cbf5f7cc5 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Thu, 24 Aug 2023 15:43:37 -0400 Subject: [PATCH 16/32] Add uniform switch statements --- .../reconvergence/reconvergence.spec.ts | 17 +- .../shader/execution/reconvergence/util.ts | 150 +++++++++++++++++- 2 files changed, 152 insertions(+), 15 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index c64437e75555..f91fceab8921 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -282,7 +282,7 @@ g.test('predefined_reconvergence') .params(u => u .combine('style', [Style.Workgroup, Style.Subgroup, Style.Maximal] as const) - .combine('test', [...iterRange(8, x => x)] as const) + .combine('test', [...iterRange(9, x => x)] as const) .beginSubcases() ) //.beforeAllSubcases(t => { @@ -292,50 +292,45 @@ g.test('predefined_reconvergence') const invocations = 128; // t.device.limits.maxSubgroupSize; const style = t.params.style; - let program: Program; + let program: Program = new Program(style, 1, invocations);; switch (t.params.test) { case 0: { - program = new Program(style, 1, invocations); program.predefinedProgram1(); break; } case 1: { - program = new Program(style, 1, invocations); program.predefinedProgram2(); break; } case 2: { - program = new Program(style, 1, invocations); program.predefinedProgram3(); break; } case 3: { - program = new Program(style, 1, invocations); program.predefinedProgramInf(); break; } case 4: { - program = new Program(style, 1, invocations); program.predefinedProgramForVar(); break; } case 5: { - program = new Program(style, 1, invocations); program.predefinedProgramCall(); break; } case 6: { - program = new Program(style, 1, invocations); program.predefinedProgram1(OpType.LoopUniform, OpType.EndLoopUniform); break; } case 7: { - program = new Program(style, 1, invocations); program.predefinedProgramInf(OpType.LoopInf, OpType.EndLoopInf); break; } + case 8: { + program.predefinedProgramSwitchUniform(); + break; + } default: { - program = new Program(style, 1, invocations); unreachable('Unhandled testcase'); } } diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 1ed5025a12e7..06ee968a1cfd 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -10,7 +10,7 @@ function getMask(size: number): bigint { return (1n << BigInt(size)) - 1n; } -/** @returns A bitmask where submask is repeated every size bits for total bits. */ +/** @returns A bitmask where |submask| is repeated every |size| bits for |total| bits. */ function getReplicatedMask(submask: bigint, size: number, total: number): bigint { const reps = Math.floor(total / size); let mask: bigint = submask & ((1n << BigInt(size)) - 1n); @@ -140,6 +140,18 @@ export enum OpType { Call, EndCall, + // Equivalent to: + // switch (inputs[x]) { + // case x*2: { ... } never taken + // case x: { ... } uniformly taken + // case x*4: { ... } never taken + // } + SwitchUniform, + EndSwitch, + + CaseMask, + EndCase, + MAX, } @@ -171,6 +183,10 @@ function serializeOpType(op: OpType): string { case OpType.Elect: return 'Elect'; case OpType.Call: return 'Call'; case OpType.EndCall: return 'EndCall'; + case OpType.SwitchUniform: return 'SwitchUniform'; + case OpType.EndSwitch: return 'EndSwitch'; + case OpType.CaseMask: return 'CaseMask'; + case OpType.EndCase: return 'EndCase'; default: unreachable('Unhandled op'); break; @@ -194,11 +210,13 @@ enum IfType { class Op { type : OpType; value : number; + caseValue: number; uniform : boolean; - constructor(type : OpType, value: number = 0, uniform: boolean = true) { + constructor(type : OpType, value: number, caseValue: number = 0, uniform: boolean = true) { this.type = type; this.value = value; + this.caseValue = caseValue; this.uniform = uniform; } }; @@ -365,6 +383,23 @@ export class Program { } break; } + case 9: { + const r2 = this.getRandomUint(4); + switch (r2) { + case 0: { + this.genSwitchUniform(); + break; + } + case 1: { + } + case 2: { + } + case 3: + default: { + break; + } + } + } default: { break; } @@ -446,7 +481,7 @@ export class Program { (beforeSize + 2 * (afterSize - beforeSize)) < this.maxCount) { for (let i = beforeSize; i < afterSize; i++) { const op = this.ops[i]; - this.ops.push(new Op(op.type, op.value, op.uniform)); + this.ops.push(new Op(op.type, op.value, op.caseValue, op.uniform)); // Make stores unique. if (op.type == OpType.Store) { this.ops[this.ops.length-1].value = this.storeBase + this.ops.length - 1; @@ -664,6 +699,27 @@ export class Program { } } + private genSwitchUniform() { + const r = this.getRandomUint(5); + this.ops.push(new Op(OpType.SwitchUniform, r)); + this.nesting++; + + this.ops.push(new Op(OpType.CaseMask, 0, 1 << (r+1))); + this.pickOp(1); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.CaseMask, 0xf, 1 << r)); + this.pickOp(1); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.CaseMask, 0, 1 << (r+2))); + this.pickOp(1); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.EndSwitch, 0)); + this.nesting--; + } + /** @returns The WGSL code for the program */ public genCode(): string { for (let i = 0; i < this.ops.length; i++) { @@ -851,6 +907,33 @@ export class Program { this.curFunc = 0; break; } + case OpType.SwitchUniform: { + this.addCode(`switch inputs[${op.value}] {`); + this.increaseIndent(); + this.addCode(`default { }`); + break; + } + case OpType.EndSwitch: { + this.decreaseIndent(); + this.addCode(`}`); + break; + } + case OpType.CaseMask: { + let values = ``; + for (let b = 0; b < 32; b++) { + if ((1 << b) & op.caseValue) { + values += `${b},`; + } + } + this.addCode(`case ${values} {`); + this.increaseIndent(); + break; + } + case OpType.EndCase: { + this.decreaseIndent(); + this.addCode(`}`); + break; + } } } @@ -1155,13 +1238,17 @@ ${this.functions[i]}`; } case OpType.ElseMask: case OpType.ElseId: - case OpType.ElseLoopCount: { + case OpType.ElseLoopCount: + case OpType.CaseMask: { if (!any(stack[nesting-1].activeMask)) { stack[nesting].activeMask = 0n; i++; continue; } } + case OpType.EndCase: + // No work + break; default: break; } @@ -1533,6 +1620,26 @@ ${this.functions[i]}`; nesting--; break; } + case OpType.SwitchUniform: { + nesting++; + const cur = stack[nesting]; + cur.reset(); + cur.activeMask = stack[nesting-1].activeMask; + cur.isSwitch = true; + break; + } + case OpType.EndSwitch: { + nesting--; + break; + } + case OpType.CaseMask: { + const mask = getReplicatedMask(BigInt(op.value), 4, this.invocations); + stack[nesting].activeMask = stack[nesting-1].activeMask & mask; + break; + } + case OpType.EndCase: { + break; + } default: { unreachable(`Unhandled op ${serializeOpType(op.type)}`); } @@ -2035,6 +2142,41 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); } + + /** + * Equivalent to: + * + * ballot() + * switch (inputs[5]) { + * default { } + * case 6 { ballot(); } + * case 5 { ballot(); } + * case 7 { ballot(); } + * } + * ballot(); + * + */ + public predefinedProgramSwitchUniform() { + const value = 5; + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.SwitchUniform, value)); + this.ops.push(new Op(OpType.CaseMask, 0, 1 << (value + 1))); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndCase, 0)); + this.ops.push(new Op(OpType.CaseMask, 0xf, 1 << value)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndCase, 0)); + this.ops.push(new Op(OpType.CaseMask, 0, 1 << (value + 2))); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndCase, 0)); + this.ops.push(new Op(OpType.EndSwitch, 0)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + } }; export function generateSeeds(numCases: number): number[] { From 8230f17570a586d8f44e6a891729f83e75bb9b74 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Fri, 25 Aug 2023 11:02:11 -0400 Subject: [PATCH 17/32] More implementation * refactored test code * each reconvergence style is now a separate set of tests * Added more switch varieties * added predefined tests for coverage * fixed simulation of ForVar * reduction was incorrectly handled * fixed result comparison for ucf cases * fixed how ucf is calculated * more documentation --- .../reconvergence/reconvergence.spec.ts | 213 ++++++--- .../shader/execution/reconvergence/util.ts | 434 +++++++++++++++--- 2 files changed, 517 insertions(+), 130 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index f91fceab8921..d4a981f612d7 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -65,12 +65,13 @@ function checkSubgroupSizeConsistency(data: Uint32Array, min: number, max: numbe return undefined; } -function dumpBallots(ballots: Uint32Array, invocations: number, locations: number) { +function dumpBallots(ballots: Uint32Array, totalInvocations: number, + invocations: number, locations: number) { let dump = `Ballots\n`; for (let id = 0; id < invocations; id++) { dump += `id[${id}]\n`; for (let loc = 0; loc < locations; loc++) { - const idx = 4 * (invocations * loc + id); + const idx = 4 * (totalInvocations * loc + id); const w = ballots[idx+3]; const z = ballots[idx+2]; const y = ballots[idx+1]; @@ -233,7 +234,7 @@ async function testProgram(t: GPUTest, program: Program) { program.sizeRefData(locMap.get(actualSize)); console.log(`${new Date()}: Full simulation size = ${actualSize}`); - let num = program.simulate(false, actualSize); + let num = program.simulate(false, actualSize, /* debug = */ false); console.log(`${new Date()}: locations = ${num}`); num = Math.min(program.maxLocations, num); @@ -271,88 +272,172 @@ async function testProgram(t: GPUTest, program: Program) { ); const ballotData = ballotReadback.data; - console.log(`${Date()}: Finished buffer readbacks`); - //dumpBallots(ballotData, program.invocations, num); + console.log(`${new Date()}: Finished buffer readbacks`); + // Only dump a single subgroup + //console.log(`${new Date()}: Reference data`); + //dumpBallots(program.refData, program.invocations, actualSize, num); + //console.log(`${new Date()}: GPU data`); + //dumpBallots(ballotData, program.invocations, actualSize, num); t.expectOK(program.checkResults(ballotData, /*locationData,*/ actualSize, num)); } -g.test('predefined_reconvergence') +async function predefinedTest(t: GPUTest, style: Style, test: number) { + const invocations = 128; // t.device.limits.maxSubgroupSize; + + let program: Program = new Program(style, 1, invocations);; + switch (test) { + case 0: { + program.predefinedProgram1(); + break; + } + case 1: { + program.predefinedProgram2(); + break; + } + case 2: { + program.predefinedProgram3(); + break; + } + case 3: { + program.predefinedProgramInf(); + break; + } + case 4: { + program.predefinedProgramForVar(); + break; + } + case 5: { + program.predefinedProgramCall(); + break; + } + case 6: { + program.predefinedProgram1(OpType.LoopUniform, OpType.EndLoopUniform); + break; + } + case 7: { + program.predefinedProgramInf(OpType.LoopInf, OpType.EndLoopInf); + break; + } + case 8: { + program.predefinedProgramSwitchUniform(); + break; + } + case 9: { + program.predefinedProgramSwitchVar(); + break; + } + case 10: { + program.predefinedProgramSwitchLoopCount(0); + break; + } + case 11: { + program.predefinedProgramSwitchLoopCount(1); + break; + } + case 12: { + program.predefinedProgramSwitchLoopCount(2); + break; + } + case 13: { + program.predefinedProgramSwitchMulticase(); + break; + } + default: { + unreachable('Unhandled testcase'); + } + } + + await testProgram(t, program); +} + +const kPredefinedTestCases = [...iterRange(14, x => x)]; + +g.test('predefined_workgroup') .desc(`Test reconvergence using some predefined programs`) .params(u => u - .combine('style', [Style.Workgroup, Style.Subgroup, Style.Maximal] as const) - .combine('test', [...iterRange(9, x => x)] as const) + .combine('test', kPredefinedTestCases) .beginSubcases() ) //.beforeAllSubcases(t => { // t.selectDeviceOrSkipTestCase({ requiredFeatures: ['chromium-experimental-subgroups'] }); //}) + .fn(async t => { + await predefinedTest(t, Style.Workgroup, t.params.test); + }); + +g.test('predefined_subgroup') + .desc(`Test reconvergence using some predefined programs`) + .params(u => + u + .combine('test', kPredefinedTestCases) + .beginSubcases() + ) + //.beforeAllSubcases(t => { + // t.selectDeviceOrSkipTestCase({ requiredFeatures: ['chromium-experimental-subgroups'] }); + //}) + .fn(async t => { + await predefinedTest(t, Style.Subgroup, t.params.test); + }); + +g.test('predefined_maximal') + .desc(`Test reconvergence using some predefined programs`) + .params(u => + u + .combine('test', kPredefinedTestCases) + .beginSubcases() + ) + //.beforeAllSubcases(t => { + // t.selectDeviceOrSkipTestCase({ requiredFeatures: ['chromium-experimental-subgroups'] }); + //}) + .fn(async t => { + await predefinedTest(t, Style.Maximal, t.params.test); + }); + +g.test('random_workgroup') + .desc(`Test reconvergence using randomly generated programs`) + .params(u => + u + .combine('seed', generateSeeds(50)) + .beginSubcases() + ) + //.beforeAllSubcases(t => { + // t.selectDeviceOrSkipTestCase({requiredFeatures: ['chromium-experimental-subgroups']}); + //}) .fn(async t => { const invocations = 128; // t.device.limits.maxSubgroupSize; - const style = t.params.style; - - let program: Program = new Program(style, 1, invocations);; - switch (t.params.test) { - case 0: { - program.predefinedProgram1(); - break; - } - case 1: { - program.predefinedProgram2(); - break; - } - case 2: { - program.predefinedProgram3(); - break; - } - case 3: { - program.predefinedProgramInf(); - break; - } - case 4: { - program.predefinedProgramForVar(); - break; - } - case 5: { - program.predefinedProgramCall(); - break; - } - case 6: { - program.predefinedProgram1(OpType.LoopUniform, OpType.EndLoopUniform); - break; - } - case 7: { - program.predefinedProgramInf(OpType.LoopInf, OpType.EndLoopInf); - break; - } - case 8: { - program.predefinedProgramSwitchUniform(); - break; - } - default: { - unreachable('Unhandled testcase'); - } - } + + let program: Program = new Program(Style.Workgroup, t.params.seed, invocations); + program.generate(); + + await testProgram(t, program); + }); + +g.test('random_subgroup') + .desc(`Test reconvergence using randomly generated programs`) + .params(u => + u + .combine('seed', generateSeeds(50)) + .beginSubcases() + ) + //.beforeAllSubcases(t => { + // t.selectDeviceOrSkipTestCase({requiredFeatures: ['chromium-experimental-subgroups']}); + //}) + .fn(async t => { + const invocations = 128; // t.device.limits.maxSubgroupSize; + + let program: Program = new Program(Style.Subgroup, t.params.seed, invocations); + program.generate(); await testProgram(t, program); }); -g.test('random_reconvergence') +g.test('random_maximal') .desc(`Test reconvergence using randomly generated programs`) .params(u => u - .combine('style', [Style.Workgroup, Style.Subgroup, Style.Maximal] as const) .combine('seed', generateSeeds(50)) - .filter(u => { - if (u.style == Style.Maximal) { - return false; - } - if (u.style == Style.Subgroup) { - return false; - } - return true; - }) .beginSubcases() ) //.beforeAllSubcases(t => { @@ -361,7 +446,7 @@ g.test('random_reconvergence') .fn(async t => { const invocations = 128; // t.device.limits.maxSubgroupSize; - let program: Program = new Program(t.params.style, t.params.seed, invocations); + let program: Program = new Program(Style.Maximal, t.params.seed, invocations); program.generate(); await testProgram(t, program); diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 06ee968a1cfd..2d81b4bf959b 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -140,7 +140,6 @@ export enum OpType { Call, EndCall, - // Equivalent to: // switch (inputs[x]) { // case x*2: { ... } never taken // case x: { ... } uniformly taken @@ -149,7 +148,18 @@ export enum OpType { SwitchUniform, EndSwitch, + // switch (subgroup_invocation_id & 3) + SwitchVar, + + // switch (i) { + // case 1: { ... } + // case 2: { ... } + // default: { ... } + // } + SwitchLoopCount, + CaseMask, + CaseLoopCount, EndCase, MAX, @@ -158,35 +168,38 @@ export enum OpType { function serializeOpType(op: OpType): string { // prettier-ignore switch (op) { - case OpType.Ballot: return 'Ballot'; - case OpType.Store: return 'Store'; - case OpType.IfMask: return 'IfMask'; - case OpType.ElseMask: return 'ElseMask'; - case OpType.EndIf: return 'EndIf'; - case OpType.IfLoopCount: return 'IfLoopCount'; - case OpType.ElseLoopCount: return 'ElseLoopCount'; - case OpType.IfId: return 'IfId'; - case OpType.ElseId: return 'ElseId'; - case OpType.Break: return 'Break'; - case OpType.Continue: return 'Continue'; - case OpType.ForUniform: return 'ForUniform'; - case OpType.EndForUniform: return 'EndForUniform'; - case OpType.ForInf: return 'ForInf'; - case OpType.EndForInf: return 'EndForInf'; - case OpType.ForVar: return 'ForVar'; - case OpType.EndForVar: return 'EndForVar'; - case OpType.LoopUniform: return 'LoopUniform'; - case OpType.EndLoopUniform: return 'EndLoopUniform'; - case OpType.LoopInf: return 'LoopInf'; - case OpType.EndLoopInf: return 'EndLoopInf'; - case OpType.Return: return 'Return'; - case OpType.Elect: return 'Elect'; - case OpType.Call: return 'Call'; - case OpType.EndCall: return 'EndCall'; - case OpType.SwitchUniform: return 'SwitchUniform'; - case OpType.EndSwitch: return 'EndSwitch'; - case OpType.CaseMask: return 'CaseMask'; - case OpType.EndCase: return 'EndCase'; + case OpType.Ballot: return 'Ballot'; + case OpType.Store: return 'Store'; + case OpType.IfMask: return 'IfMask'; + case OpType.ElseMask: return 'ElseMask'; + case OpType.EndIf: return 'EndIf'; + case OpType.IfLoopCount: return 'IfLoopCount'; + case OpType.ElseLoopCount: return 'ElseLoopCount'; + case OpType.IfId: return 'IfId'; + case OpType.ElseId: return 'ElseId'; + case OpType.Break: return 'Break'; + case OpType.Continue: return 'Continue'; + case OpType.ForUniform: return 'ForUniform'; + case OpType.EndForUniform: return 'EndForUniform'; + case OpType.ForInf: return 'ForInf'; + case OpType.EndForInf: return 'EndForInf'; + case OpType.ForVar: return 'ForVar'; + case OpType.EndForVar: return 'EndForVar'; + case OpType.LoopUniform: return 'LoopUniform'; + case OpType.EndLoopUniform: return 'EndLoopUniform'; + case OpType.LoopInf: return 'LoopInf'; + case OpType.EndLoopInf: return 'EndLoopInf'; + case OpType.Return: return 'Return'; + case OpType.Elect: return 'Elect'; + case OpType.Call: return 'Call'; + case OpType.EndCall: return 'EndCall'; + case OpType.SwitchUniform: return 'SwitchUniform'; + case OpType.SwitchVar: return 'SwitchVar'; + case OpType.SwitchLoopCount: return 'SwitchLoopCount'; + case OpType.EndSwitch: return 'EndSwitch'; + case OpType.CaseMask: return 'CaseMask'; + case OpType.CaseLoopCount: return 'CaseLoopCount'; + case OpType.EndCase: return 'EndCase'; default: unreachable('Unhandled op'); break; @@ -222,29 +235,59 @@ class Op { }; export class Program { + // Number of invocations in the program + // Max supported is 128 public readonly invocations: number; + // Pseduo-random number generator private readonly prng: PRNG; + // Instruction list private ops : Op[]; + // Reconvergence style public readonly style: Style; + // Minimum number of instructions in a program private readonly minCount: number; + // Maximum number of instructions in a program + // Note: this is a soft max to ensure functional programs. private readonly maxCount: number; + // Maximum nesting of scopes permitted private readonly maxNesting: number; + // Maximum loop nesting permitted private readonly maxLoopNesting: number; + // Current nesting private nesting: number; + // Current loop nesting private loopNesting: number; + // Current loop nesting in the current function private loopNestingThisFunction: number; + // Current call nesting private callNesting: number; + // Number of pregenerated masks for testing private readonly numMasks: number; + // Pregenerated masks. + // 4 * |numMasks| entries representing ballots. private masks: number[]; + // Current function index private curFunc: number; + // WGSL code of each function private functions: string[]; + // Indent level for each function private indents: number[]; + // Offset value for OpType.Store private readonly storeBase: number; - private refData: Uint32Array; + // Reference simulation output + public refData: Uint32Array; + // Maps whether a particular loop nest is infinite or not private isLoopInf: Map; + // Maps whether a particular infinite loop has had a break inserted private doneInfLoopBreak: Map; + // Maximum number of locations per invocation + // Each location stores a vec4u public readonly maxLocations: number; + // Maximum nesting in the actual program private maxProgramNesting; + // Indicates if the program satisfies uniform control flow for |style| + // This depends on simulating a particular subgroup size + public ucf: boolean; /** * constructor @@ -254,6 +297,7 @@ export class Program { */ constructor(style : Style = Style.Workgroup, seed: number = 1, invocations: number) { this.invocations = invocations; + assert(invocations <= 128); this.prng = new PRNG(seed); this.ops = []; this.style = style; @@ -294,6 +338,7 @@ export class Program { this.doneInfLoopBreak = new Map(); this.maxProgramNesting = 10; // default stack allocation this.maxLocations = 130000; // keep the buffer under 256MiB + this.ucf = false; } /** @returns A random float between 0 and 1 */ @@ -391,11 +436,22 @@ export class Program { break; } case 1: { + if (this.loopNesting >= 0) { + this.genSwitchLoopCount(); + break; + } + // fallthrough } case 2: { + if (this.style != Style.Maximal) { + this.genSwitchMulticase(); + break; + } + // fallthrough } case 3: default: { + this.genSwitchVar(); break; } } @@ -703,6 +759,7 @@ export class Program { const r = this.getRandomUint(5); this.ops.push(new Op(OpType.SwitchUniform, r)); this.nesting++; + this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); this.ops.push(new Op(OpType.CaseMask, 0, 1 << (r+1))); this.pickOp(1); @@ -720,6 +777,80 @@ export class Program { this.nesting--; } + private genSwitchVar() { + this.ops.push(new Op(OpType.SwitchVar, 0)); + this.nesting++; + this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); + + this.ops.push(new Op(OpType.CaseMask, 0x1, 1<<0)); + this.pickOp(1); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.CaseMask, 0x2, 1<<1)); + this.pickOp(1); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.CaseMask, 0x4, 1<<2)); + this.pickOp(1); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.CaseMask, 0x8, 1<<3)); + this.pickOp(1); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.EndSwitch, 0)); + this.nesting--; + } + + private genSwitchLoopCount() { + const r = this.getRandomUint(this.loopNesting); + this.ops.push(new Op(OpType.SwitchLoopCount, r)); + this.nesting++; + this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); + + this.ops.push(new Op(OpType.CaseLoopCount, 1<<1, 1)); + this.pickOp(1); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.CaseLoopCount, 1<<2, 2)); + this.pickOp(1); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.CaseLoopCount, 0xfffffff9, 0xffffffff)); + this.pickOp(1); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.EndSwitch, 0)); + this.nesting--; + } + + // switch (subgroup_invocation_id & 3) { + // default { } + // case 0x3: { ... } + // case 0xc: { ... } + // } + // + // This is not generated for maximal style cases because it is not clear what + // convergence should be expected. There are multiple valid lowerings of a + // switch that would lead to different convergence scenarios. To test this + // properly would likely require a range of values which is difficult for + // this infrastructure to produce. + private genSwitchMulticase() { + this.ops.push(new Op(OpType.SwitchVar, 0)); + this.nesting++; + + this.ops.push(new Op(OpType.CaseMask, 0x3, (1<<0)|(1<<1))); + this.pickOp(2); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.CaseMask, 0xc, (1<<2)|(1<<3))); + this.pickOp(2); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.EndSwitch, 0)); + this.nesting--; + } + /** @returns The WGSL code for the program */ public genCode(): string { for (let i = 0; i < this.ops.length; i++) { @@ -913,6 +1044,18 @@ export class Program { this.addCode(`default { }`); break; } + case OpType.SwitchVar: { + this.addCode(`switch subgroup_id & 0x3 {`); + this.increaseIndent(); + this.addCode(`default { }`); + break; + } + case OpType.SwitchLoopCount: { + const iter = `i${op.value}`; + this.addCode(`switch ${iter} {`); + this.increaseIndent(); + break; + } case OpType.EndSwitch: { this.decreaseIndent(); this.addCode(`}`); @@ -929,6 +1072,15 @@ export class Program { this.increaseIndent(); break; } + case OpType.CaseLoopCount: { + if (op.caseValue === 0xffffffff) { + this.addCode(`default {`); + } else { + this.addCode(`case ${op.caseValue} {`); + } + this.increaseIndent(); + break; + } case OpType.EndCase: { this.decreaseIndent(); this.addCode(`}`); @@ -1230,6 +1382,7 @@ ${this.functions[i]}`; case OpType.Return: case OpType.Continue: case OpType.Break: { + // No reason to simulate if the current stack entry is inactive. if (!any(stack[nesting].activeMask)) { i++; continue; @@ -1239,7 +1392,9 @@ ${this.functions[i]}`; case OpType.ElseMask: case OpType.ElseId: case OpType.ElseLoopCount: - case OpType.CaseMask: { + case OpType.CaseMask: + case OpType.CaseLoopCount: { + // No reason to simulate if the previous stack entry is inactive. if (!any(stack[nesting-1].activeMask)) { stack[nesting].activeMask = 0n; i++; @@ -1258,6 +1413,9 @@ ${this.functions[i]}`; // Flag if this ballot is not workgroup uniform. if (this.style == Style.Workgroup && any(curMask) && !all(curMask, this.invocations)) { op.uniform = false; + } else { + op.uniform = true; + this.ucf = true; } // Flag if this ballot is not subgroup uniform. @@ -1266,6 +1424,9 @@ ${this.functions[i]}`; const subgroupMask = (curMask >> BigInt(id)) & getMask(subgroupSize); if (subgroupMask != 0n && !all(subgroupMask, subgroupSize)) { op.uniform = false; + } else { + op.uniform = true; + this.ucf = true; } } } @@ -1473,9 +1634,11 @@ ${this.functions[i]}`; cur.tripCount++; cur.activeMask |= cur.continueMask; cur.continueMask = 0n; - let done = !any(cur.activeMask) || cur.tripCount === subgroupSize; + let done = !any(cur.activeMask) || cur.tripCount === Math.floor(subgroupSize / op.value); if (!done) { - let submask = getMask(Math.floor(subgroupSize / op.value)) & ~getMask(cur.tripCount); + // i < (subgroup_invocation_id / reduction) + 1 + // So remove all ids < tripCount * reduction + let submask = getMask(subgroupSize) & ~getMask(cur.tripCount * op.value); let mask = getReplicatedMask(submask, subgroupSize, this.invocations); cur.activeMask &= mask; done = !any(cur.activeMask); @@ -1620,11 +1783,14 @@ ${this.functions[i]}`; nesting--; break; } - case OpType.SwitchUniform: { + case OpType.SwitchUniform: + case OpType.SwitchVar: + case OpType.SwitchLoopCount: { nesting++; const cur = stack[nesting]; cur.reset(); cur.activeMask = stack[nesting-1].activeMask; + cur.header = i; cur.isSwitch = true; break; } @@ -1637,6 +1803,31 @@ ${this.functions[i]}`; stack[nesting].activeMask = stack[nesting-1].activeMask & mask; break; } + case OpType.CaseLoopCount: { + let n = nesting; + let l = loopNesting; + + const findLoop = this.ops[stack[nesting].header].value; + while (n >= 0 && l >= 0) { + if (stack[n].isLoop) { + l--; + if (l == findLoop) { + break; + } + } + n--; + } + if (n < 0 || l < 0) { + unreachable(`Failed to find loop for CaseLoopCount`); + } + + if (((1 << stack[n].tripCount) & op.value) != 0) { + stack[nesting].activeMask = stack[nesting-1].activeMask; + } else { + stack[nesting].activeMask = 0n; + } + break; + } case OpType.EndCase: { break; } @@ -1692,19 +1883,12 @@ ${this.functions[i]}`; this.simulate(true, 64); } i++; - } while (this.style != Style.Maximal && !this.isUCF()); + } while (this.style != Style.Maximal && !this.ucf); } /** @returns true if the program has uniform control flow for some ballot */ private isUCF(): boolean { - let ucf: boolean = false; - for (let i = 0; i < this.ops.length; i++) { - const op = this.ops[i]; - if (op.type === OpType.Ballot && op.uniform) { - ucf = true; - } - } - return ucf; + return this.ucf; } /** @@ -1768,28 +1952,27 @@ ${this.functions[i]}`; if (refLoc < numLocs) { // Fully converged simulation - // Search for the corresponding store in the result data. + // Search for the corresponding data in the result. let storeRefLoc = refLoc - 1; - while (resLoc < totalLocs && - !this.matchResult(ballots, this.baseIndex(id, resLoc), - this.refData, this.baseIndex(id, storeRefLoc))) { + while (resLoc + 1 < totalLocs && + !(this.matchResult(ballots, this.baseIndex(id, resLoc), + this.refData, this.baseIndex(id, storeRefLoc)) && + this.matchResult(ballots, this.baseIndex(id, resLoc+1), + this.refData, this.baseIndex(id, refLoc)))) { resLoc++; } - if (resLoc >= totalLocs) { - const refIdx = this.baseIndex(id, storeRefLoc); - return Error(`Failure for invocation ${id}: could not find associated store for reference location ${storeRefLoc}: ${this.refData[refIdx]},${this.refData[refIdx+1]},${this.refData[refIdx+2]},${this.refData[refIdx+3]}`); - } else { - // Found a matching store, now check the ballot. - const resIdx = this.baseIndex(id, resLoc + 1); - const refIdx = this.baseIndex(id, refLoc); - if (!this.matchResult(ballots, resIdx, this.refData, refIdx)) { - return Error(`Failure for invocation ${id} at location ${resLoc} -- expected: (0x${hex(this.refData[refIdx+3])},0x${hex(this.refData[refIdx+2])},0x${hex(this.refData[refIdx+1])},0x${hex(this.refData[refIdx])}) -- got: (0x${hex(ballots[resIdx+3])},0x${hex(ballots[resIdx+2])},0x${hex(ballots[resIdx+1])},0x${hex(ballots[resIdx])})`); - } - resLoc++; + if (resLoc + 1 >= totalLocs) { + const sIdx = this.baseIndex(id, storeRefLoc); + const bIdx = this.baseIndex(id, refLoc); + const ref = this.refData; + let msg = `Failure for invocation ${id}: could not find match for:\n`; + msg += `- store[${storeRefLoc}] = ${this.refData[sIdx]}\n`; + msg += `- ballot[${refLoc}] = (0x${hex(ref[bIdx+3])},0x${hex(ref[bIdx+2])},0x${hex(ref[bIdx+1])},0x${hex(ref[bIdx])})`; + return Error(msg); } + // Match both locations so don't revisit them. + resLoc++; refLoc++; } } @@ -1801,15 +1984,17 @@ ${this.functions[i]}`; const id = Math.floor(idx_uvec4 % this.invocations); const loc = Math.floor(idx_uvec4 / this.invocations); if (!this.matchResult(ballots, i, this.refData, i)) { - return Error(`Failure for invocation ${id} at location ${loc}: -- expected: (0x${hex(this.refData[i+3])},0x${hex(this.refData[i+2])},0x${hex(this.refData[i+1])},0x${hex(this.refData[i])}) -- got: (0x${hex(ballots[i+3])},0x${hex(ballots[i+2])},0x${hex(ballots[i+1])},0x${hex(ballots[i])})`); + let msg = `Failure for invocation ${id} at location ${loc}:\n`; + msg += `- expected: (0x${hex(this.refData[i+3])},0x${hex(this.refData[i+2])},0x${hex(this.refData[i+1])},0x${hex(this.refData[i])})\n`; + msg += `- got: (0x${hex(ballots[i+3])},0x${hex(ballots[i+2])},0x${hex(ballots[i+1])},0x${hex(ballots[i])})`; + return Error(msg); } } for (let i = this.refData.length; i < ballots.length; i++) { if (ballots[i] !== 0) { - return Error(`Unexpected write at end of buffer (index = ${i}): -- got: (${ballots[i]})`); + let msg = `Unexpected write at end of buffer (index = ${i}):\n`; + msg += `- got: (${ballots[i]})`; + return Error(msg); } } } @@ -2177,6 +2362,123 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); } + + /** + * Equivalent to: + * + * ballot(); + * switch subgroup_invocation_id & 3 { + * default { } + * case 0: { ballot(); } + * case 1: { ballot(); } + * case 2: { ballot(); } + * case 3: { ballot(); } + * } + * ballot(); + */ + public predefinedProgramSwitchVar() { + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.SwitchVar, 0)); + this.ops.push(new Op(OpType.CaseMask, 0x1, 1<<0)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndCase, 0)); + this.ops.push(new Op(OpType.CaseMask, 0x2, 1<<1)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndCase, 0)); + this.ops.push(new Op(OpType.CaseMask, 0x4, 1<<2)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndCase, 0)); + this.ops.push(new Op(OpType.CaseMask, 0x8, 1<<3)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndCase, 0)); + this.ops.push(new Op(OpType.EndSwitch, 0)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + } + + /** + * Equivalent to: + * + * for (var i0 = 0u; i0 < inputs[3]; i0++) { + * for (var i1 = 0u; i1 < inputs[3]; i1++) { + * for (var i2 = 0u; i2 < subgroup_invocation_id + 1; i2++) { + * ballot(); + * switch i_loop { + * case 1 { ballot(); } + * case 2 { ballot(); } + * default { ballot(); } + * } + * ballot(); + * } + * } + * } + */ + public predefinedProgramSwitchLoopCount(loop: number) { + this.ops.push(new Op(OpType.ForUniform, 1)); + this.ops.push(new Op(OpType.ForUniform, 2)); + this.ops.push(new Op(OpType.ForVar, 4)); + + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.SwitchLoopCount, loop)); + + this.ops.push(new Op(OpType.CaseLoopCount, 1<<1, 1)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.CaseLoopCount, 1<<2, 2)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.CaseLoopCount, 0xfffffff9, 0xffffffff)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.EndSwitch, 0)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + + this.ops.push(new Op(OpType.EndForVar, 4)); + this.ops.push(new Op(OpType.EndForUniform, 3)); + this.ops.push(new Op(OpType.EndForUniform, 3)); + } + + /** + * Equivalent to: + * + * switch subgroup_invocation_id & 0x3 { + * default { } + * case 0,1 { ballot(); } + * case 2,3 { ballot(); } + * } + */ + public predefinedProgramSwitchMulticase() { + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.SwitchVar, 0)); + + this.ops.push(new Op(OpType.CaseMask, 0x3, (1<<0)|(1<<1))); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.CaseMask, 0xc, (1<<2)|(1<<3))); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndCase, 0)); + + this.ops.push(new Op(OpType.EndSwitch, 0)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + } }; export function generateSeeds(numCases: number): number[] { From 9449f3f92919a9a7bfeb26c37af2eb954d02e1fc Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Fri, 25 Aug 2023 13:42:03 -0400 Subject: [PATCH 18/32] cleanup --- .../execution/reconvergence/reconvergence.spec.ts | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index d4a981f612d7..3eb58211594d 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -57,10 +57,10 @@ function checkSubgroupSizeConsistency(data: Uint32Array, min: number, max: numbe return new Error(`Subgroup size ballot value (${ballot}) is greater than device maximum ${max}`); if (builtin != ballot) { - return new Error(`Subgroup size mismatch: - - builtin value = ${builtin} - - ballot = ${ballot} -`); + let msg = `Subgroup size mismatch:\n`; + msg += `- builtin value = ${builtin}\n`; + msg += `- ballot = ${ballot}`; + return Error(msg); } return undefined; } @@ -88,9 +88,10 @@ function dumpBallots(ballots: Uint32Array, totalInvocations: number, function checkIds(data: Uint32Array, subgroupSize: number): Error | undefined { for (let i = 0; i < data.length; i++) { if (data[i] !== (i % subgroupSize)) { - return Error(`subgroup_invocation_id does not map as assumed to local_invocation_index: -location_invocation_index = ${i} -subgroup_invocation_id = ${data[i]}`); + let msg = `subgroup_invocation_id does map as assumed to local_invocation_index:\n`; + msg += `location_invocation_index = ${i}\n`; + msg += `subgroup_invocation_id = ${data[i]}`; + return Error(msg); } } return undefined; From 01cd4c81b5e5393467ef57de4b1885b55a25dedd Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Fri, 25 Aug 2023 14:24:05 -0400 Subject: [PATCH 19/32] Add feature based skips * requires unsafe typecast for experimental feature --- .../reconvergence/reconvergence.spec.ts | 63 ++++++++++++------- 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 3eb58211594d..8f0e1d7818e6 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -283,8 +283,10 @@ async function testProgram(t: GPUTest, program: Program) { t.expectOK(program.checkResults(ballotData, /*locationData,*/ actualSize, num)); } +const kNumInvocations = 128; + async function predefinedTest(t: GPUTest, style: Style, test: number) { - const invocations = 128; // t.device.limits.maxSubgroupSize; + const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize; let program: Program = new Program(style, 1, invocations);; switch (test) { @@ -361,9 +363,11 @@ g.test('predefined_workgroup') .combine('test', kPredefinedTestCases) .beginSubcases() ) - //.beforeAllSubcases(t => { - // t.selectDeviceOrSkipTestCase({ requiredFeatures: ['chromium-experimental-subgroups'] }); - //}) + .beforeAllSubcases(t => { + t.selectDeviceOrSkipTestCase({ + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + }); + }) .fn(async t => { await predefinedTest(t, Style.Workgroup, t.params.test); }); @@ -375,9 +379,11 @@ g.test('predefined_subgroup') .combine('test', kPredefinedTestCases) .beginSubcases() ) - //.beforeAllSubcases(t => { - // t.selectDeviceOrSkipTestCase({ requiredFeatures: ['chromium-experimental-subgroups'] }); - //}) + .beforeAllSubcases(t => { + t.selectDeviceOrSkipTestCase({ + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + }); + }) .fn(async t => { await predefinedTest(t, Style.Subgroup, t.params.test); }); @@ -389,25 +395,34 @@ g.test('predefined_maximal') .combine('test', kPredefinedTestCases) .beginSubcases() ) - //.beforeAllSubcases(t => { - // t.selectDeviceOrSkipTestCase({ requiredFeatures: ['chromium-experimental-subgroups'] }); - //}) + .beforeAllSubcases(t => { + t.selectDeviceOrSkipTestCase({ + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + }); + }) .fn(async t => { await predefinedTest(t, Style.Maximal, t.params.test); }); +const kNumRandomCases = 50; + g.test('random_workgroup') .desc(`Test reconvergence using randomly generated programs`) .params(u => u - .combine('seed', generateSeeds(50)) + .combine('seed', generateSeeds(kNumRandomCases)) .beginSubcases() ) //.beforeAllSubcases(t => { // t.selectDeviceOrSkipTestCase({requiredFeatures: ['chromium-experimental-subgroups']}); //}) + .beforeAllSubcases(t => { + t.selectDeviceOrSkipTestCase({ + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + }); + }) .fn(async t => { - const invocations = 128; // t.device.limits.maxSubgroupSize; + const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize; let program: Program = new Program(Style.Workgroup, t.params.seed, invocations); program.generate(); @@ -419,14 +434,16 @@ g.test('random_subgroup') .desc(`Test reconvergence using randomly generated programs`) .params(u => u - .combine('seed', generateSeeds(50)) + .combine('seed', generateSeeds(kNumRandomCases)) .beginSubcases() ) - //.beforeAllSubcases(t => { - // t.selectDeviceOrSkipTestCase({requiredFeatures: ['chromium-experimental-subgroups']}); - //}) + .beforeAllSubcases(t => { + t.selectDeviceOrSkipTestCase({ + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + }); + }) .fn(async t => { - const invocations = 128; // t.device.limits.maxSubgroupSize; + const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize; let program: Program = new Program(Style.Subgroup, t.params.seed, invocations); program.generate(); @@ -438,14 +455,16 @@ g.test('random_maximal') .desc(`Test reconvergence using randomly generated programs`) .params(u => u - .combine('seed', generateSeeds(50)) + .combine('seed', generateSeeds(kNumRandomCases)) .beginSubcases() ) - //.beforeAllSubcases(t => { - // t.selectDeviceOrSkipTestCase({requiredFeatures: ['chromium-experimental-subgroups']}); - //}) + .beforeAllSubcases(t => { + t.selectDeviceOrSkipTestCase({ + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + }); + }) .fn(async t => { - const invocations = 128; // t.device.limits.maxSubgroupSize; + const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize; let program: Program = new Program(Style.Maximal, t.params.seed, invocations); program.generate(); From d2a42a67731b20a0072b13e9537eb772bc40069c Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Fri, 25 Aug 2023 14:26:07 -0400 Subject: [PATCH 20/32] cleanup --- .../shader/execution/reconvergence/reconvergence.spec.ts | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 8f0e1d7818e6..c76272ad8e3d 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -413,9 +413,6 @@ g.test('random_workgroup') .combine('seed', generateSeeds(kNumRandomCases)) .beginSubcases() ) - //.beforeAllSubcases(t => { - // t.selectDeviceOrSkipTestCase({requiredFeatures: ['chromium-experimental-subgroups']}); - //}) .beforeAllSubcases(t => { t.selectDeviceOrSkipTestCase({ requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] From d5a691ca36ad64fc9b24e0f02c11c6f297e46b46 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Fri, 25 Aug 2023 14:36:33 -0400 Subject: [PATCH 21/32] more docs --- .../shader/execution/reconvergence/util.ts | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 2d81b4bf959b..19ddd646c9d4 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -208,9 +208,17 @@ function serializeOpType(op: OpType): string { } enum IfType { + // If the mask is 0, generates a random uniform comparison + // Otherwise, tests subgroup_invocation_id against a mask Mask, + + // Generates a uniform true comparison Uniform, + + // if subgroup_invocation_id == iN LoopCount, + + // if subgroup_id < inputs[N] Lid, }; @@ -221,9 +229,13 @@ enum IfType { * not the operation is uniform. */ class Op { + // Instruction type type : OpType; + // Instruction specific value value : number; + // Case specific value caseValue: number; + // Indicates if the instruction is uniform or not uniform : boolean; constructor(type : OpType, value: number, caseValue: number = 0, uniform: boolean = true) { @@ -351,6 +363,12 @@ export class Program { return this.prng.randomU32() % max; } + /** + * Pick |count| random instructions generators + * + * @param count the number of instructions + * + */ private pickOp(count : number) { for (let i = 0; i < count; i++) { if (this.ops.length >= this.maxCount) { @@ -1182,35 +1200,34 @@ ${this.functions[i]}`; return code; } - /** - * Adds indentation to the code for the current function. - */ + /** Adds indentation to the code for the current function. */ private genIndent() { this.functions[this.curFunc] += ' '.repeat(this.indents[this.curFunc]); } - /** - * Increase the amount of indenting for the current function. - */ + /** Increase the amount of indenting for the current function. */ private increaseIndent() { this.indents[this.curFunc] += 2; } - /** - * Decrease the amount of indenting for the current function. - */ + /** Decrease the amount of indenting for the current function. */ private decreaseIndent() { this.indents[this.curFunc] -= 2; } - /** - * Adds the line 'code' to the current function. - */ + /** Adds the line 'code' to the current function. */ private addCode(code: string) { this.genIndent(); this.functions[this.curFunc] += code + `\n`; } + /** + * Debugging function that dump statistics about the program + * + * Reports number of instructions, stores, and loops. + * + * @param detailed If true, dumps more detailed stats + */ public dumpStats(detailed: boolean = true) { let stats = `Total instructions: ${this.ops.length}\n`; let nesting = 0; From df87cbb55da55d4857626d858fce363c5343717e Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Fri, 25 Aug 2023 14:41:52 -0400 Subject: [PATCH 22/32] fix switch loop count conditional generation --- src/webgpu/shader/execution/reconvergence/util.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 19ddd646c9d4..f58d684f5693 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -454,7 +454,7 @@ export class Program { break; } case 1: { - if (this.loopNesting >= 0) { + if (this.loopNesting > 0) { this.genSwitchLoopCount(); break; } From 306132a5edb005be5f3840e0d0320f11a3e43194 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Sun, 27 Aug 2023 15:45:01 -0400 Subject: [PATCH 23/32] Add another reconvergence style * WGSLv1 style is a closer match to current WGSL spec * doesn't require loops to converge between iterations * added predefined and random test suites to exercise it * refactored uniform tests for reuse * improved loop continue ballot generation to avoid false errors * added a new predefined test that could distinguish between Workgroup and WGSLv1 reconvergence --- .../reconvergence/reconvergence.spec.ts | 55 ++++- .../shader/execution/reconvergence/util.ts | 192 +++++++++++++----- 2 files changed, 185 insertions(+), 62 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index c76272ad8e3d..763e1a008cc7 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -346,6 +346,10 @@ async function predefinedTest(t: GPUTest, style: Style, test: number) { program.predefinedProgramSwitchMulticase(); break; } + case 14: { + program.predefinedProgramWGSLv1(); + break; + } default: { unreachable('Unhandled testcase'); } @@ -354,10 +358,10 @@ async function predefinedTest(t: GPUTest, style: Style, test: number) { await testProgram(t, program); } -const kPredefinedTestCases = [...iterRange(14, x => x)]; +const kPredefinedTestCases = [...iterRange(15, x => x)]; g.test('predefined_workgroup') - .desc(`Test reconvergence using some predefined programs`) + .desc(`Test workgroup reconvergence using some predefined programs`) .params(u => u .combine('test', kPredefinedTestCases) @@ -373,7 +377,7 @@ g.test('predefined_workgroup') }); g.test('predefined_subgroup') - .desc(`Test reconvergence using some predefined programs`) + .desc(`Test subgroup reconvergence using some predefined programs`) .params(u => u .combine('test', kPredefinedTestCases) @@ -389,7 +393,7 @@ g.test('predefined_subgroup') }); g.test('predefined_maximal') - .desc(`Test reconvergence using some predefined programs`) + .desc(`Test maximal reconvergence using some predefined programs`) .params(u => u .combine('test', kPredefinedTestCases) @@ -404,10 +408,26 @@ g.test('predefined_maximal') await predefinedTest(t, Style.Maximal, t.params.test); }); +g.test('predefined_wgslv1') + .desc(`Test WGSL v1 reconvergence using some predefined programs`) + .params(u => + u + .combine('test', kPredefinedTestCases) + .beginSubcases() + ) + .beforeAllSubcases(t => { + t.selectDeviceOrSkipTestCase({ + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + }); + }) + .fn(async t => { + await predefinedTest(t, Style.WGSLv1, t.params.test); + }); + const kNumRandomCases = 50; g.test('random_workgroup') - .desc(`Test reconvergence using randomly generated programs`) + .desc(`Test workgroup reconvergence using randomly generated programs`) .params(u => u .combine('seed', generateSeeds(kNumRandomCases)) @@ -428,7 +448,7 @@ g.test('random_workgroup') }); g.test('random_subgroup') - .desc(`Test reconvergence using randomly generated programs`) + .desc(`Test subgroup reconvergence using randomly generated programs`) .params(u => u .combine('seed', generateSeeds(kNumRandomCases)) @@ -449,7 +469,7 @@ g.test('random_subgroup') }); g.test('random_maximal') - .desc(`Test reconvergence using randomly generated programs`) + .desc(`Test maximal reconvergence using randomly generated programs`) .params(u => u .combine('seed', generateSeeds(kNumRandomCases)) @@ -468,3 +488,24 @@ g.test('random_maximal') await testProgram(t, program); }); + +g.test('random_wgslv1') + .desc(`Test WGSL v1 reconvergence using randomly generated programs`) + .params(u => + u + .combine('seed', generateSeeds(kNumRandomCases)) + .beginSubcases() + ) + .beforeAllSubcases(t => { + t.selectDeviceOrSkipTestCase({ + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + }); + }) + .fn(async t => { + const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize; + + let program: Program = new Program(Style.WGSLv1, t.params.seed, invocations); + program.generate(); + + await testProgram(t, program); + }); diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index f58d684f5693..a894dbca1004 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -77,6 +77,10 @@ export enum Style { // Maximal uniformity Maximal = 2, + + // Guarantees provided by WGSL v1. + // Very similar to Workgroup, but less strict for loops. + WGSLv1 = 3, }; export enum OpType { @@ -1316,6 +1320,36 @@ ${this.functions[i]}`; this.refData.fill(0); } + /** + * Returns true if |mask| is uniform for the given style + * + * @param mask The active mask + * @param size The subgroup size + * @returns true if |mask| is uniform for the given style + * + */ + private isUniform(mask: bigint, size: number): boolean { + if (this.style == Style.Workgroup || this.style === Style.WGSLv1) { + if (any(mask) && !all(mask, this.invocations)) { + return false; + } else { + return true; + } + } else if (this.style === Style.Subgroup) { + let uniform: boolean = true; + for (let id = 0; id < this.invocations; id += size) { + const subgroupMask = (mask >> BigInt(id)) & getMask(size); + if (subgroupMask != 0n && !all(subgroupMask, size)) { + uniform = false; + break; + } + } + return uniform; + } + + return true; + } + /** * Simulate the program for the given subgroup size * @@ -1343,6 +1377,8 @@ ${this.functions[i]}`; isCall: boolean; // This state is a switch isSwitch: boolean; + // This state is considered nonuniform despite the active mask. + isNonUniform: boolean; constructor() { this.activeMask = 0n; @@ -1352,16 +1388,19 @@ ${this.functions[i]}`; this.tripCount = 0; this.isCall = false; this.isSwitch = false; + this.isNonUniform = false; } - reset() { - this.activeMask = 0n; + // Reset the stack entry based on the parent state. + reset(prev: State, header: number) { + this.activeMask = prev.activeMask; this.continueMask = 0n; - this.header = 0; + this.header = header; this.isLoop = false; this.tripCount = 0; this.isCall = false; this.isSwitch = false; + this.isNonUniform = prev.isNonUniform; } }; for (let idx = 0; idx < this.ops.length; idx++) { @@ -1387,7 +1426,7 @@ ${this.functions[i]}`; unreachable(`Max stack nesting surpassed (${stack.length} vs ${this.nesting}) at ops[${i}] = ${serializeOpType(op.type)}`); } if (debug) { - console.log(`ops[${i}] = ${serializeOpType(op.type)}, nesting = ${nesting}, loopNesting = ${loopNesting}, value = ${op.value}`); + console.log(`ops[${i}] = ${serializeOpType(op.type)}, nesting = ${nesting}, loopNesting = ${loopNesting}, value = ${op.value}, nonuniform = ${stack[nesting].isNonUniform}`); console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); } @@ -1427,25 +1466,12 @@ ${this.functions[i]}`; switch (op.type) { case OpType.Ballot: { const curMask = stack[nesting].activeMask; - // Flag if this ballot is not workgroup uniform. - if (this.style == Style.Workgroup && any(curMask) && !all(curMask, this.invocations)) { - op.uniform = false; - } else { - op.uniform = true; - this.ucf = true; + const uniform = this.isUniform(curMask, subgroupSize); + if (this.style !== Style.Maximal) { + op.uniform = uniform; } - - // Flag if this ballot is not subgroup uniform. - if (this.style == Style.Subgroup && any(curMask)) { - for (let id = 0; id < this.invocations; id += subgroupSize) { - const subgroupMask = (curMask >> BigInt(id)) & getMask(subgroupSize); - if (subgroupMask != 0n && !all(subgroupMask, subgroupSize)) { - op.uniform = false; - } else { - op.uniform = true; - this.ucf = true; - } - } + if (uniform) { + this.ucf = true; } if (!any(curMask)) { @@ -1460,7 +1486,7 @@ ${this.functions[i]}`; if (testBit(curMask, id)) { if (!countOnly) { const idx = this.baseIndex(id, locs[id]); - if (op.uniform) { + if (op.uniform && !stack[nesting].isNonUniform) { this.refData[idx + 0] = mask[0]; this.refData[idx + 1] = mask[1]; this.refData[idx + 2] = mask[2]; @@ -1495,9 +1521,7 @@ ${this.functions[i]}`; case OpType.IfMask: { nesting++; const cur = stack[nesting]; - cur.reset(); - cur.activeMask = stack[nesting-1].activeMask; - cur.header = i; + cur.reset(stack[nesting-1], i); // O is always uniform true. if (op.value != 0 && any(cur.activeMask)) { let subMask = this.getValueMask(op.value); @@ -1523,9 +1547,7 @@ ${this.functions[i]}`; case OpType.IfId: { nesting++; const cur = stack[nesting]; - cur.reset(); - cur.activeMask = stack[nesting-1].activeMask; - cur.header = i; + cur.reset(stack[nesting-1], i); if (any(cur.activeMask)) { // All invocations with subgroup invocation id less than op.value are active. const mask = getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); @@ -1555,11 +1577,11 @@ ${this.functions[i]}`; nesting++; const cur = stack[nesting]; - cur.reset(); - cur.activeMask = stack[nesting-1].activeMask; - cur.header = i; + cur.reset(stack[nesting-1], i); if (any(cur.activeMask)) { - cur.activeMask &= getReplicatedMask(BigInt(1 << stack[n].tripCount), subgroupSize, this.invocations); + const submask = BigInt(1 << stack[n].tripCount); + const mask = getReplicatedMask(submask, subgroupSize, this.invocations); + cur.activeMask &= mask; } break; } @@ -1577,7 +1599,9 @@ ${this.functions[i]}`; stack[nesting].activeMask = stack[nesting-1].activeMask; if (any(stack[nesting].activeMask)) { - stack[nesting].activeMask &= ~getReplicatedMask(BigInt(1 << stack[n].tripCount), subgroupSize, this.invocations); + const submask = BigInt(1 << stack[n].tripCount); + const mask = getReplicatedMask(submask, subgroupSize, this.invocations); + stack[nesting].activeMask &= ~mask; } break; } @@ -1594,10 +1618,8 @@ ${this.functions[i]}`; nesting++; loopNesting++; const cur = stack[nesting]; - cur.reset(); - cur.header = i; + cur.reset(stack[nesting-1], i); cur.isLoop = true; - cur.activeMask = stack[nesting-1].activeMask; break; } case OpType.EndForUniform: { @@ -1609,6 +1631,9 @@ ${this.functions[i]}`; if (cur.tripCount < this.ops[cur.header].value && any(cur.activeMask)) { i = cur.header + 1; + if (this.style === Style.WGSLv1 && !all(cur.activeMask, subgroupSize)) { + cur.isNonUniform = true; + } continue; } else { loopNesting--; @@ -1623,6 +1648,7 @@ ${this.functions[i]}`; cur.continueMask = 0n; if (any(cur.activeMask)) { let maskArray = new Uint32Array(); + const uniform = this.isUniform(cur.activeMask, subgroupSize) && !cur.isNonUniform; for (let id = 0; id < this.invocations; id++) { if (id % subgroupSize === 0) { maskArray = getSubgroupMask(cur.activeMask, subgroupSize, id); @@ -1630,14 +1656,21 @@ ${this.functions[i]}`; if (testBit(cur.activeMask, id)) { if (!countOnly) { const idx = this.baseIndex(id, locs[id]); - this.refData[idx + 0] = maskArray[0]; - this.refData[idx + 1] = maskArray[1]; - this.refData[idx + 2] = maskArray[2]; - this.refData[idx + 3] = maskArray[3]; + if (uniform) { + this.refData[idx + 0] = maskArray[0]; + this.refData[idx + 1] = maskArray[1]; + this.refData[idx + 2] = maskArray[2]; + this.refData[idx + 3] = maskArray[3]; + } else { + this.refData.fill(0x12345678, idx, idx + 4); + } } locs[id]++; } } + if (this.style === Style.WGSLv1 && !uniform) { + cur.isNonUniform = true; + } i = cur.header + 1; continue; } else { @@ -1665,6 +1698,9 @@ ${this.functions[i]}`; loopNesting--; nesting--; } else { + if (this.style === Style.WGSLv1 && !all(cur.activeMask, subgroupSize)) { + cur.isNonUniform = true; + } i = cur.header + 1; continue; } @@ -1677,6 +1713,9 @@ ${this.functions[i]}`; cur.continueMask = 0n; if (cur.tripCount < this.ops[cur.header].value && any(cur.activeMask)) { + if (this.style === Style.WGSLv1 && !all(cur.activeMask, subgroupSize)) { + cur.isNonUniform = true; + } i = cur.header + 1; continue; } else { @@ -1691,6 +1730,7 @@ ${this.functions[i]}`; cur.activeMask |= cur.continueMask; if (any(cur.activeMask)) { let maskArray = new Uint32Array(); + const uniform = this.isUniform(cur.activeMask, subgroupSize) && !cur.isNonUniform; for (let id = 0; id < this.invocations; id++) { if (id % subgroupSize === 0) { maskArray = getSubgroupMask(cur.activeMask, subgroupSize, id); @@ -1698,14 +1738,21 @@ ${this.functions[i]}`; if (testBit(cur.activeMask, id)) { if (!countOnly) { const idx = this.baseIndex(id, locs[id]); - this.refData[idx + 0] = maskArray[0]; - this.refData[idx + 1] = maskArray[1]; - this.refData[idx + 2] = maskArray[2]; - this.refData[idx + 3] = maskArray[3]; + if (uniform) { + this.refData[idx + 0] = maskArray[0]; + this.refData[idx + 1] = maskArray[1]; + this.refData[idx + 2] = maskArray[2]; + this.refData[idx + 3] = maskArray[3]; + } else { + this.refData.fill(0x12345678, idx, idx + 4); + } } locs[id]++; } } + if (this.style === Style.WGSLv1 && !uniform) { + cur.isNonUniform = true; + } i = cur.header + 1; continue; } else { @@ -1741,9 +1788,15 @@ ${this.functions[i]}`; break; } + const uniform = this.style === Style.WGSLv1 && this.isUniform(mask, subgroupSize); + let n = nesting; for (; n >= 0; n--) { stack[n].activeMask &= ~mask; + if (!uniform) { + // Not all invocations continue on the same path. + stack[n].isNonUniform = true; + } if (stack[n].isLoop) { stack[n].continueMask |= mask; break; @@ -1780,9 +1833,7 @@ ${this.functions[i]}`; case OpType.Elect: { nesting++; const cur = stack[nesting]; - cur.reset(); - cur.activeMask = stack[nesting-1].activeMask; - cur.header = i; + cur.reset(stack[nesting-1], i); if (any(cur.activeMask)) { cur.activeMask = getElectMask(cur.activeMask, subgroupSize, this.invocations); } @@ -1791,8 +1842,8 @@ ${this.functions[i]}`; case OpType.Call: { nesting++; const cur = stack[nesting]; - cur.reset(); - cur.activeMask = stack[nesting-1].activeMask; + // Header is unused for calls. + cur.reset(stack[nesting-1], 0); cur.isCall = true; break; } @@ -1805,9 +1856,7 @@ ${this.functions[i]}`; case OpType.SwitchLoopCount: { nesting++; const cur = stack[nesting]; - cur.reset(); - cur.activeMask = stack[nesting-1].activeMask; - cur.header = i; + cur.reset(stack[nesting-1], i); cur.isSwitch = true; break; } @@ -1950,7 +1999,7 @@ ${this.functions[i]}`; public checkResults(ballots: Uint32Array, /*locations: Uint32Array,*/ subgroupSize: number, numLocs: number): Error | undefined { let totalLocs = Math.min(numLocs, this.maxLocations); - if (this.style == Style.Workgroup || this.style === Style.Subgroup) { + if (this.style !== Style.Maximal) { if (!this.isUCF()) { return Error(`Expected some uniform condition for this test`); } @@ -1994,7 +2043,7 @@ ${this.functions[i]}`; } } } - } else if (this.style == Style.Maximal) { + } else { // Expect exact matches. for (let i = 0; i < this.refData.length; i += 4) { const idx_uvec4 = Math.floor(i / 4); @@ -2496,6 +2545,39 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); } + + /** + * Equivalent to: + * + * ballot(); + * for (var i = 0; i < inputs[3]; i++) { + * ballot(); + * if (subgroupElect()) { + * continue; + * } + * } + * ballot(); + * + * This case can distinguish between Workgroup and WGSLv1 reconvergence. + * The ballot in the loop is not required to be converged for WGSLv1. + */ + public predefinedProgramWGSLv1() { + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.ForUniform, 3)); + + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.Elect, 0)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.Continue, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); + + this.ops.push(new Op(OpType.EndForUniform, 2)); + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + } }; export function generateSeeds(numCases: number): number[] { From a190a07e9c7706155a55d0ff5b1a03808b7e6f39 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Mon, 28 Aug 2023 10:55:30 -0400 Subject: [PATCH 24/32] Add noise generation * Add no-op code fragments at low frequency * Increase number of random testcases --- .../reconvergence/reconvergence.spec.ts | 2 +- .../shader/execution/reconvergence/util.ts | 36 +++++++++++++++---- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 763e1a008cc7..cfdc1b5ccc55 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -424,7 +424,7 @@ g.test('predefined_wgslv1') await predefinedTest(t, Style.WGSLv1, t.params.test); }); -const kNumRandomCases = 50; +const kNumRandomCases = 100; g.test('random_workgroup') .desc(`Test workgroup reconvergence using randomly generated programs`) diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index a894dbca1004..1a58a59d3342 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -166,6 +166,9 @@ export enum OpType { CaseLoopCount, EndCase, + // Fancy no-ops. + Noise, + MAX, } @@ -204,6 +207,7 @@ function serializeOpType(op: OpType): string { case OpType.CaseMask: return 'CaseMask'; case OpType.CaseLoopCount: return 'CaseLoopCount'; case OpType.EndCase: return 'EndCase'; + case OpType.Noise: return 'Noise'; default: unreachable('Unhandled op'); break; @@ -514,12 +518,12 @@ export class Program { } } - //deUint32 r = this.getRandomUint(10000); - //if (r < 3) { - // ops.push_back({OP_NOISE, 0}); - //} else if (r < 10) { - // ops.push_back({OP_NOISE, 1}); - //} + const r = this.getRandomUint(10000); + if (r < 3) { + this.ops.push(new Op(OpType.Noise, 0)); + } else if (r < 10) { + this.ops.push(new Op(OpType.Noise, 1)); + } } private genIf(type: IfType) { @@ -1108,6 +1112,24 @@ export class Program { this.addCode(`}`); break; } + case OpType.Noise: { + if (op.value == 0) { + this.addCode(`while (!subgroupElect()) { }`); + } else { + // The if is uniform false. + this.addCode(`if inputs[0] == 1234 {`); + this.increaseIndent(); + this.addCode(`var b = subgroupBallot();`); + this.addCode(`while b.x != 0 {`); + this.increaseIndent(); + this.addCode(`b = subgroupBallot();`); + this.decreaseIndent(); + this.addCode(`}`); + this.decreaseIndent(); + this.addCode(`}`); + } + break; + } } } @@ -1458,6 +1480,7 @@ ${this.functions[i]}`; } } case OpType.EndCase: + case OpType.Noise: // No work break; default: @@ -1894,6 +1917,7 @@ ${this.functions[i]}`; } break; } + case OpType.Noise: case OpType.EndCase: { break; } From 7411e0c8b74e056a82acb58285debf056338b503 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Mon, 28 Aug 2023 12:07:50 -0400 Subject: [PATCH 25/32] Cleanup * remove some debug output * add a control for other debug output --- .../reconvergence/reconvergence.spec.ts | 54 +++++++++++++------ 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index cfdc1b5ccc55..8280e83f63a7 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -97,11 +97,31 @@ function checkIds(data: Uint32Array, subgroupSize: number): Error | undefined { return undefined; } +/** + * Bitmask for debug information: + * + * 0x1 - wgsl + * 0x2 - stats + * 0x4 - terminate after wgsl + * 0x8 - simulation active masks + * 0x10 - simulation reference data + * 0x20 - gpu data + * + * So setting kDebugLevel to 0x5 would dump WGSL and end the test. + */ +const kDebugLevel = 0x0; + async function testProgram(t: GPUTest, program: Program) { const wgsl = program.genCode(); - //console.log(wgsl); - //program.dumpStats(true); - //return; + if (kDebugLevel & 0x1) { + console.log(wgsl); + } + if (kDebugLevel & 0x2) { + program.dumpStats(true); + } + if (kDebugLevel & 0x4) { + return; + } // TODO: query the device const minSubgroupSize = 4; @@ -114,12 +134,14 @@ async function testProgram(t: GPUTest, program: Program) { locMap.set(size, num); numLocs = Math.max(num, numLocs); } - numLocs = Math.min(program.maxLocations, numLocs); + if (numLocs > program.maxLocations) { + t.expectOK(Error(`Total locations (${numLocs}) truncated to ${program.maxLocations}`), + { mode: 'warn' }); + numLocs = program.maxLocations; + } // Add 1 to ensure there are no extraneous writes. numLocs++; - console.log(`${new Date()}: Maximum locations = ${numLocs}`); - console.log(`${new Date()}: creating pipeline`); const pipeline = t.device.createComputePipeline({ layout: 'auto', compute: { @@ -202,7 +224,6 @@ async function testProgram(t: GPUTest, program: Program) { ], }); - console.log(`${new Date()}: running pipeline`); const encoder = t.device.createCommandEncoder(); const pass = encoder.beginComputePass(); pass.setPipeline(pipeline); @@ -228,15 +249,13 @@ async function testProgram(t: GPUTest, program: Program) { method: 'copy', } ); - console.log(`${new Date()}: done pipeline`); const sizeData: Uint32Array = sizeReadback.data; const actualSize = sizeData[0]; t.expectOK(checkSubgroupSizeConsistency(sizeData, minSubgroupSize, maxSubgroupSize)); program.sizeRefData(locMap.get(actualSize)); - console.log(`${new Date()}: Full simulation size = ${actualSize}`); - let num = program.simulate(false, actualSize, /* debug = */ false); - console.log(`${new Date()}: locations = ${num}`); + const debug = (kDebugLevel & 0x8) !== 0; + let num = program.simulate(false, actualSize, debug); num = Math.min(program.maxLocations, num); const idReadback = await t.readGPUBufferRangeTyped( @@ -273,12 +292,15 @@ async function testProgram(t: GPUTest, program: Program) { ); const ballotData = ballotReadback.data; - console.log(`${new Date()}: Finished buffer readbacks`); // Only dump a single subgroup - //console.log(`${new Date()}: Reference data`); - //dumpBallots(program.refData, program.invocations, actualSize, num); - //console.log(`${new Date()}: GPU data`); - //dumpBallots(ballotData, program.invocations, actualSize, num); + if (kDebugLevel & 0x10) { + console.log(`${new Date()}: Reference data`); + dumpBallots(program.refData, program.invocations, actualSize, num); + } + if (kDebugLevel & 0x20) { + console.log(`${new Date()}: GPU data`); + dumpBallots(ballotData, program.invocations, actualSize, num); + } t.expectOK(program.checkResults(ballotData, /*locationData,*/ actualSize, num)); } From cede40adcc525c9d7948efe389700692193aa955 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Mon, 28 Aug 2023 14:32:34 -0400 Subject: [PATCH 26/32] docs --- .../shader/execution/reconvergence/util.ts | 156 ++++++++++++++++-- 1 file changed, 140 insertions(+), 16 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 1a58a59d3342..21d920264341 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -68,6 +68,9 @@ function all(value: bigint, size: number): boolean { return value === ((1n << BigInt(size)) - 1n); } +/** + * Reconvergence style being tested. + */ export enum Style { // Workgroup uniform control flow Workgroup = 0, @@ -83,6 +86,9 @@ export enum Style { WGSLv1 = 3, }; +/** + * Instruction type + */ export enum OpType { // Store a ballot. // During simulation, uniform is set to false if the @@ -172,6 +178,7 @@ export enum OpType { MAX, } +/** @returns The stringified version of |op|. */ function serializeOpType(op: OpType): string { // prettier-ignore switch (op) { @@ -215,6 +222,9 @@ function serializeOpType(op: OpType): string { return ''; } +/** + * Different styles of if conditions + */ enum IfType { // If the mask is 0, generates a random uniform comparison // Otherwise, tests subgroup_invocation_id against a mask @@ -254,6 +264,18 @@ class Op { } }; +/** + * Main class for testcase generation. + * + * Major steps involved in a test: + * 1. Generation (either generate() or a predefined case) + * 2. Simulation + * 3. Result comparison + * + * The interface of the program is fixed and invariant of the particular + * program being tested. + * + */ export class Program { // Number of invocations in the program // Max supported is 128 @@ -482,6 +504,10 @@ export class Program { } } } + case 10: { + this.genElect(false); + break; + } default: { break; } @@ -491,6 +517,13 @@ export class Program { } } + /** + * Ballot generation + * + * Can insert ballots, stores, noise into the program. + * For non-maximal styles, if a ballot is generated, a store always precedes + * it. + */ private genBallot() { // Optionally insert ballots, stores, and noise. // Ballots and stores are used to determine correctness. @@ -526,6 +559,13 @@ export class Program { } } + /** + * Generate an if based on |type| + * + * @param type The type of the if condition, see IfType + * + * Generates if/else structures. + */ private genIf(type: IfType) { let maskIdx = this.getRandomUint(this.numMasks); if (type == IfType.Uniform) @@ -578,6 +618,11 @@ export class Program { this.nesting--; } + /** + * Generate a uniform for loop + * + * The number of iterations is randomly selected [1, 5]. + */ private genForUniform() { const n = this.getRandomUint(5) + 1; // [1, 5] this.ops.push(new Op(OpType.ForUniform, n)); @@ -592,6 +637,19 @@ export class Program { this.nesting--; } + /** + * Generate an infinite for loop + * + * The loop will always include an elect based break to prevent a truly + * infinite loop. The maximum number of iterations is the number of + * invocations in the program, but it is scaled by the loop nesting. Inside + * one loop the number of iterations is halved and inside two loops the + * number of iterations in quartered. This scaling is used to reduce runtime + * and memory. + * + * The for_update also performs a ballot. + * + */ private genForInf() { this.ops.push(new Op(OpType.ForInf, 0)); this.nesting++; @@ -618,6 +676,14 @@ export class Program { this.nesting--; } + /** + * Generate a for loop with variable iterations per invocation + * + * The loop condition is based on subgroup_invocation_id + 1. So each + * invocation executes a different number of iterations, though the this is + * scaled by the amount of loop nesting the same as |generateForInf|. + * + */ private genForVar() { // op.value is the iteration reduction factor. const reduction = this.loopNesting === 0 ? 1 : this.loopNesting === 1 ? 2 : 4; @@ -635,6 +701,11 @@ export class Program { this.nesting--; } + /** + * Generate a loop construct with uniform iterations + * + * Same as |genForUniform|, but coded as a loop construct. + */ private genLoopUniform() { const n = this.getRandomUint(5) + 1; this.ops.push(new Op(OpType.LoopUniform, n)); @@ -651,6 +722,11 @@ export class Program { this.nesting--; } + /** + * Generate an infinite loop construct + * + * This is the same as |genForInf| but uses a loop construct. + */ private genLoopInf() { const header = this.ops.length; this.ops.push(new Op(OpType.LoopInf, 0)); @@ -679,6 +755,13 @@ export class Program { this.nesting--; } + /** + * Generates an if based on subgroupElect() + * + * @param forceBreak If true, forces the then statement to contain a break + * @param reduction This generates extra breaks + * + */ private genElect(forceBreak: boolean, reduction: number = 1) { this.ops.push(new Op(OpType.Elect, 0)); this.nesting++; @@ -711,6 +794,13 @@ export class Program { } } + /** + * Generate a break if in a loop. + * + * Only generates a break within a loop, but may break out of a switch and + * not just a loop. Sometimes the break uses a non-uniform if/else to break. + * + */ private genBreak() { if (this.loopNestingThisFunction > 0) { // Sometimes put the break in a divergent if @@ -728,6 +818,11 @@ export class Program { } } + /** + * Generate a continue if in a loop + * + * Sometimes uses a non-uniform if/else to continue. + */ private genContinue() { if (this.loopNestingThisFunction > 0 && !this.isLoopInf.get(this.loopNesting)) { // Sometimes put the continue in a divergent if @@ -745,6 +840,10 @@ export class Program { } } + /** + * Generates a function call. + * + */ private genCall() { this.ops.push(new Op(OpType.Call, 0)); this.callNesting++; @@ -761,6 +860,11 @@ export class Program { this.ops.push(new Op(OpType.EndCall, 0)); } + /** + * Generates a return + * + * Rarely, this will return from the main function + */ private genReturn() { const r = this.getRandomFloat(); if (this.nesting > 0 && @@ -781,20 +885,28 @@ export class Program { } } + /** + * Generate a uniform switch. + * + * Some dead case constructs are also generated. + */ private genSwitchUniform() { const r = this.getRandomUint(5); this.ops.push(new Op(OpType.SwitchUniform, r)); this.nesting++; this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); + // Never taken this.ops.push(new Op(OpType.CaseMask, 0, 1 << (r+1))); this.pickOp(1); this.ops.push(new Op(OpType.EndCase, 0)); + // Always taken this.ops.push(new Op(OpType.CaseMask, 0xf, 1 << r)); this.pickOp(1); this.ops.push(new Op(OpType.EndCase, 0)); + // Never taken this.ops.push(new Op(OpType.CaseMask, 0, 1 << (r+2))); this.pickOp(1); this.ops.push(new Op(OpType.EndCase, 0)); @@ -803,6 +915,10 @@ export class Program { this.nesting--; } + /** + * Generates a non-uniform switch based on subgroup_invocation_id + * + */ private genSwitchVar() { this.ops.push(new Op(OpType.SwitchVar, 0)); this.nesting++; @@ -828,6 +944,10 @@ export class Program { this.nesting--; } + /** + * Generates switch based on an active loop induction variable. + * + */ private genSwitchLoopCount() { const r = this.getRandomUint(this.loopNesting); this.ops.push(new Op(OpType.SwitchLoopCount, r)); @@ -850,17 +970,20 @@ export class Program { this.nesting--; } - // switch (subgroup_invocation_id & 3) { - // default { } - // case 0x3: { ... } - // case 0xc: { ... } - // } - // - // This is not generated for maximal style cases because it is not clear what - // convergence should be expected. There are multiple valid lowerings of a - // switch that would lead to different convergence scenarios. To test this - // properly would likely require a range of values which is difficult for - // this infrastructure to produce. + /** + * switch (subgroup_invocation_id & 3) { + * default { } + * case 0x3: { ... } + * case 0xc: { ... } + * } + * + * This is not generated for maximal style cases because it is not clear what + * convergence should be expected. There are multiple valid lowerings of a + * switch that would lead to different convergence scenarios. To test this + * properly would likely require a range of values which is difficult for + * this infrastructure to produce. + * + */ private genSwitchMulticase() { this.ops.push(new Op(OpType.SwitchVar, 0)); this.nesting++; @@ -1381,7 +1504,6 @@ ${this.functions[i]}`; * BigInt is not the fastest value to manipulate. Care should be taken to optimize it's use. * TODO: would it be better to roll my own 128 bitvector? * - * TODO: reconvergence guarantees in WGSL are not as strong as this simulation */ public simulate(countOnly: boolean, subgroupSize: number, debug: boolean = false): number { class State { @@ -1430,6 +1552,8 @@ ${this.functions[i]}`; } // Allocate the stack based on the maximum nesting in the program. + // Note: this has proven to be considerably more performant than pushing + // and popping from the array. let stack: State[] = new Array(this.maxProgramNesting + 1); for (let i = 0; i < stack.length; i++) { stack[i] = new State(); @@ -1479,10 +1603,6 @@ ${this.functions[i]}`; continue; } } - case OpType.EndCase: - case OpType.Noise: - // No work - break; default: break; } @@ -1919,6 +2039,7 @@ ${this.functions[i]}`; } case OpType.Noise: case OpType.EndCase: { + // No work break; } default: { @@ -1940,6 +2061,9 @@ ${this.functions[i]}`; /** * @returns a mask formed from |masks[idx]| + * + * @param idx The index in |this.masks| to use. + * */ private getValueMask(idx: number): bigint { const x = this.masks[4*idx]; From aa37e32f5dbdf1cef481d9dbdbcf8c919acf9156 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Thu, 31 Aug 2023 19:52:43 -0400 Subject: [PATCH 27/32] Formatting * fix errors from npm run fix --- .../reconvergence/reconvergence.spec.ts | 208 ++++------ .../shader/execution/reconvergence/util.ts | 391 ++++++++++-------- 2 files changed, 315 insertions(+), 284 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 8280e83f63a7..325db4e2745e 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -2,20 +2,10 @@ export const description = `Experimental reconvergence tests based on the Vulkan https://github.com/KhronosGroup/VK-GL-CTS/blob/main/external/vulkancts/modules/vulkan/reconvergence/vktReconvergenceTests.cpp`; import { makeTestGroup } from '../../../../common/framework/test_group.js'; +import { iterRange, unreachable } from '../../../../common/util/util.js'; import { GPUTest } from '../../../gpu_test.js'; -import { - assert, - iterRange, - TypedArrayBufferViewConstructor, - unreachable -} from '../../../../common/util/util.js'; -import { - hex, - Style, - OpType, - Program, - generateSeeds -} from './util.js' + +import { hex, Style, OpType, Program, generateSeeds } from './util.js'; export const g = makeTestGroup(GPUTest); @@ -24,9 +14,9 @@ export const g = makeTestGroup(GPUTest); */ function popcount(input: number): number { let n = input; - n = n - ((n >> 1) & 0x55555555) - n = (n & 0x33333333) + ((n >> 2) & 0x33333333) - return ((n + (n >> 4) & 0xF0F0F0F) * 0x1010101) >> 24 + n = n - ((n >> 1) & 0x55555555); + n = (n & 0x33333333) + ((n >> 2) & 0x33333333); + return (((n + (n >> 4)) & 0xf0f0f0f) * 0x1010101) >> 24; } /** @@ -39,24 +29,32 @@ function popcount(input: number): number { * @returns an error if either the builtin value or ballot count is outside [min, max], * not a a power of 2, or they do not match. */ -function checkSubgroupSizeConsistency(data: Uint32Array, min: number, max: number): Error | undefined { +function checkSubgroupSizeConsistency( + data: Uint32Array, + min: number, + max: number +): Error | undefined { const builtin: number = data[0]; const ballot: number = data[1]; - if (popcount(builtin) != 1) + if (popcount(builtin) !== 1) return new Error(`Subgroup size builtin value (${builtin}) is not a power of two`); if (builtin < min) return new Error(`Subgroup size builtin value (${builtin}) is less than device minimum ${min}`); if (max < builtin) - return new Error(`Subgroup size builtin value (${builtin}) is greater than device maximum ${max}`); + return new Error( + `Subgroup size builtin value (${builtin}) is greater than device maximum ${max}` + ); - if (popcount(ballot) != 1) + if (popcount(ballot) !== 1) return new Error(`Subgroup size ballot value (${builtin}) is not a power of two`); if (ballot < min) return new Error(`Subgroup size ballot value (${ballot}) is less than device minimum ${min}`); if (max < ballot) - return new Error(`Subgroup size ballot value (${ballot}) is greater than device maximum ${max}`); + return new Error( + `Subgroup size ballot value (${ballot}) is greater than device maximum ${max}` + ); - if (builtin != ballot) { + if (builtin !== ballot) { let msg = `Subgroup size mismatch:\n`; msg += `- builtin value = ${builtin}\n`; msg += `- ballot = ${ballot}`; @@ -65,18 +63,24 @@ function checkSubgroupSizeConsistency(data: Uint32Array, min: number, max: numbe return undefined; } -function dumpBallots(ballots: Uint32Array, totalInvocations: number, - invocations: number, locations: number) { +function dumpBallots( + ballots: Uint32Array, + totalInvocations: number, + invocations: number, + locations: number +) { let dump = `Ballots\n`; for (let id = 0; id < invocations; id++) { dump += `id[${id}]\n`; for (let loc = 0; loc < locations; loc++) { const idx = 4 * (totalInvocations * loc + id); - const w = ballots[idx+3]; - const z = ballots[idx+2]; - const y = ballots[idx+1]; - const x = ballots[idx+0]; - dump += ` loc[${loc}] = (0x${hex(w)},0x${hex(z)},0x${hex(y)},0x${hex(x)}), (${w},${z},${y},${x})\n`; + const w = ballots[idx + 3]; + const z = ballots[idx + 2]; + const y = ballots[idx + 1]; + const x = ballots[idx + 0]; + dump += ` loc[${loc}] = (0x${hex(w)},0x${hex(z)},0x${hex(y)},0x${hex( + x + )}), (${w},${z},${y},${x})\n`; } } console.log(dump); @@ -87,7 +91,7 @@ function dumpBallots(ballots: Uint32Array, totalInvocations: number, */ function checkIds(data: Uint32Array, subgroupSize: number): Error | undefined { for (let i = 0; i < data.length; i++) { - if (data[i] !== (i % subgroupSize)) { + if (data[i] !== i % subgroupSize) { let msg = `subgroup_invocation_id does map as assumed to local_invocation_index:\n`; msg += `location_invocation_index = ${i}\n`; msg += `subgroup_invocation_id = ${data[i]}`; @@ -123,20 +127,21 @@ async function testProgram(t: GPUTest, program: Program) { return; } - // TODO: query the device + // TODO: Query the limits when they are wired up. const minSubgroupSize = 4; const maxSubgroupSize = 128; let numLocs = 0; const locMap = new Map(); for (let size = minSubgroupSize; size <= maxSubgroupSize; size *= 2) { - let num = program.simulate(true, size); + const num = program.simulate(true, size); locMap.set(size, num); numLocs = Math.max(num, numLocs); } if (numLocs > program.maxLocations) { - t.expectOK(Error(`Total locations (${numLocs}) truncated to ${program.maxLocations}`), - { mode: 'warn' }); + t.expectOK(Error(`Total locations (${numLocs}) truncated to ${program.maxLocations}`), { + mode: 'warn', + }); numLocs = program.maxLocations; } // Add 1 to ensure there are no extraneous writes. @@ -163,7 +168,7 @@ async function testProgram(t: GPUTest, program: Program) { const ballotLength = numLocs * program.invocations * 4; const ballotBuffer = t.makeBufferWithContents( new Uint32Array([...iterRange(ballotLength, x => 0)]), - GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC + GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC ); t.trackForCleanup(ballotBuffer); @@ -177,14 +182,14 @@ async function testProgram(t: GPUTest, program: Program) { const sizeLength = 2; const sizeBuffer = t.makeBufferWithContents( new Uint32Array([...iterRange(sizeLength, x => 0)]), - GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC + GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC ); t.trackForCleanup(sizeBuffer); const idLength = program.invocations; const idBuffer = t.makeBufferWithContents( new Uint32Array([...iterRange(idLength, x => 0)]), - GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC + GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC ); t.trackForCleanup(idBuffer); @@ -194,13 +199,13 @@ async function testProgram(t: GPUTest, program: Program) { { binding: 0, resource: { - buffer: inputBuffer + buffer: inputBuffer, }, }, { binding: 1, resource: { - buffer: ballotBuffer + buffer: ballotBuffer, }, }, //{ @@ -212,13 +217,13 @@ async function testProgram(t: GPUTest, program: Program) { { binding: 3, resource: { - buffer: sizeBuffer + buffer: sizeBuffer, }, }, { binding: 4, resource: { - buffer: idBuffer + buffer: idBuffer, }, }, ], @@ -228,7 +233,7 @@ async function testProgram(t: GPUTest, program: Program) { const pass = encoder.beginComputePass(); pass.setPipeline(pipeline); pass.setBindGroup(0, bindGroup); - pass.dispatchWorkgroups(1,1,1); + pass.dispatchWorkgroups(1, 1, 1); pass.end(); t.queue.submit([encoder.finish()]); @@ -240,15 +245,12 @@ async function testProgram(t: GPUTest, program: Program) { // Generate a warning if this is not true of the device. // This mapping is not guaranteed by APIs (Vulkan particularly), but seems reliable // (for linear workgroups at least). - const sizeReadback = await t.readGPUBufferRangeTyped( - sizeBuffer, - { - srcByteOffset: 0, - type: Uint32Array, - typedLength: sizeLength, - method: 'copy', - } - ); + const sizeReadback = await t.readGPUBufferRangeTyped(sizeBuffer, { + srcByteOffset: 0, + type: Uint32Array, + typedLength: sizeLength, + method: 'copy', + }); const sizeData: Uint32Array = sizeReadback.data; const actualSize = sizeData[0]; t.expectOK(checkSubgroupSizeConsistency(sizeData, minSubgroupSize, maxSubgroupSize)); @@ -258,15 +260,12 @@ async function testProgram(t: GPUTest, program: Program) { let num = program.simulate(false, actualSize, debug); num = Math.min(program.maxLocations, num); - const idReadback = await t.readGPUBufferRangeTyped( - idBuffer, - { - srcByteOffset: 0, - type: Uint32Array, - typedLength: idLength, - method: 'copy', - } - ); + const idReadback = await t.readGPUBufferRangeTyped(idBuffer, { + srcByteOffset: 0, + type: Uint32Array, + typedLength: idLength, + method: 'copy', + }); const idData = idReadback.data; t.expectOK(checkIds(idData, actualSize), { mode: 'warn' }); @@ -281,15 +280,12 @@ async function testProgram(t: GPUTest, program: Program) { //); //const locationData = locationReadback.data; - const ballotReadback = await t.readGPUBufferRangeTyped( - ballotBuffer, - { - srcByteOffset: 0, - type: Uint32Array, - typedLength: ballotLength, - method: 'copy', - } - ); + const ballotReadback = await t.readGPUBufferRangeTyped(ballotBuffer, { + srcByteOffset: 0, + type: Uint32Array, + typedLength: ballotLength, + method: 'copy', + }); const ballotData = ballotReadback.data; // Only dump a single subgroup @@ -310,7 +306,7 @@ const kNumInvocations = 128; async function predefinedTest(t: GPUTest, style: Style, test: number) { const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize; - let program: Program = new Program(style, 1, invocations);; + const program: Program = new Program(style, 1, invocations); switch (test) { case 0: { program.predefinedProgram1(); @@ -384,14 +380,10 @@ const kPredefinedTestCases = [...iterRange(15, x => x)]; g.test('predefined_workgroup') .desc(`Test workgroup reconvergence using some predefined programs`) - .params(u => - u - .combine('test', kPredefinedTestCases) - .beginSubcases() - ) + .params(u => u.combine('test', kPredefinedTestCases).beginSubcases()) .beforeAllSubcases(t => { t.selectDeviceOrSkipTestCase({ - requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName], }); }) .fn(async t => { @@ -400,14 +392,10 @@ g.test('predefined_workgroup') g.test('predefined_subgroup') .desc(`Test subgroup reconvergence using some predefined programs`) - .params(u => - u - .combine('test', kPredefinedTestCases) - .beginSubcases() - ) + .params(u => u.combine('test', kPredefinedTestCases).beginSubcases()) .beforeAllSubcases(t => { t.selectDeviceOrSkipTestCase({ - requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName], }); }) .fn(async t => { @@ -416,14 +404,10 @@ g.test('predefined_subgroup') g.test('predefined_maximal') .desc(`Test maximal reconvergence using some predefined programs`) - .params(u => - u - .combine('test', kPredefinedTestCases) - .beginSubcases() - ) + .params(u => u.combine('test', kPredefinedTestCases).beginSubcases()) .beforeAllSubcases(t => { t.selectDeviceOrSkipTestCase({ - requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName], }); }) .fn(async t => { @@ -432,14 +416,10 @@ g.test('predefined_maximal') g.test('predefined_wgslv1') .desc(`Test WGSL v1 reconvergence using some predefined programs`) - .params(u => - u - .combine('test', kPredefinedTestCases) - .beginSubcases() - ) + .params(u => u.combine('test', kPredefinedTestCases).beginSubcases()) .beforeAllSubcases(t => { t.selectDeviceOrSkipTestCase({ - requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName], }); }) .fn(async t => { @@ -450,20 +430,16 @@ const kNumRandomCases = 100; g.test('random_workgroup') .desc(`Test workgroup reconvergence using randomly generated programs`) - .params(u => - u - .combine('seed', generateSeeds(kNumRandomCases)) - .beginSubcases() - ) + .params(u => u.combine('seed', generateSeeds(kNumRandomCases)).beginSubcases()) .beforeAllSubcases(t => { t.selectDeviceOrSkipTestCase({ - requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName], }); }) .fn(async t => { const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize; - let program: Program = new Program(Style.Workgroup, t.params.seed, invocations); + const program: Program = new Program(Style.Workgroup, t.params.seed, invocations); program.generate(); await testProgram(t, program); @@ -471,20 +447,16 @@ g.test('random_workgroup') g.test('random_subgroup') .desc(`Test subgroup reconvergence using randomly generated programs`) - .params(u => - u - .combine('seed', generateSeeds(kNumRandomCases)) - .beginSubcases() - ) + .params(u => u.combine('seed', generateSeeds(kNumRandomCases)).beginSubcases()) .beforeAllSubcases(t => { t.selectDeviceOrSkipTestCase({ - requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName], }); }) .fn(async t => { const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize; - let program: Program = new Program(Style.Subgroup, t.params.seed, invocations); + const program: Program = new Program(Style.Subgroup, t.params.seed, invocations); program.generate(); await testProgram(t, program); @@ -492,20 +464,16 @@ g.test('random_subgroup') g.test('random_maximal') .desc(`Test maximal reconvergence using randomly generated programs`) - .params(u => - u - .combine('seed', generateSeeds(kNumRandomCases)) - .beginSubcases() - ) + .params(u => u.combine('seed', generateSeeds(kNumRandomCases)).beginSubcases()) .beforeAllSubcases(t => { t.selectDeviceOrSkipTestCase({ - requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName], }); }) .fn(async t => { const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize; - let program: Program = new Program(Style.Maximal, t.params.seed, invocations); + const program: Program = new Program(Style.Maximal, t.params.seed, invocations); program.generate(); await testProgram(t, program); @@ -513,20 +481,16 @@ g.test('random_maximal') g.test('random_wgslv1') .desc(`Test WGSL v1 reconvergence using randomly generated programs`) - .params(u => - u - .combine('seed', generateSeeds(kNumRandomCases)) - .beginSubcases() - ) + .params(u => u.combine('seed', generateSeeds(kNumRandomCases)).beginSubcases()) .beforeAllSubcases(t => { t.selectDeviceOrSkipTestCase({ - requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName] + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName], }); }) .fn(async t => { const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize; - let program: Program = new Program(Style.WGSLv1, t.params.seed, invocations); + const program: Program = new Program(Style.WGSLv1, t.params.seed, invocations); program.generate(); await testProgram(t, program); diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 21d920264341..b1f4acc5d524 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -15,7 +15,7 @@ function getReplicatedMask(submask: bigint, size: number, total: number): bigint const reps = Math.floor(total / size); let mask: bigint = submask & ((1n << BigInt(size)) - 1n); for (let i = 1; i < reps; i++) { - mask |= (mask << BigInt(size)); + mask |= mask << BigInt(size); } return mask; } @@ -45,7 +45,7 @@ function getSubgroupMask(fullMask: bigint, size: number, id: number = 0): Uint32 const arr: Uint32Array = new Uint32Array(4); const subgroup_id: number = Math.floor(id / size); const shift: number = subgroup_id * size; - let mask: bigint = (fullMask >> BigInt(shift)) & getMask(size); + const mask: bigint = (fullMask >> BigInt(shift)) & getMask(size); arr[0] = Number(BigInt.asUintN(32, mask)); arr[1] = Number(BigInt.asUintN(32, mask >> 32n)); arr[2] = Number(BigInt.asUintN(32, mask >> 64n)); @@ -55,7 +55,7 @@ function getSubgroupMask(fullMask: bigint, size: number, id: number = 0): Uint32 /** @returns true if bit |bit| is set to 1. */ function testBit(mask: bigint, bit: number): boolean { - return ((mask >> BigInt(bit)) & 0x1n) == 1n; + return ((mask >> BigInt(bit)) & 0x1n) === 1n; } /** @returns true if any bit in value is 1. */ @@ -65,7 +65,7 @@ function any(value: bigint): boolean { /** @returns true if all bits in value from [0, size) are 1. */ function all(value: bigint, size: number): boolean { - return value === ((1n << BigInt(size)) - 1n); + return value === (1n << BigInt(size)) - 1n; } /** @@ -84,7 +84,7 @@ export enum Style { // Guarantees provided by WGSL v1. // Very similar to Workgroup, but less strict for loops. WGSLv1 = 3, -}; +} /** * Instruction type @@ -238,7 +238,7 @@ enum IfType { // if subgroup_id < inputs[N] Lid, -}; +} /** * Operation in a Program. @@ -248,21 +248,21 @@ enum IfType { */ class Op { // Instruction type - type : OpType; + type: OpType; // Instruction specific value - value : number; + value: number; // Case specific value caseValue: number; // Indicates if the instruction is uniform or not - uniform : boolean; + uniform: boolean; - constructor(type : OpType, value: number, caseValue: number = 0, uniform: boolean = true) { + constructor(type: OpType, value: number, caseValue: number = 0, uniform: boolean = true) { this.type = type; this.value = value; this.caseValue = caseValue; this.uniform = uniform; } -}; +} /** * Main class for testcase generation. @@ -283,7 +283,7 @@ export class Program { // Pseduo-random number generator private readonly prng: PRNG; // Instruction list - private ops : Op[]; + private ops: Op[]; // Reconvergence style public readonly style: Style; // Minimum number of instructions in a program @@ -337,7 +337,7 @@ export class Program { * @param style Enum indicating the type of reconvergence being tested * @param seed Value used to seed the PRNG */ - constructor(style : Style = Style.Workgroup, seed: number = 1, invocations: number) { + constructor(style: Style = Style.Workgroup, seed: number = 1, invocations: number) { this.invocations = invocations; assert(invocations <= 128); this.prng = new PRNG(seed); @@ -399,7 +399,7 @@ export class Program { * @param count the number of instructions * */ - private pickOp(count : number) { + private pickOp(count: number) { for (let i = 0; i < count; i++) { if (this.ops.length >= this.maxCount) { return; @@ -434,9 +434,15 @@ export class Program { if (this.loopNesting < this.maxLoopNesting) { const r2 = this.getRandomUint(3); switch (r2) { - case 0: this.genForUniform(); break; - case 1: this.genForInf(); break; - case 2: this.genForVar(); break; + case 0: + this.genForUniform(); + break; + case 1: + this.genForInf(); + break; + case 2: + this.genForVar(); + break; default: { break; } @@ -454,9 +460,11 @@ export class Program { } case 7: { // Calls and returns. - if (this.getRandomFloat() < 0.2 && - this.callNesting == 0 && - this.nesting < this.maxNesting - 1) { + if ( + this.getRandomFloat() < 0.2 && + this.callNesting === 0 && + this.nesting < this.maxNesting - 1 + ) { this.genCall(); } else { this.genReturn(); @@ -467,8 +475,12 @@ export class Program { if (this.loopNesting < this.maxLoopNesting) { const r2 = this.getRandomUint(2); switch (r2) { - case 0: this.genLoopUniform(); break; - case 1: this.genLoopInf(); break; + case 0: + this.genLoopUniform(); + break; + case 1: + this.genLoopInf(); + break; default: { break; } @@ -491,7 +503,7 @@ export class Program { // fallthrough } case 2: { - if (this.style != Style.Maximal) { + if (this.style !== Style.Maximal) { this.genSwitchMulticase(); break; } @@ -503,6 +515,7 @@ export class Program { break; } } + break; } case 10: { this.genElect(false); @@ -529,11 +542,16 @@ export class Program { // Ballots and stores are used to determine correctness. if (this.getRandomFloat() < 0.2) { const cur_length = this.ops.length; - if (cur_length < 2 || - !(this.ops[cur_length - 1].type == OpType.Ballot || - (this.ops[cur_length-1].type == OpType.Store && this.ops[cur_length - 2].type == OpType.Ballot))) { + if ( + cur_length < 2 || + !( + this.ops[cur_length - 1].type === OpType.Ballot || + (this.ops[cur_length - 1].type === OpType.Store && + this.ops[cur_length - 2].type === OpType.Ballot) + ) + ) { // Perform a store with each ballot so the results can be correlated. - if (this.style != Style.Maximal) + if (this.style !== Style.Maximal) this.ops.push(new Op(OpType.Store, cur_length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); } @@ -541,12 +559,17 @@ export class Program { if (this.getRandomFloat() < 0.1) { const cur_length = this.ops.length; - if (cur_length < 2 || - !(this.ops[cur_length - 1].type == OpType.Store || - (this.ops[cur_length - 1].type == OpType.Ballot && this.ops[cur_length - 2].type == OpType.Store))) { + if ( + cur_length < 2 || + !( + this.ops[cur_length - 1].type === OpType.Store || + (this.ops[cur_length - 1].type === OpType.Ballot && + this.ops[cur_length - 2].type === OpType.Store) + ) + ) { // Subgroup and workgroup styles do a store with every ballot. // Don't bloat the code by adding more. - if (this.style == Style.Maximal) + if (this.style === Style.Maximal) this.ops.push(new Op(OpType.Store, cur_length + this.storeBase)); } } @@ -568,13 +591,12 @@ export class Program { */ private genIf(type: IfType) { let maskIdx = this.getRandomUint(this.numMasks); - if (type == IfType.Uniform) - maskIdx = 0; + if (type === IfType.Uniform) maskIdx = 0; const lid = this.getRandomUint(this.invocations); - if (type == IfType.Lid) { + if (type === IfType.Lid) { this.ops.push(new Op(OpType.IfId, lid)); - } else if (type == IfType.LoopCount) { + } else if (type === IfType.LoopCount) { this.ops.push(new Op(OpType.IfLoopCount, 0)); } else { this.ops.push(new Op(OpType.IfMask, maskIdx)); @@ -583,15 +605,15 @@ export class Program { this.nesting++; this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); - let beforeSize = this.ops.length; + const beforeSize = this.ops.length; this.pickOp(2); - let afterSize = this.ops.length; + const afterSize = this.ops.length; const randElse = this.getRandomFloat(); if (randElse < 0.5) { - if (type == IfType.Lid) { + if (type === IfType.Lid) { this.ops.push(new Op(OpType.ElseId, lid)); - } else if (type == IfType.LoopCount) { + } else if (type === IfType.LoopCount) { this.ops.push(new Op(OpType.ElseLoopCount, 0)); } else { this.ops.push(new Op(OpType.ElseMask, maskIdx)); @@ -599,14 +621,17 @@ export class Program { // Sometimes make the else identical to the if, but don't just completely // blow up the instruction count. - if (randElse < 0.1 && beforeSize != afterSize && - (beforeSize + 2 * (afterSize - beforeSize)) < this.maxCount) { + if ( + randElse < 0.1 && + beforeSize !== afterSize && + beforeSize + 2 * (afterSize - beforeSize) < this.maxCount + ) { for (let i = beforeSize; i < afterSize; i++) { const op = this.ops[i]; this.ops.push(new Op(op.type, op.value, op.caseValue, op.uniform)); // Make stores unique. - if (op.type == OpType.Store) { - this.ops[this.ops.length-1].value = this.storeBase + this.ops.length - 1; + if (op.type === OpType.Store) { + this.ops[this.ops.length - 1].value = this.storeBase + this.ops.length - 1; } } } else { @@ -805,7 +830,7 @@ export class Program { if (this.loopNestingThisFunction > 0) { // Sometimes put the break in a divergent if if (this.getRandomFloat() < 0.1) { - const r = this.getRandomUint(this.numMasks-1) + 1; + const r = this.getRandomUint(this.numMasks - 1) + 1; this.ops.push(new Op(OpType.IfMask, r)); this.ops.push(new Op(OpType.Break, 0)); this.ops.push(new Op(OpType.ElseMask, r)); @@ -827,7 +852,7 @@ export class Program { if (this.loopNestingThisFunction > 0 && !this.isLoopInf.get(this.loopNesting)) { // Sometimes put the continue in a divergent if if (this.getRandomFloat() < 0.1) { - const r = this.getRandomUint(this.numMasks-1) + 1; + const r = this.getRandomUint(this.numMasks - 1) + 1; this.ops.push(new Op(OpType.IfMask, r)); this.ops.push(new Op(OpType.Continue, 0)); this.ops.push(new Op(OpType.ElseMask, r)); @@ -867,10 +892,12 @@ export class Program { */ private genReturn() { const r = this.getRandomFloat(); - if (this.nesting > 0 && - (r < 0.05 || - (this.callNesting > 0 && this.loopNestingThisFunction > 0 && r < 0.2) || - (this.callNesting > 0 && this.loopNestingThisFunction > 1 && r < 0.5))) { + if ( + this.nesting > 0 && + (r < 0.05 || + (this.callNesting > 0 && this.loopNestingThisFunction > 0 && r < 0.2) || + (this.callNesting > 0 && this.loopNestingThisFunction > 1 && r < 0.5)) + ) { this.genBallot(); if (this.getRandomFloat() < 0.1) { this.ops.push(new Op(OpType.IfMask, 0)); @@ -897,7 +924,7 @@ export class Program { this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); // Never taken - this.ops.push(new Op(OpType.CaseMask, 0, 1 << (r+1))); + this.ops.push(new Op(OpType.CaseMask, 0, 1 << (r + 1))); this.pickOp(1); this.ops.push(new Op(OpType.EndCase, 0)); @@ -907,7 +934,7 @@ export class Program { this.ops.push(new Op(OpType.EndCase, 0)); // Never taken - this.ops.push(new Op(OpType.CaseMask, 0, 1 << (r+2))); + this.ops.push(new Op(OpType.CaseMask, 0, 1 << (r + 2))); this.pickOp(1); this.ops.push(new Op(OpType.EndCase, 0)); @@ -924,19 +951,19 @@ export class Program { this.nesting++; this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); - this.ops.push(new Op(OpType.CaseMask, 0x1, 1<<0)); + this.ops.push(new Op(OpType.CaseMask, 0x1, 1 << 0)); this.pickOp(1); this.ops.push(new Op(OpType.EndCase, 0)); - this.ops.push(new Op(OpType.CaseMask, 0x2, 1<<1)); + this.ops.push(new Op(OpType.CaseMask, 0x2, 1 << 1)); this.pickOp(1); this.ops.push(new Op(OpType.EndCase, 0)); - this.ops.push(new Op(OpType.CaseMask, 0x4, 1<<2)); + this.ops.push(new Op(OpType.CaseMask, 0x4, 1 << 2)); this.pickOp(1); this.ops.push(new Op(OpType.EndCase, 0)); - this.ops.push(new Op(OpType.CaseMask, 0x8, 1<<3)); + this.ops.push(new Op(OpType.CaseMask, 0x8, 1 << 3)); this.pickOp(1); this.ops.push(new Op(OpType.EndCase, 0)); @@ -954,11 +981,11 @@ export class Program { this.nesting++; this.maxProgramNesting = Math.max(this.nesting, this.maxProgramNesting); - this.ops.push(new Op(OpType.CaseLoopCount, 1<<1, 1)); + this.ops.push(new Op(OpType.CaseLoopCount, 1 << 1, 1)); this.pickOp(1); this.ops.push(new Op(OpType.EndCase, 0)); - this.ops.push(new Op(OpType.CaseLoopCount, 1<<2, 2)); + this.ops.push(new Op(OpType.CaseLoopCount, 1 << 2, 2)); this.pickOp(1); this.ops.push(new Op(OpType.EndCase, 0)); @@ -988,11 +1015,11 @@ export class Program { this.ops.push(new Op(OpType.SwitchVar, 0)); this.nesting++; - this.ops.push(new Op(OpType.CaseMask, 0x3, (1<<0)|(1<<1))); + this.ops.push(new Op(OpType.CaseMask, 0x3, (1 << 0) | (1 << 1))); this.pickOp(2); this.ops.push(new Op(OpType.EndCase, 0)); - this.ops.push(new Op(OpType.CaseMask, 0xc, (1<<2)|(1<<3))); + this.ops.push(new Op(OpType.CaseMask, 0xc, (1 << 2) | (1 << 3))); this.pickOp(2); this.ops.push(new Op(OpType.EndCase, 0)); @@ -1022,16 +1049,18 @@ export class Program { break; } case OpType.IfMask: { - if (op.value == 0) { + if (op.value === 0) { const idx = this.getRandomUint(4); this.addCode(`if inputs[${idx}] == ${idx} {`); } else { const idx = op.value; - const x = this.masks[4*idx]; - const y = this.masks[4*idx+1]; - const z = this.masks[4*idx+2]; - const w = this.masks[4*idx+3]; - this.addCode(`if testBit(vec4u(0x${hex(x)},0x${hex(y)},0x${hex(z)},0x${hex(w)}), subgroup_id) {`); + const x = this.masks[4 * idx]; + const y = this.masks[4 * idx + 1]; + const z = this.masks[4 * idx + 2]; + const w = this.masks[4 * idx + 3]; + this.addCode( + `if testBit(vec4u(0x${hex(x)},0x${hex(y)},0x${hex(z)},0x${hex(w)}), subgroup_id) {` + ); } this.increaseIndent(); break; @@ -1042,7 +1071,7 @@ export class Program { break; } case OpType.IfLoopCount: { - this.addCode(`if subgroup_id == i${this.loopNesting-1} {`); + this.addCode(`if subgroup_id == i${this.loopNesting - 1} {`); this.increaseIndent(); break; } @@ -1079,7 +1108,9 @@ export class Program { } case OpType.ForVar: { const iter = `i${this.loopNesting}`; - this.addCode(`for (var ${iter} = 0u; ${iter} < (subgroup_id / ${op.value}) + 1; ${iter}++) {`); + this.addCode( + `for (var ${iter} = 0u; ${iter} < (subgroup_id / ${op.value}) + 1; ${iter}++) {` + ); this.loopNesting++; this.increaseIndent(); break; @@ -1168,7 +1199,7 @@ export class Program { this.curFunc = this.functions.length; this.functions.push(``); this.indents.push(0); - let decl = `fn f${this.curFunc}(` + let decl = `fn f${this.curFunc}(`; for (let i = 0; i < this.loopNesting; i++) { decl += `i${i} : u32,`; } @@ -1236,7 +1267,7 @@ export class Program { break; } case OpType.Noise: { - if (op.value == 0) { + if (op.value === 0) { this.addCode(`while (!subgroupElect()) { }`); } else { // The if is uniform false. @@ -1342,7 +1373,7 @@ fn f0() { for (let i = 0; i < this.functions.length; i++) { code += ` ${this.functions[i]}`; - if (i == 0) { + if (i === 0) { code += `\n}\n`; } } @@ -1383,9 +1414,9 @@ ${this.functions[i]}`; let stores = 0; let totalStores = 0; let totalLoops = 0; - let loopsAtNesting = new Array(this.maxLoopNesting); + const loopsAtNesting = new Array(this.maxLoopNesting); loopsAtNesting.fill(0); - let storesAtNesting = new Array(this.maxLoopNesting + 1); + const storesAtNesting = new Array(this.maxLoopNesting + 1); storesAtNesting.fill(0); for (let i = 0; i < this.ops.length; i++) { const op = this.ops[i]; @@ -1474,7 +1505,7 @@ ${this.functions[i]}`; * */ private isUniform(mask: bigint, size: number): boolean { - if (this.style == Style.Workgroup || this.style === Style.WGSLv1) { + if (this.style === Style.Workgroup || this.style === Style.WGSLv1) { if (any(mask) && !all(mask, this.invocations)) { return false; } else { @@ -1484,7 +1515,7 @@ ${this.functions[i]}`; let uniform: boolean = true; for (let id = 0; id < this.invocations; id += size) { const subgroupMask = (mask >> BigInt(id)) & getMask(size); - if (subgroupMask != 0n && !all(subgroupMask, size)) { + if (subgroupMask !== 0n && !all(subgroupMask, size)) { uniform = false; break; } @@ -1546,7 +1577,7 @@ ${this.functions[i]}`; this.isSwitch = false; this.isNonUniform = prev.isNonUniform; } - }; + } for (let idx = 0; idx < this.ops.length; idx++) { this.ops[idx].uniform = true; } @@ -1554,7 +1585,7 @@ ${this.functions[i]}`; // Allocate the stack based on the maximum nesting in the program. // Note: this has proven to be considerably more performant than pushing // and popping from the array. - let stack: State[] = new Array(this.maxProgramNesting + 1); + const stack: State[] = new Array(this.maxProgramNesting + 1); for (let i = 0; i < stack.length; i++) { stack[i] = new State(); } @@ -1562,17 +1593,27 @@ ${this.functions[i]}`; let nesting = 0; let loopNesting = 0; - let locs = new Array(this.invocations); + const locs = new Array(this.invocations); locs.fill(0); let i = 0; while (i < this.ops.length) { const op = this.ops[i]; if (nesting >= stack.length) { - unreachable(`Max stack nesting surpassed (${stack.length} vs ${this.nesting}) at ops[${i}] = ${serializeOpType(op.type)}`); + unreachable( + `Max stack nesting surpassed (${stack.length} vs ${ + this.nesting + }) at ops[${i}] = ${serializeOpType(op.type)}` + ); } if (debug) { - console.log(`ops[${i}] = ${serializeOpType(op.type)}, nesting = ${nesting}, loopNesting = ${loopNesting}, value = ${op.value}, nonuniform = ${stack[nesting].isNonUniform}`); + console.log( + `ops[${i}] = ${serializeOpType( + op.type + )}, nesting = ${nesting}, loopNesting = ${loopNesting}, value = ${ + op.value + }, nonuniform = ${stack[nesting].isNonUniform}` + ); console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); } @@ -1597,11 +1638,12 @@ ${this.functions[i]}`; case OpType.CaseMask: case OpType.CaseLoopCount: { // No reason to simulate if the previous stack entry is inactive. - if (!any(stack[nesting-1].activeMask)) { + if (!any(stack[nesting - 1].activeMask)) { stack[nesting].activeMask = 0n; i++; continue; } + break; } default: break; @@ -1664,9 +1706,9 @@ ${this.functions[i]}`; case OpType.IfMask: { nesting++; const cur = stack[nesting]; - cur.reset(stack[nesting-1], i); + cur.reset(stack[nesting - 1], i); // O is always uniform true. - if (op.value != 0 && any(cur.activeMask)) { + if (op.value !== 0 && any(cur.activeMask)) { let subMask = this.getValueMask(op.value); subMask &= getMask(subgroupSize); cur.activeMask &= getReplicatedMask(subMask, subgroupSize, this.invocations); @@ -1676,8 +1718,8 @@ ${this.functions[i]}`; case OpType.ElseMask: { // 0 is always uniform true so the else will never be taken. const cur = stack[nesting]; - const prev = stack[nesting-1]; - if (op.value == 0) { + const prev = stack[nesting - 1]; + if (op.value === 0) { cur.activeMask = 0n; } else if (any(prev.activeMask)) { let subMask = this.getValueMask(op.value); @@ -1690,7 +1732,7 @@ ${this.functions[i]}`; case OpType.IfId: { nesting++; const cur = stack[nesting]; - cur.reset(stack[nesting-1], i); + cur.reset(stack[nesting - 1], i); if (any(cur.activeMask)) { // All invocations with subgroup invocation id less than op.value are active. const mask = getReplicatedMask(getMask(op.value), subgroupSize, this.invocations); @@ -1699,7 +1741,7 @@ ${this.functions[i]}`; break; } case OpType.ElseId: { - const prev = stack[nesting-1]; + const prev = stack[nesting - 1]; // All invocations with a subgroup invocation id greater or equal to op.value are active. stack[nesting].activeMask = prev.activeMask; if (any(prev.activeMask)) { @@ -1720,7 +1762,7 @@ ${this.functions[i]}`; nesting++; const cur = stack[nesting]; - cur.reset(stack[nesting-1], i); + cur.reset(stack[nesting - 1], i); if (any(cur.activeMask)) { const submask = BigInt(1 << stack[n].tripCount); const mask = getReplicatedMask(submask, subgroupSize, this.invocations); @@ -1740,7 +1782,7 @@ ${this.functions[i]}`; unreachable(`Failed to find loop for ElseLoopCount`); } - stack[nesting].activeMask = stack[nesting-1].activeMask; + stack[nesting].activeMask = stack[nesting - 1].activeMask; if (any(stack[nesting].activeMask)) { const submask = BigInt(1 << stack[n].tripCount); const mask = getReplicatedMask(submask, subgroupSize, this.invocations); @@ -1761,7 +1803,7 @@ ${this.functions[i]}`; nesting++; loopNesting++; const cur = stack[nesting]; - cur.reset(stack[nesting-1], i); + cur.reset(stack[nesting - 1], i); cur.isLoop = true; break; } @@ -1771,8 +1813,7 @@ ${this.functions[i]}`; cur.tripCount++; cur.activeMask |= cur.continueMask; cur.continueMask = 0n; - if (cur.tripCount < this.ops[cur.header].value && - any(cur.activeMask)) { + if (cur.tripCount < this.ops[cur.header].value && any(cur.activeMask)) { i = cur.header + 1; if (this.style === Style.WGSLv1 && !all(cur.activeMask, subgroupSize)) { cur.isNonUniform = true; @@ -1831,8 +1872,8 @@ ${this.functions[i]}`; if (!done) { // i < (subgroup_invocation_id / reduction) + 1 // So remove all ids < tripCount * reduction - let submask = getMask(subgroupSize) & ~getMask(cur.tripCount * op.value); - let mask = getReplicatedMask(submask, subgroupSize, this.invocations); + const submask = getMask(subgroupSize) & ~getMask(cur.tripCount * op.value); + const mask = getReplicatedMask(submask, subgroupSize, this.invocations); cur.activeMask &= mask; done = !any(cur.activeMask); } @@ -1854,8 +1895,7 @@ ${this.functions[i]}`; cur.tripCount++; cur.activeMask |= cur.continueMask; cur.continueMask = 0n; - if (cur.tripCount < this.ops[cur.header].value && - any(cur.activeMask)) { + if (cur.tripCount < this.ops[cur.header].value && any(cur.activeMask)) { if (this.style === Style.WGSLv1 && !all(cur.activeMask, subgroupSize)) { cur.isNonUniform = true; } @@ -1906,7 +1946,7 @@ ${this.functions[i]}`; } case OpType.Break: { // Remove this active mask from all stack entries for the current loop/switch. - let mask: bigint = stack[nesting].activeMask; + const mask: bigint = stack[nesting].activeMask; if (!any(mask)) { break; } @@ -1926,7 +1966,7 @@ ${this.functions[i]}`; case OpType.Continue: { // Remove this active mask from stack entries in this loop. // Add this mask to the loop's continue mask for the next iteration. - let mask: bigint = stack[nesting].activeMask; + const mask: bigint = stack[nesting].activeMask; if (!any(mask)) { break; } @@ -1953,7 +1993,7 @@ ${this.functions[i]}`; } case OpType.Return: { // Remove this active mask from all stack entries for this function. - let mask: bigint = stack[nesting].activeMask; + const mask: bigint = stack[nesting].activeMask; if (!any(mask)) { break; } @@ -1967,7 +2007,7 @@ ${this.functions[i]}`; } // op.value for Return is the call nesting. // If the value is > 0 we should have encountered the call on the stack. - if (op.value != 0 && n < 0) { + if (op.value !== 0 && n < 0) { unreachable(`Failed to find call for return`); } @@ -1976,7 +2016,7 @@ ${this.functions[i]}`; case OpType.Elect: { nesting++; const cur = stack[nesting]; - cur.reset(stack[nesting-1], i); + cur.reset(stack[nesting - 1], i); if (any(cur.activeMask)) { cur.activeMask = getElectMask(cur.activeMask, subgroupSize, this.invocations); } @@ -1986,7 +2026,7 @@ ${this.functions[i]}`; nesting++; const cur = stack[nesting]; // Header is unused for calls. - cur.reset(stack[nesting-1], 0); + cur.reset(stack[nesting - 1], 0); cur.isCall = true; break; } @@ -1999,7 +2039,7 @@ ${this.functions[i]}`; case OpType.SwitchLoopCount: { nesting++; const cur = stack[nesting]; - cur.reset(stack[nesting-1], i); + cur.reset(stack[nesting - 1], i); cur.isSwitch = true; break; } @@ -2009,7 +2049,7 @@ ${this.functions[i]}`; } case OpType.CaseMask: { const mask = getReplicatedMask(BigInt(op.value), 4, this.invocations); - stack[nesting].activeMask = stack[nesting-1].activeMask & mask; + stack[nesting].activeMask = stack[nesting - 1].activeMask & mask; break; } case OpType.CaseLoopCount: { @@ -2020,7 +2060,7 @@ ${this.functions[i]}`; while (n >= 0 && l >= 0) { if (stack[n].isLoop) { l--; - if (l == findLoop) { + if (l === findLoop) { break; } } @@ -2030,8 +2070,8 @@ ${this.functions[i]}`; unreachable(`Failed to find loop for CaseLoopCount`); } - if (((1 << stack[n].tripCount) & op.value) != 0) { - stack[nesting].activeMask = stack[nesting-1].activeMask; + if (((1 << stack[n].tripCount) & op.value) !== 0) { + stack[nesting].activeMask = stack[nesting - 1].activeMask; } else { stack[nesting].activeMask = 0n; } @@ -2049,7 +2089,7 @@ ${this.functions[i]}`; i++; } - assert(nesting == 0); + assert(nesting === 0); let maxLoc = 0; for (let id = 0; id < this.invocations; id++) { @@ -2066,10 +2106,10 @@ ${this.functions[i]}`; * */ private getValueMask(idx: number): bigint { - const x = this.masks[4*idx]; - const y = this.masks[4*idx+1]; - const z = this.masks[4*idx+2]; - const w = this.masks[4*idx+3]; + const x = this.masks[4 * idx]; + const y = this.masks[4 * idx + 1]; + const z = this.masks[4 * idx + 2]; + const w = this.masks[4 * idx + 3]; let mask: bigint = 0n; mask |= BigInt(x); mask |= BigInt(y) << 32n; @@ -2082,7 +2122,7 @@ ${this.functions[i]}`; public generate() { let i = 0; do { - if (i != 0) { + if (i !== 0) { console.log(`Warning regenerating UCF testcase`); } this.ops = []; @@ -2093,11 +2133,11 @@ ${this.functions[i]}`; // If this is an uniform control flow case, make sure a uniform ballot is // generated. A subgroup size of 64 is used for testing purposes here. - if (this.style != Style.Maximal) { + if (this.style !== Style.Maximal) { this.simulate(true, 64); } i++; - } while (this.style != Style.Maximal && !this.ucf); + } while (this.style !== Style.Maximal && !this.ucf); } /** @returns true if the program has uniform control flow for some ballot */ @@ -2114,7 +2154,6 @@ ${this.functions[i]}`; * @returns The base index in a Uint32Array */ private baseIndex(id: number, loc: number): number { - const capped_loc = Math.min(this.maxLocations, loc); return 4 * (this.invocations * loc + id); } @@ -2125,14 +2164,16 @@ ${this.functions[i]}`; * @param resIdx The base result index * @param ref The reference data * @param refIdx The base reference index - * + * * @returns true if 4 successive values match in both arrays */ private matchResult(res: Uint32Array, resIdx: number, ref: Uint32Array, refIdx: number): boolean { - return res[resIdx + 0] === ref[refIdx + 0] && - res[resIdx + 1] === ref[refIdx + 1] && - res[resIdx + 2] === ref[refIdx + 2] && - res[resIdx + 3] === ref[refIdx + 3]; + return ( + res[resIdx + 0] === ref[refIdx + 0] && + res[resIdx + 1] === ref[refIdx + 1] && + res[resIdx + 2] === ref[refIdx + 2] && + res[resIdx + 3] === ref[refIdx + 3] + ); } /** @@ -2144,9 +2185,12 @@ ${this.functions[i]}`; * @param numLocs The maximum locations used in simulation * @returns an error if the results do meet expectatations */ - public checkResults(ballots: Uint32Array, /*locations: Uint32Array,*/ - subgroupSize: number, numLocs: number): Error | undefined { - let totalLocs = Math.min(numLocs, this.maxLocations); + public checkResults( + ballots: Uint32Array /*locations: Uint32Array,*/, + subgroupSize: number, + numLocs: number + ): Error | undefined { + const totalLocs = Math.min(numLocs, this.maxLocations); if (this.style !== Style.Maximal) { if (!this.isUCF()) { return Error(`Expected some uniform condition for this test`); @@ -2154,25 +2198,38 @@ ${this.functions[i]}`; // Subgroup and Workgroup tests always have an associated store // preceeding them in the buffer. const maskArray = getSubgroupMask(getMask(subgroupSize), subgroupSize); - const zeroArray = new Uint32Array([0,0,0,0]); for (let id = 0; id < this.invocations; id++) { let refLoc = 1; let resLoc = 0; while (refLoc < totalLocs) { - while (refLoc < totalLocs && - !this.matchResult(this.refData, this.baseIndex(id, refLoc), maskArray, 0)) { + while ( + refLoc < totalLocs && + !this.matchResult(this.refData, this.baseIndex(id, refLoc), maskArray, 0) + ) { refLoc++; } if (refLoc < numLocs) { // Fully converged simulation // Search for the corresponding data in the result. - let storeRefLoc = refLoc - 1; - while (resLoc + 1 < totalLocs && - !(this.matchResult(ballots, this.baseIndex(id, resLoc), - this.refData, this.baseIndex(id, storeRefLoc)) && - this.matchResult(ballots, this.baseIndex(id, resLoc+1), - this.refData, this.baseIndex(id, refLoc)))) { + const storeRefLoc = refLoc - 1; + while ( + resLoc + 1 < totalLocs && + !( + this.matchResult( + ballots, + this.baseIndex(id, resLoc), + this.refData, + this.baseIndex(id, storeRefLoc) + ) && + this.matchResult( + ballots, + this.baseIndex(id, resLoc + 1), + this.refData, + this.baseIndex(id, refLoc) + ) + ) + ) { resLoc++; } @@ -2182,7 +2239,9 @@ ${this.functions[i]}`; const ref = this.refData; let msg = `Failure for invocation ${id}: could not find match for:\n`; msg += `- store[${storeRefLoc}] = ${this.refData[sIdx]}\n`; - msg += `- ballot[${refLoc}] = (0x${hex(ref[bIdx+3])},0x${hex(ref[bIdx+2])},0x${hex(ref[bIdx+1])},0x${hex(ref[bIdx])})`; + msg += `- ballot[${refLoc}] = (0x${hex(ref[bIdx + 3])},0x${hex( + ref[bIdx + 2] + )},0x${hex(ref[bIdx + 1])},0x${hex(ref[bIdx])})`; return Error(msg); } // Match both locations so don't revisit them. @@ -2199,8 +2258,12 @@ ${this.functions[i]}`; const loc = Math.floor(idx_uvec4 / this.invocations); if (!this.matchResult(ballots, i, this.refData, i)) { let msg = `Failure for invocation ${id} at location ${loc}:\n`; - msg += `- expected: (0x${hex(this.refData[i+3])},0x${hex(this.refData[i+2])},0x${hex(this.refData[i+1])},0x${hex(this.refData[i])})\n`; - msg += `- got: (0x${hex(ballots[i+3])},0x${hex(ballots[i+2])},0x${hex(ballots[i+1])},0x${hex(ballots[i])})`; + msg += `- expected: (0x${hex(this.refData[i + 3])},0x${hex(this.refData[i + 2])},0x${hex( + this.refData[i + 1] + )},0x${hex(this.refData[i])})\n`; + msg += `- got: (0x${hex(ballots[i + 3])},0x${hex(ballots[i + 2])},0x${hex( + ballots[i + 1] + )},0x${hex(ballots[i])})`; return Error(msg); } } @@ -2241,13 +2304,15 @@ ${this.functions[i]}`; * * ForUniform and EndForUniform * * LoopUniform and EndLoopUniform */ - public predefinedProgram1(beginLoop: OpType = OpType.ForUniform, - endLoop: OpType = OpType.EndForUniform) { + public predefinedProgram1( + beginLoop: OpType = OpType.ForUniform, + endLoop: OpType = OpType.EndForUniform + ) { // Set the mask for index 1 - this.masks[4*1 + 0] = 0xaaaaaaaa - this.masks[4*1 + 1] = 0xaaaaaaaa - this.masks[4*1 + 2] = 0xaaaaaaaa - this.masks[4*1 + 3] = 0xaaaaaaaa + this.masks[4 * 1 + 0] = 0xaaaaaaaa; + this.masks[4 * 1 + 1] = 0xaaaaaaaa; + this.masks[4 * 1 + 2] = 0xaaaaaaaa; + this.masks[4 * 1 + 3] = 0xaaaaaaaa; this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); @@ -2304,10 +2369,10 @@ ${this.functions[i]}`; */ public predefinedProgram2() { // Set the mask for index 1 - this.masks[4*1 + 0] = 0x00ff00ff - this.masks[4*1 + 1] = 0x00ff00ff - this.masks[4*1 + 2] = 0x00ff00ff - this.masks[4*1 + 3] = 0x00ff00ff + this.masks[4 * 1 + 0] = 0x00ff00ff; + this.masks[4 * 1 + 1] = 0x00ff00ff; + this.masks[4 * 1 + 2] = 0x00ff00ff; + this.masks[4 * 1 + 3] = 0x00ff00ff; this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); this.ops.push(new Op(OpType.Ballot, 0)); @@ -2367,10 +2432,10 @@ ${this.functions[i]}`; */ public predefinedProgram3() { // Set the mask for index 1 - this.masks[4*1 + 0] = 0xd2f269c6; - this.masks[4*1 + 1] = 0xffe83b3f; - this.masks[4*1 + 2] = 0xa279f695; - this.masks[4*1 + 3] = 0x58899224; + this.masks[4 * 1 + 0] = 0xd2f269c6; + this.masks[4 * 1 + 1] = 0xffe83b3f; + this.masks[4 * 1 + 2] = 0xa279f695; + this.masks[4 * 1 + 3] = 0x58899224; this.ops.push(new Op(OpType.IfId, 107)); @@ -2418,8 +2483,10 @@ ${this.functions[i]}`; * * ForInf and EndForInf * * LoopInf and EndLoopInf */ - public predefinedProgramInf(beginType: OpType = OpType.ForInf, - endType: OpType = OpType.EndForInf) { + public predefinedProgramInf( + beginType: OpType = OpType.ForInf, + endType: OpType = OpType.EndForInf + ) { this.ops.push(new Op(beginType, 0)); this.ops.push(new Op(OpType.Store, this.ops.length + this.storeBase)); @@ -2594,19 +2661,19 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); this.ops.push(new Op(OpType.SwitchVar, 0)); - this.ops.push(new Op(OpType.CaseMask, 0x1, 1<<0)); + this.ops.push(new Op(OpType.CaseMask, 0x1, 1 << 0)); this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); this.ops.push(new Op(OpType.EndCase, 0)); - this.ops.push(new Op(OpType.CaseMask, 0x2, 1<<1)); + this.ops.push(new Op(OpType.CaseMask, 0x2, 1 << 1)); this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); this.ops.push(new Op(OpType.EndCase, 0)); - this.ops.push(new Op(OpType.CaseMask, 0x4, 1<<2)); + this.ops.push(new Op(OpType.CaseMask, 0x4, 1 << 2)); this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); this.ops.push(new Op(OpType.EndCase, 0)); - this.ops.push(new Op(OpType.CaseMask, 0x8, 1<<3)); + this.ops.push(new Op(OpType.CaseMask, 0x8, 1 << 3)); this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); this.ops.push(new Op(OpType.EndCase, 0)); @@ -2641,12 +2708,12 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Ballot, 0)); this.ops.push(new Op(OpType.SwitchLoopCount, loop)); - this.ops.push(new Op(OpType.CaseLoopCount, 1<<1, 1)); + this.ops.push(new Op(OpType.CaseLoopCount, 1 << 1, 1)); this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); this.ops.push(new Op(OpType.EndCase, 0)); - this.ops.push(new Op(OpType.CaseLoopCount, 1<<2, 2)); + this.ops.push(new Op(OpType.CaseLoopCount, 1 << 2, 2)); this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); this.ops.push(new Op(OpType.EndCase, 0)); @@ -2679,12 +2746,12 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Ballot, 0)); this.ops.push(new Op(OpType.SwitchVar, 0)); - this.ops.push(new Op(OpType.CaseMask, 0x3, (1<<0)|(1<<1))); + this.ops.push(new Op(OpType.CaseMask, 0x3, (1 << 0) | (1 << 1))); this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); this.ops.push(new Op(OpType.EndCase, 0)); - this.ops.push(new Op(OpType.CaseMask, 0xc, (1<<2)|(1<<3))); + this.ops.push(new Op(OpType.CaseMask, 0xc, (1 << 2) | (1 << 3))); this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); this.ops.push(new Op(OpType.EndCase, 0)); @@ -2726,11 +2793,11 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); } -}; +} export function generateSeeds(numCases: number): number[] { - let prng: PRNG = new PRNG(1); - let output: number[] = new Array(numCases); + const prng: PRNG = new PRNG(1); + const output: number[] = new Array(numCases); for (let i = 0; i < numCases; i++) { output[i] = prng.randomU32(); } From 3367642e3c74ba6530c7ec7e897072cfcc2b5e54 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Thu, 31 Aug 2023 20:07:23 -0400 Subject: [PATCH 28/32] Lots of comments to satisfy linting --- .../reconvergence/reconvergence.spec.ts | 70 +++---- .../shader/execution/reconvergence/util.ts | 192 +++++++++--------- 2 files changed, 129 insertions(+), 133 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 325db4e2745e..e76be22ee21b 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -5,7 +5,7 @@ import { makeTestGroup } from '../../../../common/framework/test_group.js'; import { iterRange, unreachable } from '../../../../common/util/util.js'; import { GPUTest } from '../../../gpu_test.js'; -import { hex, Style, OpType, Program, generateSeeds } from './util.js'; +import { /*hex, */ Style, OpType, Program, generateSeeds } from './util.js'; export const g = makeTestGroup(GPUTest); @@ -63,28 +63,28 @@ function checkSubgroupSizeConsistency( return undefined; } -function dumpBallots( - ballots: Uint32Array, - totalInvocations: number, - invocations: number, - locations: number -) { - let dump = `Ballots\n`; - for (let id = 0; id < invocations; id++) { - dump += `id[${id}]\n`; - for (let loc = 0; loc < locations; loc++) { - const idx = 4 * (totalInvocations * loc + id); - const w = ballots[idx + 3]; - const z = ballots[idx + 2]; - const y = ballots[idx + 1]; - const x = ballots[idx + 0]; - dump += ` loc[${loc}] = (0x${hex(w)},0x${hex(z)},0x${hex(y)},0x${hex( - x - )}), (${w},${z},${y},${x})\n`; - } - } - console.log(dump); -} +//function dumpBallots( +// ballots: Uint32Array, +// totalInvocations: number, +// invocations: number, +// locations: number +//) { +// let dump = `Ballots\n`; +// for (let id = 0; id < invocations; id++) { +// dump += `id[${id}]\n`; +// for (let loc = 0; loc < locations; loc++) { +// const idx = 4 * (totalInvocations * loc + id); +// const w = ballots[idx + 3]; +// const z = ballots[idx + 2]; +// const y = ballots[idx + 1]; +// const x = ballots[idx + 0]; +// dump += ` loc[${loc}] = (0x${hex(w)},0x${hex(z)},0x${hex(y)},0x${hex( +// x +// )}), (${w},${z},${y},${x})\n`; +// } +// } +// console.log(dump); +//} /** * Checks the mapping of subgroup_invocation_id to local_invocation_index @@ -117,9 +117,9 @@ const kDebugLevel = 0x0; async function testProgram(t: GPUTest, program: Program) { const wgsl = program.genCode(); - if (kDebugLevel & 0x1) { - console.log(wgsl); - } + //if (kDebugLevel & 0x1) { + // console.log(wgsl); + //} if (kDebugLevel & 0x2) { program.dumpStats(true); } @@ -127,7 +127,7 @@ async function testProgram(t: GPUTest, program: Program) { return; } - // TODO: Query the limits when they are wired up. + // Query the limits when they are wired up. const minSubgroupSize = 4; const maxSubgroupSize = 128; @@ -289,14 +289,14 @@ async function testProgram(t: GPUTest, program: Program) { const ballotData = ballotReadback.data; // Only dump a single subgroup - if (kDebugLevel & 0x10) { - console.log(`${new Date()}: Reference data`); - dumpBallots(program.refData, program.invocations, actualSize, num); - } - if (kDebugLevel & 0x20) { - console.log(`${new Date()}: GPU data`); - dumpBallots(ballotData, program.invocations, actualSize, num); - } + //if (kDebugLevel & 0x10) { + // console.log(`${new Date()}: Reference data`); + // dumpBallots(program.refData, program.invocations, actualSize, num); + //} + //if (kDebugLevel & 0x20) { + // console.log(`${new Date()}: GPU data`); + // dumpBallots(ballotData, program.invocations, actualSize, num); + //} t.expectOK(program.checkResults(ballotData, /*locationData,*/ actualSize, num)); } diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index b1f4acc5d524..9aa5e5a68539 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -344,12 +344,9 @@ export class Program { this.ops = []; this.style = style; this.minCount = 30; - //this.maxCount = 50000; // TODO: what is a reasonable limit? - this.maxCount = 20000; // TODO: what is a reasonable limit? - // TODO: https://crbug.com/tint/2011 + this.maxCount = 20000; // what is a reasonable limit? + // https://crbug.com/tint/2011 // Tint is double counting depth - //this.maxNesting = this.getRandomUint(70) + 30; // [30,100) - //this.maxNesting = this.getRandomUint(40) + 20; this.maxNesting = this.getRandomUint(20) + 20; // Loops significantly affect runtime and memory performance this.maxLoopNesting = 3; //4; @@ -1408,81 +1405,80 @@ ${this.functions[i]}`; * * @param detailed If true, dumps more detailed stats */ - public dumpStats(detailed: boolean = true) { - let stats = `Total instructions: ${this.ops.length}\n`; - let nesting = 0; - let stores = 0; - let totalStores = 0; - let totalLoops = 0; - const loopsAtNesting = new Array(this.maxLoopNesting); - loopsAtNesting.fill(0); - const storesAtNesting = new Array(this.maxLoopNesting + 1); - storesAtNesting.fill(0); - for (let i = 0; i < this.ops.length; i++) { - const op = this.ops[i]; - switch (op.type) { - case OpType.Store: - case OpType.Ballot: { - stores++; - storesAtNesting[nesting]++; - break; - } - case OpType.ForUniform: - case OpType.LoopUniform: - case OpType.ForVar: - case OpType.ForInf: - case OpType.LoopInf: { - totalLoops++; - loopsAtNesting[nesting]++; - if (detailed) { - stats += ' '.repeat(nesting) + `${stores} stores\n`; - } - totalStores += stores; - stores = 0; - - if (detailed) { - let iters = `subgroup size`; - if (op.type === OpType.ForUniform || op.type === OpType.LoopUniform) { - iters = `${op.value}`; - } - stats += ' '.repeat(nesting) + serializeOpType(op.type) + `: ${iters} iterations\n`; - } - nesting++; - break; - } - case OpType.EndForUniform: - case OpType.EndForInf: - case OpType.EndForVar: - case OpType.EndLoopUniform: - case OpType.EndLoopInf: { - if (detailed) { - stats += ' '.repeat(nesting) + `${stores} stores\n`; - } - totalStores += stores; - stores = 0; - - nesting--; - if (detailed) { - stats += ' '.repeat(nesting) + serializeOpType(op.type) + '\n'; - } - break; - } - default: - break; - } - } - totalStores += stores; - stats += `\n`; - stats += `${totalLoops} loops\n`; - for (let i = 0; i < loopsAtNesting.length; i++) { - stats += ` ${loopsAtNesting[i]} at nesting ${i}\n`; - } - stats += `${totalStores} stores\n`; - for (let i = 0; i < storesAtNesting.length; i++) { - stats += ` ${storesAtNesting[i]} at nesting ${i}\n`; - } - console.log(stats); - } + //public dumpStats(detailed: boolean = true) { + // let stats = `Total instructions: ${this.ops.length}\n`; + // let nesting = 0; + // let stores = 0; + // let totalStores = 0; + // let totalLoops = 0; + // const loopsAtNesting = new Array(this.maxLoopNesting); + // loopsAtNesting.fill(0); + // const storesAtNesting = new Array(this.maxLoopNesting + 1); + // storesAtNesting.fill(0); + // for (const op of this.ops) { + // switch (op.type) { + // case OpType.Store: + // case OpType.Ballot: { + // stores++; + // storesAtNesting[nesting]++; + // break; + // } + // case OpType.ForUniform: + // case OpType.LoopUniform: + // case OpType.ForVar: + // case OpType.ForInf: + // case OpType.LoopInf: { + // totalLoops++; + // loopsAtNesting[nesting]++; + // if (detailed) { + // stats += ' '.repeat(nesting) + `${stores} stores\n`; + // } + // totalStores += stores; + // stores = 0; + + // if (detailed) { + // let iters = `subgroup size`; + // if (op.type === OpType.ForUniform || op.type === OpType.LoopUniform) { + // iters = `${op.value}`; + // } + // stats += ' '.repeat(nesting) + serializeOpType(op.type) + `: ${iters} iterations\n`; + // } + // nesting++; + // break; + // } + // case OpType.EndForUniform: + // case OpType.EndForInf: + // case OpType.EndForVar: + // case OpType.EndLoopUniform: + // case OpType.EndLoopInf: { + // if (detailed) { + // stats += ' '.repeat(nesting) + `${stores} stores\n`; + // } + // totalStores += stores; + // stores = 0; + + // nesting--; + // if (detailed) { + // stats += ' '.repeat(nesting) + serializeOpType(op.type) + '\n'; + // } + // break; + // } + // default: + // break; + // } + // } + // totalStores += stores; + // stats += `\n`; + // stats += `${totalLoops} loops\n`; + // for (let i = 0; i < loopsAtNesting.length; i++) { + // stats += ` ${loopsAtNesting[i]} at nesting ${i}\n`; + // } + // stats += `${totalStores} stores\n`; + // for (let i = 0; i < storesAtNesting.length; i++) { + // stats += ` ${storesAtNesting[i]} at nesting ${i}\n`; + // } + // console.log(stats); + //} /** * Sizes the simulation buffer. @@ -1533,7 +1529,7 @@ ${this.functions[i]}`; * @param subgroupSize The subgroup size to simulate * * BigInt is not the fastest value to manipulate. Care should be taken to optimize it's use. - * TODO: would it be better to roll my own 128 bitvector? + * Would it be better to roll my own 128 bitvector? * */ public simulate(countOnly: boolean, subgroupSize: number, debug: boolean = false): number { @@ -1578,8 +1574,8 @@ ${this.functions[i]}`; this.isNonUniform = prev.isNonUniform; } } - for (let idx = 0; idx < this.ops.length; idx++) { - this.ops[idx].uniform = true; + for (const op of this.ops) { + op.uniform = true; } // Allocate the stack based on the maximum nesting in the program. @@ -1606,16 +1602,16 @@ ${this.functions[i]}`; }) at ops[${i}] = ${serializeOpType(op.type)}` ); } - if (debug) { - console.log( - `ops[${i}] = ${serializeOpType( - op.type - )}, nesting = ${nesting}, loopNesting = ${loopNesting}, value = ${ - op.value - }, nonuniform = ${stack[nesting].isNonUniform}` - ); - console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); - } + //if (debug) { + // console.log( + // `ops[${i}] = ${serializeOpType( + // op.type + // )}, nesting = ${nesting}, loopNesting = ${loopNesting}, value = ${ + // op.value + // }, nonuniform = ${stack[nesting].isNonUniform}` + // ); + // console.log(` mask = ${stack[nesting].activeMask.toString(16)}`); + //} // Early outs if no invocations are active. // Don't skip ops that change nesting. @@ -2120,11 +2116,11 @@ ${this.functions[i]}`; /** @returns a randomized program */ public generate() { - let i = 0; + //let i = 0; do { - if (i !== 0) { - console.log(`Warning regenerating UCF testcase`); - } + //if (i !== 0) { + // console.log(`Warning regenerating UCF testcase`); + //} this.ops = []; while (this.ops.length < this.minCount) { this.pickOp(1); @@ -2136,7 +2132,7 @@ ${this.functions[i]}`; if (this.style !== Style.Maximal) { this.simulate(true, 64); } - i++; + //i++; } while (this.style !== Style.Maximal && !this.ucf); } From 6a80aaa9db5491969eedeb2cfddb1322098b0c40 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Thu, 31 Aug 2023 20:13:25 -0400 Subject: [PATCH 29/32] missed function --- .../shader/execution/reconvergence/reconvergence.spec.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index e76be22ee21b..1b91fac9d00b 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -120,9 +120,9 @@ async function testProgram(t: GPUTest, program: Program) { //if (kDebugLevel & 0x1) { // console.log(wgsl); //} - if (kDebugLevel & 0x2) { - program.dumpStats(true); - } + //if (kDebugLevel & 0x2) { + // program.dumpStats(true); + //} if (kDebugLevel & 0x4) { return; } From b8f7b881383221e0574acd713cbce426cebd9be5 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Thu, 31 Aug 2023 20:24:18 -0400 Subject: [PATCH 30/32] Comments * switch code fragments in comments to drop { } because the linter is over aggressive --- .../shader/execution/reconvergence/util.ts | 131 +++++++----------- 1 file changed, 51 insertions(+), 80 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 9aa5e5a68539..e4666e2a2d7a 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -995,11 +995,10 @@ export class Program { } /** - * switch (subgroup_invocation_id & 3) { - * default { } - * case 0x3: { ... } - * case 0xc: { ... } - * } + * switch (subgroup_invocation_id & 3) + * default + * case 0x3: ... + * case 0xc: ... * * This is not generated for maximal style cases because it is not clear what * convergence should be expected. There are multiple valid lowerings of a @@ -2279,18 +2278,15 @@ ${this.functions[i]}`; * Equivalent to: * * ballot(); // fully uniform - * if (inputs[1] == 1) { + * if (inputs[1] == 1) * ballot(); // fullly uniform - * for (var i = 0; i < 3; i++) { + * for (var i = 0; i < 3; i++) * ballot(); // Simulation expects fully uniform, WGSL does not. - * if (testBit(vec4u(0xaaaaaaaa,0xaaaaaaa,0xaaaaaaaa,0xaaaaaaaa), subgroup_id)) { + * if (testBit(vec4u(0xaaaaaaaa,0xaaaaaaa,0xaaaaaaaa,0xaaaaaaaa), subgroup_id)) * ballot(); // non-uniform * continue; - * } * ballot(); // non-uniform - * } * ballot(); // fully uniform - * } * ballot(); // fully uniform * * @param beginLoop The loop type @@ -2344,18 +2340,16 @@ ${this.functions[i]}`; * Equivalent to: * * ballot(); // uniform - * if (subgroup_id < 16) { + * if (subgroup_id < 16) * ballot(); // 0xffff - * if (testbit(vec4u(0x00ff00ff,00ff00ff,00ff00ff,00ff00ff), subgroup_id)) { + * if (testbit(vec4u(0x00ff00ff,00ff00ff,00ff00ff,00ff00ff), subgroup_id)) * ballot(); // 0xff - * if (inputs[1] == 1) { + * if (inputs[1] == 1) * ballot(); // 0xff - * } * ballot(); // 0xff - * } else { + * else * ballot(); // 0xF..0000 * return; - * } * ballot; // 0xffff * * In this program, subgroups larger than 16 invocations diverge at the first if. @@ -2407,19 +2401,16 @@ ${this.functions[i]}`; /** * Equivalent to: * - * if subgroup_id < inputs[107] { - * if subgroup_id < inputs[112] { + * if subgroup_id < inputs[107] + * if subgroup_id < inputs[112] * ballot(); - * if testBit(vec4u(0xd2f269c6,0xffe83b3f,0xa279f695,0x58899224), subgroup_id) { + * if testBit(vec4u(0xd2f269c6,0xffe83b3f,0xa279f695,0x58899224), subgroup_id) * ballot(); - * } else { + * else * ballot() - * } * ballot(); - * } else { + * else * ballot(); - * } - * } * * The first two if statements are uniform for subgroup sizes 64 or less. * The third if statement is non-uniform for all subgroup sizes. @@ -2463,13 +2454,11 @@ ${this.functions[i]}`; /** * Equivalent to: * - * for (var i = 0; ; i++, ballot()) { + * for (var i = 0; ; i++, ballot()) * ballot(); - * if (subgroupElect()) { + * if (subgroupElect()) * ballot(); * break; - * } - * } * ballot(); * * @param beginType The loop type @@ -2501,13 +2490,11 @@ ${this.functions[i]}`; /** * Equivalent to: * - * for (var i = 0; i < subgroup_invocation_id + 1; i++) { + * for (var i = 0; i < subgroup_invocation_id + 1; i++) * ballot(); - * } * ballot(); - * for (var i = 0; i < subgroup_invocation_id + 1; i++) { + * for (var i = 0; i < subgroup_invocation_id + 1; i++) * ballot(); - * } * ballot(); */ public predefinedProgramForVar() { @@ -2533,32 +2520,25 @@ ${this.functions[i]}`; /** * Equivalent to: * - * fn f0() { - * for (var i = 0; i < inputs[3]; i++) { + * fn f0() + * for (var i = 0; i < inputs[3]; i++) * f1(i); * ballot(); - * } * ballot(); - * if (inputs[3] == 3) { + * if (inputs[3] == 3) * f2(); * ballot(); - * } * ballot() - * } - * fn f1(i : u32) { + * fn f1(i : u32) * ballot(); - * if (subgroup_invocation_id == i) { + * if (subgroup_invocation_id == i) * ballot(); * return; - * } - * } - * fn f2() { + * fn f2() * ballot(); - * if (testBit(vec4u(0xaaaaaaaa,0xaaaaaaaa,0xaaaaaaaa,0xaaaaaaaa), local_invocation_index)) { + * if (testBit(vec4u(0xaaaaaaaa,0xaaaaaaaa,0xaaaaaaaa,0xaaaaaaaa), local_invocation_index)) * ballot(); * return; - * } - * } */ public predefinedProgramCall() { this.masks[4 + 0] = 0xaaaaaaaa; @@ -2609,12 +2589,11 @@ ${this.functions[i]}`; * Equivalent to: * * ballot() - * switch (inputs[5]) { - * default { } - * case 6 { ballot(); } - * case 5 { ballot(); } - * case 7 { ballot(); } - * } + * switch (inputs[5]) + * default + * case 6 ballot(); + * case 5 ballot(); + * case 7 ballot(); * ballot(); * */ @@ -2644,13 +2623,12 @@ ${this.functions[i]}`; * Equivalent to: * * ballot(); - * switch subgroup_invocation_id & 3 { - * default { } - * case 0: { ballot(); } - * case 1: { ballot(); } - * case 2: { ballot(); } - * case 3: { ballot(); } - * } + * switch subgroup_invocation_id & 3 + * default + * case 0: ballot(); + * case 1: ballot(); + * case 2: ballot(); + * case 3: ballot(); * ballot(); */ public predefinedProgramSwitchVar() { @@ -2681,19 +2659,15 @@ ${this.functions[i]}`; /** * Equivalent to: * - * for (var i0 = 0u; i0 < inputs[3]; i0++) { - * for (var i1 = 0u; i1 < inputs[3]; i1++) { - * for (var i2 = 0u; i2 < subgroup_invocation_id + 1; i2++) { + * for (var i0 = 0u; i0 < inputs[3]; i0++) + * for (var i1 = 0u; i1 < inputs[3]; i1++) + * for (var i2 = 0u; i2 < subgroup_invocation_id + 1; i2++) * ballot(); - * switch i_loop { - * case 1 { ballot(); } - * case 2 { ballot(); } - * default { ballot(); } - * } + * switch i_loop + * case 1 ballot(); + * case 2 ballot(); + * default ballot(); * ballot(); - * } - * } - * } */ public predefinedProgramSwitchLoopCount(loop: number) { this.ops.push(new Op(OpType.ForUniform, 1)); @@ -2731,11 +2705,10 @@ ${this.functions[i]}`; /** * Equivalent to: * - * switch subgroup_invocation_id & 0x3 { - * default { } - * case 0,1 { ballot(); } - * case 2,3 { ballot(); } - * } + * switch subgroup_invocation_id & 0x3 + * default + * case 0,1 ballot(); + * case 2,3 ballot(); */ public predefinedProgramSwitchMulticase() { this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); @@ -2761,12 +2734,10 @@ ${this.functions[i]}`; * Equivalent to: * * ballot(); - * for (var i = 0; i < inputs[3]; i++) { + * for (var i = 0; i < inputs[3]; i++) * ballot(); - * if (subgroupElect()) { + * if (subgroupElect()) * continue; - * } - * } * ballot(); * * This case can distinguish between Workgroup and WGSLv1 reconvergence. From d7ea45d279c180f9d22f003cd32c8ad3dae4e8fe Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Thu, 7 Sep 2023 20:28:54 -0400 Subject: [PATCH 31/32] Fix a bug in the calculation of stack uniformity * The simulation incorrectly marked some continues as non-uniform * this leads to some ballots being marked with the special value * in WGSLv1, Workgroup, and Subgroup styles this would lead to fewer ballots being checked (so likely no change in results) * in Maximal though, this leads to some ballots being a wacky value and extra failures --- src/webgpu/shader/execution/reconvergence/util.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index e4666e2a2d7a..69a4392a5721 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -1966,7 +1966,7 @@ ${this.functions[i]}`; break; } - const uniform = this.style === Style.WGSLv1 && this.isUniform(mask, subgroupSize); + const uniform = this.style !== Style.WGSLv1 || this.isUniform(mask, subgroupSize); let n = nesting; for (; n >= 0; n--) { From 80726955389a2bd50fae9e070ad6686a89784a32 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Mon, 11 Sep 2023 15:36:29 -0400 Subject: [PATCH 32/32] Add a new test suite * Added a new test suite 'uniform_maximal' that tests that ballots all work as expected when no divergent branches exist in the code * The generator has a mode to only generate uniform conditions * removes several if, loop, and switch styles * restricts types of breaks and continues that generated * removes the generation of the election based noise operation * Adds a predefined test to cover some operations --- .../reconvergence/reconvergence.spec.ts | 28 ++- .../shader/execution/reconvergence/util.ts | 196 +++++++++++++++++- 2 files changed, 212 insertions(+), 12 deletions(-) diff --git a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts index 1b91fac9d00b..65dc6b52ea12 100644 --- a/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts +++ b/src/webgpu/shader/execution/reconvergence/reconvergence.spec.ts @@ -113,7 +113,7 @@ function checkIds(data: Uint32Array, subgroupSize: number): Error | undefined { * * So setting kDebugLevel to 0x5 would dump WGSL and end the test. */ -const kDebugLevel = 0x0; +const kDebugLevel = 0x00; async function testProgram(t: GPUTest, program: Program) { const wgsl = program.genCode(); @@ -159,7 +159,7 @@ async function testProgram(t: GPUTest, program: Program) { // Inputs have a value equal to their index. const inputBuffer = t.makeBufferWithContents( - new Uint32Array([...iterRange(128, x => x)]), + new Uint32Array([...iterRange(129, x => x)]), GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST ); t.trackForCleanup(inputBuffer); @@ -368,6 +368,10 @@ async function predefinedTest(t: GPUTest, style: Style, test: number) { program.predefinedProgramWGSLv1(); break; } + case 15: { + program.predefinedProgramAllUniform(); + break; + } default: { unreachable('Unhandled testcase'); } @@ -376,7 +380,7 @@ async function predefinedTest(t: GPUTest, style: Style, test: number) { await testProgram(t, program); } -const kPredefinedTestCases = [...iterRange(15, x => x)]; +const kPredefinedTestCases = [...iterRange(16, x => x)]; g.test('predefined_workgroup') .desc(`Test workgroup reconvergence using some predefined programs`) @@ -495,3 +499,21 @@ g.test('random_wgslv1') await testProgram(t, program); }); + +g.test('uniform_maximal') + .desc(`Test workgroup reconvergence with only uniform branches`) + .params(u => u.combine('seed', generateSeeds(500)).beginSubcases()) + .beforeAllSubcases(t => { + t.selectDeviceOrSkipTestCase({ + requiredFeatures: ['chromium-experimental-subgroups' as GPUFeatureName], + }); + }) + .fn(async t => { + const invocations = kNumInvocations; // t.device.limits.maxSubgroupSize; + + const onlyUniform: boolean = true; + const program: Program = new Program(Style.Maximal, t.params.seed, invocations, onlyUniform); + program.generate(); + + await testProgram(t, program); + }); diff --git a/src/webgpu/shader/execution/reconvergence/util.ts b/src/webgpu/shader/execution/reconvergence/util.ts index 69a4392a5721..0f4ff8a221a5 100644 --- a/src/webgpu/shader/execution/reconvergence/util.ts +++ b/src/webgpu/shader/execution/reconvergence/util.ts @@ -330,6 +330,8 @@ export class Program { // Indicates if the program satisfies uniform control flow for |style| // This depends on simulating a particular subgroup size public ucf: boolean; + // Indicates that only uniform branches should be generated. + private onlyUniform: boolean; /** * constructor @@ -337,7 +339,12 @@ export class Program { * @param style Enum indicating the type of reconvergence being tested * @param seed Value used to seed the PRNG */ - constructor(style: Style = Style.Workgroup, seed: number = 1, invocations: number) { + constructor( + style: Style = Style.Workgroup, + seed: number = 1, + invocations: number, + onlyUniform: boolean = false + ) { this.invocations = invocations; assert(invocations <= 128); this.prng = new PRNG(seed); @@ -378,6 +385,7 @@ export class Program { this.maxProgramNesting = 10; // default stack allocation this.maxLocations = 130000; // keep the buffer under 256MiB this.ucf = false; + this.onlyUniform = onlyUniform; } /** @returns A random float between 0 and 1 */ @@ -390,13 +398,31 @@ export class Program { return this.prng.randomU32() % max; } + /** + * Pick |count| random instructions + * + * @param count The number of instructions + * + * If |this.onlyUniform| is true then only uniform instructions will be + * selected. + * + */ + private pickOp(count: number) { + if (this.onlyUniform) { + this.pickUniformOp(count); + } else { + this.pickAnyOp(count); + } + } + /** * Pick |count| random instructions generators * * @param count the number of instructions * + * These instructions could be uniform or non-uniform. */ - private pickOp(count: number) { + private pickAnyOp(count: number) { for (let i = 0; i < count; i++) { if (this.ops.length >= this.maxCount) { return; @@ -527,6 +553,97 @@ export class Program { } } + /** + * Pick |count| random uniform instructions generators + * + * @param count the number of instructions + * + */ + private pickUniformOp(count: number) { + for (let i = 0; i < count; i++) { + if (this.ops.length >= this.maxCount) { + return; + } + + this.genBallot(); + if (this.nesting < this.maxNesting) { + const r = this.getRandomUint(10); + switch (r) { + case 0: + case 1: { + this.genIf(IfType.Lid); + break; + } + case 2: + case 3: { + this.genIf(IfType.Uniform); + break; + } + case 4: { + // Avoid very deep loop nests to limit memory and runtime. + if (this.loopNesting < this.maxLoopNesting) { + this.genForUniform(); + } + break; + } + case 5: { + this.genBreak(); + break; + } + case 6: { + this.genContinue(); + break; + } + case 7: { + // Calls and returns. + if ( + this.getRandomFloat() < 0.2 && + this.callNesting === 0 && + this.nesting < this.maxNesting - 1 + ) { + this.genCall(); + } else { + this.genReturn(); + } + break; + } + case 8: { + if (this.loopNesting < this.maxLoopNesting) { + this.genLoopUniform(); + } + break; + } + case 9: { + // crbug.com/tint/2039 + // Tint generates invalid code for switch inside loops. + if (this.loopNestingThisFunction > 0) { + break; + } + const r2 = this.getRandomUint(2); + switch (r2) { + case 1: { + if (this.loopNesting > 0) { + this.genSwitchLoopCount(); + break; + } + // fallthrough + } + default: { + this.genSwitchUniform(); + break; + } + } + break; + } + default: { + break; + } + } + } + this.genBallot(); + } + } + /** * Ballot generation * @@ -572,7 +689,7 @@ export class Program { } const r = this.getRandomUint(10000); - if (r < 3) { + if (r < 3 && !this.onlyUniform) { this.ops.push(new Op(OpType.Noise, 0)); } else if (r < 10) { this.ops.push(new Op(OpType.Noise, 1)); @@ -590,7 +707,7 @@ export class Program { let maskIdx = this.getRandomUint(this.numMasks); if (type === IfType.Uniform) maskIdx = 0; - const lid = this.getRandomUint(this.invocations); + const lid = this.onlyUniform ? this.invocations : this.getRandomUint(this.invocations); if (type === IfType.Lid) { this.ops.push(new Op(OpType.IfId, lid)); } else if (type === IfType.LoopCount) { @@ -820,13 +937,14 @@ export class Program { * Generate a break if in a loop. * * Only generates a break within a loop, but may break out of a switch and - * not just a loop. Sometimes the break uses a non-uniform if/else to break. + * not just a loop. Sometimes the break uses a non-uniform if/else to break + * (unless only uniform branches are specified). * */ private genBreak() { if (this.loopNestingThisFunction > 0) { // Sometimes put the break in a divergent if - if (this.getRandomFloat() < 0.1) { + if (this.getRandomFloat() < 0.1 && !this.onlyUniform) { const r = this.getRandomUint(this.numMasks - 1) + 1; this.ops.push(new Op(OpType.IfMask, r)); this.ops.push(new Op(OpType.Break, 0)); @@ -843,12 +961,13 @@ export class Program { /** * Generate a continue if in a loop * - * Sometimes uses a non-uniform if/else to continue. + * Sometimes uses a non-uniform if/else to continue (unless only uniform + * branches are specified). */ private genContinue() { if (this.loopNestingThisFunction > 0 && !this.isLoopInf.get(this.loopNesting)) { // Sometimes put the continue in a divergent if - if (this.getRandomFloat() < 0.1) { + if (this.getRandomFloat() < 0.1 && !this.onlyUniform) { const r = this.getRandomUint(this.numMasks - 1) + 1; this.ops.push(new Op(OpType.IfMask, r)); this.ops.push(new Op(OpType.Continue, 0)); @@ -896,7 +1015,7 @@ export class Program { (this.callNesting > 0 && this.loopNestingThisFunction > 1 && r < 0.5)) ) { this.genBallot(); - if (this.getRandomFloat() < 0.1) { + if (this.getRandomFloat() < 0.1 && !this.onlyUniform) { this.ops.push(new Op(OpType.IfMask, 0)); this.ops.push(new Op(OpType.Return, this.callNesting)); this.ops.push(new Op(OpType.ElseMask, 0)); @@ -2760,6 +2879,65 @@ ${this.functions[i]}`; this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); this.ops.push(new Op(OpType.Ballot, 0)); } + + /** + * Equivalent to: + * + * for (var i0 = 0u; i0 < inputs[3]; i0++) + * ballot(); + * if subgroup_invocation_id < inputs[128] + * ballot(); + * if subgroup_invocation_id < inputs[128] + * ballot(); + * if subgroup_invocation_id < inputs[128] + * for (var i1 = 0u; i1 < inputs[3]; i1++) + * if subgroup_invocation_id < inputs[128] + * ballot(); + * break; + * if inputs[3] == 3 + * ballot(); + * ballot(); + * + */ + public predefinedProgramAllUniform() { + this.ops.push(new Op(OpType.ForUniform, 3)); // for 0 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + + this.ops.push(new Op(OpType.IfId, 128)); // if 0 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + + this.ops.push(new Op(OpType.IfId, 128)); // if 1 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + + this.ops.push(new Op(OpType.IfId, 128)); // if 2 + this.ops.push(new Op(OpType.ForUniform, 3)); // for 1 + this.ops.push(new Op(OpType.IfId, 128)); // if 3 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.Break, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); // end if 3 + + this.ops.push(new Op(OpType.IfMask, 0)); // if 4 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + this.ops.push(new Op(OpType.EndIf, 0)); // end if 4 + + this.ops.push(new Op(OpType.EndForUniform, 0)); // end for 1 + + this.ops.push(new Op(OpType.ElseId, 128)); // else if 2 + this.ops.push(new Op(OpType.EndIf, 0)); // end if 2 + this.ops.push(new Op(OpType.Store, this.storeBase + this.ops.length)); + this.ops.push(new Op(OpType.Ballot, 0)); + + this.ops.push(new Op(OpType.EndIf, 0)); // end if 1 + + this.ops.push(new Op(OpType.EndIf, 0)); // end if 0 + + this.ops.push(new Op(OpType.EndForUniform, 0)); // end for 0 + } } export function generateSeeds(numCases: number): number[] {