From 0afe203cf0493c548831379df66fa42220079960 Mon Sep 17 00:00:00 2001 From: Craig Gidney Date: Fri, 18 Aug 2023 18:54:21 -0700 Subject: [PATCH] Fix crumble MPP fusing/splitting bugs (#610) --- glue/crumble/circuit/circuit.js | 52 ++++++++----- glue/crumble/circuit/circuit.test.js | 31 +++++++- glue/crumble/circuit/pauli_frame.js | 2 +- glue/crumble/gates/gateset_mpp.js | 2 +- glue/crumble/run_tests_headless.js | 2 +- glue/crumble/test/test_main.js | 25 ++++++- glue/crumble/test/test_util.js | 31 +++++--- src/stim/diagram/crumble_data.cc | 106 +++++++++++++-------------- 8 files changed, 161 insertions(+), 90 deletions(-) diff --git a/glue/crumble/circuit/circuit.js b/glue/crumble/circuit/circuit.js index 79e7acff3..61d1f0f87 100644 --- a/glue/crumble/circuit/circuit.js +++ b/glue/crumble/circuit/circuit.js @@ -94,7 +94,11 @@ function simplifiedMPP(args, combinedTargets) { for (let t of combinedTargets) { if (t[0] === 'X' || t[0] === 'Y' || t[0] === 'Z') { bases += t[0]; - qubits.push(parseInt(t.substring(1))); + let v = parseInt(t.substring(1)); + if (v !== v) { + throw Error(`Non-Pauli target given to MPP: ${combinedTargets}`); + } + qubits.push(v); } else { throw Error(`Non-Pauli target given to MPP: ${combinedTargets}`); } @@ -148,17 +152,6 @@ class Circuit { let layers = [new Layer()]; let i2q = new Map(); let used_positions = new Set(); - let next_auto_position_x = 0; - let ensure_has_coords = t => { - while (!i2q.has(t)) { - let k = `${next_auto_position_x},0`; - if (!used_positions.has(k)) { - used_positions.add(k); - i2q.set(t, [next_auto_position_x, 0]); - } - next_auto_position_x++; - } - }; let findEndOfBlock = (lines, startIndex, endIndex) => { let nestLevel = 0; @@ -305,7 +298,6 @@ class Circuit { if (typeof parseInt(targ) !== 'number') { throw new Error(line); } - ensure_has_coords(t); } if (ignored) { console.warn("IGNORED", name); @@ -348,6 +340,29 @@ class Circuit { layers.pop(); } + let next_auto_position_x = 0; + let ensure_has_coords = (t) => { + let b = true; + while (!i2q.has(t)) { + let x = b ? t : next_auto_position_x; + let k = `${x},0`; + if (!used_positions.has(k)) { + used_positions.add(k); + i2q.set(t, [x, 0]); + } + next_auto_position_x += !b; + b = false; + } + }; + + for (let layer of layers) { + for (let op of layer.iter_gates_and_markers()) { + for (let t of op.id_targets) { + ensure_has_coords(t); + } + } + } + let numQubits = Math.max(...i2q.keys(), 0) + 1; let qubitCoords = new Float64Array(numQubits*2); for (let q = 0; q < numQubits; q++) { @@ -503,6 +518,9 @@ class Circuit { for (let layer of this.layers) { let opsByName = groupBy(layer.iter_gates_and_markers(), op => { let key = op.gate.name; + if (key.startsWith('M') && !GATE_MAP.has(key)) { + key = 'MPP'; + } if (op.args.length > 0) { key += '(' + [...op.args].join(',') + ')'; } @@ -525,18 +543,18 @@ class Circuit { let gateName = nameWithArgs.split('(')[0]; let gate = GATE_MAP.get(gateName); - if (gate === undefined && gateName.startsWith('M')) { + if (gate === undefined && gateName === 'MPP') { let line = ['MPP ']; for (let op of group) { for (let k = 0; k < op.id_targets.length; k++) { - line.push(gateName[k + 1] + old2new.get(op.id_targets[k])); + line.push(op.gate.name[k + 1] + old2new.get(op.id_targets[k])); line.push('*'); } line.pop(); + line.push(' '); } - out.push(line.join('')); + out.push(line.join('').trim()); } else { - if (gate !== undefined && gate.can_fuse) { let flatTargetGroups = []; for (let op of group) { diff --git a/glue/crumble/circuit/circuit.test.js b/glue/crumble/circuit/circuit.test.js index 8a396429d..744758313 100644 --- a/glue/crumble/circuit/circuit.test.js +++ b/glue/crumble/circuit/circuit.test.js @@ -470,9 +470,9 @@ test("circuit.inferAndConvertCoordinates", () => { H 0 3 2 1 `)).isEqualTo(Circuit.fromStimCircuit(` QUBIT_COORDS(0, 0) 0 - QUBIT_COORDS(3, 0) 1 + QUBIT_COORDS(1, 0) 1 QUBIT_COORDS(2, 0) 2 - QUBIT_COORDS(1, 0) 3 + QUBIT_COORDS(3, 0) 3 H 0 3 2 1 `)); @@ -504,3 +504,30 @@ test("circuit.inferAndConvertCoordinates", () => { H 0 3 2 1 `)); }); + +test("circuit.parse_mpp", () => { + assertThat(Circuit.fromStimCircuit(` + MPP Z0*Z1*Z2 Z3*Z4*Z6 + `).toString()).isEqualTo(` +QUBIT_COORDS(0, 0) 0 +QUBIT_COORDS(1, 0) 1 +QUBIT_COORDS(2, 0) 2 +QUBIT_COORDS(3, 0) 3 +QUBIT_COORDS(4, 0) 4 +QUBIT_COORDS(6, 0) 5 +MPP Z0*Z1*Z2 Z3*Z4*Z5 + `.trim()) + + assertThat(Circuit.fromStimCircuit(` + MPP Z0*Z1*Z2 Z3*Z4*Z5*X6 + `).toString()).isEqualTo(` +QUBIT_COORDS(0, 0) 0 +QUBIT_COORDS(1, 0) 1 +QUBIT_COORDS(2, 0) 2 +QUBIT_COORDS(3, 0) 3 +QUBIT_COORDS(4, 0) 4 +QUBIT_COORDS(5, 0) 5 +QUBIT_COORDS(6, 0) 6 +MPP Z0*Z1*Z2 Z3*Z4*Z5*X6 + `.trim()) +}); diff --git a/glue/crumble/circuit/pauli_frame.js b/glue/crumble/circuit/pauli_frame.js index 5e69d892e..4c3bda6ff 100644 --- a/glue/crumble/circuit/pauli_frame.js +++ b/glue/crumble/circuit/pauli_frame.js @@ -231,7 +231,7 @@ class PauliFrame { /** * @param {!string} bases - * @param {!Uint32Array} targets + * @param {!Uint32Array|!Array.} targets */ do_mpp(bases, targets) { let anticommutes = 0; diff --git a/glue/crumble/gates/gateset_mpp.js b/glue/crumble/gates/gateset_mpp.js index bfa724d91..15486e51b 100644 --- a/glue/crumble/gates/gateset_mpp.js +++ b/glue/crumble/gates/gateset_mpp.js @@ -10,7 +10,7 @@ function make_mpp_gate(bases) { return new Gate( 'M' + bases, bases.length, - false, + true, false, undefined, (frame, targets) => frame.do_mpp(bases, targets), diff --git a/glue/crumble/run_tests_headless.js b/glue/crumble/run_tests_headless.js index f7fc4d3d3..8c08a09e6 100644 --- a/glue/crumble/run_tests_headless.js +++ b/glue/crumble/run_tests_headless.js @@ -1,7 +1,7 @@ import {run_tests} from "./test/test_util.js" import "./test/test_import_all.js" -let total = await run_tests(() => {}); +let total = await run_tests(() => {}, _name => true); if (!total.passed) { throw new Error("Some tests failed"); } diff --git a/glue/crumble/test/test_main.js b/glue/crumble/test/test_main.js index 141880e76..9a9e813e7 100644 --- a/glue/crumble/test/test_main.js +++ b/glue/crumble/test/test_main.js @@ -6,7 +6,20 @@ await imported.catch(() => {}); let status = /** @type {!HTMLDivElement} */ document.getElementById("status"); let acc = /** @type {!HTMLDivElement} */ document.getElementById("acc"); + +let testFilter = undefined; +let hash = document.location.hash; +if (hash.startsWith('#')) { + hash = hash.substring(1); +} +if (hash.startsWith('test=')) { + hash = hash.substring(5); + testFilter = hash; + console.log(`Only running '${testFilter}'`) +} + status.textContent = "Running tests..."; + let total = await run_tests(progress => { status.textContent = `${progress.num_tests_left - progress.num_tests}/${progress.num_tests} ${progress.name} ${progress.passed ? 'passed' : 'failed'} (${progress.num_tests_failed} failed)`; if (!progress.passed) { @@ -24,11 +37,15 @@ let total = await run_tests(progress => { d.textContent = `Test '${progress.name}' skipped`; acc.appendChild(d); } -}); +}, name => testFilter === undefined || name === testFilter); if (!total.passed) { - status.textContent = `${total.num_tests_failed} tests failed out of ${total.num_tests}.`; -} else if (total.skipped) { - status.textContent = `All ${total.num_tests} tests passed (some were skipped).`; + if (total.num_skipped > 0) { + status.textContent = `${total.num_tests_failed} tests failed out of ${total.num_tests - total.num_skipped} (${total.num_skipped} skipped).`; + } else { + status.textContent = `${total.num_tests_failed} tests failed out of ${total.num_tests}.`; + } +} else if (total.num_skipped > 0) { + status.textContent = `All ${total.num_tests} tests passed (${total.num_skipped} skipped).`; } else { status.textContent = `All ${total.num_tests} tests passed.`; } diff --git a/glue/crumble/test/test_util.js b/glue/crumble/test/test_util.js index 874aca22d..350ce9a62 100644 --- a/glue/crumble/test/test_util.js +++ b/glue/crumble/test/test_util.js @@ -383,29 +383,34 @@ class TestProgress { * @param {!int} num_tests * @param {!int} num_tests_failed * @param {!int} num_tests_left - * @param {!boolean} skipped + * @param {!int} num_skipped */ - constructor(name, passed, error, num_tests, num_tests_failed, num_tests_left, skipped) { + constructor(name, passed, error, num_tests, num_tests_failed, num_tests_left, num_skipped) { this.name = name; this.passed = passed; this.error = error; this.num_tests = num_tests; this.num_tests_failed = num_tests_failed; this.num_tests_left = num_tests_left; - this.skipped = skipped; + this.num_skipped = num_skipped; } } /** * @param {!function(progress: !TestProgress)} callback + * @param {!function(name: !string): !boolean} name_filter */ -async function run_tests(callback) { +async function run_tests(callback, name_filter) { let num_tests = _tests.length; let num_tests_left = _tests.length; let num_tests_failed = 0; let all_passed = true; - let any_skipped = false; + let num_skipped = 0; for (let test of _tests) { + if (!name_filter(test.name)) { + num_skipped += 1; + continue; + } console.log("run test", test.name); _usedAssertIndices = 0; let name = test.name; @@ -425,7 +430,7 @@ async function run_tests(callback) { if (ex instanceof Error && ex.message === "skipRestOfTestIfHeadless:document === undefined") { console.warn(`skipped part of test '${test.name}' because tests are running headless`); skipped = true; - any_skipped = true; + num_skipped += 1; passed = true; } else { error = ex; @@ -435,14 +440,18 @@ async function run_tests(callback) { } } num_tests_left--; - callback(new TestProgress(name, passed, error, num_tests, num_tests_failed, num_tests_left, skipped)); + callback(new TestProgress(name, passed, error, num_tests, num_tests_failed, num_tests_left, num_skipped)); } - if (all_passed) { - console.log("all tests passed"); + + let msg = `done running tests: ${num_tests_failed} failed, ${num_skipped} skipped`; + if (num_tests_failed > 0) { + console.error(msg); + } else if (num_skipped > 0) { + console.warn(msg); } else { - console.error("all tests run, some tests failed"); + console.log(msg); } - return new TestProgress('', all_passed, undefined, num_tests, num_tests_failed, 0, any_skipped); + return new TestProgress('', all_passed, undefined, num_tests, num_tests_failed, 0, num_skipped); } export {test, run_tests, assertThat, skipRestOfTestIfHeadless} diff --git a/src/stim/diagram/crumble_data.cc b/src/stim/diagram/crumble_data.cc index 9ac764ec9..90ad18a16 100644 --- a/src/stim/diagram/crumble_data.cc +++ b/src/stim/diagram/crumble_data.cc @@ -138,37 +138,37 @@ std::string stim_draw_internal::make_crumble_html() { )CRUMBLE_PART"); result.append(R"CRUMBLE_PART(