From 14f114a0218ae3efcc405ba2f63e9dddcfdb1fc7 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 11 May 2020 11:38:19 +0200 Subject: [PATCH 01/11] Add central function for duration comparability test --- qupulse/_program/waveforms.py | 6 +-- .../pulses/multi_channel_pulse_template.py | 23 +++++++--- qupulse/utils/numeric.py | 43 ++++++++++++++++++- 3 files changed, 60 insertions(+), 12 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index fcfed0810..6f633e00b 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -15,6 +15,7 @@ from qupulse import ChannelID from qupulse.utils import checked_int_cast, isclose from qupulse.utils.types import TimeType, time_from_float +from qupulse.utils.numeric import are_durations_compatible from qupulse.comparable import Comparable from qupulse.expressions import ExpressionScalar from qupulse.pulses.interpolation import InterpolationStrategy @@ -409,10 +410,9 @@ def get_sub_waveform_sort_key(waveform): waveform.defined_channels & self.__defined_channels) self.__defined_channels |= waveform.defined_channels - if not all(isclose(waveform.duration, self._sub_waveforms[0].duration) for waveform in self._sub_waveforms[1:]): - # meaningful error message: + durations = list(subwaveform.duration for subwaveform in self._sub_waveforms) + if not are_durations_compatible(*durations): durations = {} - for waveform in self._sub_waveforms: for duration, channels in durations.items(): if isclose(waveform.duration, duration): diff --git a/qupulse/pulses/multi_channel_pulse_template.py b/qupulse/pulses/multi_channel_pulse_template.py index 5b8c2a50e..ba7ed7a49 100644 --- a/qupulse/pulses/multi_channel_pulse_template.py +++ b/qupulse/pulses/multi_channel_pulse_template.py @@ -17,6 +17,7 @@ from qupulse.utils import isclose from qupulse.utils.sympy import almost_equal, Sympifyable from qupulse.utils.types import ChannelID, TimeType +from qupulse.utils.numeric import are_durations_compatible from qupulse._program.waveforms import MultiChannelWaveform, Waveform, TransformingWaveform from qupulse._program.transformation import ParallelConstantChannelTransformation, Transformation, chain_transformations from qupulse.pulses.pulse_template import PulseTemplate, AtomicPulseTemplate @@ -87,13 +88,21 @@ def __init__(self, category=DeprecationWarning) if not duration: - duration = self._subtemplates[0].duration - for subtemplate in self._subtemplates[1:]: - if almost_equal(duration.sympified_expression, subtemplate.duration.sympified_expression): - continue - else: - raise ValueError('Could not assert duration equality of {} and {}'.format(duration, - subtemplate.duration)) + durations = list(subtemplate.duration for subtemplate in subtemplates) + are_compatible = are_durations_compatible(*durations) + + if are_compatible is False: + # durations definitely not compatible + raise ValueError('Could not assert duration equality of {} and {}'.format(repr(duration), + repr(subtemplate.duration))) + elif are_compatible is None: + # cannot assert compatibility + raise ValueError('Could not assert duration equality of {} and {}'.format(repr(duration), + repr(subtemplate.duration))) + + else: + assert are_compatible is True + self._duration = None elif duration is True: self._duration = None diff --git a/qupulse/utils/numeric.py b/qupulse/utils/numeric.py index 53c640bbf..caf8431d8 100644 --- a/qupulse/utils/numeric.py +++ b/qupulse/utils/numeric.py @@ -1,5 +1,5 @@ -from typing import Tuple, Type -from numbers import Rational +from typing import Tuple, Type, Optional +from numbers import Rational, Real from math import gcd @@ -98,3 +98,42 @@ def approximate_rational(x: Rational, abs_err: Rational, fraction_type: Type[Rat def approximate_double(x: float, abs_err: float, fraction_type: Type[Rational]) -> Rational: """Return the fraction with the smallest denominator in (x - abs_err, x + abs_err).""" return approximate_rational(fraction_type(x), fraction_type(abs_err), fraction_type=fraction_type) + + +def are_durations_compatible(first_duration: Real, *other_durations: Real, + max_abs_spread=1e-10, max_rel_spread=1e-10) -> Optional[bool]: + """Durations and maximum allowed spreads must be positive. + + For the durations to be considered compatible, the difference between them must be smaller than at least one of + the allowed spreads. + + Args: + first_duration: Singled out duration for performance reasons. Not handled differently by the algorithm. + *other_durations: Other durations to compare for compatibility + max_abs_spread: Maximum difference for being considered "compatible", regardless of the magnitude of the input + max_rel_spread: maximum difference for being considered "compatible", relative to the magnitude of the + maximum input duration + + Returns: + True or False if decidable else None + """ + min_duration = max_duration = first_duration + for duration in other_durations: + min_duration = min(min_duration, duration) + max_duration = max(max_duration, duration) + assert 0 < max_duration, "At least one duration must be positive" + # spread = max_duration - min_duration + # allowed_spread = max(max_rel_spread * max_duration, max_abs_spread) + are_compatible = max_duration - min_duration < max(max_rel_spread * max_duration, max_abs_spread) + if are_compatible in (False, True): + return are_compatible + + # durations are sympy expressions with clear ordering + elif are_compatible.is_Boolean: + return bool(are_compatible) + + else: + # Not decidable + return None + + From 02c3336da10d94f50f475a3bedcf13319bcdd0de Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 19 May 2020 09:50:37 +0200 Subject: [PATCH 02/11] Create central waveform allocation function that initializes with nan per default --- qupulse/_program/waveforms.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 6f633e00b..9d9526a64 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -26,6 +26,12 @@ "MultiChannelWaveform", "RepetitionWaveform", "TransformingWaveform", "ArithmeticWaveform"] +def alloc_for_sample(size: int) -> np.ndarray: + """All "preallocation" happens via this function. It uses NaN by default to make incomplete initialization better + visible.""" + return np.full(shape=size, fill_value=np.nan) + + class Waveform(Comparable, metaclass=ABCMeta): """Represents an instantiated PulseTemplate which can be sampled to retrieve arrays of voltage values for the hardware.""" @@ -216,7 +222,7 @@ def unsafe_sample(self, sample_times: np.ndarray, output_array: Union[np.ndarray, None]=None) -> np.ndarray: if output_array is None: - output_array = np.empty_like(sample_times) + output_array = alloc_for_sample(sample_times.size) for entry1, entry2 in zip(self._table[:-1], self._table[1:]): indices = slice(np.searchsorted(sample_times, entry1.t, 'left'), @@ -273,7 +279,7 @@ def unsafe_sample(self, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: if output_array is None: - output_array = np.empty(len(sample_times)) + output_array = alloc_for_sample(sample_times.size) output_array[:] = self._expression.evaluate_numeric(t=sample_times) return output_array @@ -317,7 +323,7 @@ def unsafe_sample(self, sample_times: np.ndarray, output_array: Union[np.ndarray, None]=None) -> np.ndarray: if output_array is None: - output_array = np.empty_like(sample_times) + output_array = alloc_for_sample(sample_times.size) time = 0 for subwaveform in self._sequenced_waveforms: # before you change anything here, make sure to understand the difference between basic and advanced @@ -480,7 +486,7 @@ def unsafe_sample(self, sample_times: np.ndarray, output_array: Union[np.ndarray, None]=None) -> np.ndarray: if output_array is None: - output_array = np.empty_like(sample_times) + output_array = alloc_for_sample(sample_times.size) body_duration = self._body.duration time = 0 for _ in range(self._repetition_count): From be17b2a064c7a62282e14b20619152cabb6a2d3e Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 19 May 2020 09:57:21 +0200 Subject: [PATCH 03/11] Add padding functionality to MultiChannelWaveform --- qupulse/_program/waveforms.py | 218 +++++++++++++++++++----------- tests/_program/waveforms_tests.py | 75 +++++++--- 2 files changed, 196 insertions(+), 97 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 9d9526a64..31911ae86 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -9,12 +9,13 @@ from weakref import WeakValueDictionary, ref from typing import Union, Set, Sequence, NamedTuple, Tuple, Any, Iterable, FrozenSet, Optional, Mapping, AbstractSet import operator +import collections import numpy as np from qupulse import ChannelID from qupulse.utils import checked_int_cast, isclose -from qupulse.utils.types import TimeType, time_from_float +from qupulse.utils.types import TimeType, FrozenDict from qupulse.utils.numeric import are_durations_compatible from qupulse.comparable import Comparable from qupulse.expressions import ExpressionScalar @@ -143,6 +144,11 @@ def __neg__(self): def __pos__(self): return self + def last_value(self, channel) -> float: + """Get the last value of the waveform""" + # TODO: Optimize this + return self.unsafe_sample(channel, np.array([float(self.duration)]))[0] + class TableWaveformEntry(NamedTuple('TableWaveformEntry', [('t', float), ('v', float), @@ -353,120 +359,178 @@ def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'W class MultiChannelWaveform(Waveform): """A MultiChannelWaveform is a Waveform object that allows combining arbitrary Waveform objects - to into a single waveform defined for several channels. + to into a single waveform defined for several channels. Most of the time you want to use + :py:meth:`MultiChannelWaveform.from_iterable` to construct a MultiChannelWaveform. + + Automatic padding and truncation: + The duration of the overall waveform is specified by the `duration` argument (None means maximum sub-waveform + duration). All channels that are not in `pad_values` need to be compatible with this duration (determined with + :func:`are_durations_compatible`). Channels that are in `pad_values` are truncated or padded with + the specified value to the required duration. A `None` value is replaced by the result of + `sub_waveform.last_sample(channel_id)`. + + Implementation detail: + Channels that have compatible durations are handled as if their pad_value entry is None. This is only relevant + in numeric corner cases to be always well behaved. + """ + + def __init__(self, + sub_waveforms: Mapping[ChannelID, Waveform], + pad_values: Mapping[ChannelID, Optional[float]], + duration: TimeType) -> None: + super().__init__() + assert sub_waveforms - The number of channels used by the MultiChannelWaveform object is the sum of the channels used - by the Waveform objects it consists of. + wf_pad_dict = {} + for ch, waveform in sub_waveforms.items(): + assert ch in waveform.defined_channels - MultiChannelWaveform allows an arbitrary mapping of channels defined by the Waveforms it - consists of and the channels it defines. For example, if the MultiChannelWaveform consists - of a two Waveform objects A and B which define two channels each, then the channels of the - MultiChannelWaveform may be 0: A.1, 1: B.0, 2: B.1, 3: A.0 where A.0 means channel 0 of Waveform - object A. + if ch not in pad_values: + assert are_durations_compatible(duration, waveform.duration) - The following constraints must hold: - - The durations of all Waveform objects must be equal. - - The channel mapping must be sane, i.e., no channel of the MultiChannelWaveform must be - assigned more than one channel of any Waveform object it consists of - """ + # add default pad that is only required in corner cases of numeric accuracy + pad_value = pad_values.get(ch, None) + if pad_value is None: + pad_value = waveform.last_value(ch) - def __init__(self, sub_waveforms: Iterable[Waveform]) -> None: - """Create a new MultiChannelWaveform instance. + wf_pad_dict[ch] = (waveform, pad_value) - Requires a list of subwaveforms in the form (Waveform, List(int)) where the list defines - the channel mapping, i.e., a value y at index x in the list means that channel x of the - subwaveform will be mapped to channel y of this MultiChannelWaveform object. + self._wf_pad = FrozenDict(wf_pad_dict) + self._duration = duration + @classmethod + def from_iterable(cls, + sub_waveforms: Iterable[Waveform], + pad_values: Optional[Mapping[ChannelID, Optional[float]]] = None, + duration: Optional[TimeType] = None + ) -> 'MultiChannelWaveform': + """Construct a MultiChannelWaveform from an iterable of Waveforms. Args: sub_waveforms (Iterable( Waveform )): The list of sub waveforms of this MultiChannelWaveform + pad_values: Value for padding if desired. None implies :py:meth:`Waveform.last_value`. Channels not + mentioned must have a compatible duration. + duration: Duration of this waveform. None implies the maximum subwaveform duration. Raises: - ValueError, if a channel mapping is out of bounds of the channels defined by this - MultiChannelWaveform - ValueError, if several subwaveform channels are assigned to a single channel of this - MultiChannelWaveform - ValueError, if subwaveforms have inconsistent durations + ValueError, if `sub_waveforms` is empty + ValueError, if the defined channels several subwaveform overlap + ValueError, if subwaveforms have incompatible durations and are not padded + ValueError, if a channel is padded that is not defined in a subwaveform """ - super().__init__() if not sub_waveforms: raise ValueError( "MultiChannelWaveform cannot be constructed without channel waveforms." ) + if pad_values is None: + pad_values = {} - # avoid unnecessary multi channel nesting - def flatten_sub_waveforms(to_flatten): - for sub_waveform in to_flatten: - if isinstance(sub_waveform, MultiChannelWaveform): - yield from sub_waveform._sub_waveforms - else: - yield sub_waveform + duration = max(sub_waveform.duration for sub_waveform in sub_waveforms) if duration is None else duration + defined_channels = collections.Counter() - # sort the waveforms with their defined channels to make compare key reproducible - def get_sub_waveform_sort_key(waveform): - return tuple(sorted(tuple('{}_stringified_numeric_channel'.format(ch) if isinstance(ch, int) else ch - for ch in waveform.defined_channels))) - - self._sub_waveforms = tuple(sorted(flatten_sub_waveforms(sub_waveforms), - key=get_sub_waveform_sort_key)) - - self.__defined_channels = set() - for waveform in self._sub_waveforms: - if waveform.defined_channels & self.__defined_channels: - raise ValueError('Channel may not be defined in multiple waveforms', - waveform.defined_channels & self.__defined_channels) - self.__defined_channels |= waveform.defined_channels - - durations = list(subwaveform.duration for subwaveform in self._sub_waveforms) - if not are_durations_compatible(*durations): - durations = {} - for waveform in self._sub_waveforms: - for duration, channels in durations.items(): - if isclose(waveform.duration, duration): - channels.update(waveform.defined_channels) - break - else: - durations[waveform.duration] = set(waveform.defined_channels) + flattened_wf = {} + flattened_pad = {} + incompatible_durations = {} + for waveform in sub_waveforms: + # if pad is not defined the sub waveform duration needs to be compatible with the overall duration + undefined_pad = waveform.defined_channels - pad_values.keys() + if undefined_pad and not are_durations_compatible(duration, waveform.duration): + # prepare error message + incompatible_durations.setdefault(waveform.duration, set()).intersection_update(undefined_pad) + + defined_channels.update(waveform.defined_channels) + + if isinstance(waveform, MultiChannelWaveform) and are_durations_compatible(waveform.duration, duration): + for ch, (wf, pad) in waveform._wf_pad.items(): + flattened_wf[ch] = wf + flattened_pad[ch] = pad + else: + for ch in waveform.defined_channels: + flattened_wf[ch] = waveform + + if incompatible_durations: raise ValueError( - "MultiChannelWaveform cannot be constructed from channel waveforms of different durations.", - durations + "MultiChannelWaveform cannot be constructed from channel waveforms of incompatible durations.", + incompatible_durations ) + if defined_channels.most_common()[0][1] > 1: + multi_defined = {ch for ch, count in defined_channels.items() if count > 1} + raise ValueError('Channel may not be defined in multiple waveforms', + multi_defined) + if pad_values.keys() - defined_channels.keys(): + raise ValueError('pad_values contains channels not defined in subwaveforms', + pad_values.keys() - defined_channels.keys()) + + return cls(flattened_wf, + {**pad_values, **flattened_pad}, + duration) @property def duration(self) -> TimeType: - return self._sub_waveforms[0].duration + return self._duration def __getitem__(self, key: ChannelID) -> Waveform: - for waveform in self._sub_waveforms: - if key in waveform.defined_channels: - return waveform - raise KeyError('Unknown channel ID: {}'.format(key), key) + try: + return self._wf_pad[key][0] + except KeyError: + raise KeyError('Unknown channel ID: {}'.format(key), key) @property def defined_channels(self) -> Set[ChannelID]: - return self.__defined_channels + return self._wf_pad.keys() @property def compare_key(self) -> Any: - # sort with channels - return self._sub_waveforms + return self._duration, self._wf_pad def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None]=None) -> np.ndarray: - return self[channel].unsafe_sample(channel, sample_times, output_array) + """Pad with last value to length of longest waveform""" + sub_waveform, pad_value = self._wf_pad[channel] + max_idx = np.searchsorted(sample_times, float(sub_waveform.duration), 'right') + if max_idx < len(sample_times): + # we need to pad in the output + if output_array is None: + output_array = alloc_for_sample(sample_times.size) + inner_output_array = output_array[:max_idx] + + sub_waveform.unsafe_sample(channel, sample_times, output_array=inner_output_array) + output_array[max_idx:] = pad_value + return output_array - def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': - relevant_sub_waveforms = tuple(swf for swf in self._sub_waveforms if swf.defined_channels & channels) - if len(relevant_sub_waveforms) == 1: - return relevant_sub_waveforms[0].get_subset_for_channels(channels) - elif len(relevant_sub_waveforms) > 1: - return MultiChannelWaveform( - sub_waveform.get_subset_for_channels(channels & sub_waveform.defined_channels) - for sub_waveform in relevant_sub_waveforms) else: - raise KeyError('Unknown channels: {}'.format(channels)) + return sub_waveform.unsafe_sample(channel, sample_times, output_array=output_array) + + def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': + # TODO: is the optimization to detect if the result can be expressed as a sub-waveform worth it? + # need to check duration compatibility then for consistent padding / truncation + self_duration = self.duration + waveforms = {} + pad_values = {} + padding = False + + for ch in channels: + wf, pad = self._wf_pad[ch] + padding = padding or wf.duration != self_duration + waveforms[ch] = wf + pad_values[ch] = pad + + if not padding: + single_waveform = None + if len(waveforms) == 1: + single_waveform, = waveforms.values() + elif len(set(waveforms.values())) == 1: + _, single_waveform = waveforms.popitem() + if single_waveform is not None: + return single_waveform.get_subset_for_channels(channels) + + return MultiChannelWaveform( + {ch: self._wf_pad[ch][0] for ch in channels}, + {ch: self._wf_pad[ch][1] for ch in channels}, + self._duration + ) class RepetitionWaveform(Waveform): diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index bc9bbce15..a1757654e 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -94,16 +94,16 @@ def test_get_subset_for_channels(self): class MultiChannelWaveformTest(unittest.TestCase): def test_init_no_args(self) -> None: with self.assertRaises(ValueError): - MultiChannelWaveform(dict()) + MultiChannelWaveform.from_iterable(dict()) with self.assertRaises(ValueError): - MultiChannelWaveform(None) + MultiChannelWaveform.from_iterable(None) def test_get_item(self): dwf_a = DummyWaveform(duration=2.2, defined_channels={'A'}) dwf_b = DummyWaveform(duration=2.2, defined_channels={'B'}) dwf_c = DummyWaveform(duration=2.2, defined_channels={'C'}) - wf = MultiChannelWaveform([dwf_a, dwf_b, dwf_c]) + wf = MultiChannelWaveform.from_iterable([dwf_a, dwf_b, dwf_c]) self.assertIs(wf['A'], dwf_a) self.assertIs(wf['B'], dwf_b) @@ -115,7 +115,7 @@ def test_get_item(self): def test_init_single_channel(self) -> None: dwf = DummyWaveform(duration=1.3, defined_channels={'A'}) - waveform = MultiChannelWaveform([dwf]) + waveform = MultiChannelWaveform.from_iterable([dwf]) self.assertEqual({'A'}, waveform.defined_channels) self.assertEqual(TimeType.from_float(1.3), waveform.duration) @@ -124,28 +124,28 @@ def test_init_several_channels(self) -> None: dwf_b = DummyWaveform(duration=2.2, defined_channels={'B'}) dwf_c = DummyWaveform(duration=2.3, defined_channels={'C'}) - waveform = MultiChannelWaveform([dwf_a, dwf_b]) + waveform = MultiChannelWaveform.from_iterable([dwf_a, dwf_b]) self.assertEqual({'A', 'B'}, waveform.defined_channels) self.assertEqual(TimeType.from_float(2.2), waveform.duration) - with self.assertRaises(ValueError): - MultiChannelWaveform([dwf_a, dwf_c]) - with self.assertRaises(ValueError): - MultiChannelWaveform([waveform, dwf_c]) - with self.assertRaises(ValueError): - MultiChannelWaveform((dwf_a, dwf_a)) + with self.assertRaisesRegex(ValueError, 'incompatible duration'): + MultiChannelWaveform.from_iterable([dwf_a, dwf_c]) + with self.assertRaisesRegex(ValueError, 'incompatible duration'): + MultiChannelWaveform.from_iterable([waveform, dwf_c]) + with self.assertRaisesRegex(ValueError, 'multiple waveforms'): + MultiChannelWaveform.from_iterable((dwf_a, dwf_a)) dwf_c_valid = DummyWaveform(duration=2.2, defined_channels={'C'}) - waveform_flat = MultiChannelWaveform((waveform, dwf_c_valid)) - self.assertEqual(len(waveform_flat.compare_key), 3) + waveform_flat = MultiChannelWaveform.from_iterable((waveform, dwf_c_valid)) + self.assertEqual(len(waveform_flat.compare_key), 2) def test_unsafe_sample(self) -> None: - sample_times = numpy.linspace(98.5, 103.5, num=11) + sample_times = numpy.linspace(.1, .534, num=11) samples_a = numpy.linspace(4, 5, 11) samples_b = numpy.linspace(2, 3, 11) dwf_a = DummyWaveform(duration=3.2, sample_output=samples_a, defined_channels={'A'}) dwf_b = DummyWaveform(duration=3.2, sample_output=samples_b, defined_channels={'B', 'C'}) - waveform = MultiChannelWaveform((dwf_a, dwf_b)) + waveform = MultiChannelWaveform.from_iterable((dwf_a, dwf_b)) result_a = waveform.unsafe_sample('A', sample_times) numpy.testing.assert_equal(result_a, samples_a) @@ -172,23 +172,58 @@ def test_unsafe_sample(self) -> None: self.assertIs(result_a, dwf_a.sample_calls[1][2]) numpy.testing.assert_equal(result_b, samples_b) + def test_padding(self): + duration = TimeType.from_float(4) + sub_duration = TimeType.from_float(3.2) + pad_sample_times = numpy.linspace(0, float(duration), num=11, dtype=float) + n_sub, n_pad = sum(pad_sample_times <= sub_duration), sum(pad_sample_times > sub_duration) + no_pad_sample_times = pad_sample_times[:n_sub] + samples_a = numpy.linspace(4, 5, n_sub) + samples_b = numpy.linspace(2, 3, n_sub) + dwf_a = DummyWaveform(duration=sub_duration, sample_output=samples_a, defined_channels={'A'}) + dwf_b = DummyWaveform(duration=sub_duration, sample_output=samples_b, defined_channels={'B'}) + waveform = MultiChannelWaveform.from_iterable((dwf_a, dwf_b), + pad_values={'A': -1, 'B': None}, + duration=duration) + result_a = waveform.unsafe_sample('A', pad_sample_times) + result_b = waveform.unsafe_sample('B', pad_sample_times) + expected_a = np.array(samples_a.tolist() + [-1] * n_pad) + expected_b = np.array(samples_b.tolist() + [samples_b[-1]] * n_pad) + np.testing.assert_equal(expected_a, result_a) + np.testing.assert_equal(expected_b, result_b) + + with mock.patch.object(dwf_a, 'unsafe_sample') as sam_a, mock.patch.object(dwf_b, 'unsafe_sample') as sam_b: + self.assertIs(sam_a.return_value, waveform.unsafe_sample('A', no_pad_sample_times)) + self.assertIs(sam_b.return_value, waveform.unsafe_sample('B', no_pad_sample_times)) + def test_equality(self) -> None: dwf_a = DummyWaveform(duration=246.2, defined_channels={'A'}) dwf_b = DummyWaveform(duration=246.2, defined_channels={'B'}) dwf_c = DummyWaveform(duration=246.2, defined_channels={'C'}) - waveform_a1 = MultiChannelWaveform([dwf_a, dwf_b]) - waveform_a2 = MultiChannelWaveform([dwf_a, dwf_b]) - waveform_a3 = MultiChannelWaveform([dwf_a, dwf_c]) + waveform_a1 = MultiChannelWaveform.from_iterable([dwf_a, dwf_b]) + waveform_a2 = MultiChannelWaveform.from_iterable([dwf_a, dwf_b]) + waveform_a3 = MultiChannelWaveform.from_iterable([dwf_a, dwf_b], duration=TimeType.from_float(246.2)) + waveform_a4 = MultiChannelWaveform.from_iterable([dwf_a, dwf_b], pad_values={'A': 1}) + waveform_a5 = MultiChannelWaveform.from_iterable([dwf_a, dwf_c]) + + waveform_b1 = MultiChannelWaveform.from_iterable([dwf_a, dwf_b], pad_values={'A': 1, 'B': 1}) + waveform_b2 = MultiChannelWaveform.from_iterable([dwf_a, dwf_b], pad_values={'A': 1, 'B': 1}, + duration=TimeType.from_float(246.2)) + waveform_b3 = MultiChannelWaveform.from_iterable([dwf_a, dwf_b], pad_values={'A': 1, 'B': 1}, duration=TimeType.from_float(246.)) self.assertEqual(waveform_a1, waveform_a1) self.assertEqual(waveform_a1, waveform_a2) - self.assertNotEqual(waveform_a1, waveform_a3) + self.assertEqual(waveform_a1, waveform_a3) + self.assertNotEqual(waveform_a1, waveform_a4) + self.assertNotEqual(waveform_a1, waveform_a5) + self.assertEqual(waveform_b1, waveform_b2) + self.assertNotEqual(waveform_b2, waveform_b3) def test_unsafe_get_subset_for_channels(self): dwf_a = DummyWaveform(duration=246.2, defined_channels={'A'}) dwf_b = DummyWaveform(duration=246.2, defined_channels={'B'}) dwf_c = DummyWaveform(duration=246.2, defined_channels={'C'}) - mcwf = MultiChannelWaveform((dwf_a, dwf_b, dwf_c)) + mcwf = MultiChannelWaveform.from_iterable((dwf_a, dwf_b, dwf_c)) with self.assertRaises(KeyError): mcwf.unsafe_get_subset_for_channels({'D'}) with self.assertRaises(KeyError): From e64caf9d8d103e9d0a521137be15e4cc4f17c35c Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 4 Aug 2020 13:06:42 +0200 Subject: [PATCH 04/11] Add _as_expression function to AtomicPulseTemplates to allow truncation --- qupulse/pulses/arithmetic_pulse_template.py | 27 ++++++---- qupulse/pulses/function_pulse_template.py | 6 ++- qupulse/pulses/mapping_pulse_template.py | 12 ++++- qupulse/pulses/point_pulse_template.py | 43 ++++++++++------ qupulse/pulses/pulse_template.py | 8 +++ qupulse/pulses/table_pulse_template.py | 56 ++++++++++++++++++--- qupulse/utils/sympy.py | 11 +++- 7 files changed, 126 insertions(+), 37 deletions(-) diff --git a/qupulse/pulses/arithmetic_pulse_template.py b/qupulse/pulses/arithmetic_pulse_template.py index 284e426a6..ae685f26f 100644 --- a/qupulse/pulses/arithmetic_pulse_template.py +++ b/qupulse/pulses/arithmetic_pulse_template.py @@ -18,6 +18,18 @@ IdentityTransformation +def _apply_operation_to_channel_dict(operator: str, + lhs: Mapping[ChannelID, Any], + rhs: Mapping[ChannelID, Any]) -> Dict[ChannelID, Any]: + result = dict(lhs) + for channel, rhs_value in rhs.items(): + if channel in result: + result[channel] = ArithmeticWaveform.operator_map[operator](result[channel], rhs_value) + else: + result[channel] = ArithmeticWaveform.rhs_only_map[operator](rhs_value) + return result + + class ArithmeticAtomicPulseTemplate(AtomicPulseTemplate): def __init__(self, lhs: AtomicPulseTemplate, @@ -96,17 +108,12 @@ def duration(self) -> ExpressionScalar: @property def integral(self) -> Dict[ChannelID, ExpressionScalar]: - lhs = self.lhs.integral - rhs = self.rhs.integral + return _apply_operation_to_channel_dict(self._arithmetic_operator, self.lhs.integral, self.rhs.integral) - result = lhs.copy() - - for channel, rhs_value in rhs.items(): - if channel in result: - result[channel] = ArithmeticWaveform.operator_map[self._arithmetic_operator](result[channel], rhs_value) - else: - result[channel] = ArithmeticWaveform.rhs_only_map[self._arithmetic_operator](rhs_value) - return result + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + return _apply_operation_to_channel_dict(self._arithmetic_operator, + self.lhs._as_expression(), + self.rhs._as_expression()) def build_waveform(self, parameters: Dict[str, Real], diff --git a/qupulse/pulses/function_pulse_template.py b/qupulse/pulses/function_pulse_template.py index 9df064681..b97d00564 100644 --- a/qupulse/pulses/function_pulse_template.py +++ b/qupulse/pulses/function_pulse_template.py @@ -6,7 +6,7 @@ """ -from typing import Any, Dict, List, Set, Optional, Union +from typing import Any, Dict, List, Set, Optional, Union, Tuple import numbers import numpy as np @@ -148,4 +148,8 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: sympy.integrate(self.__expression.sympified_expression, ('t', 0, self.duration.sympified_expression)) )} + def _as_expression(self) -> Tuple[Dict[ChannelID, ExpressionScalar], list]: + expr = ExpressionScalar.make(self.__expression.underlying_expression.subs({'t': self._AS_EXPRESSION_TIME})) + return {self.__channel: expr}, [] + diff --git a/qupulse/pulses/mapping_pulse_template.py b/qupulse/pulses/mapping_pulse_template.py index a9ff5a0f1..d25e1848b 100644 --- a/qupulse/pulses/mapping_pulse_template.py +++ b/qupulse/pulses/mapping_pulse_template.py @@ -354,7 +354,6 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: # todo: make Expressions compatible with sympy.subs() parameter_mapping = {parameter_name: expression.underlying_expression for parameter_name, expression in self.__parameter_mapping.items()} - for channel, ch_integral in internal_integral.items(): channel_out = self.__channel_mapping.get(channel, channel) if channel_out is None: @@ -366,6 +365,17 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: return expressions + def _as_expression(self) -> Tuple[Dict[ChannelID, ExpressionScalar], list]: + parameter_mapping = {parameter_name: expression.underlying_expression + for parameter_name, expression in self.__parameter_mapping.items()} + inner, assumptions = self.__template._as_expression() + raise NotImplementedError("map assumptions") + return { + self.__channel_mapping.get(ch, ch): ExpressionScalar(ch_expr.sympified_expression.subs(parameter_mapping)) + for ch, ch_expr in inner.items() + if self.__channel_mapping.get(ch, ch) is not None + }, assumptions + class MissingMappingException(Exception): """Indicates that no mapping was specified for some parameter declaration of a diff --git a/qupulse/pulses/point_pulse_template.py b/qupulse/pulses/point_pulse_template.py index 40abcde4c..5a355a7c3 100644 --- a/qupulse/pulses/point_pulse_template.py +++ b/qupulse/pulses/point_pulse_template.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Union, Set, Dict, Sequence, Any +from typing import Optional, List, Union, Set, Dict, Sequence, Any, Tuple from numbers import Real import itertools import numbers @@ -136,23 +136,34 @@ def parameter_names(self) -> Set[str]: @property def integral(self) -> Dict[ChannelID, ExpressionScalar]: - expressions = {channel: 0 for channel in self._channels} - for first_entry, second_entry in zip(self._entries[:-1], self._entries[1:]): - substitutions = {'t0': first_entry.t.sympified_expression, - 't1': second_entry.t.sympified_expression} - - v0 = sympy.IndexedBase(Broadcast(first_entry.v.underlying_expression, (len(self.defined_channels),))) - v1 = sympy.IndexedBase(Broadcast(second_entry.v.underlying_expression, (len(self.defined_channels),))) - - for i, channel in enumerate(self._channels): - substitutions['v0'] = v0[i] - substitutions['v1'] = v1[i] - - expressions[channel] += first_entry.interp.integral.sympified_expression.subs(substitutions) - - expressions = {c: ExpressionScalar(expressions[c]) for c in expressions} + expressions = {} + shape = (len(self.defined_channels),) + + for i, channel in enumerate(self._channels): + def value_trafo(v): + try: + return v.underlying_expression[i] + except TypeError: + return sympy.IndexedBase(Broadcast(v.underlying_expression, shape))[i] + expressions[channel] = TableEntry._sequence_integral(self._entries, expression_extractor=value_trafo) return expressions + def _as_expression(self) -> Tuple[Dict[ChannelID, ExpressionScalar], list]: + t = self._AS_EXPRESSION_TIME + shape = (len(self.defined_channels),) + expressions = {} + assumptions = [] + + for i, channel in enumerate(self._channels): + def value_trafo(v): + try: + return v.underlying_expression[i] + except TypeError: + return sympy.IndexedBase(Broadcast(v.underlying_expression, shape))[i] + + pw, assumptions = TableEntry._sequence_as_expression(self._entries, expression_extractor=value_trafo, t=t) + expressions[channel] = pw + return expressions, assumptions class InvalidPointDimension(Exception): def __init__(self, expected, received): diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index d546e6614..6a4a9dc72 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -12,6 +12,8 @@ import collections from numbers import Real, Number +import sympy + from qupulse.utils.types import ChannelID, DocStringABCMeta, FrozenDict from qupulse.serialization import Serializable from qupulse.expressions import ExpressionScalar, Expression, ExpressionLike @@ -290,6 +292,8 @@ class AtomicPulseTemplate(PulseTemplate, MeasurementDefiner): Implies that no AtomicPulseTemplate object is interruptable. """ + _AS_EXPRESSION_TIME = sympy.Dummy('_t', positive=True) + def __init__(self, *, identifier: Optional[str], measurements: Optional[List[MeasurementDeclaration]]): @@ -345,6 +349,10 @@ def build_waveform(self, does not represent a valid waveform of finite length. """ + @abstractmethod + def _as_expression(self) -> Tuple[Dict[ChannelID, ExpressionScalar], list]: + """Helper function to allow integral calculation in case of truncation. 't' is by convention the time.""" + class DoubleParameterNameException(Exception): diff --git a/qupulse/pulses/table_pulse_template.py b/qupulse/pulses/table_pulse_template.py index 0d903aa33..b8096ecca 100644 --- a/qupulse/pulses/table_pulse_template.py +++ b/qupulse/pulses/table_pulse_template.py @@ -7,13 +7,14 @@ declared parameters. """ -from typing import Union, Dict, List, Set, Optional, Any, Tuple, Sequence, NamedTuple +from typing import Union, Dict, List, Set, Optional, Any, Tuple, Sequence, NamedTuple, Callable import numbers import itertools import warnings import numpy as np import sympy +import more_itertools from qupulse.utils.types import ChannelID from qupulse.serialization import Serializer, PulseRegistryType @@ -58,6 +59,44 @@ def instantiate(self, parameters: Dict[str, numbers.Real]) -> TableWaveformEntry def get_serialization_data(self) -> tuple: return self.t.get_serialization_data(), self.v.get_serialization_data(), str(self.interp) + @classmethod + def _sequence_integral(cls, entry_sequence: Sequence['TableEntry'], + expression_extractor: Callable[[Expression], sympy.Expr]) -> ExpressionScalar: + expr = 0 + for first_entry, second_entry in more_itertools.pairwise(entry_sequence): + substitutions = {'t0': first_entry.t.sympified_expression, + 'v0': expression_extractor(first_entry.v), + 't1': second_entry.t.sympified_expression, + 'v1': expression_extractor(second_entry.v)} + + expr += first_entry.interp.integral.sympified_expression.subs(substitutions) + return ExpressionScalar(expr) + + @classmethod + def _sequence_as_expression(cls, entry_sequence: Sequence['TableEntry'], + expression_extractor: Callable[[Expression], sympy.Expr], + t: sympy.Dummy) -> Tuple[ExpressionScalar, List[sympy.AppliedPredicate]]: + + # args are tested in order + piecewise_args = [] + assumptions = [] + for first_entry, second_entry in more_itertools.pairwise(entry_sequence): + t0, t1 = first_entry.t.sympified_expression, second_entry.t.sympified_expression + substitutions = {'t0': t0, + 'v0': expression_extractor(first_entry.v), + 't1': t1, + 'v1': expression_extractor(second_entry.v), + 't': t} + time_gate = sympy.And(t0 <= t, t < t1) + + interpolation_expr = first_entry.interp.expression.underlying_expression.subs(substitutions) + + piecewise_args.append((interpolation_expr, time_gate)) + assumptions.append(sympy.Q.is_true(t0 <= t1)) + + piecewise_args.append((0, True)) + return ExpressionScalar(sympy.Piecewise(*piecewise_args)), assumptions + class TablePulseTemplate(AtomicPulseTemplate, ParameterConstrainer): """The TablePulseTemplate class implements pulses described by a table with time, voltage and interpolation strategy @@ -347,15 +386,16 @@ def is_valid_interpolation_strategy(inter): def integral(self) -> Dict[ChannelID, ExpressionScalar]: expressions = dict() for channel, channel_entries in self._entries.items(): + expressions[channel] = TableEntry._sequence_integral(channel_entries, lambda v: v.sympified_expression) - expr = 0 - for first_entry, second_entry in zip(channel_entries[:-1], channel_entries[1:]): - substitutions = {'t0': ExpressionScalar(first_entry.t).sympified_expression, 'v0': ExpressionScalar(first_entry.v).sympified_expression, - 't1': ExpressionScalar(second_entry.t).sympified_expression, 'v1': ExpressionScalar(second_entry.v).sympified_expression} - - expr += first_entry.interp.integral.sympified_expression.subs(substitutions) - expressions[channel] = ExpressionScalar(expr) + return expressions + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + expressions = dict() + for channel, channel_entries in self._entries.items(): + expressions[channel] = TableEntry._sequence_as_expression(channel_entries, + lambda v: v.sympified_expression, + t=self._AS_EXPRESSION_TIME) return expressions diff --git a/qupulse/utils/sympy.py b/qupulse/utils/sympy.py index 1e3350842..0e43df6a9 100644 --- a/qupulse/utils/sympy.py +++ b/qupulse/utils/sympy.py @@ -87,16 +87,25 @@ class Broadcast(sympy.Function): >>> assert bc.subs({'a': 2}) == sympy.Array([2, 2, 2]) >>> assert bc.subs({'a': (1, 2, 3)}) == sympy.Array([1, 2, 3]) """ + nargs = (2,) @classmethod def eval(cls, x, shape) -> Optional[sympy.Array]: - if hasattr(shape, 'free_symbols') and shape.free_symbols: + if getattr(shape, 'free_symbols', None): # cannot do anything return None if hasattr(x, '__len__') or not x.free_symbols: return sympy.Array(numpy.broadcast_to(x, shape)) + def _eval_Integral(self, *symbols, **assumptions): + x, shape = self.args + return Broadcast(sympy.Integral(x, *symbols, **assumptions), shape) + + def _eval_derivative(self, sym): + x, shape = self.args + return Broadcast(sympy.diff(x, sym), shape) + class Len(sympy.Function): nargs = 1 From 98a91c1a34ae29a023a841e8e93468cbae5e7a14 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 4 Aug 2020 13:25:08 +0200 Subject: [PATCH 05/11] Cleanup --- ReleaseNotes.txt | 7 + qupulse/expressions.py | 9 + qupulse/pulses/function_pulse_template.py | 7 +- qupulse/pulses/mapping_pulse_template.py | 7 +- .../pulses/multi_channel_pulse_template.py | 189 ++++++++++++------ qupulse/pulses/point_pulse_template.py | 3 +- qupulse/pulses/pulse_template.py | 5 +- qupulse/pulses/table_pulse_template.py | 8 +- qupulse/utils/numeric.py | 4 +- tests/_program/loop_tests.py | 4 +- .../multi_channel_pulse_template_tests.py | 26 +-- tests/pulses/point_pulse_template_tests.py | 49 +++++ tests/pulses/pulse_template_tests.py | 7 +- tests/pulses/sequencing_dummies.py | 31 ++- tests/pulses/table_pulse_template_tests.py | 2 +- 15 files changed, 257 insertions(+), 101 deletions(-) diff --git a/ReleaseNotes.txt b/ReleaseNotes.txt index 547282905..da58dabad 100644 --- a/ReleaseNotes.txt +++ b/ReleaseNotes.txt @@ -2,6 +2,13 @@ - General: - Unify `TimeType.from_float` between fractions and gmpy2 backend behaviour (fixes issue 529). + - Add central allocation function for sampled data `_program.waveforms.alloc_for_sample` that initializes with nan + per default + + - Pulse Templates: + - AtomicMultiChannelPulseTemplate: + - Remove deprecated `external_parameters` keyword argument. + - Add padding and truncation functionality with `pad_values` ## 0.5 ## diff --git a/qupulse/expressions.py b/qupulse/expressions.py index 90594dc69..54057bce8 100644 --- a/qupulse/expressions.py +++ b/qupulse/expressions.py @@ -347,6 +347,15 @@ def get_serialization_data(self) -> Union[str, float, int]: def is_nan(self) -> bool: return sympy.sympify('nan') == self._sympified_expression + def _parse_evaluate_numeric_result(self, + result: Union[Number, numpy.ndarray], + call_arguments: Any) -> Number: + parsed = super()._parse_evaluate_numeric_result(result, call_arguments) + if isinstance(parsed, numpy.ndarray): + return parsed[()] + else: + return parsed + class ExpressionVariableMissingException(Exception): """An exception indicating that a variable value was not provided during expression evaluation. diff --git a/qupulse/pulses/function_pulse_template.py b/qupulse/pulses/function_pulse_template.py index b97d00564..145863a4c 100644 --- a/qupulse/pulses/function_pulse_template.py +++ b/qupulse/pulses/function_pulse_template.py @@ -6,10 +6,9 @@ """ -from typing import Any, Dict, List, Set, Optional, Union, Tuple +from typing import Any, Dict, List, Set, Optional, Union import numbers -import numpy as np import sympy from qupulse.expressions import ExpressionScalar @@ -148,8 +147,8 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: sympy.integrate(self.__expression.sympified_expression, ('t', 0, self.duration.sympified_expression)) )} - def _as_expression(self) -> Tuple[Dict[ChannelID, ExpressionScalar], list]: + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: expr = ExpressionScalar.make(self.__expression.underlying_expression.subs({'t': self._AS_EXPRESSION_TIME})) - return {self.__channel: expr}, [] + return {self.__channel: expr} diff --git a/qupulse/pulses/mapping_pulse_template.py b/qupulse/pulses/mapping_pulse_template.py index d25e1848b..ee60e26f0 100644 --- a/qupulse/pulses/mapping_pulse_template.py +++ b/qupulse/pulses/mapping_pulse_template.py @@ -365,16 +365,15 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: return expressions - def _as_expression(self) -> Tuple[Dict[ChannelID, ExpressionScalar], list]: + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: parameter_mapping = {parameter_name: expression.underlying_expression for parameter_name, expression in self.__parameter_mapping.items()} - inner, assumptions = self.__template._as_expression() - raise NotImplementedError("map assumptions") + inner = self.__template._as_expression() return { self.__channel_mapping.get(ch, ch): ExpressionScalar(ch_expr.sympified_expression.subs(parameter_mapping)) for ch, ch_expr in inner.items() if self.__channel_mapping.get(ch, ch) is not None - }, assumptions + } class MissingMappingException(Exception): diff --git a/qupulse/pulses/multi_channel_pulse_template.py b/qupulse/pulses/multi_channel_pulse_template.py index ba7ed7a49..3f474f374 100644 --- a/qupulse/pulses/multi_channel_pulse_template.py +++ b/qupulse/pulses/multi_channel_pulse_template.py @@ -11,12 +11,14 @@ import numbers import warnings +import sympy + from qupulse.serialization import Serializer, PulseRegistryType from qupulse.parameter_scope import Scope from qupulse.utils import isclose from qupulse.utils.sympy import almost_equal, Sympifyable -from qupulse.utils.types import ChannelID, TimeType +from qupulse.utils.types import ChannelID, FrozenDict, TimeType from qupulse.utils.numeric import are_durations_compatible from qupulse._program.waveforms import MultiChannelWaveform, Waveform, TransformingWaveform from qupulse._program.transformation import ParallelConstantChannelTransformation, Transformation, chain_transformations @@ -24,22 +26,31 @@ from qupulse.pulses.mapping_pulse_template import MappingPulseTemplate, MappingTuple from qupulse.pulses.parameters import Parameter, ParameterConstrainer from qupulse.pulses.measurement import MeasurementDeclaration, MeasurementWindow -from qupulse.expressions import Expression, ExpressionScalar +from qupulse.expressions import Expression, ExpressionScalar, ExpressionLike __all__ = ["AtomicMultiChannelPulseTemplate", "ParallelConstantChannelPulseTemplate"] class AtomicMultiChannelPulseTemplate(AtomicPulseTemplate, ParameterConstrainer): - """Combines multiple PulseTemplates that are defined on different channels into an AtomicPulseTemplate.""" def __init__(self, *subtemplates: Union[AtomicPulseTemplate, MappingTuple, MappingPulseTemplate], - external_parameters: Optional[Set[str]]=None, identifier: Optional[str]=None, parameter_constraints: Optional[List]=None, measurements: Optional[List[MeasurementDeclaration]]=None, registry: PulseRegistryType=None, - duration: Union[str, Expression, bool]=False) -> None: - """Parallels multiple AtomicPulseTemplates of the same duration. The duration equality check is performed on + duration: Optional[ExpressionLike] = None, + pad_values: Mapping[ChannelID, ExpressionLike] = None) -> None: + """Parallels multiple AtomicPulseTemplates that are defined on different channels. The `duration` and + `pad_values` arguments can be used to determine how differences in the sub-templates' durations are handled. + + `duration` is True: + There are no compatibility checks performed during the initialization of this object. + `duration` is None (default): + The durations may not be incompatible if it can be determined + + + + equality check is performed on construction by default. If the duration keyword argument is given the check is performed on instantiation (when build_waveform is called). duration can be a Expression to enforce a certain duration or True for an unspecified duration. @@ -56,23 +67,52 @@ def __init__(self, AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements) ParameterConstrainer.__init__(self, parameter_constraints=parameter_constraints) + if duration in (False, True): + warnings.warn('Boolean duration is deprecated since qupulse 0.6', DeprecationWarning) + duration = None + self._subtemplates = [st if isinstance(st, PulseTemplate) else MappingPulseTemplate.from_tuple(st) for st in subtemplates] - for subtemplate in self._subtemplates: - if isinstance(subtemplate, AtomicPulseTemplate): - continue - elif isinstance(subtemplate, MappingPulseTemplate): - if isinstance(subtemplate.template, AtomicPulseTemplate): - continue - else: - raise TypeError('Non atomic subtemplate of MappingPulseTemplate: {}'.format(subtemplate.template)) - else: - raise TypeError('Non atomic subtemplate: {}'.format(subtemplate)) + if duration is None: + self._duration = None + else: + self._duration = ExpressionScalar(duration) + + if pad_values is None: + self._pad_values = FrozenDict() + else: + self._pad_values = FrozenDict((ch, None if value is None else ExpressionScalar(value)) + for ch, value in pad_values.items()) if not self._subtemplates: raise ValueError('Cannot create empty MultiChannelPulseTemplate') + if self._pad_values.keys() - self.defined_channels: + raise ValueError('Padding value for channels not defined in subtemplates', + self._pad_values.keys() - self.defined_channels) + + # factored out for easier readability + # important that asserts happen before register + self._assert_atomic_sub_templates() + self._assert_disjoint_channels() + self._assert_compatible_durations() + + self._register(registry=registry) + + def _assert_atomic_sub_templates(self): + for sub_template in self._subtemplates: + template = sub_template + while isinstance(template, MappingPulseTemplate): + template = template.template + + if not isinstance(template, AtomicPulseTemplate): + if template is sub_template: + raise TypeError('Non atomic subtemplate: {}'.format(template)) + else: + raise TypeError('Non atomic subtemplate of MappingPulseTemplate: {}'.format(template)) + + def _assert_disjoint_channels(self): defined_channels = [st.defined_channels for st in self._subtemplates] # check there are no intersections between channels @@ -83,52 +123,40 @@ def __init__(self, 'subtemplate {}'.format(i + 2 + j), (channels_i & channels_j).pop()) - if external_parameters is not None: - warnings.warn("external_parameters is an obsolete argument and will be removed in the future.", - category=DeprecationWarning) - - if not duration: - durations = list(subtemplate.duration for subtemplate in subtemplates) - are_compatible = are_durations_compatible(*durations) - - if are_compatible is False: - # durations definitely not compatible - raise ValueError('Could not assert duration equality of {} and {}'.format(repr(duration), - repr(subtemplate.duration))) - elif are_compatible is None: - # cannot assert compatibility - raise ValueError('Could not assert duration equality of {} and {}'.format(repr(duration), - repr(subtemplate.duration))) - - else: - assert are_compatible is True - - self._duration = None - elif duration is True: - self._duration = None - else: - self._duration = ExpressionScalar(duration) - - self._register(registry=registry) + def _assert_compatible_durations(self): + """Check if we can prove that durations of unpadded waveforms are incompatible.""" + unpadded_durations = [sub_template.duration + for sub_template in self._subtemplates + if sub_template.defined_channels - self._pad_values.keys()] + are_compatible = are_durations_compatible(self.duration, *unpadded_durations) + if are_compatible is False: + # durations definitely not compatible + raise ValueError('Durations are definitely not compatible: {}'.format(unpadded_durations), + unpadded_durations) @property def duration(self) -> ExpressionScalar: if self._duration: return self._duration else: - return self._subtemplates[0].duration + return ExpressionScalar(sympy.Max(*(subtemplate.duration for subtemplate in self._subtemplates))) @property def parameter_names(self) -> Set[str]: return set.union(self.measurement_parameters, self.constrained_parameters, *(st.parameter_names for st in self._subtemplates), - self._duration.variables if self._duration else ()) + self._duration.variables if self._duration else (), + *(value.variables for value in self._pad_values.values() if value is not None)) @property def subtemplates(self) -> Sequence[Union[AtomicPulseTemplate, MappingPulseTemplate]]: return self._subtemplates + @property + def pad_values(self) -> Mapping[ChannelID, Optional[Expression]]: + return self._pad_values + @property def defined_channels(self) -> Set[ChannelID]: return set.union(*(st.defined_channels for st in self._subtemplates)) @@ -148,21 +176,29 @@ def build_waveform(self, parameters: Dict[str, numbers.Real], if sub_waveform is not None: sub_waveforms.append(sub_waveform) + pad_values = {} + for ch, pad_expression in self._pad_values.items(): + ch = channel_mapping[ch] + if ch is None: + continue + elif pad_expression is None: + pad_values[ch] = None + else: + pad_values[ch] = pad_expression.evaluate_in_scope(parameters) + if len(sub_waveforms) == 0: return None - if len(sub_waveforms) == 1: - waveform = sub_waveforms[0] + if self._duration is None: + duration = None else: - waveform = MultiChannelWaveform(sub_waveforms) - - if self._duration: - expected_duration = self._duration.evaluate_numeric(**parameters) + duration = TimeType.from_float(self._duration.evaluate_numeric(**parameters)) - if not isclose(expected_duration, waveform.duration): - raise ValueError('The duration does not ' - 'equal the expected duration', - expected_duration, waveform.duration) + if len(sub_waveforms) == 1 and (duration in (None, sub_waveforms[0].duration)): + # No padding + waveform = sub_waveforms[0] + else: + waveform = MultiChannelWaveform.from_iterable(sub_waveforms, pad_values, duration=duration) return waveform @@ -188,6 +224,10 @@ def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[ data['parameter_constraints'] = [str(constraint) for constraint in self.parameter_constraints] if self.measurement_declarations: data['measurements'] = self.measurement_declarations + if self._pad_values: + data['pad_values'] = self._pad_values + if self._duration is not None: + data['duration'] = self._duration return data @@ -203,10 +243,41 @@ def deserialize(cls, serializer: Optional[Serializer]=None, **kwargs) -> 'Atomic @property def integral(self) -> Dict[ChannelID, ExpressionScalar]: - expressions = dict() - for subtemplate in self._subtemplates: - expressions.update(subtemplate.integral) - return expressions + t = self._AS_EXPRESSION_TIME + self_duration = self.duration.underlying_expression + as_expression = self._as_expression() + integral = {} + for sub_template in self._subtemplates: + if sub_template.duration == self.duration: + # we use this shortcut if there is no truncation/padding to get nicer expressions + integral.update(sub_template.integral) + else: + for ch in sub_template.defined_channels: + expr = as_expression[ch] + ch_integral = sympy.integrate(expr.underlying_expression, (t, 0, self_duration)) + integral[ch] = ExpressionScalar(ch_integral) + return integral + + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + t = self._AS_EXPRESSION_TIME + as_expression = {} + for sub_template in self.subtemplates: + sub_duration = sub_template.duration.sympified_expression + sub_as_expression = sub_template._as_expression() + + if sub_duration == self.duration: + # we use this shortcut if there is no truncation/padding to get nicer expressions + as_expression.update(sub_as_expression) + else: + padding = t > sub_duration + + for ch, ch_expr in sub_as_expression.items(): + pad_value = self._pad_values.get(ch, None) + if pad_value is None: + pad_value = ch_expr.underlying_expression.subs({t: sub_duration}) + as_expression[ch] = ExpressionScalar(sympy.Piecewise((pad_value, padding), + (ch_expr.underlying_expression, True))) + return as_expression class ParallelConstantChannelPulseTemplate(PulseTemplate): diff --git a/qupulse/pulses/point_pulse_template.py b/qupulse/pulses/point_pulse_template.py index 5a355a7c3..ffa8d1867 100644 --- a/qupulse/pulses/point_pulse_template.py +++ b/qupulse/pulses/point_pulse_template.py @@ -98,7 +98,7 @@ def build_waveform(self, if len(waveforms) == 1: return waveforms.pop() else: - return MultiChannelWaveform(waveforms) + return MultiChannelWaveform.from_iterable(waveforms) @property def point_pulse_entries(self) -> Sequence[PointPulseEntry]: @@ -165,6 +165,7 @@ def value_trafo(v): expressions[channel] = pw return expressions, assumptions + class InvalidPointDimension(Exception): def __init__(self, expected, received): super().__init__('Expected a point of dimension {} but received {}'.format(expected, received)) diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index 6a4a9dc72..aa558f561 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -350,8 +350,9 @@ def build_waveform(self, """ @abstractmethod - def _as_expression(self) -> Tuple[Dict[ChannelID, ExpressionScalar], list]: - """Helper function to allow integral calculation in case of truncation. 't' is by convention the time.""" + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + """Helper function to allow integral calculation in case of truncation. AtomicPulseTemplate._AS_EXPRESSION_TIME + is by convention the time variable.""" class DoubleParameterNameException(Exception): diff --git a/qupulse/pulses/table_pulse_template.py b/qupulse/pulses/table_pulse_template.py index b8096ecca..226b14138 100644 --- a/qupulse/pulses/table_pulse_template.py +++ b/qupulse/pulses/table_pulse_template.py @@ -75,11 +75,10 @@ def _sequence_integral(cls, entry_sequence: Sequence['TableEntry'], @classmethod def _sequence_as_expression(cls, entry_sequence: Sequence['TableEntry'], expression_extractor: Callable[[Expression], sympy.Expr], - t: sympy.Dummy) -> Tuple[ExpressionScalar, List[sympy.AppliedPredicate]]: + t: sympy.Dummy) -> ExpressionScalar: # args are tested in order piecewise_args = [] - assumptions = [] for first_entry, second_entry in more_itertools.pairwise(entry_sequence): t0, t1 = first_entry.t.sympified_expression, second_entry.t.sympified_expression substitutions = {'t0': t0, @@ -92,10 +91,9 @@ def _sequence_as_expression(cls, entry_sequence: Sequence['TableEntry'], interpolation_expr = first_entry.interp.expression.underlying_expression.subs(substitutions) piecewise_args.append((interpolation_expr, time_gate)) - assumptions.append(sympy.Q.is_true(t0 <= t1)) piecewise_args.append((0, True)) - return ExpressionScalar(sympy.Piecewise(*piecewise_args)), assumptions + return ExpressionScalar(sympy.Piecewise(*piecewise_args)) class TablePulseTemplate(AtomicPulseTemplate, ParameterConstrainer): @@ -297,7 +295,7 @@ def build_waveform(self, if len(waveforms) == 1: return waveforms.pop() else: - return MultiChannelWaveform(waveforms) + return MultiChannelWaveform.from_iterable(waveforms) @staticmethod def from_array(times: np.ndarray, voltages: np.ndarray, channels: List[ChannelID]) -> 'TablePulseTemplate': diff --git a/qupulse/utils/numeric.py b/qupulse/utils/numeric.py index caf8431d8..dc243d5b2 100644 --- a/qupulse/utils/numeric.py +++ b/qupulse/utils/numeric.py @@ -121,7 +121,7 @@ def are_durations_compatible(first_duration: Real, *other_durations: Real, for duration in other_durations: min_duration = min(min_duration, duration) max_duration = max(max_duration, duration) - assert 0 < max_duration, "At least one duration must be positive" + assert (0 <= max_duration) is not False, "At least one duration must be positive" # spread = max_duration - min_duration # allowed_spread = max(max_rel_spread * max_duration, max_abs_spread) are_compatible = max_duration - min_duration < max(max_rel_spread * max_duration, max_abs_spread) @@ -129,7 +129,7 @@ def are_durations_compatible(first_duration: Real, *other_durations: Real, return are_compatible # durations are sympy expressions with clear ordering - elif are_compatible.is_Boolean: + elif getattr(are_compatible, 'is_Boolean', False): return bool(are_compatible) else: diff --git a/tests/_program/loop_tests.py b/tests/_program/loop_tests.py index 69e363857..2254934de 100644 --- a/tests/_program/loop_tests.py +++ b/tests/_program/loop_tests.py @@ -32,8 +32,8 @@ def generate_single_channel_waveform(self, channel): defined_channels={channel}) def generate_multi_channel_waveform(self): - return MultiChannelWaveform([self.generate_single_channel_waveform(self.channel_names[ch_i]) - for ch_i in range(self.num_channels)]) + return MultiChannelWaveform.from_iterable([self.generate_single_channel_waveform(self.channel_names[ch_i]) + for ch_i in range(self.num_channels)]) def __call__(self): return self.generate_multi_channel_waveform() diff --git a/tests/pulses/multi_channel_pulse_template_tests.py b/tests/pulses/multi_channel_pulse_template_tests.py index 3e920c7d3..9e6205cb7 100644 --- a/tests/pulses/multi_channel_pulse_template_tests.py +++ b/tests/pulses/multi_channel_pulse_template_tests.py @@ -2,6 +2,7 @@ from unittest import mock import numpy +import sympy from qupulse.parameter_scope import DictScope from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform, MappingPulseTemplate,\ @@ -84,11 +85,10 @@ def test_instantiation_duration_check(self): duration='t_3', waveform=DummyWaveform(duration=4, defined_channels={'c3'}))] - with self.assertRaisesRegex(ValueError, 'duration equality'): - AtomicMultiChannelPulseTemplate(*subtemplates) - amcpt = AtomicMultiChannelPulseTemplate(*subtemplates, duration=True) - self.assertIs(amcpt.duration, subtemplates[0].duration) + self.assertIs(amcpt.duration, sympy.Max(subtemplates[0].duration, + subtemplates[1].duration, + subtemplates[2].duration)) with self.assertRaisesRegex(ValueError, 'duration'): amcpt.build_waveform(parameters=dict(t_1=3, t_2=3, t_3=3), @@ -110,15 +110,10 @@ def test_instantiation_duration_check(self): amcpt.build_waveform(parameters=dict(t_1=3+1e-11, t_2=3, t_3=3, t_0=3), channel_mapping={ch: ch for ch in 'c1 c2 c3'.split()}) - def test_external_parameters_warning(self): - with self.assertWarnsRegex(DeprecationWarning, "external_parameters", - msg="AtomicMultiChannelPulseTemplate did not issue a warning for argument external_parameters"): - AtomicMultiChannelPulseTemplate(DummyPulseTemplate(), external_parameters={'a'}) - def test_duration(self): - sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'}), - DummyPulseTemplate(duration='t1', defined_channels={'B'}), - DummyPulseTemplate(duration='t2', defined_channels={'C'})] + sts = [DummyPulseTemplate(duration=1, defined_channels={'A'}), + DummyPulseTemplate(duration=1, defined_channels={'B'}), + DummyPulseTemplate(duration=2, defined_channels={'C'})] with self.assertRaises(ValueError): AtomicMultiChannelPulseTemplate(*sts) @@ -126,7 +121,7 @@ def test_duration(self): AtomicMultiChannelPulseTemplate(sts[0], sts[2]) template = AtomicMultiChannelPulseTemplate(*sts[:1]) - self.assertEqual(template.duration, 't1') + self.assertEqual(template.duration, 1) def test_mapping_template_pure_conversion(self): template = AtomicMultiChannelPulseTemplate(*zip(self.subtemplates, self.param_maps, self.chan_maps)) @@ -183,11 +178,12 @@ def test_parameter_names_2(self): self.assertEqual({'pp1', 'pp2', 'pp3', 'hugo', 'd', 'my_duration'}, template.parameter_names) def test_integral(self) -> None: - sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'}, + sts = [DummyPulseTemplate(duration=ExpressionScalar('t1'), defined_channels={'A'}, integrals={'A': ExpressionScalar('2+k')}), - DummyPulseTemplate(duration='t1', defined_channels={'B', 'C'}, + DummyPulseTemplate(duration=ExpressionScalar('t1'), defined_channels={'B', 'C'}, integrals={'B': ExpressionScalar('t1-t0*3.1'), 'C': ExpressionScalar('l')})] pulse = AtomicMultiChannelPulseTemplate(*sts) + self.assertEqual({'A': ExpressionScalar('2+k'), 'B': ExpressionScalar('t1-t0*3.1'), 'C': ExpressionScalar('l')}, diff --git a/tests/pulses/point_pulse_template_tests.py b/tests/pulses/point_pulse_template_tests.py index 969a240c5..9d75bdc44 100644 --- a/tests/pulses/point_pulse_template_tests.py +++ b/tests/pulses/point_pulse_template_tests.py @@ -97,6 +97,13 @@ def test_integral(self) -> None: ppt = PointPulseTemplate([(0, 0), ('t_init', 0)], ['X', 'Y']) self.assertEqual(ppt.integral, {'X': 0, 'Y': 0}) + ppt = PointPulseTemplate([(0., 'a', 'linear'), ('t_1', 'b'), ('t_2', (0, 0))], ('X', 'Y')) + parameters = {'a': (3.4, 4.1), 'b': 4, 't_1': 2, 't_2': 5} + integral = {ch: v.evaluate_in_scope(parameters) for ch, v in ppt.integral.items()} + self.assertEqual({'X': 2 * (3.4 + 4) / 2 + (5 - 2) * 4, + 'Y': 2 * (4.1 + 4) / 2 + (5 - 2) * 4}, + integral) + class PointPulseTemplateSequencingTests(unittest.TestCase): def test_build_waveform_empty(self): @@ -290,3 +297,45 @@ def test_serializer_integration_old(self): self.assertEqual(template.point_pulse_entries, self.template.point_pulse_entries) self.assertEqual(template.measurement_declarations, self.template.measurement_declarations) self.assertEqual(template.parameter_constraints, self.template.parameter_constraints) + + +class PointPulseExpressionIntegralTests(unittest.TestCase): + def test_integral_as_expression_compatible(self): + import sympy + from sympy import Q + from sympy.assumptions import assuming + template = PointPulseTemplate(**PointPulseTemplateSerializationTests().make_kwargs()) + + t = template._AS_EXPRESSION_TIME + as_expression, assumptions = template._as_expression() + integral = template.integral + duration = template.duration.underlying_expression + + self.assertEqual(template.defined_channels, integral.keys()) + self.assertEqual(template.defined_channels, as_expression.keys()) + + assumptions.append(Q.is_true(t <= duration)) + assumptions.append(Q.is_true(0 <= duration)) + + parameter_sets = [ + {'foo': 1., 'hugo': 2., 'sudo': 3., 'A': 4., 'B': 5., 'a': 6.}, + {'foo': 1.1, 'hugo': 2.6, 'sudo': 2.7, 'A': np.array([3., 4.]), 'B': 5., 'a': 6.}, + ] + + with assuming(*assumptions): + for channel in template.defined_channels: + ch_expr = as_expression[channel].underlying_expression + ch_int = integral[channel].underlying_expression + + symbolic = sympy.integrate(ch_expr, (t, 0, duration)) + for assumption in assumptions: + symbolic = sympy.refine(symbolic, assumption) + symbolic = sympy.simplify(symbolic) + + for parameters in parameter_sets: + num_from_expr = ExpressionScalar(symbolic).evaluate_in_scope(parameters) + num_from_in = ExpressionScalar(ch_int).evaluate_in_scope(parameters) + np.testing.assert_almost_equal(num_from_in, num_from_expr) + + # TODO: the following fails + # self.assertEqual(ch_int, symbolic) diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index 6ef448118..f78793685 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -3,6 +3,7 @@ from unittest import mock from typing import Optional, Dict, Set, Any, Union +import sympy from qupulse.parameter_scope import Scope, DictScope from qupulse.utils.types import ChannelID @@ -134,6 +135,9 @@ def duration(self) -> Expression: def integral(self) -> Dict[ChannelID, ExpressionScalar]: raise NotImplementedError() + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + raise NotImplementedError() + class PulseTemplateTest(unittest.TestCase): @@ -352,7 +356,7 @@ class AtomicPulseTemplateTests(unittest.TestCase): def test_internal_create_program(self) -> None: measurement_windows = [('M', 0, 5)] single_wf = DummyWaveform(duration=6, defined_channels={'A'}) - wf = MultiChannelWaveform([single_wf]) + wf = MultiChannelWaveform.from_iterable([single_wf]) template = AtomicPulseTemplateStub(measurements=measurement_windows, parameter_names={'foo'}) scope = DictScope.from_kwargs(foo=7.2, volatile={'gutes_zeuch'}) @@ -437,3 +441,4 @@ def test_internal_create_program_volatile(self): to_single_waveform=set(), global_transformation=None) self.assertEqual(Loop(), program) + diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 549935a32..d22952bed 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -75,9 +75,9 @@ def __hash__(self): class DummyWaveform(Waveform): - def __init__(self, duration: float=0, sample_output: Union[numpy.ndarray, dict]=None, defined_channels=None) -> None: + def __init__(self, duration: Union[float, TimeType]=0, sample_output: Union[numpy.ndarray, dict]=None, defined_channels=None) -> None: super().__init__() - self.duration_ = TimeType.from_float(duration) + self.duration_ = duration if isinstance(duration, TimeType) else TimeType.from_float(duration) self.sample_output = sample_output if defined_channels is None: if isinstance(sample_output, dict): @@ -142,6 +142,15 @@ def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'Waveform' def defined_channels(self): return self.defined_channels_ + def last_value(self, channel) -> float: + if self.sample_output is None: + return 0. + elif isinstance(self.sample_output, dict): + sample_output = self.sample_output[channel] + else: + sample_output = self.sample_output + return sample_output[-1] + class DummyInterpolationStrategy(InterpolationStrategy): @@ -168,13 +177,13 @@ class DummyPulseTemplate(AtomicPulseTemplate): def __init__(self, requires_stop: bool=False, - parameter_names: Set[str]={}, - defined_channels: Set[ChannelID]={'default'}, + parameter_names: Set[str]=set(), + defined_channels: Set[ChannelID]=None, duration: Any=0, waveform: Waveform=tuple(), measurement_names: Set[str] = set(), measurements: list=list(), - integrals: Dict[ChannelID, ExpressionScalar]={'default': ExpressionScalar(0)}, + integrals: Dict[ChannelID, ExpressionScalar]=None, program: Optional[Loop]=None, identifier=None, registry=None) -> None: @@ -182,6 +191,11 @@ def __init__(self, self.requires_stop_ = requires_stop self.requires_stop_arguments = [] + if defined_channels is None: + defined_channels = {'default'} + if integrals is None: + integrals = {ch: ExpressionScalar(0) for ch in defined_channels} + self.parameter_names_ = parameter_names self.defined_channels_ = defined_channels self._duration = Expression(duration) @@ -252,3 +266,10 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: def compare_key(self) -> Tuple[Any, ...]: return (self.requires_stop_, self.parameter_names, self.defined_channels, self.duration, self.waveform, self.measurement_names, self.integral) + + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + assert self.duration != 0 + t = self._AS_EXPRESSION_TIME + duration = self.duration.underlying_expression + return {ch: ExpressionScalar(integral.underlying_expression*t/duration) + for ch, integral in self.integral.items()} diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index 8fc7fd49c..d94898e7d 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -422,7 +422,7 @@ def test_integral(self) -> None: 'symbolic': [(3, 'a', 'hold'), ('b', 4, 'linear'), ('c', Expression('d'), 'hold')]}) expected = {0: Expression('6'), 'other_channel': Expression(7), - 'symbolic': Expression('(b-3.)*a + (c-b)*(d+4.) / 2')} + 'symbolic': Expression('(b-3)*a + (c-b)*(d+4) / 2')} self.assertEqual(expected, pulse.integral) From 8d3f2c81c0f2868d151489a2e85d9cad653175f2 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Tue, 4 Aug 2020 13:44:42 +0200 Subject: [PATCH 06/11] Fix tests --- qupulse/pulses/point_pulse_template.py | 10 ++-- tests/_program/waveforms_tests.py | 2 +- .../multi_channel_pulse_template_tests.py | 10 ++-- tests/pulses/point_pulse_template_tests.py | 47 ++++++++----------- tests/pulses/table_pulse_template_tests.py | 4 +- 5 files changed, 34 insertions(+), 39 deletions(-) diff --git a/qupulse/pulses/point_pulse_template.py b/qupulse/pulses/point_pulse_template.py index ffa8d1867..63939aa33 100644 --- a/qupulse/pulses/point_pulse_template.py +++ b/qupulse/pulses/point_pulse_template.py @@ -64,7 +64,8 @@ def defined_channels(self) -> Set[ChannelID]: def build_waveform(self, parameters: Dict[str, Real], - channel_mapping: Dict[ChannelID, Optional[ChannelID]]) -> Optional[TableWaveform]: + channel_mapping: Dict[ChannelID, Optional[ChannelID]]) -> Optional[Union[TableWaveform, + MultiChannelWaveform]]: self.validate_parameter_constraints(parameters=parameters, volatile=set()) if all(channel_mapping[channel] is None @@ -148,11 +149,10 @@ def value_trafo(v): expressions[channel] = TableEntry._sequence_integral(self._entries, expression_extractor=value_trafo) return expressions - def _as_expression(self) -> Tuple[Dict[ChannelID, ExpressionScalar], list]: + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: t = self._AS_EXPRESSION_TIME shape = (len(self.defined_channels),) expressions = {} - assumptions = [] for i, channel in enumerate(self._channels): def value_trafo(v): @@ -161,9 +161,9 @@ def value_trafo(v): except TypeError: return sympy.IndexedBase(Broadcast(v.underlying_expression, shape))[i] - pw, assumptions = TableEntry._sequence_as_expression(self._entries, expression_extractor=value_trafo, t=t) + pw = TableEntry._sequence_as_expression(self._entries, expression_extractor=value_trafo, t=t) expressions[channel] = pw - return expressions, assumptions + return expressions class InvalidPointDimension(Exception): diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index a1757654e..4c0e1cfc8 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -176,7 +176,7 @@ def test_padding(self): duration = TimeType.from_float(4) sub_duration = TimeType.from_float(3.2) pad_sample_times = numpy.linspace(0, float(duration), num=11, dtype=float) - n_sub, n_pad = sum(pad_sample_times <= sub_duration), sum(pad_sample_times > sub_duration) + n_sub, n_pad = sum(pad_sample_times <= float(sub_duration)), sum(pad_sample_times > float(sub_duration)) no_pad_sample_times = pad_sample_times[:n_sub] samples_a = numpy.linspace(4, 5, n_sub) samples_b = numpy.linspace(2, 3, n_sub) diff --git a/tests/pulses/multi_channel_pulse_template_tests.py b/tests/pulses/multi_channel_pulse_template_tests.py index 9e6205cb7..72f9b8f4a 100644 --- a/tests/pulses/multi_channel_pulse_template_tests.py +++ b/tests/pulses/multi_channel_pulse_template_tests.py @@ -85,10 +85,12 @@ def test_instantiation_duration_check(self): duration='t_3', waveform=DummyWaveform(duration=4, defined_channels={'c3'}))] - amcpt = AtomicMultiChannelPulseTemplate(*subtemplates, duration=True) - self.assertIs(amcpt.duration, sympy.Max(subtemplates[0].duration, - subtemplates[1].duration, - subtemplates[2].duration)) + with self.assertWarnsRegex(DeprecationWarning, "Boolean duration is deprecated since qupulse 0.6"): + amcpt = AtomicMultiChannelPulseTemplate(*subtemplates, duration=True) + self.assertEqual(amcpt.duration.sympified_expression, + sympy.Max(subtemplates[0].duration, + subtemplates[1].duration, + subtemplates[2].duration)) with self.assertRaisesRegex(ValueError, 'duration'): amcpt.build_waveform(parameters=dict(t_1=3, t_2=3, t_3=3), diff --git a/tests/pulses/point_pulse_template_tests.py b/tests/pulses/point_pulse_template_tests.py index 9d75bdc44..527738afb 100644 --- a/tests/pulses/point_pulse_template_tests.py +++ b/tests/pulses/point_pulse_template_tests.py @@ -154,10 +154,10 @@ def test_build_waveform_multi_channel_same(self): (1., 0., HoldInterpolationStrategy()), (1.1, 21., LinearInterpolationStrategy())]) self.assertEqual(wf.defined_channels, {1, 'A'}) - self.assertEqual(wf._sub_waveforms[0].defined_channels, {1}) - self.assertEqual(wf._sub_waveforms[0], expected_1) - self.assertEqual(wf._sub_waveforms[1].defined_channels, {'A'}) - self.assertEqual(wf._sub_waveforms[1], expected_A) + self.assertEqual(wf._wf_pad[1][0].defined_channels, {1}) + self.assertEqual(wf._wf_pad[1][0], expected_1) + self.assertEqual(wf._wf_pad['A'][0].defined_channels, {'A'}) + self.assertEqual(wf._wf_pad['A'][0], expected_A) def test_build_waveform_multi_channel_vectorized(self): ppt = PointPulseTemplate([('t1', 'A'), @@ -175,10 +175,10 @@ def test_build_waveform_multi_channel_vectorized(self): (1., 0., HoldInterpolationStrategy()), (1.1, 20., LinearInterpolationStrategy())]) self.assertEqual(wf.defined_channels, {1, 'A'}) - self.assertEqual(wf._sub_waveforms[0].defined_channels, {1}) - self.assertEqual(wf._sub_waveforms[0], expected_1) - self.assertEqual(wf._sub_waveforms[1].defined_channels, {'A'}) - self.assertEqual(wf._sub_waveforms[1], expected_A) + self.assertEqual(wf._wf_pad[1][0].defined_channels, {1}) + self.assertEqual(wf._wf_pad[1][0], expected_1) + self.assertEqual(wf._wf_pad['A'][0].defined_channels, {'A'}) + self.assertEqual(wf._wf_pad['A'][0], expected_A) def test_build_waveform_none_channel(self): ppt = PointPulseTemplate([('t1', 'A'), @@ -303,39 +303,32 @@ class PointPulseExpressionIntegralTests(unittest.TestCase): def test_integral_as_expression_compatible(self): import sympy from sympy import Q - from sympy.assumptions import assuming template = PointPulseTemplate(**PointPulseTemplateSerializationTests().make_kwargs()) t = template._AS_EXPRESSION_TIME - as_expression, assumptions = template._as_expression() + as_expression = template._as_expression() integral = template.integral duration = template.duration.underlying_expression self.assertEqual(template.defined_channels, integral.keys()) self.assertEqual(template.defined_channels, as_expression.keys()) - assumptions.append(Q.is_true(t <= duration)) - assumptions.append(Q.is_true(0 <= duration)) - parameter_sets = [ {'foo': 1., 'hugo': 2., 'sudo': 3., 'A': 4., 'B': 5., 'a': 6.}, {'foo': 1.1, 'hugo': 2.6, 'sudo': 2.7, 'A': np.array([3., 4.]), 'B': 5., 'a': 6.}, ] - with assuming(*assumptions): - for channel in template.defined_channels: - ch_expr = as_expression[channel].underlying_expression - ch_int = integral[channel].underlying_expression + for channel in template.defined_channels: + ch_expr = as_expression[channel].underlying_expression + ch_int = integral[channel].underlying_expression - symbolic = sympy.integrate(ch_expr, (t, 0, duration)) - for assumption in assumptions: - symbolic = sympy.refine(symbolic, assumption) - symbolic = sympy.simplify(symbolic) + symbolic = sympy.integrate(ch_expr, (t, 0, duration)) + symbolic = sympy.simplify(symbolic) - for parameters in parameter_sets: - num_from_expr = ExpressionScalar(symbolic).evaluate_in_scope(parameters) - num_from_in = ExpressionScalar(ch_int).evaluate_in_scope(parameters) - np.testing.assert_almost_equal(num_from_in, num_from_expr) + for parameters in parameter_sets: + num_from_expr = ExpressionScalar(symbolic).evaluate_in_scope(parameters) + num_from_in = ExpressionScalar(ch_int).evaluate_in_scope(parameters) + np.testing.assert_almost_equal(num_from_in, num_from_expr) - # TODO: the following fails - # self.assertEqual(ch_int, symbolic) + # TODO: the following fails even with a lot of assumptions in sympy 1.6 + # self.assertEqual(ch_int, symbolic) diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index d08732777..049cbc58d 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -614,10 +614,10 @@ def test_build_waveform_multi_channel(self): channel_mapping=channel_mapping) self.assertIsInstance(waveform, MultiChannelWaveform) - self.assertEqual(len(waveform._sub_waveforms), 2) + self.assertEqual(len(waveform._wf_pad), 2) channels = {'oh', 'ch'} - for wf in waveform._sub_waveforms: + for wf, _ in waveform._wf_pad.values(): self.assertIsInstance(wf, TableWaveform) self.assertIn(wf._channel_id, channels) channels.remove(wf._channel_id) From 835441d2448f784d836cc80e6f7a26d0182e265c Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 5 Aug 2020 12:30:10 +0200 Subject: [PATCH 07/11] Add test, doc and a small fix to TableEnty sequence classmethods --- qupulse/pulses/mapping_pulse_template.py | 5 ++- qupulse/pulses/table_pulse_template.py | 27 ++++++++++-- tests/pulses/table_pulse_template_tests.py | 51 +++++++++++++++++++++- 3 files changed, 77 insertions(+), 6 deletions(-) diff --git a/qupulse/pulses/mapping_pulse_template.py b/qupulse/pulses/mapping_pulse_template.py index ee60e26f0..af5b64e00 100644 --- a/qupulse/pulses/mapping_pulse_template.py +++ b/qupulse/pulses/mapping_pulse_template.py @@ -360,7 +360,7 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: continue expressions[channel_out] = ExpressionScalar( - ch_integral.sympified_expression.subs(parameter_mapping) + ch_integral.sympified_expression.subs(parameter_mapping, simultaneous=True) ) return expressions @@ -370,7 +370,8 @@ def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: for parameter_name, expression in self.__parameter_mapping.items()} inner = self.__template._as_expression() return { - self.__channel_mapping.get(ch, ch): ExpressionScalar(ch_expr.sympified_expression.subs(parameter_mapping)) + self.__channel_mapping.get(ch, ch): ExpressionScalar(ch_expr.sympified_expression.subs(parameter_mapping, + simultaneous=True)) for ch, ch_expr in inner.items() if self.__channel_mapping.get(ch, ch) is not None } diff --git a/qupulse/pulses/table_pulse_template.py b/qupulse/pulses/table_pulse_template.py index a1b86a84c..7ce492960 100644 --- a/qupulse/pulses/table_pulse_template.py +++ b/qupulse/pulses/table_pulse_template.py @@ -63,20 +63,40 @@ def get_serialization_data(self) -> tuple: @classmethod def _sequence_integral(cls, entry_sequence: Sequence['TableEntry'], expression_extractor: Callable[[Expression], sympy.Expr]) -> ExpressionScalar: + """Returns an expression for the time integral over the complete sequence of table entries. + + Args: + entry_sequence: Sequence of table entries. Assumed to be ordered by time. + expression_extractor: Convert each entry's voltage into a sympy expression. Can be used to select single + channels from a vectorized expression. + + Returns: + Scalar expression for the integral. + """ expr = 0 for first_entry, second_entry in more_itertools.pairwise(entry_sequence): substitutions = {'t0': first_entry.t.sympified_expression, 'v0': expression_extractor(first_entry.v), 't1': second_entry.t.sympified_expression, 'v1': expression_extractor(second_entry.v)} - - expr += first_entry.interp.integral.sympified_expression.subs(substitutions) + expr += first_entry.interp.integral.sympified_expression.subs(substitutions, simultaneous=True) return ExpressionScalar(expr) @classmethod def _sequence_as_expression(cls, entry_sequence: Sequence['TableEntry'], expression_extractor: Callable[[Expression], sympy.Expr], t: sympy.Dummy) -> ExpressionScalar: + """Create an expression out of a sequence of table entries. + + Args: + entry_sequence: Table entries to be represented as an expression. They are assumed to be ordered by time. + expression_extractor: Convert each entry's voltage into a sympy expression. Can be used to select single + channels from a vectorized expression. + t: Time variable + + Returns: + Scalar expression that covers the complete sequence and is zero outside. + """ # args are tested in order piecewise_args = [] @@ -89,7 +109,8 @@ def _sequence_as_expression(cls, entry_sequence: Sequence['TableEntry'], 't': t} time_gate = sympy.And(t0 <= t, t < t1) - interpolation_expr = first_entry.interp.expression.underlying_expression.subs(substitutions) + interpolation_expr = first_entry.interp.expression.underlying_expression.subs(substitutions, + simultaneous=True) piecewise_args.append((interpolation_expr, time_gate)) diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index 049cbc58d..ed298be39 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -2,8 +2,9 @@ import warnings import numpy +import sympy -from qupulse.expressions import Expression +from qupulse.expressions import Expression, ExpressionScalar from qupulse.serialization import Serializer from qupulse.pulses.table_pulse_template import TablePulseTemplate, TableWaveform, TableEntry, TableWaveformEntry, ZeroDurationTablePulseTemplate, AmbiguousTablePulseEntry, concatenate from qupulse.pulses.parameters import ParameterNotProvidedException, ParameterConstraintViolation, ParameterConstraint @@ -38,6 +39,54 @@ def test_unknown_interpolation_strategy(self): with self.assertRaises(KeyError): TableEntry(0, 0, 'foo') + def test_sequence_integral(self): + def get_sympy(v): + return v.sympified_expression + + entries = [TableEntry(0, 0, 'hold'), TableEntry(1, 0, 'hold')] + self.assertEqual(ExpressionScalar(0), TableEntry._sequence_integral(entries, get_sympy)) + + entries = [TableEntry(0, 1, 'hold'), TableEntry(1, 1, 'hold')] + self.assertEqual(ExpressionScalar(1), TableEntry._sequence_integral(entries, get_sympy)) + + entries = [TableEntry(0, 0, 'linear'), TableEntry(1, 1, 'hold')] + self.assertEqual(ExpressionScalar(.5), TableEntry._sequence_integral(entries, get_sympy)) + + entries = [TableEntry('t0', 'a', 'linear'), TableEntry('t1', 'b', 'hold'), TableEntry('t2', 'c', 'hold')] + self.assertEqual(ExpressionScalar('(t1-t0)*(a+b)/2 + (t2-t1)*b'), + TableEntry._sequence_integral(entries, get_sympy)) + + def test_sequence_as_expression(self): + def get_sympy(v): + return v.sympified_expression + + t = sympy.Dummy('t') + + times = { + t: 0.5, + 't0': 0.3, + 't1': 0.7, + 't2': 1.3, + } + + entries = [TableEntry(0, 0, 'hold'), TableEntry(1, 0, 'hold')] + self.assertEqual(ExpressionScalar(0), + TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) + + entries = [TableEntry(0, 1, 'hold'), TableEntry(1, 1, 'hold')] + self.assertEqual(ExpressionScalar(1), + TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) + + entries = [TableEntry(0, 0, 'linear'), TableEntry(1, 1, 'hold')] + self.assertEqual(ExpressionScalar(.5), + TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) + + entries = [TableEntry('t0', 'a', 'linear'), + TableEntry('t1', 'b', 'hold'), + TableEntry('t2', 'c', 'hold')] + self.assertEqual(ExpressionScalar('(a+b)*.5'), + TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) + class TablePulseTemplateTest(unittest.TestCase): def __init__(self, *args, **kwargs): From 3934008536f742c173d8d6cce76f8dd65c774ca3 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 5 Aug 2020 12:46:53 +0200 Subject: [PATCH 08/11] Additional tests --- tests/utils/time_type_tests.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/utils/time_type_tests.py b/tests/utils/time_type_tests.py index 6273d501c..d5dcecd62 100644 --- a/tests/utils/time_type_tests.py +++ b/tests/utils/time_type_tests.py @@ -5,6 +5,7 @@ import importlib import fractions import random +import math from unittest import mock try: @@ -66,6 +67,14 @@ def test_non_finite_float(self): with self.assertRaisesRegex(ValueError, 'Cannot represent'): qutypes.TimeType.from_float(float('nan')) + def assert_self_init_works(self, time_type): + t = time_type.from_fraction(1, 3) + self.assertIs(t._value, time_type(t)._value) + + def test_self_init(self): + self.assert_self_init_works(self.fallback_qutypes.TimeType) + self.assert_self_init_works(qutypes.TimeType) + def test_fraction_fallback(self): self.assertIs(fractions.Fraction, self.fallback_qutypes.TimeType._InternalType) @@ -180,6 +189,29 @@ def test_comparisons_work(self): def test_comparisons_work_fallback(self): self.assert_comparisons_work(self.fallback_qutypes.TimeType) + def assert_simple_arithmetic_work(self, time_type): + t1 = time_type.from_fraction(19, 3) + as_frac = fractions.Fraction(19, 3) + + self.assertEqual(math.ceil(as_frac), math.ceil(t1)) + self.assertEqual(math.floor(as_frac), math.floor(t1)) + self.assertEqual(math.trunc(as_frac), math.trunc(t1)) + self.assertEqual(3 % as_frac, 3 % t1) + self.assertEqual(-as_frac, -t1) + self.assertIs(t1, +t1) + self.assertEqual(as_frac**3, t1**3) + self.assertEqual(3 ** as_frac, 3 ** t1) + self.assertEqual(3 / as_frac, 3 / t1) + self.assertEqual(as_frac // 3, t1 // 3) + self.assertEqual(3 // as_frac, 3 // t1) + + def test_simple_arithmetic(self): + self.assert_simple_arithmetic_work(qutypes.TimeType) + self.assert_simple_arithmetic_work(self.fallback_qutypes.TimeType) + + def test_time_from_fraction(self): + self.assertEqual(qutypes.time_from_fraction(1, 3), qutypes.TimeType.from_fraction(1, 3)) + def get_some_floats(seed=42, n=1000): rand = random.Random(seed) From 5b58f8fcf53efe2515347f6d21cf348528aa0933 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 5 Aug 2020 13:07:38 +0200 Subject: [PATCH 09/11] Add as_expression test stubs --- tests/pulses/arithmetic_pulse_template_tests.py | 3 +++ tests/pulses/function_pulse_tests.py | 3 +++ tests/pulses/multi_channel_pulse_template_tests.py | 3 +++ tests/pulses/table_pulse_template_tests.py | 3 +++ 4 files changed, 12 insertions(+) diff --git a/tests/pulses/arithmetic_pulse_template_tests.py b/tests/pulses/arithmetic_pulse_template_tests.py index df2929ca3..5bd4bd2dc 100644 --- a/tests/pulses/arithmetic_pulse_template_tests.py +++ b/tests/pulses/arithmetic_pulse_template_tests.py @@ -118,6 +118,9 @@ def test_integral(self): self.assertEqual(expected_plus, (lhs + rhs).integral) self.assertEqual(expected_minus, (lhs - rhs).integral) + def test_as_expression(self): + raise NotImplementedError() + def test_duration(self): lhs = DummyPulseTemplate(duration=ExpressionScalar('x'), defined_channels={'a', 'b'}, parameter_names={'x', 'y'}) rhs = DummyPulseTemplate(duration=ExpressionScalar('y'), defined_channels={'a', 'c'}, parameter_names={'x', 'z'}) diff --git a/tests/pulses/function_pulse_tests.py b/tests/pulses/function_pulse_tests.py index b33c91e2e..59b76a23e 100644 --- a/tests/pulses/function_pulse_tests.py +++ b/tests/pulses/function_pulse_tests.py @@ -84,6 +84,9 @@ def test_integral(self) -> None: pulse = FunctionPulseTemplate('sin(0.5*t+b)', '2*Tmax') self.assertEqual({'default': Expression('2.0*cos(b) - 2.0*cos(1.0*Tmax+b)')}, pulse.integral) + def test_as_expression(self): + raise NotImplementedError() + class FunctionPulseSerializationTest(SerializableTests, unittest.TestCase): diff --git a/tests/pulses/multi_channel_pulse_template_tests.py b/tests/pulses/multi_channel_pulse_template_tests.py index 72f9b8f4a..c2eb661c3 100644 --- a/tests/pulses/multi_channel_pulse_template_tests.py +++ b/tests/pulses/multi_channel_pulse_template_tests.py @@ -191,6 +191,9 @@ def test_integral(self) -> None: 'C': ExpressionScalar('l')}, pulse.integral) + def test_as_expression(self): + raise NotImplementedError() + class MultiChannelPulseTemplateSequencingTests(unittest.TestCase): def test_build_waveform(self): diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index ed298be39..8ad2243cf 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -475,6 +475,9 @@ def test_integral(self) -> None: self.assertEqual(expected, pulse.integral) + def test_as_expression(self): + raise NotImplementedError() + class TablePulseTemplateConstraintTest(ParameterConstrainerTest): def __init__(self, *args, **kwargs): From d616f3c5c92f69495f416fdd4917d39227118c81 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 5 Aug 2020 15:27:34 +0200 Subject: [PATCH 10/11] Add more tests and remove requirement of padding value for truncation --- qupulse/_program/waveforms.py | 17 ++++++++--- .../pulses/arithmetic_pulse_template_tests.py | 29 ++++++++++++++++++- tests/pulses/function_pulse_tests.py | 4 ++- tests/pulses/mapping_pulse_template_tests.py | 21 ++++++++++++++ .../multi_channel_pulse_template_tests.py | 21 ++++++++++++++ 5 files changed, 86 insertions(+), 6 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 31911ae86..5273e2df1 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -386,7 +386,7 @@ def __init__(self, assert ch in waveform.defined_channels if ch not in pad_values: - assert are_durations_compatible(duration, waveform.duration) + assert waveform.duration > duration or are_durations_compatible(duration, waveform.duration) # add default pad that is only required in corner cases of numeric accuracy pad_value = pad_values.get(ch, None) @@ -409,7 +409,7 @@ def from_iterable(cls, sub_waveforms (Iterable( Waveform )): The list of sub waveforms of this MultiChannelWaveform pad_values: Value for padding if desired. None implies :py:meth:`Waveform.last_value`. Channels not - mentioned must have a compatible duration. + mentioned must have a longer or compatible duration. duration: Duration of this waveform. None implies the maximum subwaveform duration. Raises: ValueError, if `sub_waveforms` is empty @@ -434,9 +434,12 @@ def from_iterable(cls, for waveform in sub_waveforms: # if pad is not defined the sub waveform duration needs to be compatible with the overall duration undefined_pad = waveform.defined_channels - pad_values.keys() - if undefined_pad and not are_durations_compatible(duration, waveform.duration): + if waveform.duration > duration: + # truncation is allowed + pass + elif undefined_pad and not are_durations_compatible(duration, waveform.duration): # prepare error message - incompatible_durations.setdefault(waveform.duration, set()).intersection_update(undefined_pad) + incompatible_durations.setdefault(waveform.duration, set()).update(undefined_pad) defined_channels.update(waveform.defined_channels) @@ -532,6 +535,12 @@ def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'W self._duration ) + def __repr__(self): + sub_waveforms = {ch: wf for ch, (wf, _) in self._wf_pad.items()} + pad_values = {ch: pad for ch, (_, pad) in self._wf_pad.items()} + duration = self.duration + return f"{type(self).__name__}(sub_waveforms={sub_waveforms}, pad_values={pad_values}, duration={duration})" + class RepetitionWaveform(Waveform): """This class allows putting multiple PulseTemplate together in one waveform on the hardware.""" diff --git a/tests/pulses/arithmetic_pulse_template_tests.py b/tests/pulses/arithmetic_pulse_template_tests.py index 5bd4bd2dc..11b7d65d7 100644 --- a/tests/pulses/arithmetic_pulse_template_tests.py +++ b/tests/pulses/arithmetic_pulse_template_tests.py @@ -119,7 +119,34 @@ def test_integral(self): self.assertEqual(expected_minus, (lhs - rhs).integral) def test_as_expression(self): - raise NotImplementedError() + integrals_lhs = dict(a=ExpressionScalar('a_lhs'), b=ExpressionScalar('b')) + integrals_rhs = dict(a=ExpressionScalar('a_rhs'), c=ExpressionScalar('c')) + + duration = 4 + t = DummyPulseTemplate._AS_EXPRESSION_TIME + expr_lhs = {ch: i * t / duration for ch, i in integrals_lhs.items()} + expr_rhs = {ch: i * t / duration for ch, i in integrals_rhs.items()} + + lhs = DummyPulseTemplate(duration=duration, defined_channels={'a', 'b'}, + parameter_names={'x', 'y'}, integrals=integrals_lhs) + rhs = DummyPulseTemplate(duration=duration, defined_channels={'a', 'c'}, + parameter_names={'x', 'z'}, integrals=integrals_rhs) + + expected_added = { + 'a': expr_lhs['a'] + expr_rhs['a'], + 'b': expr_lhs['b'], + 'c': expr_rhs['c'] + } + added_expr = (lhs + rhs)._as_expression() + self.assertEqual(expected_added, added_expr) + + subs_expr = (lhs - rhs)._as_expression() + expected_subs = { + 'a': expr_lhs['a'] - expr_rhs['a'], + 'b': expr_lhs['b'], + 'c': -expr_rhs['c'] + } + self.assertEqual(expected_subs, subs_expr) def test_duration(self): lhs = DummyPulseTemplate(duration=ExpressionScalar('x'), defined_channels={'a', 'b'}, parameter_names={'x', 'y'}) diff --git a/tests/pulses/function_pulse_tests.py b/tests/pulses/function_pulse_tests.py index 59b76a23e..ba4d214da 100644 --- a/tests/pulses/function_pulse_tests.py +++ b/tests/pulses/function_pulse_tests.py @@ -85,7 +85,9 @@ def test_integral(self) -> None: self.assertEqual({'default': Expression('2.0*cos(b) - 2.0*cos(1.0*Tmax+b)')}, pulse.integral) def test_as_expression(self): - raise NotImplementedError() + pulse = FunctionPulseTemplate('sin(0.5*t+b)', '2*Tmax') + expr = sympy.sin(0.5 * pulse._AS_EXPRESSION_TIME + sympy.sympify('b')) + self.assertEqual({'default': Expression.make(expr)}, pulse._as_expression()) class FunctionPulseSerializationTest(SerializableTests, unittest.TestCase): diff --git a/tests/pulses/mapping_pulse_template_tests.py b/tests/pulses/mapping_pulse_template_tests.py index 2a70c78da..2c9d4355f 100644 --- a/tests/pulses/mapping_pulse_template_tests.py +++ b/tests/pulses/mapping_pulse_template_tests.py @@ -253,6 +253,26 @@ def test_integral(self) -> None: self.assertEqual({'a': Expression('2*f'), 'B': Expression('-3.2*f+2.3')}, pulse.integral) + def test_as_expression(self): + from sympy.abc import f, k, b + duration = 5 + dummy = DummyPulseTemplate(defined_channels={'A', 'B', 'C'}, + parameter_names={'k', 'f', 'b'}, + integrals={'A': Expression(2 * k), + 'B': Expression(-3.2*f+b), + 'C': Expression(1)}, duration=duration) + t = DummyPulseTemplate._AS_EXPRESSION_TIME + dummy_expr = {ch: i * t / duration for ch, i in dummy._integrals.items()} + pulse = MappingPulseTemplate(dummy, parameter_mapping={'k': 'f', 'b': 2.3}, channel_mapping={'A': 'a', + 'C': None}, + allow_partial_parameter_mapping=True) + + expected = { + 'a': Expression(2*f*t/duration), + 'B': Expression((-3.2*f + 2.3)*t/duration), + } + self.assertEqual(expected, pulse._as_expression()) + def test_duration(self): seconds2ns = 1e9 pulse_duration = 1.0765001496284785e-07 @@ -507,6 +527,7 @@ def test_deserialize(self) -> None: self.assertEqual(data['parameter_constraints'], [str(pc) for pc in deserialized.parameter_constraints]) self.assertIs(deserialized.template, dummy_pt) + class MappingPulseTemplateRegressionTests(unittest.TestCase): def test_issue_451(self): from qupulse.pulses import TablePT, SequencePT, AtomicMultiChannelPT diff --git a/tests/pulses/multi_channel_pulse_template_tests.py b/tests/pulses/multi_channel_pulse_template_tests.py index c2eb661c3..b21c04639 100644 --- a/tests/pulses/multi_channel_pulse_template_tests.py +++ b/tests/pulses/multi_channel_pulse_template_tests.py @@ -11,6 +11,7 @@ from qupulse.pulses.parameters import ParameterConstraint, ParameterConstraintViolation, ConstantParameter from qupulse.expressions import ExpressionScalar, Expression from qupulse._program.transformation import LinearTransformation, chain_transformations +from qupulse.utils.types import TimeType from tests.pulses.sequencing_dummies import DummyPulseTemplate, DummyWaveform from tests.serialization_dummies import DummySerializer @@ -262,6 +263,26 @@ def test_get_measurement_windows(self): meas_windows = pt.get_measurement_windows({}, measurement_mapping) self.assertEqual(expected, meas_windows) + def test_build_waveform_padding_and_truncation(self): + wfs = [DummyWaveform(duration=1.1, defined_channels={'A'}), + DummyWaveform(duration=1.2, defined_channels={'B'}), + DummyWaveform(duration=0.9, defined_channels={'C'})] + + sts = [DummyPulseTemplate(duration='ta', defined_channels={'A'}, waveform=wfs[0]), + DummyPulseTemplate(duration='tb', defined_channels={'B'}, waveform=wfs[1]), + DummyPulseTemplate(duration='tc', defined_channels={'C'}, waveform=wfs[2])] + + pt = AtomicMultiChannelPulseTemplate(*sts, duration='t_dur', pad_values={'C': 'c'}) + + wf = pt.build_waveform(parameters={'a': 1., 'b': 2., 'c': 3., + 't_dur': 1.1, 'ta': 1.1, + 'tb': 1.2, 'tc': 0.9}, channel_mapping={'A': 'A', 'B': 'B', 'C': 'C'}) + + expected_wf = MultiChannelWaveform(dict(zip('ABC', wfs)), pad_values={'C': 3.}, + duration=TimeType.from_float(1.1)) + self.assertEqual(expected_wf, wf) + + class AtomicMultiChannelPulseTemplateSerializationTests(SerializableTests, unittest.TestCase): From 2b0eaac312ab1388e40e12fa8e5f5d6dca9f3d25 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 5 Aug 2020 18:59:53 +0200 Subject: [PATCH 11/11] Sympy expression fixes. Still failing, missing adn incomplete tests --- qupulse/_program/waveforms.py | 6 +- qupulse/pulses/point_pulse_template.py | 13 +++- qupulse/pulses/table_pulse_template.py | 42 +++++++++---- qupulse/utils/sympy.py | 28 +++++++-- tests/pulses/point_pulse_template_tests.py | 71 ++++++++++++++-------- tests/pulses/table_pulse_template_tests.py | 42 ++++++++++--- tests/utils/sympy_tests.py | 11 ++++ 7 files changed, 156 insertions(+), 57 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 5273e2df1..a91415cce 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -89,7 +89,7 @@ def get_sampled(self, if np.any(sample_times[:-1] >= sample_times[1:]): raise ValueError('The sample times are not monotonously increasing') - if sample_times[0] < 0 or sample_times[-1] > self.duration: + if sample_times[0] < 0 or sample_times[-1] > float(self.duration): raise ValueError('The sample times are not in the range [0, duration]') if channel not in self.defined_channels: raise KeyError('Channel not defined in this waveform: {}'.format(channel)) @@ -153,8 +153,8 @@ def last_value(self, channel) -> float: class TableWaveformEntry(NamedTuple('TableWaveformEntry', [('t', float), ('v', float), ('interp', InterpolationStrategy)])): - def __init__(self, t: float, v: float, interp: InterpolationStrategy): - if not callable(interp): + def __init__(self, t: float, v: float, interp: Optional[InterpolationStrategy]): + if not callable(interp) or interp is None: raise TypeError('{} is neither callable nor of type InterpolationStrategy'.format(interp)) diff --git a/qupulse/pulses/point_pulse_template.py b/qupulse/pulses/point_pulse_template.py index 63939aa33..8f5b561ff 100644 --- a/qupulse/pulses/point_pulse_template.py +++ b/qupulse/pulses/point_pulse_template.py @@ -146,7 +146,9 @@ def value_trafo(v): return v.underlying_expression[i] except TypeError: return sympy.IndexedBase(Broadcast(v.underlying_expression, shape))[i] - expressions[channel] = TableEntry._sequence_integral(self._entries, expression_extractor=value_trafo) + pre_entry = TableEntry(0, self._entries[0].v, None) + entries = [pre_entry] + self._entries + expressions[channel] = TableEntry._sequence_integral(entries, expression_extractor=value_trafo) return expressions def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: @@ -160,8 +162,13 @@ def value_trafo(v): return v.underlying_expression[i] except TypeError: return sympy.IndexedBase(Broadcast(v.underlying_expression, shape))[i] - - pw = TableEntry._sequence_as_expression(self._entries, expression_extractor=value_trafo, t=t) + pre_value = value_trafo(self._entries[0].v) + post_value = value_trafo(self._entries[-1].v) + pw = TableEntry._sequence_as_expression(self._entries, + expression_extractor=value_trafo, + t=t, + post_value=post_value, + pre_value=pre_value) expressions[channel] = pw return expressions diff --git a/qupulse/pulses/table_pulse_template.py b/qupulse/pulses/table_pulse_template.py index 7ce492960..f94794e9a 100644 --- a/qupulse/pulses/table_pulse_template.py +++ b/qupulse/pulses/table_pulse_template.py @@ -39,13 +39,13 @@ class TableEntry(NamedTuple('TableEntry', [('t', ExpressionScalar), ('v', Expression), - ('interp', InterpolationStrategy)])): + ('interp', Optional[InterpolationStrategy])])): __slots__ = () - def __new__(cls, t: ValueInInit, v: ValueInInit, interp: Union[str, InterpolationStrategy]='default'): + def __new__(cls, t: ValueInInit, v: ValueInInit, interp: Optional[Union[str, InterpolationStrategy]]='default'): if interp in TablePulseTemplate.interpolation_strategies: interp = TablePulseTemplate.interpolation_strategies[interp] - if not isinstance(interp, InterpolationStrategy): + if interp is not None and not isinstance(interp, InterpolationStrategy): raise KeyError(interp, 'is not a valid interpolation strategy') return super().__new__(cls, ExpressionScalar.make(t), @@ -58,7 +58,8 @@ def instantiate(self, parameters: Dict[str, numbers.Real]) -> TableWaveformEntry self.interp) def get_serialization_data(self) -> tuple: - return self.t.get_serialization_data(), self.v.get_serialization_data(), str(self.interp) + interp = None if self.interp is None else str(self.interp) + return self.t.get_serialization_data(), self.v.get_serialization_data(), interp @classmethod def _sequence_integral(cls, entry_sequence: Sequence['TableEntry'], @@ -79,13 +80,15 @@ def _sequence_integral(cls, entry_sequence: Sequence['TableEntry'], 'v0': expression_extractor(first_entry.v), 't1': second_entry.t.sympified_expression, 'v1': expression_extractor(second_entry.v)} - expr += first_entry.interp.integral.sympified_expression.subs(substitutions, simultaneous=True) + expr += second_entry.interp.integral.sympified_expression.subs(substitutions, simultaneous=True) return ExpressionScalar(expr) @classmethod def _sequence_as_expression(cls, entry_sequence: Sequence['TableEntry'], expression_extractor: Callable[[Expression], sympy.Expr], - t: sympy.Dummy) -> ExpressionScalar: + t: sympy.Dummy, + pre_value: Optional[sympy.Expr], + post_value: Optional[sympy.Expr]) -> ExpressionScalar: """Create an expression out of a sequence of table entries. Args: @@ -93,6 +96,8 @@ def _sequence_as_expression(cls, entry_sequence: Sequence['TableEntry'], expression_extractor: Convert each entry's voltage into a sympy expression. Can be used to select single channels from a vectorized expression. t: Time variable + pre_value: If not None all t values smaller than the first entry's time give this value + post_value: If not None all t values larger than the last entry's time give this value Returns: Scalar expression that covers the complete sequence and is zero outside. @@ -109,12 +114,16 @@ def _sequence_as_expression(cls, entry_sequence: Sequence['TableEntry'], 't': t} time_gate = sympy.And(t0 <= t, t < t1) - interpolation_expr = first_entry.interp.expression.underlying_expression.subs(substitutions, + interpolation_expr = second_entry.interp.expression.underlying_expression.subs(substitutions, simultaneous=True) piecewise_args.append((interpolation_expr, time_gate)) - piecewise_args.append((0, True)) + if pre_value is not None: + piecewise_args.append((pre_value, t < entry_sequence[0].t.sympified_expression)) + if post_value is not None: + piecewise_args.append((post_value, t >= entry_sequence[-1].t.sympified_expression)) + return ExpressionScalar(sympy.Piecewise(*piecewise_args)) @@ -200,16 +209,17 @@ def __init__(self, entries: Dict[ChannelID, Sequence[EntryInInit]], self._register(registry=registry) def _add_entry(self, channel, new_entry: TableEntry) -> None: + ch_entries = self._entries[channel] # comparisons with Expression can yield None -> use 'is True' and 'is False' if (new_entry.t < 0) is True: raise ValueError('Time parameter number {} of channel {} is negative.'.format( - len(self._entries[channel]), channel)) + len(ch_entries), channel)) - for previous_entry in self._entries[channel]: + for previous_entry in ch_entries: if (new_entry.t < previous_entry.t) is True: raise ValueError('Time parameter number {} of channel {} is smaller than a previous one'.format( - len(self._entries[channel]), channel)) + len(ch_entries), channel)) self._entries[channel].append(new_entry) @@ -406,6 +416,9 @@ def is_valid_interpolation_strategy(inter): def integral(self) -> Dict[ChannelID, ExpressionScalar]: expressions = dict() for channel, channel_entries in self._entries.items(): + pre_entry = TableEntry(0, channel_entries[0].v, None) + post_entry = TableEntry(self.duration, channel_entries[-1].v, 'hold') + channel_entries = [pre_entry] + channel_entries + [post_entry] expressions[channel] = TableEntry._sequence_integral(channel_entries, lambda v: v.sympified_expression) return expressions @@ -413,9 +426,14 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: expressions = dict() for channel, channel_entries in self._entries.items(): + pre_value = channel_entries[0].v.sympified_expression + post_value = channel_entries[-1].v.sympified_expression + expressions[channel] = TableEntry._sequence_as_expression(channel_entries, lambda v: v.sympified_expression, - t=self._AS_EXPRESSION_TIME) + t=self._AS_EXPRESSION_TIME, + pre_value=pre_value, + post_value=post_value) return expressions diff --git a/qupulse/utils/sympy.py b/qupulse/utils/sympy.py index 0e43df6a9..2bbc6adf8 100644 --- a/qupulse/utils/sympy.py +++ b/qupulse/utils/sympy.py @@ -134,6 +134,22 @@ def numpy_compatible_mul(*args) -> Union[sympy.Mul, sympy.Array]: return sympy.Mul(*args) +def numpy_compatible_add(*args) -> Union[sympy.Add, sympy.Array]: + if any(isinstance(a, sympy.NDimArray) for a in args): + result = 0 + for a in args: + result = result + (numpy.array(a.tolist()) if isinstance(a, sympy.NDimArray) else a) + return sympy.Array(result) + else: + return sympy.Add(*args) + + +_NUMPY_COMPATIBLE = { + sympy.Add: numpy_compatible_add, + sympy.Mul: numpy_compatible_mul +} + + def numpy_compatible_ceiling(input_value: Any) -> Any: if isinstance(input_value, numpy.ndarray): return numpy.ceil(input_value).astype(numpy.int64) @@ -163,6 +179,8 @@ def sympify(expr: Union[str, Number, sympy.Expr, numpy.str_], **kwargs) -> sympy # putting numpy.str_ in sympy.sympify behaves unexpected in version 1.1.1 # It seems to ignore the locals argument expr = str(expr) + if isinstance(expr, (tuple, list)): + expr = numpy.array(expr) try: return sympy.sympify(expr, **kwargs, locals=sympify_namespace) except TypeError as err: @@ -218,20 +236,18 @@ def _recursive_substitution(expression: sympy.Expr, substitutions: Dict[sympy.Symbol, sympy.Expr]) -> sympy.Expr: if not expression.free_symbols: return expression - elif expression.func is sympy.Symbol: + elif expression.func in (sympy.Symbol, sympy.Dummy): return substitutions.get(expression, expression) - elif expression.func is sympy.Mul: - func = numpy_compatible_mul - else: - func = expression.func + func = _NUMPY_COMPATIBLE.get(expression.func, expression.func) substitutions = {s: substitutions.get(s, s) for s in get_free_symbols(expression)} return func(*(_recursive_substitution(arg, substitutions) for arg in expression.args)) def recursive_substitution(expression: sympy.Expr, substitutions: Dict[str, Union[sympy.Expr, numpy.ndarray, str]]) -> sympy.Expr: - substitutions = {sympy.Symbol(k): sympify(v) for k, v in substitutions.items()} + substitutions = {k if isinstance(k, (sympy.Symbol, sympy.Dummy)) else sympy.Symbol(k): sympify(v) + for k, v in substitutions.items()} for s in get_free_symbols(expression): substitutions.setdefault(s, s) return _recursive_substitution(expression, substitutions) diff --git a/tests/pulses/point_pulse_template_tests.py b/tests/pulses/point_pulse_template_tests.py index 527738afb..07e858243 100644 --- a/tests/pulses/point_pulse_template_tests.py +++ b/tests/pulses/point_pulse_template_tests.py @@ -76,28 +76,28 @@ def test_parameter_names(self): def test_integral(self) -> None: pulse = PointPulseTemplate( - [(1, (2, 'b'), 'linear'), - (3, (0, 0), 'jump'), - (4, (2, 'c'), 'hold'), + [(1, (2, 'b'), 'hold'), + (3, (0, 0), 'linear'), + (4, (2, 'c'), 'jump'), (5, (8, 'd'), 'hold')], [0, 'other_channel'] ) - self.assertEqual({0: ExpressionScalar('6'), - 'other_channel': ExpressionScalar('b + 2*c')}, + self.assertEqual({0: ExpressionScalar('2 + 6'), + 'other_channel': ExpressionScalar('b + b + 2*c')}, pulse.integral) pulse = PointPulseTemplate( - [(1, ('2', 'b'), 'linear'), ('t0', (0, 0), 'jump'), (4, (2.0, 'c'), 'hold'), ('g', (8, 'd'), 'hold')], + [(1, ('2', 'b'), 'hold'), ('t0', (0, 0), 'linear'), (4, (2.0, 'c'), 'jump'), ('g', (8, 'd'), 'hold')], ['symbolic', 1] ) - self.assertEqual({'symbolic': ExpressionScalar('2.0*g - 1.0*t0 - 1.0'), - 1: ExpressionScalar('b*(t0 - 1) / 2 + c*(g - 4) + c*(-t0 + 4)')}, + self.assertEqual({'symbolic': ExpressionScalar('2 + 2.0*g - 1.0*t0 - 1.0'), + 1: ExpressionScalar('b + b*(t0 - 1) / 2 + c*(g - 4) + c*(-t0 + 4)')}, pulse.integral) ppt = PointPulseTemplate([(0, 0), ('t_init', 0)], ['X', 'Y']) self.assertEqual(ppt.integral, {'X': 0, 'Y': 0}) - ppt = PointPulseTemplate([(0., 'a', 'linear'), ('t_1', 'b'), ('t_2', (0, 0))], ('X', 'Y')) + ppt = PointPulseTemplate([(0., 'a'), ('t_1', 'b', 'linear'), ('t_2', (0, 0))], ('X', 'Y')) parameters = {'a': (3.4, 4.1), 'b': 4, 't_1': 2, 't_2': 5} integral = {ch: v.evaluate_in_scope(parameters) for ch, v in ppt.integral.items()} self.assertEqual({'X': 2 * (3.4 + 4) / 2 + (5 - 2) * 4, @@ -300,35 +300,58 @@ def test_serializer_integration_old(self): class PointPulseExpressionIntegralTests(unittest.TestCase): + def setUp(self): + self.template = PointPulseTemplate(**PointPulseTemplateSerializationTests().make_kwargs()) + self.parameter_sets = [ + {'foo': 1., 'hugo': 2., 'sudo': 3., 'A': 4., 'B': 5., 'a': 6., 'ilse': 7., 'k': 8.}, + {'foo': 1.1, 'hugo': 2.6, 'sudo': 2.7, 'A': np.array([3., 4.]), 'B': 5., 'a': 6., 'ilse': 7., 'k': 8.}, + ] + def test_integral_as_expression_compatible(self): import sympy - from sympy import Q - template = PointPulseTemplate(**PointPulseTemplateSerializationTests().make_kwargs()) - - t = template._AS_EXPRESSION_TIME - as_expression = template._as_expression() - integral = template.integral - duration = template.duration.underlying_expression - self.assertEqual(template.defined_channels, integral.keys()) - self.assertEqual(template.defined_channels, as_expression.keys()) + t = self.template._AS_EXPRESSION_TIME + as_expression = self.template._as_expression() + integral = self.template.integral + duration = self.template.duration.underlying_expression - parameter_sets = [ - {'foo': 1., 'hugo': 2., 'sudo': 3., 'A': 4., 'B': 5., 'a': 6.}, - {'foo': 1.1, 'hugo': 2.6, 'sudo': 2.7, 'A': np.array([3., 4.]), 'B': 5., 'a': 6.}, - ] + self.assertEqual(self.template.defined_channels, integral.keys()) + self.assertEqual(self.template.defined_channels, as_expression.keys()) - for channel in template.defined_channels: + for channel in self.template.defined_channels: ch_expr = as_expression[channel].underlying_expression ch_int = integral[channel].underlying_expression symbolic = sympy.integrate(ch_expr, (t, 0, duration)) symbolic = sympy.simplify(symbolic) - for parameters in parameter_sets: + for parameters in self.parameter_sets: num_from_expr = ExpressionScalar(symbolic).evaluate_in_scope(parameters) num_from_in = ExpressionScalar(ch_int).evaluate_in_scope(parameters) np.testing.assert_almost_equal(num_from_in, num_from_expr) # TODO: the following fails even with a lot of assumptions in sympy 1.6 # self.assertEqual(ch_int, symbolic) + + def test_as_expression_wf_and_sample_compatible(self): + as_expression = self.template._as_expression() + + for parameters in self.parameter_sets: + wf = self.template.build_waveform(parameters, {c: c for c in self.template.defined_channels}) + + ts = np.linspace(0, float(wf.duration), num=33) + sampled = {ch: wf.get_sampled(ch, ts) for ch in self.template.defined_channels} + + from_expr = {} + for ch, expected_vs in sampled.items(): + ch_expr = as_expression[ch] + + ch_from_expr = [] + for t, expected in zip(ts, expected_vs): + result_expr = ch_expr.evaluate_symbolic({**parameters, self.template._AS_EXPRESSION_TIME: t}) + ch_from_expr.append(result_expr.sympified_expression) + from_expr[ch] = ch_from_expr + + np.testing.assert_almost_equal(expected_vs, ch_from_expr) + + diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index 8ad2243cf..580fd9b7a 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -69,15 +69,15 @@ def get_sympy(v): 't2': 1.3, } - entries = [TableEntry(0, 0, 'hold'), TableEntry(1, 0, 'hold')] + entries = [TableEntry(0, 0, None), TableEntry(1, 0, 'hold')] self.assertEqual(ExpressionScalar(0), TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) - entries = [TableEntry(0, 1, 'hold'), TableEntry(1, 1, 'hold')] + entries = [TableEntry(0, 1, None), TableEntry(1, 1, 'hold')] self.assertEqual(ExpressionScalar(1), TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) - entries = [TableEntry(0, 0, 'linear'), TableEntry(1, 1, 'hold')] + entries = [TableEntry(0, 0, None), TableEntry(1, 1, 'linear')] self.assertEqual(ExpressionScalar(.5), TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) @@ -466,17 +466,41 @@ def test_identifier(self) -> None: self.assertEqual(pulse.identifier, identifier) def test_integral(self) -> None: - pulse = TablePulseTemplate(entries={0: [(1, 2, 'linear'), (3, 0, 'jump'), (4, 2, 'hold'), (5, 8, 'hold')], - 'other_channel': [(0, 7, 'linear'), (2, 0, 'hold'), (10, 0)], - 'symbolic': [(3, 'a', 'hold'), ('b', 4, 'linear'), ('c', Expression('d'), 'hold')]}) - expected = {0: Expression('6'), + pulse = TablePulseTemplate(entries={0: [(1, 2), (3, 0, 'linear'), (4, 2, 'jump'), (5, 8, 'hold')], + 'other_channel': [(0, 7), (2, 0, 'linear'), (10, 0)], + 'symbolic': [(3, 'a'), ('b', 4, 'hold'), ('c', Expression('d'), 'linear')]}) + expected = {0: Expression('2 + 2 + 2 + 2 + (Max(c, 10) - 5) * 8'), 'other_channel': Expression(7), - 'symbolic': Expression('(b-3)*a + (c-b)*(d+4) / 2')} + 'symbolic': Expression('3 * a + (b-3)*a + (c-b)*(d+4) / 2 + (Max(10, c) - c) * d')} self.assertEqual(expected, pulse.integral) def test_as_expression(self): - raise NotImplementedError() + pulse = TablePulseTemplate(entries={0: [(0, 0), (1, 2), (3, 0, 'linear'), (4, 2, 'jump'), (5, 8, 'hold')], + 'other_channel': [(0, 7), (2, 0, 'linear'), (10, 0)], + 'symbolic': [(3, 'a'), ('b', 4, 'hold'), + ('c', Expression('d'), 'linear')]}) + parameters = dict(a=2., b=4, c=9, d=8) + wf = pulse.build_waveform(parameters, channel_mapping={0: 0, + 'other_channel': 'other_channel', + 'symbolic': 'symbolic'}) + expr = pulse._as_expression() + ts = numpy.linspace(0, float(wf.duration), num=33) + sampled = {ch: wf.get_sampled(ch, ts) for ch in pulse.defined_channels} + + from_expr = {} + for ch, expected_vs in sampled.items(): + ch_expr = expr[ch] + + ch_from_expr = [] + for t, expected in zip(ts, expected_vs): + params = {**parameters, TablePulseTemplate._AS_EXPRESSION_TIME: t} + result = ch_expr.sympified_expression.subs(params, simultaneous=True) + ch_from_expr.append(result) + from_expr[ch] = ch_from_expr + + numpy.testing.assert_almost_equal(expected_vs, ch_from_expr) + class TablePulseTemplateConstraintTest(ParameterConstrainerTest): diff --git a/tests/utils/sympy_tests.py b/tests/utils/sympy_tests.py index def2f9d2b..32d37a474 100644 --- a/tests/utils/sympy_tests.py +++ b/tests/utils/sympy_tests.py @@ -14,6 +14,7 @@ a_ = IndexedBase(a) b_ = IndexedBase(b) +dummy_a = sympy.Dummy('a') from qupulse.utils.sympy import sympify as qc_sympify, substitute_with_eval, recursive_substitution, Len,\ evaluate_lambdified, evaluate_compiled, get_most_simple_representation, get_variables, get_free_symbols,\ @@ -50,6 +51,11 @@ (Sum(a_[i], (i, 0, Len(a) - 1)), {'a': sympy.Array([1, 2, 3])}, 6), ] +dummy_substitution_cases = [ + (a * dummy_a + sympy.exp(dummy_a), {'a': b}, b * dummy_a + sympy.exp(dummy_a)), + (a * dummy_a + sympy.exp(dummy_a), {dummy_a: b}, a * b + sympy.exp(b)), +] + ##################################################### SYMPIFY ########################################################## simple_sympify = [ @@ -199,6 +205,11 @@ def test_full_featured_cases(self): result = self.substitute(expr, subs) self.assertEqual(result, expected) + def test_dummy_subs(self): + for expr, subs, expected in dummy_substitution_cases: + result = self.substitute(expr, subs) + self.assertEqual(result, expected) + class SubstituteWithEvalTests(SubstitutionTests): def substitute(self, expression: sympy.Expr, substitutions: dict):