From 1fccd44424fad02d1eefd0c7a935205164ce4450 Mon Sep 17 00:00:00 2001 From: "Stefan J. Wernli" Date: Tue, 6 Feb 2024 10:17:21 -0800 Subject: [PATCH] Fix state ordering in Python (#1122) This change uses the vector returned from the internals directly rather than converting into a hashmap so that the state ordering can be preserved for display. Fixes #1119 Before: ![image](https://github.com/microsoft/qsharp/assets/10567287/4ff3d4d1-021b-4b27-b797-266312cc13cc) After: ![image](https://github.com/microsoft/qsharp/assets/10567287/eff67e4e-a756-45e3-b246-16829f9befcf) --- pip/src/displayable_output.rs | 3 +-- pip/src/displayable_output/tests.rs | 29 +++++++++++++++++++---------- pip/src/interpreter.rs | 15 ++++++++------- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/pip/src/displayable_output.rs b/pip/src/displayable_output.rs index fceb8f549e..9d5fa103c9 100644 --- a/pip/src/displayable_output.rs +++ b/pip/src/displayable_output.rs @@ -7,11 +7,10 @@ mod tests; use num_bigint::BigUint; use num_complex::{Complex64, ComplexFloat}; use qsc::{fmt_basis_state_label, fmt_complex, format_state_id, get_phase}; -use rustc_hash::FxHashMap; use std::fmt::Write; #[derive(Clone)] -pub struct DisplayableState(pub FxHashMap, pub usize); +pub struct DisplayableState(pub Vec<(BigUint, Complex64)>, pub usize); impl DisplayableState { pub fn to_plain(&self) -> String { diff --git a/pip/src/displayable_output/tests.rs b/pip/src/displayable_output/tests.rs index 2aa54cb59d..1a9550e703 100644 --- a/pip/src/displayable_output/tests.rs +++ b/pip/src/displayable_output/tests.rs @@ -3,18 +3,12 @@ use num_bigint::BigUint; use num_complex::Complex; -use rustc_hash::FxHashMap; use crate::displayable_output::DisplayableState; #[test] fn display_neg_zero() { - let s = DisplayableState( - vec![(BigUint::default(), Complex::new(-0.0, -0.0))] - .into_iter() - .collect::>(), - 1, - ); + let s = DisplayableState(vec![(BigUint::default(), Complex::new(-0.0, -0.0))], 1); // -0 should be displayed as 0.0000 without a minus sign assert_eq!("STATE:\n|0⟩: 0.0000+0.0000𝑖", s.to_plain()); } @@ -22,11 +16,26 @@ fn display_neg_zero() { #[test] fn display_rounds_to_neg_zero() { let s = DisplayableState( - vec![(BigUint::default(), Complex::new(-0.00001, -0.00001))] - .into_iter() - .collect::>(), + vec![(BigUint::default(), Complex::new(-0.00001, -0.00001))], 1, ); // -0.00001 should be displayed as 0.0000 without a minus sign assert_eq!("STATE:\n|0⟩: 0.0000+0.0000𝑖", s.to_plain()); } + +#[test] +fn display_preserves_order() { + let s = DisplayableState( + vec![ + (BigUint::from(0_u64), Complex::new(0.0, 0.0)), + (BigUint::from(1_u64), Complex::new(0.0, 1.0)), + (BigUint::from(2_u64), Complex::new(1.0, 0.0)), + (BigUint::from(3_u64), Complex::new(1.0, 1.0)), + ], + 2, + ); + assert_eq!( + "STATE:\n|00⟩: 0.0000+0.0000𝑖\n|01⟩: 0.0000+1.0000𝑖\n|10⟩: 1.0000+0.0000𝑖\n|11⟩: 1.0000+1.0000𝑖", + s.to_plain() + ); +} diff --git a/pip/src/interpreter.rs b/pip/src/interpreter.rs index 360b6cf530..bd34b4692c 100644 --- a/pip/src/interpreter.rs +++ b/pip/src/interpreter.rs @@ -28,7 +28,6 @@ use qsc::{ PackageType, SourceMap, }; use resource_estimator::{self as re, estimate_expr}; -use rustc_hash::FxHashMap; use std::fmt::Write; #[pymodule] @@ -168,10 +167,7 @@ impl Interpreter { /// pairs of real and imaginary amplitudes. fn dump_machine(&mut self) -> StateDump { let (state, qubit_count) = self.interpreter.get_quantum_state(); - StateDump(DisplayableState( - state.into_iter().collect::>(), - qubit_count, - )) + StateDump(DisplayableState(state, qubit_count)) } fn run( @@ -336,7 +332,13 @@ impl StateDump { // Pass by value is needed for compatiblity with the pyo3 API. #[allow(clippy::needless_pass_by_value)] fn __getitem__(&self, key: BigUint) -> Option<(f64, f64)> { - self.0 .0.get(&key).map(|state| (state.re, state.im)) + self.0 .0.iter().find_map(|state| { + if state.0 == key { + Some((state.1.re, state.1.im)) + } else { + None + } + }) } fn __len__(&self) -> usize { @@ -459,7 +461,6 @@ impl Receiver for OptionalCallbackReceiver<'_> { qubit_count: usize, ) -> core::result::Result<(), Error> { if let Some(callback) = &self.callback { - let state = state.into_iter().collect::>(); let out = DisplayableOutput::State(DisplayableState(state, qubit_count)); callback .call1(