From 7cfcb7437ca0a4f63781aa55155b49c2b07c9a58 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Fri, 31 May 2024 20:43:40 +0200 Subject: [PATCH 01/35] equality handling SimpleExpression; file refactor --- qupulse/expressions/simple.py | 126 ++++++++++++++++++++++++++++++++++ qupulse/program/__init__.py | 97 ++------------------------ qupulse/program/linspace.py | 4 +- qupulse/program/waveforms.py | 99 ++++++++++++++++++++------ qupulse/utils/__init__.py | 21 ++++-- 5 files changed, 224 insertions(+), 123 deletions(-) create mode 100644 qupulse/expressions/simple.py diff --git a/qupulse/expressions/simple.py b/qupulse/expressions/simple.py new file mode 100644 index 00000000..5504b7db --- /dev/null +++ b/qupulse/expressions/simple.py @@ -0,0 +1,126 @@ +import numpy as np +from numbers import Real, Number +from typing import Optional, Union, Sequence, ContextManager, Mapping, Tuple, Generic, TypeVar, Iterable, Dict +from dataclasses import dataclass + +from functools import total_ordering +from qupulse.utils.sympy import _lambdify_modules +from qupulse.expressions import sympy as sym_expr, Expression +from qupulse.utils.types import MeasurementWindow, TimeType, FrozenMapping + + +NumVal = TypeVar('NumVal', bound=Real) + + +@total_ordering +@dataclass +class SimpleExpression(Generic[NumVal]): + """This is a potential hardware evaluable expression of the form + + C + C1*R1 + C2*R2 + ... + where R1, R2, ... are potential runtime parameters. + + The main use case is the expression of for loop dependent variables where the Rs are loop indices. There the + expressions can be calculated via simple increments. + """ + + base: NumVal + offsets: Mapping[str, NumVal] + + def __post_init__(self): + assert isinstance(self.offsets, Mapping) + + def value(self, scope: Mapping[str, NumVal]) -> NumVal: + value = self.base + for name, factor in self.offsets.items(): + value += scope[name] * factor + return value + + def __abs__(self): + return abs(self.base)+sum([abs(o) for o in self.offsets.values()]) + + def __eq__(self, other): + #there is no good way to compare it without having a value, + #but cannot require more parameters in magic method? + #so have this weird full equality for now which doesn logically make sense + #in most cases to catch unintended consequences + + if isinstance(other, (float, int, TimeType)): + return self.base==other and all([o==other for o in self.offsets]) + + if type(other) == type(self): + if len(self.offsets)!=len(other.offsets): return False + return self.base==other.base and all([o1==o2 for o1,o2 in zip(self.offsets,other.offsets)]) + + return NotImplemented + + def __gt__(self, other): + #there is no good way to compare it without having a value, + #but cannot require more parameters in magic method? + #so have this weird full equality for now which doesn logically make sense + #in most cases to catch unintended consequences + + if isinstance(other, (float, int, TimeType)): + return self.base>other and all([o>other for o in self.offsets.values()]) + + if type(other) == type(self): + if len(self.offsets)!=len(other.offsets): return False + return self.base>other.base and all([o1>o2 for o1,o2 in zip(self.offsets.values(),other.offsets.values())]) + + return NotImplemented + + + def __add__(self, other): + if isinstance(other, (float, int, TimeType)): + return SimpleExpression(self.base + other, self.offsets) + + if type(other) == type(self): + offsets = self.offsets.copy() + for name, value in other.offsets.items(): + offsets[name] = value + offsets.get(name, 0) + return SimpleExpression(self.base + other.base, offsets) + + return NotImplemented + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + return self.__add__(-other) + + def __rsub__(self, other): + (-self).__add__(other) + + def __neg__(self): + return SimpleExpression(-self.base, {name: -value for name, value in self.offsets.items()}) + + def __mul__(self, other: NumVal): + if isinstance(other, (float, int, TimeType)): + return SimpleExpression(self.base * other, {name: other * value for name, value in self.offsets.items()}) + + return NotImplemented + + def __rmul__(self, other): + return self.__mul__(other) + + def __truediv__(self, other): + inv = 1 / other + return self.__mul__(inv) + + @property + def free_symbols(self): + return () + + def _sympy_(self): + return self + + def replace(self, r, s): + return self + + def evaluate_in_scope_(self, *args, **kwargs): + # TODO: remove. It is currently required to avoid nesting this class in an expression for the MappedScope + # We can maybe replace is with a HardwareScope or something along those lines + return self + + +_lambdify_modules.append({'SimpleExpression': SimpleExpression}) diff --git a/qupulse/program/__init__.py b/qupulse/program/__init__.py index 611a96fc..82b28148 100644 --- a/qupulse/program/__init__.py +++ b/qupulse/program/__init__.py @@ -1,101 +1,12 @@ -import contextlib -from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Optional, Union, Sequence, ContextManager, Mapping, Tuple, Generic, TypeVar, Iterable, Dict -from numbers import Real, Number - -import numpy as np +from typing import Protocol, runtime_checkable from qupulse._program.waveforms import Waveform -from qupulse.utils.types import MeasurementWindow, TimeType, FrozenMapping +from qupulse.utils.types import MeasurementWindow, TimeType from qupulse._program.volatile import VolatileRepetitionCount from qupulse.parameter_scope import Scope -from qupulse.expressions import sympy as sym_expr, Expression -from qupulse.utils.sympy import _lambdify_modules - -from typing import Protocol, runtime_checkable - - -NumVal = TypeVar('NumVal', bound=Real) - - -@dataclass -class SimpleExpression(Generic[NumVal]): - """This is a potential hardware evaluable expression of the form - - C + C1*R1 + C2*R2 + ... - where R1, R2, ... are potential runtime parameters. - - The main use case is the expression of for loop dependent variables where the Rs are loop indices. There the - expressions can be calculated via simple increments. - """ - - base: NumVal - offsets: Mapping[str, NumVal] - - def __post_init__(self): - assert isinstance(self.offsets, Mapping) - - def value(self, scope: Mapping[str, NumVal]) -> NumVal: - value = self.base - for name, factor in self.offsets: - value += scope[name] * factor - return value - - def __add__(self, other): - if isinstance(other, (float, int, TimeType)): - return SimpleExpression(self.base + other, self.offsets) - - if type(other) == type(self): - offsets = self.offsets.copy() - for name, value in other.offsets.items(): - offsets[name] = value + offsets.get(name, 0) - return SimpleExpression(self.base + other.base, offsets) - - return NotImplemented - - def __radd__(self, other): - return self.__add__(other) - - def __sub__(self, other): - return self.__add__(-other) - - def __rsub__(self, other): - (-self).__add__(other) - - def __neg__(self): - return SimpleExpression(-self.base, {name: -value for name, value in self.offsets.items()}) - - def __mul__(self, other: NumVal): - if isinstance(other, (float, int, TimeType)): - return SimpleExpression(self.base * other, {name: other * value for name, value in self.offsets.items()}) - - return NotImplemented - - def __rmul__(self, other): - return self.__mul__(other) - - def __truediv__(self, other): - inv = 1 / other - return self.__mul__(inv) - - @property - def free_symbols(self): - return () - - def _sympy_(self): - return self - - def replace(self, r, s): - return self - - def evaluate_in_scope_(self, *args, **kwargs): - # TODO: remove. It is currently required to avoid nesting this class in an expression for the MappedScope - # We can maybe replace is with a HardwareScope or something along those lines - return self - - -_lambdify_modules.append({'SimpleExpression': SimpleExpression}) +from qupulse.expressions import sympy as sym_expr +from qupulse.expressions.simple import SimpleExpression RepetitionCount = Union[int, VolatileRepetitionCount, SimpleExpression[int]] diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 0d454c09..9ed021a1 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -7,8 +7,8 @@ from qupulse import ChannelID, MeasurementWindow from qupulse.parameter_scope import Scope, MappedScope, FrozenDict -from qupulse.program import (ProgramBuilder, HardwareTime, HardwareVoltage, Waveform, RepetitionCount, TimeType, - SimpleExpression) +from qupulse.program import ProgramBuilder, HardwareTime, HardwareVoltage, Waveform, RepetitionCount, TimeType +from qupulse.expressions.simple import SimpleExpression from qupulse.program.waveforms import MultiChannelWaveform # this resolution is used to unify increments diff --git a/qupulse/program/waveforms.py b/qupulse/program/waveforms.py index 94c39553..45e24b89 100644 --- a/qupulse/program/waveforms.py +++ b/qupulse/program/waveforms.py @@ -18,15 +18,13 @@ from qupulse import ChannelID from qupulse.program.transformation import Transformation -from qupulse.utils import checked_int_cast, isclose -from qupulse.utils.types import TimeType, time_from_float from qupulse.utils.performance import is_monotonic from qupulse.comparable import Comparable from qupulse.expressions import ExpressionScalar +from qupulse.expressions.simple import SimpleExpression from qupulse.pulses.interpolation import InterpolationStrategy from qupulse.utils import checked_int_cast, isclose -from qupulse.utils.types import TimeType, time_from_float, FrozenDict -from qupulse.program.transformation import Transformation +from qupulse.utils.types import TimeType, time_from_float from qupulse.utils import pairwise class ConstantFunctionPulseTemplateWarning(UserWarning): @@ -51,6 +49,13 @@ def _to_time_type(duration: Real) -> TimeType: else: return time_from_float(float(duration), absolute_error=PULSE_TO_WAVEFORM_ERROR) +def _to_hardware_time(duration: Union[Real, SimpleExpression]) -> Union[TimeType, SimpleExpression[TimeType]]: + if isinstance(duration, SimpleExpression): + return SimpleExpression[TimeType](_to_time_type(duration.base), + {name:_to_time_type(value) for name,value in duration.offsets.items()}) + else: + return _to_time_type(duration) + class Waveform(Comparable, metaclass=ABCMeta): """Represents an instantiated PulseTemplate which can be sampled to retrieve arrays of voltage @@ -215,7 +220,16 @@ def reversed(self) -> 'Waveform': """Returns a reversed version of this waveform.""" # We don't check for constness here because const waveforms are supposed to override this method return ReversedWaveform(self) - + + @abstractmethod + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Hashable: + """key for hashing *without* channel reference + """ + + def _hash_only_subset(self, channel_subset: Set[ChannelID]) -> int: + """Return a hash value of this Comparable object.""" + return hash(self.get_subset_for_channels(channel_subset)._compare_subset_key(channel_subset)) + class TableWaveformEntry(NamedTuple('TableWaveformEntry', [('t', Real), ('v', float), @@ -248,7 +262,7 @@ def __init__(self, category=DeprecationWarning) waveform_table = self._validate_input(waveform_table) - super().__init__(duration=_to_time_type(waveform_table[-1].t)) + super().__init__(duration=_to_hardware_time(waveform_table[-1].t)) self._table = waveform_table self._channel_id = channel @@ -353,7 +367,11 @@ def from_table(cls, channel: ChannelID, table: Sequence[EntryInInit]) -> Union[' @property def compare_key(self) -> Any: return self._channel_id, self._table - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Any: + assert self.defined_channels == channel_subset + return self._table + def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, @@ -397,10 +415,7 @@ class ConstantWaveform(Waveform): def __init__(self, duration: Real, amplitude: Any, channel: ChannelID): """ Create a qupulse waveform corresponding to a ConstantPulseTemplate """ - super().__init__(duration=_to_time_type(duration)) - if hasattr(amplitude, 'shape'): - amplitude = amplitude[()] - hash(amplitude) + super().__init__(duration=_to_hardware_time(duration)) self._amplitude = amplitude self._channel = channel @@ -409,7 +424,7 @@ def from_mapping(cls, duration: Real, constant_values: Mapping[ChannelID, float] 'MultiChannelWaveform']: """Construct a ConstantWaveform or a MultiChannelWaveform of ConstantWaveforms with given duration and values""" assert constant_values - duration = _to_time_type(duration) + duration = _to_hardware_time(duration) if len(constant_values) == 1: (channel, amplitude), = constant_values.items() return cls(duration, amplitude=amplitude, channel=channel) @@ -437,7 +452,11 @@ def defined_channels(self) -> AbstractSet[ChannelID]: @property def compare_key(self) -> Tuple[Any, ...]: return self._duration, self._amplitude, self._channel - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Tuple[Any, ...]: + assert self.defined_channels == channel_subset + return self._duration, self._amplitude + def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, @@ -483,7 +502,7 @@ def __init__(self, expression: ExpressionScalar, elif not expression.variables: warnings.warn("Constant FunctionWaveform is not recommended as the constant propagation will be suboptimal", category=ConstantFunctionPulseTemplateWarning) - super().__init__(duration=_to_time_type(duration)) + super().__init__(duration=_to_hardware_time(duration)) self._expression = expression self._channel_id = channel @@ -509,7 +528,11 @@ def defined_channels(self) -> AbstractSet[ChannelID]: @property def compare_key(self) -> Any: return self._channel_id, self._expression, self._duration - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Any: + assert self.defined_channels == channel_subset + return self._expression, self._duration + @property def duration(self) -> TimeType: return self._duration @@ -639,7 +662,10 @@ def unsafe_sample(self, @property def compare_key(self) -> Tuple[Waveform]: return self._sequenced_waveforms - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Tuple[Any]: + return tuple(wf._compare_subset_key(channel_subset) for wf in self._sequenced_waveforms) + @property def duration(self) -> TimeType: return self._duration @@ -788,7 +814,15 @@ def defined_channels(self) -> AbstractSet[ChannelID]: def compare_key(self) -> Any: # sort with channels return self._sub_waveforms - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Any: + if len(channel_subset) == 0: return + if channel_subset != self.defined_channels: #also catches channel_subset >= self.defined_channels + # print(self.defined_channels) + # print(channel_subset) + return self.get_subset_for_channels(channel_subset)._compare_subset_key(channel_subset) + return tuple(self[channel]._compare_subset_key({channel}) for channel in channel_subset) + def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, @@ -856,7 +890,10 @@ def unsafe_sample(self, @property def compare_key(self) -> Tuple[Any, int]: return self._body.compare_key, self._repetition_count - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Tuple[Any, int]: + return self._body._compare_subset_key(channel_subset), self._repetition_count + def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Waveform: return RepetitionWaveform.from_repetition_count( body=self._body.unsafe_get_subset_for_channels(channels), @@ -931,7 +968,11 @@ def defined_channels(self) -> AbstractSet[ChannelID]: @property def compare_key(self) -> Tuple[Waveform, Transformation]: return self.inner_waveform, self.transformation - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Tuple[Any, Transformation]: + remaining_channels = self.transformation.get_input_channels(channel_subset) + return self.inner_waveform._compare_subset_key(remaining_channels), self.transformation + def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'SubsetWaveform': return SubsetWaveform(self, channel_subset=channels) @@ -980,7 +1021,12 @@ def defined_channels(self) -> FrozenSet[ChannelID]: @property def compare_key(self) -> Tuple[frozenset, Waveform]: return self.defined_channels, self.inner_waveform - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Any: + #creating another subset from inner_waveform may run into recursive loops? + #so pipe through until MultiChannelWF is reached basically? + return self._inner_waveform._compare_subset_key(channel_subset) + def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: return self.inner_waveform.get_subset_for_channels(channels) @@ -1131,7 +1177,10 @@ def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: @property def compare_key(self) -> Tuple[str, Waveform, Waveform]: return self._arithmetic_operator, self._lhs, self._rhs - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Tuple[str, Any, Any]: + return self._arithmetic_operator, self._lhs._compare_subset_key(channel_subset), self._rhs._compare_subset_key(channel_subset) + class FunctorWaveform(Waveform): # TODO: Use Protocol to enforce that it accepts second argument has the keyword out @@ -1191,7 +1240,10 @@ def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: @property def compare_key(self) -> Tuple[Waveform, FrozenSet]: return self._inner_waveform, frozenset(self._functor.items()) - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Tuple[Any, FrozenSet]: + return self._inner_waveform._compare_subset_key(channel_subset), frozenset(self._functor.items()) + class ReversedWaveform(Waveform): """Reverses the inner waveform in time.""" @@ -1232,6 +1284,9 @@ def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'W @property def compare_key(self) -> Hashable: return self._inner.compare_key + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Any: + return self._inner._compare_subset_key(channel_subset) def reversed(self) -> 'Waveform': return self._inner diff --git a/qupulse/utils/__init__.py b/qupulse/utils/__init__.py index 326072f4..a568ac80 100644 --- a/qupulse/utils/__init__.py +++ b/qupulse/utils/__init__.py @@ -7,14 +7,15 @@ from collections import OrderedDict from frozendict import frozendict from qupulse.expressions import ExpressionScalar, ExpressionLike +from qupulse.expressions.simple import SimpleExpression import numpy try: - from math import isclose + from math import isclose as math_isclose except ImportError: # py version < 3.5 - isclose = None + math_isclose = None try: from functools import cached_property @@ -51,8 +52,17 @@ def _fallback_is_close(a, b, *, rel_tol=1e-09, abs_tol=0.0): return abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) # pragma: no cover -if not isclose: - isclose = _fallback_is_close +if not math_isclose: + math_isclose = _fallback_is_close + + +def checked_is_close(a, b, *, rel_tol=1e-09, abs_tol=0.0): + if isinstance(a,SimpleExpression) or isinstance(b,SimpleExpression): + return _fallback_is_close(a, b, rel_tol=rel_tol, abs_tol=abs_tol) + return math_isclose(a,b,rel_tol=rel_tol,abs_tol=abs_tol) + + +isclose = checked_is_close def _fallback_pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]: @@ -148,5 +158,4 @@ def to_next_multiple(sample_rate: ExpressionLike, quantum: int, return lambda duration: -(-(duration*sample_rate)//quantum) * (quantum/sample_rate) else: #still return 0 if duration==0 - return lambda duration: ExpressionScalar(f'{quantum}/({sample_rate})*Max({min_quanta},-(-{duration}*{sample_rate}//{quantum}))*Max(0, sign({duration}))') - \ No newline at end of file + return lambda duration: ExpressionScalar(f'{quantum}/({sample_rate})*Max({min_quanta},-(-{duration}*{sample_rate}//{quantum}))*Max(0, sign({duration}))') \ No newline at end of file From 74eb4d7ffb2dfada46791c99cf24331086315a68 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Fri, 31 May 2024 21:10:35 +0200 Subject: [PATCH 02/35] supply __lt__ as well to be logically consistent --- qupulse/expressions/simple.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/qupulse/expressions/simple.py b/qupulse/expressions/simple.py index 5504b7db..e5ab3a1f 100644 --- a/qupulse/expressions/simple.py +++ b/qupulse/expressions/simple.py @@ -1,6 +1,6 @@ import numpy as np from numbers import Real, Number -from typing import Optional, Union, Sequence, ContextManager, Mapping, Tuple, Generic, TypeVar, Iterable, Dict +from typing import Optional, Union, Sequence, ContextManager, Mapping, Tuple, Generic, TypeVar, Iterable, Dict, List from dataclasses import dataclass from functools import total_ordering @@ -55,21 +55,25 @@ def __eq__(self, other): return NotImplemented def __gt__(self, other): + return all([b for b in self._return_greater_comparison_bools(other)]) + + def __lt__(self, other): + return all([not b for b in self._return_greater_comparison_bools(other)]) + + def _return_greater_comparison_bools(self, other) -> List[bool]: #there is no good way to compare it without having a value, #but cannot require more parameters in magic method? #so have this weird full equality for now which doesn logically make sense #in most cases to catch unintended consequences - if isinstance(other, (float, int, TimeType)): - return self.base>other and all([o>other for o in self.offsets.values()]) + return [self.base>other] + [o>other for o in self.offsets.values()] if type(other) == type(self): - if len(self.offsets)!=len(other.offsets): return False - return self.base>other.base and all([o1>o2 for o1,o2 in zip(self.offsets.values(),other.offsets.values())]) + if len(self.offsets)!=len(other.offsets): return [False] + return [self.base>other.base] + [o1>o2 for o1,o2 in zip(self.offsets.values(),other.offsets.values())] return NotImplemented - def __add__(self, other): if isinstance(other, (float, int, TimeType)): return SimpleExpression(self.base + other, self.offsets) From f5958062a3460666359a05187fed839116fc040e Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:57:16 +0200 Subject: [PATCH 03/35] implement dummywaveform abstract compare subset --- tests/pulses/sequencing_dummies.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 21c3c7e6..99358dab 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -59,7 +59,24 @@ def compare_key(self) -> Any: ) else: return id(self) - + + @property + def _compare_subset_key(self, channel_subset) -> Any: + assert self.channels==channel_subset + if self.sample_output is not None: + try: + if isinstance(self.sample_output,dict): + return hash(self.sample_output.values().tobytes()) + return hash(self.sample_output.tobytes()) + except AttributeError: + pass + return hash( + tuple(sorted((getattr(output, 'tobytes', lambda: output)(),) + for output in self.sample_output.values())) + ) + else: + return id(self) + @property def measurement_windows(self): return [] From bc985cd3dccce96199f05c51c17386ef1192042f Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:33:18 +0200 Subject: [PATCH 04/35] tests --- qupulse/program/waveforms.py | 2 +- tests/_program/waveforms_tests.py | 103 ++++++++++++++++++++++++++++-- 2 files changed, 100 insertions(+), 5 deletions(-) diff --git a/qupulse/program/waveforms.py b/qupulse/program/waveforms.py index 45e24b89..e1f623c8 100644 --- a/qupulse/program/waveforms.py +++ b/qupulse/program/waveforms.py @@ -1286,7 +1286,7 @@ def compare_key(self) -> Hashable: return self._inner.compare_key def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Any: - return self._inner._compare_subset_key(channel_subset) + return (self._inner._compare_subset_key(channel_subset),'-') def reversed(self) -> 'Waveform': return self._inner diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index c62fceb3..6aba9b1e 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -50,7 +50,11 @@ def unsafe_sample(self, def compare_key(self): raise NotImplementedError() - + @property + def _compare_subset_key(self, channel_subset): + raise NotImplementedError() + + class WaveformTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -292,8 +296,25 @@ def test_constant_default_impl(self): self.assertIsNone(wf_non_const.constant_value_dict()) self.assertIsNone(wf_mixed.constant_value_dict()) self.assertEqual(wf_mixed.constant_value('C'), 2.2) - - + + def test_hash_subset(self): + dwf_a = DummyWaveform(duration=246.2, defined_channels={'A'}, sample_output={'A': 1*np.ones(3)}) + dwf_b = DummyWaveform(duration=246.2, defined_channels={'B'}, sample_output={'B': 2*np.ones(3)}) + dwf_c = DummyWaveform(duration=246.2, defined_channels={'C'}, sample_output={'C': 3*np.ones(3)}) + waveform_a1 = MultiChannelWaveform([dwf_a, dwf_b, dwf_c]) + waveform_a2 = MultiChannelWaveform([dwf_a, dwf_b]) + waveform_a3 = MultiChannelWaveform([dwf_a, dwf_c]) + + self.assertEqual(waveform_a1._hash_only_subset({'A','B'}), + waveform_a2._hash_only_subset({'A','B'})) + self.assertEqual(waveform_a1._hash_only_subset({'A','C'}), + waveform_a3._hash_only_subset({'A','C'})) + self.assertNotEqual(waveform_a1._hash_only_subset({'A','B'}), + waveform_a3._hash_only_subset({'A','C'})) + + self.assertRaises(KeyError, lambda: waveform_a1._hash_only_subset({'A','B','C','D'})) + + class RepetitionWaveformTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -335,6 +356,11 @@ def test_compare_key(self): wf = RepetitionWaveform(body_wf, 2) self.assertEqual(wf.compare_key, (body_wf.compare_key, 2)) + def test_compare_subset(self): + body_wf = DummyWaveform(defined_channels={'a'}) + wf = RepetitionWaveform(body_wf, 2) + self.assertEqual(wf._compare_subset_key({'a',}), (body_wf._compare_subset_key({'a',}), 2)) + def test_unsafe_get_subset_for_channels(self): body_wf = DummyWaveform(defined_channels={'a', 'b'}) @@ -492,6 +518,11 @@ def test_repr(self): r = repr(swf) self.assertEqual(swf, eval(r)) + def test_compare_subset(self): + body_wf = DummyWaveform(defined_channels={'a'}) + wf = SequenceWaveform([body_wf, body_wf]) + self.assertEqual(wf._compare_subset_key({'a',}), tuple(2*[body_wf._compare_subset_key({'a',}),])) + class ConstantWaveformTests(unittest.TestCase): def test_waveform_duration(self): @@ -521,6 +552,14 @@ def test_constness(self): self.assertTrue(waveform.is_constant()) assert_constant_consistent(self, waveform) + def test_hash_subset(self): + wf_1 = ConstantWaveform(10, 1., 'A') + wf_2 = ConstantWaveform(10, 1., 'B') + wf_3 = ConstantWaveform(10, 2., 'A') + + self.assertEqual(wf_1._hash_only_subset({'A',}), wf_2._hash_only_subset({'B',})) + self.assertNotEqual(wf_1._hash_only_subset({'A',}), wf_3._hash_only_subset({'A',})) + class TableWaveformTests(unittest.TestCase): @@ -654,6 +693,25 @@ def test_simple_properties(self): evaled = eval(repr(waveform)) self.assertEqual(evaled, waveform) + def test_hash_subset(self): + + interp = 'jump' + entries = (TableWaveformEntry(0, 0, interp), + TableWaveformEntry(2.1, -33.2, interp), + TableWaveformEntry(5.7, 123.4, interp)) + wf_1 = TableWaveform('A', entries) + entries = (TableWaveformEntry(0, 0, interp), + TableWaveformEntry(2.1, -33.2, interp), + TableWaveformEntry(5.7, 123.4, interp)) + wf_2 = TableWaveform('B', entries) + entries = (TableWaveformEntry(0.5, 0, interp), + TableWaveformEntry(2.1, -33.2, interp), + TableWaveformEntry(5.7, 123.4, interp)) + wf_3 = TableWaveform('A', entries) + + self.assertEqual(wf_1._hash_only_subset({'A',}), wf_2._hash_only_subset({'B',})) + self.assertNotEqual(wf_1._hash_only_subset({'A',}), wf_3._hash_only_subset({'A',})) + class WaveformEntryTest(unittest.TestCase): def test_interpolation_exception(self): @@ -796,6 +854,16 @@ def test_const_value(self): with mock.patch.object(inner_wf, 'constant_value', side_effect=inner_const_values.values()) as constant_value: self.assertIsNone(trafo_wf.constant_value('C')) + def test_compare_subset(self): + output_channels = {'c', 'd', 'e'} + input_channels = {'a', 'b'} + trafo = TransformationDummy(output_channels=output_channels,input_channels=input_channels) + inner_wf = DummyWaveform(duration=1.5, defined_channels=input_channels) + trafo_wf = TransformingWaveform(inner_waveform=inner_wf, transformation=trafo) + + self.assertEqual(trafo_wf._compare_subset_key(output_channels), + (inner_wf._compare_subset_key(input_channels), trafo)) + class SubsetWaveformTest(unittest.TestCase): def test_simple_properties(self): @@ -836,6 +904,13 @@ def test_unsafe_sample(self): self.assertIs(expected_data, actual_data) unsafe_sample.assert_called_once_with('g', time, output) + def test_compare_subset(self): + inner_wf = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'c'}) + subset_wf = SubsetWaveform(inner_wf, {'a', 'c'}) + + self.assertEqual(subset_wf._compare_subset_key({'a', 'c'}), + inner_wf._compare_subset_key({'a', 'c'}),) + class ArithmeticWaveformTest(unittest.TestCase): def test_from_operator(self): @@ -893,6 +968,8 @@ def test_simple_properties(self): self.assertEqual(lhs.duration, arith.duration) self.assertEqual(('-', lhs, rhs), arith.compare_key) + self.assertEqual(('-', lhs._compare_subset_key({'a','b'}), rhs._compare_subset_key({'a','b'})), + arith._compare_subset_key({'a','b'})) def test_unsafe_get_subset_for_channels(self): lhs = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'c'}) @@ -1007,6 +1084,14 @@ def test_repr(self): r = repr(wf) self.assertEqual(wf, eval(r)) + def test_compare_subset(self): + wf1a = FunctionWaveform(ExpressionScalar('1+2*t'), 3, channel='A') + wf1b = FunctionWaveform(ExpressionScalar('t*2+1'), 3, channel='A') + + self.assertEqual(wf1a._compare_subset_key({'A',}), (wf1a._expression,wf1a._duration)) + self.assertEqual(wf1a._compare_subset_key({'A',}), + wf1b._compare_subset_key({'A',}),) + class FunctorWaveformTests(unittest.TestCase): def test_duration(self): @@ -1075,6 +1160,15 @@ def test_compare_key(self): self.assertNotEqual(wf11, wf21) self.assertNotEqual(wf11, wf22) + def test_compare_subset(self): + inner_wf_1 = DummyWaveform(defined_channels={'A', 'B'}) + functors_1 = dict(A=np.positive, B=np.negative) + + wf11 = FunctorWaveform(inner_wf_1, functors_1) + + self.assertEqual((inner_wf_1._compare_subset_key({'A',}), frozenset(functors_1.items())), + wf11._compare_subset_key({'A',})) + class ReversedWaveformTest(unittest.TestCase): def test_simple_properties(self): @@ -1084,6 +1178,7 @@ def test_simple_properties(self): self.assertEqual(dummy_wf.duration, reversed_wf.duration) self.assertEqual(dummy_wf.defined_channels, reversed_wf.defined_channels) self.assertEqual(dummy_wf.compare_key, reversed_wf.compare_key) + self.assertEqual(reversed_wf._compare_subset_key({'A',}), (dummy_wf._compare_subset_key({'A',}),'-')) self.assertNotEqual(reversed_wf, dummy_wf) def test_reversed_sample(self): @@ -1103,4 +1198,4 @@ def test_reversed_sample(self): np.testing.assert_equal(output, sample_output[::-1]) np.testing.assert_equal(dummy_wf.sample_calls, [ ('A', list(1.5 - time_array[::-1]), None), - ('A', list(1.5 - time_array[::-1]), mem[::-1])]) + ('A', list(1.5 - time_array[::-1]), mem[::-1])]) \ No newline at end of file From 627a27f4c3066fb603fb29dc2b0f28becb8ac1c2 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Mon, 3 Jun 2024 14:33:54 +0200 Subject: [PATCH 05/35] Update setup.cfg --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index a77b1e90..76b607da 100644 --- a/setup.cfg +++ b/setup.cfg @@ -71,7 +71,7 @@ qctoolkit = *.pyi [tool:pytest] -testpaths = tests tests/pulses tests/hardware tests/backward_compatibility +testpaths = tests tests/pulses tests/hardware tests/backward_compatibility tests/_program tests/expressions tests/program tests/utils python_files=*_tests.py *_bug.py filterwarnings = # syntax is action:message_regex:category:module_regex:lineno From 06aa1b36e914dc549111af2773d034ccf5cd880f Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Mon, 3 Jun 2024 14:46:24 +0200 Subject: [PATCH 06/35] delete wrong decorator --- tests/pulses/sequencing_dummies.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 99358dab..46895643 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -60,7 +60,6 @@ def compare_key(self) -> Any: else: return id(self) - @property def _compare_subset_key(self, channel_subset) -> Any: assert self.channels==channel_subset if self.sample_output is not None: From f87b1ea706d2482188ed513290bc75e7fe4709d9 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Mon, 3 Jun 2024 15:56:41 +0200 Subject: [PATCH 07/35] somewhat fix the tests --- qupulse/program/waveforms.py | 3 ++- tests/_program/waveforms_tests.py | 39 +++++++++++++++++++++--------- tests/pulses/sequencing_dummies.py | 2 +- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/qupulse/program/waveforms.py b/qupulse/program/waveforms.py index e1f623c8..0d044953 100644 --- a/qupulse/program/waveforms.py +++ b/qupulse/program/waveforms.py @@ -223,7 +223,8 @@ def reversed(self) -> 'Waveform': @abstractmethod def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Hashable: - """key for hashing *without* channel reference + """key for hashing *without* channel reference. Don't call directly, + only via _hash_only_subset. """ def _hash_only_subset(self, channel_subset: Set[ChannelID]) -> int: diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index 6aba9b1e..21fdcbb6 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -13,8 +13,8 @@ from qupulse.program.transformation import LinearTransformation from qupulse.expressions import ExpressionScalar, Expression -from tests.pulses.sequencing_dummies import DummyWaveform, DummyInterpolationStrategy -from tests._program.transformation_tests import TransformationStub +# from tests.pulses.sequencing_dummies import DummyWaveform, DummyInterpolationStrategy +# from tests._program.transformation_tests import TransformationStub def assert_constant_consistent(test_case: unittest.TestCase, wf: Waveform): @@ -521,7 +521,8 @@ def test_repr(self): def test_compare_subset(self): body_wf = DummyWaveform(defined_channels={'a'}) wf = SequenceWaveform([body_wf, body_wf]) - self.assertEqual(wf._compare_subset_key({'a',}), tuple(2*[body_wf._compare_subset_key({'a',}),])) + self.assertEqual(wf.get_subset_for_channels({'a'})._compare_subset_key({'a',}), + tuple(2*[body_wf.get_subset_for_channels({'a'})._compare_subset_key({'a',}),])) class ConstantWaveformTests(unittest.TestCase): @@ -695,7 +696,7 @@ def test_simple_properties(self): def test_hash_subset(self): - interp = 'jump' + interp = HoldInterpolationStrategy() entries = (TableWaveformEntry(0, 0, interp), TableWaveformEntry(2.1, -33.2, interp), TableWaveformEntry(5.7, 123.4, interp)) @@ -908,8 +909,8 @@ def test_compare_subset(self): inner_wf = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'c'}) subset_wf = SubsetWaveform(inner_wf, {'a', 'c'}) - self.assertEqual(subset_wf._compare_subset_key({'a', 'c'}), - inner_wf._compare_subset_key({'a', 'c'}),) + self.assertEqual(subset_wf._compare_subset_key({'a', 'b', 'c'}), + inner_wf._compare_subset_key({'a', 'b', 'c'}),) class ArithmeticWaveformTest(unittest.TestCase): @@ -966,10 +967,22 @@ def test_simple_properties(self): self.assertIs(rhs, arith.rhs) self.assertEqual('-', arith.arithmetic_operator) self.assertEqual(lhs.duration, arith.duration) - + + def test_compare_subset(self): + lhs_1 = DummyWaveform(duration=1.5, defined_channels={'a',}) + lhs_2 = DummyWaveform(duration=1.5, defined_channels={'b',}) + lhs_3 = DummyWaveform(duration=1.5, defined_channels={'c'}) + rhs_1 = DummyWaveform(duration=1.5, defined_channels={'a',}) + rhs_2 = DummyWaveform(duration=1.5, defined_channels={'b',}) + rhs_3 = DummyWaveform(duration=1.5, defined_channels={'d'}) + + lhs = MultiChannelWaveform([lhs_1,lhs_2,lhs_3]) + rhs = MultiChannelWaveform([rhs_1,rhs_2,rhs_3]) + arith = ArithmeticWaveform(lhs, '-', rhs) + self.assertEqual(('-', lhs, rhs), arith.compare_key) - self.assertEqual(('-', lhs._compare_subset_key({'a','b'}), rhs._compare_subset_key({'a','b'})), - arith._compare_subset_key({'a','b'})) + self.assertEqual(('-', lhs._compare_subset_key({'a','b',}),rhs._compare_subset_key({'a','b',})), + arith._compare_subset_key({'a','b',})) def test_unsafe_get_subset_for_channels(self): lhs = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'c'}) @@ -1166,8 +1179,9 @@ def test_compare_subset(self): wf11 = FunctorWaveform(inner_wf_1, functors_1) - self.assertEqual((inner_wf_1._compare_subset_key({'A',}), frozenset(functors_1.items())), - wf11._compare_subset_key({'A',})) + self.assertEqual((inner_wf_1._compare_subset_key({'A', 'B'}), + frozenset(functors_1.items())), + wf11._compare_subset_key({'A', 'B'})) class ReversedWaveformTest(unittest.TestCase): @@ -1178,7 +1192,8 @@ def test_simple_properties(self): self.assertEqual(dummy_wf.duration, reversed_wf.duration) self.assertEqual(dummy_wf.defined_channels, reversed_wf.defined_channels) self.assertEqual(dummy_wf.compare_key, reversed_wf.compare_key) - self.assertEqual(reversed_wf._compare_subset_key({'A',}), (dummy_wf._compare_subset_key({'A',}),'-')) + self.assertEqual(reversed_wf._compare_subset_key({'A','B'}), + (dummy_wf._compare_subset_key({'A','B'}),'-')) self.assertNotEqual(reversed_wf, dummy_wf) def test_reversed_sample(self): diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 46895643..3d2ccaba 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -43,7 +43,7 @@ def __init__(self, duration: Union[float, TimeType]=0, sample_output: Union[nump defined_channels = set(sample_output.keys()) else: defined_channels = {'A'} - self.defined_channels_ = defined_channels + self.defined_channels_ = self.channels = defined_channels self.sample_calls = [] @property From e6b1e9e485ad54049606e3222609b605858fecd7 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Tue, 4 Jun 2024 10:27:36 +0200 Subject: [PATCH 08/35] uncomment imports --- tests/_program/waveforms_tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index 21fdcbb6..f38631e8 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -13,8 +13,8 @@ from qupulse.program.transformation import LinearTransformation from qupulse.expressions import ExpressionScalar, Expression -# from tests.pulses.sequencing_dummies import DummyWaveform, DummyInterpolationStrategy -# from tests._program.transformation_tests import TransformationStub +from tests.pulses.sequencing_dummies import DummyWaveform, DummyInterpolationStrategy +from tests._program.transformation_tests import TransformationStub def assert_constant_consistent(test_case: unittest.TestCase, wf: Waveform): From 3c65a9c5d56292bfe2494121d4d3a8cac4817e82 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Wed, 5 Jun 2024 10:18:00 +0200 Subject: [PATCH 09/35] forward voltage increment resolution --- qupulse/hardware/awgs/base.py | 9 +++++++-- qupulse/program/linspace.py | 6 ++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index 5b1bb7c7..a580eb48 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -17,7 +17,7 @@ from qupulse.hardware.util import get_sample_times, not_none_indices from qupulse.utils.types import ChannelID from qupulse.program.linspace import LinSpaceNode, LinSpaceArbitraryWaveform, to_increment_commands, Command, \ - Increment, Set as LSPSet, LoopLabel, LoopJmp, Wait, Play + Increment, Set as LSPSet, LoopLabel, LoopJmp, Wait, Play, DEFAULT_INCREMENT_RESOLUTION from qupulse.program.loop import Loop from qupulse.program.waveforms import Waveform from qupulse.comparable import Comparable @@ -191,6 +191,7 @@ def __init__(self, program: AllowedProgramTypes, voltage_transformations: Tuple[Optional[Callable], ...], sample_rate: TimeType, waveforms: Sequence[Waveform] = None, + voltage_resolution: Optional[float] = None, program_type: _ProgramType = _ProgramType.Loop): """ @@ -204,6 +205,8 @@ def __init__(self, program: AllowedProgramTypes, sample_rate: waveforms: These waveforms are sampled and stored in _waveforms. If None the waveforms are extracted from loop + voltage_resolution: voltage resolution for LinSpaceProgram, i.e. 2**(-16) for 16 bit AWG + program_type: type of program from _ProgramType, determined by the ProgramBuilder used. """ assert len(channels) == len(amplitudes) == len(offsets) == len(voltage_transformations) @@ -219,7 +222,9 @@ def __init__(self, program: AllowedProgramTypes, self._program = program if program_type == _ProgramType.Linspace: - self._transformed_commands = self._transform_linspace_commands(to_increment_commands(program)) + #!!! the voltage resolution may not be adequately represented if voltage transformations are not None? + self._transformed_commands = self._transform_linspace_commands( + to_increment_commands(program,resolution=voltage_resolution if voltage_resolution is not None else DEFAULT_INCREMENT_RESOLUTION)) if waveforms is None: if program_type is _ProgramType.Loop: diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 9ed021a1..d5c85bac 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -402,9 +402,11 @@ def add_node(self, node: Union[LinSpaceNode, Sequence[LinSpaceNode]]): raise TypeError("The node type is not handled", type(node), node) -def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode]) -> List[Command]: +def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode], + resolution: float = DEFAULT_INCREMENT_RESOLUTION + ) -> List[Command]: """translate the given linspace node tree to a minimal sequence of set and increment commands as well as loops.""" - state = _TranslationState() + state = _TranslationState(resolution=resolution) state.add_node(linspace_nodes) return state.commands From 634cd3097aa7c3a5d61058b53b11964a8dde5c93 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Wed, 5 Jun 2024 14:06:31 +0200 Subject: [PATCH 10/35] wrongly assumed hardware resolution --- qupulse/program/linspace.py | 1 + 1 file changed, 1 insertion(+) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index d5c85bac..c3f06590 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -406,6 +406,7 @@ def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode], resolution: float = DEFAULT_INCREMENT_RESOLUTION ) -> List[Command]: """translate the given linspace node tree to a minimal sequence of set and increment commands as well as loops.""" + if resolution: raise NotImplementedError('wrongly assumed resolution. need to fix') state = _TranslationState(resolution=resolution) state.add_node(linspace_nodes) return state.commands From 7d6b01a9692f9ec848c03d2deea8ef5a9f8c6e17 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Thu, 6 Jun 2024 10:29:45 +0200 Subject: [PATCH 11/35] draft dependent waits & dependency domains --- qupulse/hardware/awgs/base.py | 8 ++- qupulse/program/linspace.py | 100 +++++++++++++++++++++++++--------- 2 files changed, 81 insertions(+), 27 deletions(-) diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index a580eb48..5d524c61 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -221,10 +221,11 @@ def __init__(self, program: AllowedProgramTypes, self._program_type = program_type self._program = program + self._voltage_resolution = voltage_resolution + if program_type == _ProgramType.Linspace: #!!! the voltage resolution may not be adequately represented if voltage transformations are not None? - self._transformed_commands = self._transform_linspace_commands( - to_increment_commands(program,resolution=voltage_resolution if voltage_resolution is not None else DEFAULT_INCREMENT_RESOLUTION)) + self._transformed_commands = self._transform_linspace_commands(to_increment_commands(program,self._voltage_resolution)) if waveforms is None: if program_type is _ProgramType.Loop: @@ -272,10 +273,13 @@ def _channel_transformations(self) -> Mapping[ChannelID, ChannelTransformation]: def _transform_linspace_commands(self, command_list: List[Command]) -> List[Command]: # all commands = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play] + TODO: voltage resolution and sample rate time->sample conversion + trafos_by_channel_idx = list(self._channel_transformations().values()) for command in command_list: if isinstance(command, (LoopLabel, LoopJmp, Play, Wait)): + # play is handled by transforming the sampled waveform continue elif isinstance(command, Increment): diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index c3f06590..77e68170 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -4,6 +4,7 @@ import numpy as np from dataclasses import dataclass from typing import Mapping, Optional, Sequence, ContextManager, Iterable, Tuple, Union, Dict, List, Iterator +from enum import Enum from qupulse import ChannelID, MeasurementWindow from qupulse.parameter_scope import Scope, MappedScope, FrozenDict @@ -13,7 +14,21 @@ # this resolution is used to unify increments # the increments themselves remain floats +# !!! translated: this is NOT a hardware resolution, +# just a programmatic 'small epsilon' to avoid rounding errors. DEFAULT_INCREMENT_RESOLUTION: float = 1e-9 +DEFAULT_TIME_RESOLUTION: float = 1e-3 + +class DepDomain(Enum): + VOLTAGE = 0 + TIME_LIN = -1 + TIME_LOG = -2 + FREQUENCY = -3 + NODEP = None + +# class DepStrategy(Enum): +# CONSTANT = 0 +# VARIABLE = 1 @dataclass(frozen=True) @@ -24,14 +39,25 @@ class DepKey: These objects allow backends which support it to track multiple amplitudes at once. """ factors: Tuple[int, ...] - + domain: DepDomain + # strategy: DepStrategy + @classmethod - def from_voltages(cls, voltages: Sequence[float], resolution: float): + def from_domain(cls, factors, resolution, domain): # remove trailing zeros - while voltages and voltages[-1] == 0: - voltages = voltages[:-1] - return cls(tuple(int(round(voltage / resolution)) for voltage in voltages)) - + while factors and factors[-1] == 0: + factors = factors[:-1] + return cls(tuple(int(round(factor / resolution)) for factor in factors), + domain) + + @classmethod + def from_voltages(cls, voltages: Sequence[float], resolution: float): + return cls.from_domain(voltages, resolution, DepDomain.VOLTAGE) + + @classmethod + def from_lin_times(cls, times: Sequence[float], resolution: float): + return cls.from_domain(times, resolution, DepDomain.TIME_LIN) + @dataclass class LinSpaceNode: @@ -139,6 +165,9 @@ def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, Hard ranges = self._get_ranges() factors = [] bases = [] + duration_base = duration + duration_factors = None + for value in voltages: if isinstance(value, float): bases.append(value) @@ -160,11 +189,21 @@ def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, Hard bases.append(base) if isinstance(duration, SimpleExpression): - duration_factors = duration.offsets + # duration_factors = duration.offsets + # duration_base = duration.base + duration_offsets = duration.offsets duration_base = duration.base - else: - duration_base = duration - duration_factors = None + duration_factors = [] + for rng_name, rng in ranges.items(): + start = TimeType(0) + step = TimeType(0) + offset = duration_offsets.get(rng_name, None) + if offset: + start += rng.start * offset + step += rng.step * offset + duration_base += start + duration_factors.append(step) + set_cmd = LinSpaceHold(bases=tuple(bases), factors=tuple(factors), @@ -223,21 +262,22 @@ class LoopLabel: @dataclass class Increment: - channel: int - value: float + channel: Optional[int] + value: Union[float,TimeType] dependency_key: DepKey @dataclass class Set: - channel: int - value: float - key: DepKey = dataclasses.field(default_factory=lambda: DepKey(())) + channel: Optional[int] + value: Union[float,TimeType] + key: DepKey = dataclasses.field(default_factory=lambda: DepKey((),DepDomain.NODEP)) @dataclass class Wait: - duration: TimeType + duration: Optional[TimeType] + key: DepKey = dataclasses.field(default_factory=lambda: DepKey((),DepDomain.NODEP)) @dataclass @@ -296,6 +336,7 @@ class _TranslationState: dep_states: Dict[int, Dict[DepKey, DepState]] = dataclasses.field(default_factory=dict) plain_voltage: Dict[int, float] = dataclasses.field(default_factory=dict) resolution: float = dataclasses.field(default_factory=lambda: DEFAULT_INCREMENT_RESOLUTION) + resolution_time: float = dataclasses.field(default_factory=lambda: DEFAULT_TIME_RESOLUTION) def new_loop(self, count: int): label = LoopLabel(self.label_num, count) @@ -311,7 +352,7 @@ def get_dependency_state(self, dependencies: Mapping[int, set]): } def set_voltage(self, channel: int, value: float): - key = DepKey(()) + key = DepKey((),DepDomain.VOLTAGE) if self.active_dep.get(channel, None) != key or self.plain_voltage.get(channel, None) != value: self.commands.append(Set(channel, value, key)) self.active_dep[channel] = key @@ -343,8 +384,14 @@ def _add_iteration_node(self, node: LinSpaceIter): self.add_node(node.body) self.commands.append(jmp) self.iterations.pop() - + def _set_indexed_voltage(self, channel: int, base: float, factors: Sequence[float]): + self.set_indexed_value(channel, base, factors, domain=DepDomain.VOLTAGE) + + def _set_indexed_lin_time(self, base: TimeType, factors: Sequence[TimeType]): + self.set_indexed_value(DepDomain.TIME_LIN.value, base, factors, domain=DepDomain.TIME_LIN) + + def set_indexed_value(self, channel, base, factors, domain): dep_key = DepKey.from_voltages(voltages=factors, resolution=self.resolution) new_dep_state = DepState( base, @@ -365,20 +412,23 @@ def _set_indexed_voltage(self, channel: int, base: float, factors: Sequence[floa self.commands.append(Increment(channel, inc, dep_key)) self.active_dep[channel] = dep_key self.dep_states[channel][dep_key] = new_dep_state - + def _add_hold_node(self, node: LinSpaceHold): - if node.duration_factors: - raise NotImplementedError("TODO") for ch, (base, factors) in enumerate(zip(node.bases, node.factors)): if factors is None: self.set_voltage(ch, base) continue - else: self._set_indexed_voltage(ch, base, factors) + + if node.duration_factors: + self._set_indexed_lin_time(node.duration_base,) + # raise NotImplementedError("TODO") - self.commands.append(Wait(node.duration_base)) + self.commands.append(Wait(None, self.active_dep[DepDomain.TIME_LIN.value])) + else: + self.commands.append(Wait(node.duration_base)) def add_node(self, node: Union[LinSpaceNode, Sequence[LinSpaceNode]]): """Translate a (sequence of) linspace node(s) to commands and add it to the internal command list.""" @@ -406,8 +456,8 @@ def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode], resolution: float = DEFAULT_INCREMENT_RESOLUTION ) -> List[Command]: """translate the given linspace node tree to a minimal sequence of set and increment commands as well as loops.""" - if resolution: raise NotImplementedError('wrongly assumed resolution. need to fix') - state = _TranslationState(resolution=resolution) + # if resolution: raise NotImplementedError('wrongly assumed resolution. need to fix') + state = _TranslationState(resolution=resolution if voltage_resolution is not None else DEFAULT_INCREMENT_RESOLUTION) state.add_node(linspace_nodes) return state.commands From 1ea517670c25b761cc05cc80869468241eeedfc6 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Thu, 6 Jun 2024 16:01:30 +0200 Subject: [PATCH 12/35] first syntactic debug --- qupulse/hardware/awgs/base.py | 2 +- qupulse/program/linspace.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index 5d524c61..f13f5e36 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -273,7 +273,7 @@ def _channel_transformations(self) -> Mapping[ChannelID, ChannelTransformation]: def _transform_linspace_commands(self, command_list: List[Command]) -> List[Command]: # all commands = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play] - TODO: voltage resolution and sample rate time->sample conversion + # TODO: voltage resolution trafos_by_channel_idx = list(self._channel_transformations().values()) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 77e68170..9039b256 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -346,7 +346,7 @@ def new_loop(self, count: int): def get_dependency_state(self, dependencies: Mapping[int, set]): return { - self.dep_states.get(ch, {}).get(DepKey.from_voltages(dep, self.resolution), None) + self.dep_states.get(ch, {}).get(DepKey.from_domain(dep, self.resolution), None) for ch, deps in dependencies.items() for dep in deps } @@ -386,13 +386,14 @@ def _add_iteration_node(self, node: LinSpaceIter): self.iterations.pop() def _set_indexed_voltage(self, channel: int, base: float, factors: Sequence[float]): - self.set_indexed_value(channel, base, factors, domain=DepDomain.VOLTAGE) + key = DepKey.from_voltages(voltages=factors, resolution=self.resolution) + self.set_indexed_value(key, channel, base, factors, domain=DepDomain.VOLTAGE) def _set_indexed_lin_time(self, base: TimeType, factors: Sequence[TimeType]): - self.set_indexed_value(DepDomain.TIME_LIN.value, base, factors, domain=DepDomain.TIME_LIN) + key = DepKey.from_lin_times(times=factors, resolution=self.resolution) + self.set_indexed_value(key, DepDomain.TIME_LIN.value, base, factors, domain=DepDomain.TIME_LIN) - def set_indexed_value(self, channel, base, factors, domain): - dep_key = DepKey.from_voltages(voltages=factors, resolution=self.resolution) + def set_indexed_value(self, dep_key, channel, base, factors, domain): new_dep_state = DepState( base, iterations=tuple(self.iterations) @@ -423,9 +424,8 @@ def _add_hold_node(self, node: LinSpaceHold): self._set_indexed_voltage(ch, base, factors) if node.duration_factors: - self._set_indexed_lin_time(node.duration_base,) + self._set_indexed_lin_time(node.duration_base,node.duration_factors) # raise NotImplementedError("TODO") - self.commands.append(Wait(None, self.active_dep[DepDomain.TIME_LIN.value])) else: self.commands.append(Wait(node.duration_base)) @@ -457,7 +457,7 @@ def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode], ) -> List[Command]: """translate the given linspace node tree to a minimal sequence of set and increment commands as well as loops.""" # if resolution: raise NotImplementedError('wrongly assumed resolution. need to fix') - state = _TranslationState(resolution=resolution if voltage_resolution is not None else DEFAULT_INCREMENT_RESOLUTION) + state = _TranslationState(resolution=resolution if resolution is not None else DEFAULT_INCREMENT_RESOLUTION) state.add_node(linspace_nodes) return state.commands From 37ee51d3624a2eecffd705b0c581e87ccd1816b8 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Thu, 6 Jun 2024 18:10:18 +0200 Subject: [PATCH 13/35] fix definition of iterations --- qupulse/program/linspace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 9039b256..ac1cd5ae 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -378,7 +378,7 @@ def _add_iteration_node(self, node: LinSpaceIter): self.add_node(node.body) if node.length > 1: - self.iterations[-1] = node.length + self.iterations[-1] = node.length - 1 label, jmp = self.new_loop(node.length - 1) self.commands.append(label) self.add_node(node.body) From b37b510f0e2508fce9c4b377b3879aa62adb0553 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Sat, 8 Jun 2024 00:11:05 +0200 Subject: [PATCH 14/35] see if one can replace int->ChannelID --- qupulse/program/linspace.py | 109 +++++++++++++++++++++++------------- 1 file changed, 69 insertions(+), 40 deletions(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index ac1cd5ae..b1f2612f 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -10,7 +10,7 @@ from qupulse.parameter_scope import Scope, MappedScope, FrozenDict from qupulse.program import ProgramBuilder, HardwareTime, HardwareVoltage, Waveform, RepetitionCount, TimeType from qupulse.expressions.simple import SimpleExpression -from qupulse.program.waveforms import MultiChannelWaveform +from qupulse.program.waveforms import MultiChannelWaveform, TransformingWaveform # this resolution is used to unify increments # the increments themselves remain floats @@ -26,6 +26,8 @@ class DepDomain(Enum): FREQUENCY = -3 NODEP = None +GeneralizedChannel = Union[DepDomain,ChannelID] + # class DepStrategy(Enum): # CONSTANT = 0 # VARIABLE = 1 @@ -62,32 +64,41 @@ def from_lin_times(cls, times: Sequence[float], resolution: float): @dataclass class LinSpaceNode: """AST node for a program that supports linear spacing of set points as well as nested sequencing and repetitions""" - - def dependencies(self) -> Mapping[int, set]: + + def dependencies(self) -> Mapping[GeneralizedChannel, set]: raise NotImplementedError +@dataclass +class LinSpaceNodeChannelSpecific(LinSpaceNode): + + channels: Tuple[GeneralizedChannel, ...] + + @property + def play_channels(self) -> Tuple[ChannelID, ...]: + return tuple(ch for ch in self.channels if isinstance(ch,ChannelID)) + @dataclass -class LinSpaceHold(LinSpaceNode): +class LinSpaceHold(LinSpaceNodeChannelSpecific): """Hold voltages for a given time. The voltages and the time may depend on the iteration index.""" - bases: Tuple[float, ...] - factors: Tuple[Optional[Tuple[float, ...]], ...] + bases: Dict[GeneralizedChannel, float] + factors: Dict[GeneralizedChannel, Optional[Tuple[float, ...]]] duration_base: TimeType duration_factors: Optional[Tuple[TimeType, ...]] - def dependencies(self) -> Mapping[int, set]: + def dependencies(self) -> Mapping[GeneralizedChannel, set]: return {idx: {factors} - for idx, factors in enumerate(self.factors) + for idx, factors in self.factors.items() if factors} @dataclass -class LinSpaceArbitraryWaveform(LinSpaceNode): +class LinSpaceArbitraryWaveform(LinSpaceNodeChannelSpecific): """This is just a wrapper to pipe arbitrary waveforms through the system.""" waveform: Waveform - channels: Tuple[ChannelID, ...] + # channels: Tuple[ChannelID, ...] @dataclass @@ -132,10 +143,12 @@ class LinSpaceBuilder(ProgramBuilder): Arbitrary waveforms are not implemented yet """ - def __init__(self, channels: Tuple[ChannelID, ...]): + def __init__(self, + # channels: Tuple[ChannelID, ...] + ): super().__init__() - self._name_to_idx = {name: idx for idx, name in enumerate(channels)} - self._idx_to_name = channels + # self._name_to_idx = {name: idx for idx, name in enumerate(channels)} + # self._voltage_idx_to_name = channels self._stack = [[]] self._ranges = [] @@ -159,19 +172,19 @@ def _get_ranges(self): return dict(self._ranges) def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, HardwareVoltage]): - voltages = sorted((self._name_to_idx[ch_name], value) for ch_name, value in voltages.items()) - voltages = [value for _, value in voltages] + # voltages = sorted((self._name_to_idx[ch_name], value) for ch_name, value in voltages.items()) + # voltages = [value for _, value in voltages] ranges = self._get_ranges() - factors = [] - bases = [] + factors = {} + bases = {} duration_base = duration duration_factors = None - for value in voltages: + for ch_name,value in voltages.items(): if isinstance(value, float): - bases.append(value) - factors.append(None) + bases[ch_name] = value + factors[ch_name] = None continue offsets = value.offsets base = value.base @@ -185,8 +198,8 @@ def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, Hard step += rng.step * offset base += start incs.append(step) - factors.append(tuple(incs)) - bases.append(base) + factors[ch_name] = tuple(incs) + bases[ch_name] = base if isinstance(duration, SimpleExpression): # duration_factors = duration.offsets @@ -205,15 +218,29 @@ def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, Hard duration_factors.append(step) - set_cmd = LinSpaceHold(bases=tuple(bases), - factors=tuple(factors), + set_cmd = LinSpaceHold(channels=tuple(voltages.keys()), + bases=bases, + factors=factors, duration_base=duration_base, duration_factors=duration_factors) self._stack[-1].append(set_cmd) def play_arbitrary_waveform(self, waveform: Waveform): - return self._stack[-1].append(LinSpaceArbitraryWaveform(waveform, self._idx_to_name)) + if not isinstance(waveform,TransformingWaveform): + return self._stack[-1].append(LinSpaceArbitraryWaveform(waveform=waveform, + channels=waveform.defined_channels, + # self._voltage_idx_to_name + ) + ) + + #test for transformations that contain SimpleExpression + wf_transformation = waveform.transformation + + return self._stack[-1].append(LinSpaceArbitraryWaveform(waveform=waveform, + # self._voltage_idx_to_name + channels=waveform.defined_channels + )) def measure(self, measurements: Optional[Sequence[MeasurementWindow]]): """Ignores measurements""" @@ -262,14 +289,14 @@ class LoopLabel: @dataclass class Increment: - channel: Optional[int] + channel: Optional[GeneralizedChannel] value: Union[float,TimeType] dependency_key: DepKey @dataclass class Set: - channel: Optional[int] + channel: Optional[GeneralizedChannel] value: Union[float,TimeType] key: DepKey = dataclasses.field(default_factory=lambda: DepKey((),DepDomain.NODEP)) @@ -332,9 +359,9 @@ class _TranslationState: label_num: int = dataclasses.field(default=0) commands: List[Command] = dataclasses.field(default_factory=list) iterations: List[int] = dataclasses.field(default_factory=list) - active_dep: Dict[int, DepKey] = dataclasses.field(default_factory=dict) - dep_states: Dict[int, Dict[DepKey, DepState]] = dataclasses.field(default_factory=dict) - plain_voltage: Dict[int, float] = dataclasses.field(default_factory=dict) + active_dep: Dict[GeneralizedChannel, DepKey] = dataclasses.field(default_factory=dict) + dep_states: Dict[GeneralizedChannel, Dict[DepKey, DepState]] = dataclasses.field(default_factory=dict) + plain_voltage: Dict[ChannelID, float] = dataclasses.field(default_factory=dict) resolution: float = dataclasses.field(default_factory=lambda: DEFAULT_INCREMENT_RESOLUTION) resolution_time: float = dataclasses.field(default_factory=lambda: DEFAULT_TIME_RESOLUTION) @@ -344,14 +371,14 @@ def new_loop(self, count: int): self.label_num += 1 return label, jmp - def get_dependency_state(self, dependencies: Mapping[int, set]): + def get_dependency_state(self, dependencies: Mapping[GeneralizedChannel, set]): return { self.dep_states.get(ch, {}).get(DepKey.from_domain(dep, self.resolution), None) for ch, deps in dependencies.items() for dep in deps } - def set_voltage(self, channel: int, value: float): + def set_voltage(self, channel: ChannelID, value: float): key = DepKey((),DepDomain.VOLTAGE) if self.active_dep.get(channel, None) != key or self.plain_voltage.get(channel, None) != value: self.commands.append(Set(channel, value, key)) @@ -385,15 +412,17 @@ def _add_iteration_node(self, node: LinSpaceIter): self.commands.append(jmp) self.iterations.pop() - def _set_indexed_voltage(self, channel: int, base: float, factors: Sequence[float]): + def _set_indexed_voltage(self, channel: ChannelID, base: float, factors: Sequence[float]): key = DepKey.from_voltages(voltages=factors, resolution=self.resolution) self.set_indexed_value(key, channel, base, factors, domain=DepDomain.VOLTAGE) def _set_indexed_lin_time(self, base: TimeType, factors: Sequence[TimeType]): key = DepKey.from_lin_times(times=factors, resolution=self.resolution) - self.set_indexed_value(key, DepDomain.TIME_LIN.value, base, factors, domain=DepDomain.TIME_LIN) + self.set_indexed_value(key, DepDomain.TIME_LIN, base, factors, domain=DepDomain.TIME_LIN) - def set_indexed_value(self, dep_key, channel, base, factors, domain): + def set_indexed_value(self, dep_key: DepKey, channel: GeneralizedChannel, + base: Union[float,TimeType], factors: Sequence[Union[float,TimeType]], + domain: DepDomain): new_dep_state = DepState( base, iterations=tuple(self.iterations) @@ -416,17 +445,17 @@ def set_indexed_value(self, dep_key, channel, base, factors, domain): def _add_hold_node(self, node: LinSpaceHold): - for ch, (base, factors) in enumerate(zip(node.bases, node.factors)): - if factors is None: - self.set_voltage(ch, base) + for ch in node.play_channels: + if node.factors[ch] is None: + self.set_voltage(ch, node.bases[ch]) continue else: - self._set_indexed_voltage(ch, base, factors) + self._set_indexed_voltage(ch, node.bases[ch], node.factors[ch]) if node.duration_factors: self._set_indexed_lin_time(node.duration_base,node.duration_factors) # raise NotImplementedError("TODO") - self.commands.append(Wait(None, self.active_dep[DepDomain.TIME_LIN.value])) + self.commands.append(Wait(None, self.active_dep[DepDomain.TIME_LIN])) else: self.commands.append(Wait(node.duration_base)) From c796b61b908736293fbf57160596274b76b7aee2 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Sat, 8 Jun 2024 01:06:22 +0200 Subject: [PATCH 15/35] fix channel trafo call --- qupulse/hardware/awgs/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index f13f5e36..e6dd0da6 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -275,7 +275,7 @@ def _transform_linspace_commands(self, command_list: List[Command]) -> List[Comm # all commands = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play] # TODO: voltage resolution - trafos_by_channel_idx = list(self._channel_transformations().values()) + # trafos_by_channel_idx = list(self._channel_transformations().values()) for command in command_list: if isinstance(command, (LoopLabel, LoopJmp, Play, Wait)): @@ -283,12 +283,12 @@ def _transform_linspace_commands(self, command_list: List[Command]) -> List[Comm # play is handled by transforming the sampled waveform continue elif isinstance(command, Increment): - ch_trafo = trafos_by_channel_idx[command.channel] + ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: raise RuntimeError("Cannot apply a voltage transformation to a linspace increment command") command.value /= ch_trafo.amplitude elif isinstance(command, LSPSet): - ch_trafo = trafos_by_channel_idx[command.channel] + ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: command.value = float(ch_trafo.voltage_transformation(command.value)) command.value -= ch_trafo.offset From 6d4e835e8fe1b12bca6e7a49e82b5df7431ccdb5 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Sat, 8 Jun 2024 10:19:12 +0200 Subject: [PATCH 16/35] fix transform commands --- qupulse/hardware/awgs/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index e6dd0da6..6e683767 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -17,7 +17,7 @@ from qupulse.hardware.util import get_sample_times, not_none_indices from qupulse.utils.types import ChannelID from qupulse.program.linspace import LinSpaceNode, LinSpaceArbitraryWaveform, to_increment_commands, Command, \ - Increment, Set as LSPSet, LoopLabel, LoopJmp, Wait, Play, DEFAULT_INCREMENT_RESOLUTION + Increment, Set as LSPSet, LoopLabel, LoopJmp, Wait, Play, DEFAULT_INCREMENT_RESOLUTION, DepDomain from qupulse.program.loop import Loop from qupulse.program.waveforms import Waveform from qupulse.comparable import Comparable @@ -279,15 +279,18 @@ def _transform_linspace_commands(self, command_list: List[Command]) -> List[Comm for command in command_list: if isinstance(command, (LoopLabel, LoopJmp, Play, Wait)): - # play is handled by transforming the sampled waveform continue elif isinstance(command, Increment): + if command.dependency_key is not DepDomain.VOLTAGE: + continue ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: raise RuntimeError("Cannot apply a voltage transformation to a linspace increment command") command.value /= ch_trafo.amplitude elif isinstance(command, LSPSet): + if command.key is not DepDomain.VOLTAGE: + continue ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: command.value = float(ch_trafo.voltage_transformation(command.value)) From 9fd050c68dca87a4723dd39da7bef943ba0586ec Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Sat, 8 Jun 2024 12:08:38 +0200 Subject: [PATCH 17/35] resolution dependent set/increment --- qupulse/program/linspace.py | 65 ++++++++++++++++++++++++++++--------- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index b1f2612f..f9b3918e 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -3,13 +3,13 @@ import dataclasses import numpy as np from dataclasses import dataclass -from typing import Mapping, Optional, Sequence, ContextManager, Iterable, Tuple, Union, Dict, List, Iterator +from typing import Mapping, Optional, Sequence, ContextManager, Iterable, Tuple, Union, Dict, List, Iterator, Generic from enum import Enum from qupulse import ChannelID, MeasurementWindow from qupulse.parameter_scope import Scope, MappedScope, FrozenDict from qupulse.program import ProgramBuilder, HardwareTime, HardwareVoltage, Waveform, RepetitionCount, TimeType -from qupulse.expressions.simple import SimpleExpression +from qupulse.expressions.simple import SimpleExpression, NumVal from qupulse.program.waveforms import MultiChannelWaveform, TransformingWaveform # this resolution is used to unify increments @@ -28,9 +28,35 @@ class DepDomain(Enum): GeneralizedChannel = Union[DepDomain,ChannelID] -# class DepStrategy(Enum): -# CONSTANT = 0 -# VARIABLE = 1 + +class ResolutionDependentValue(Generic[NumVal]): + + def __init__(self, + bases: Sequence[NumVal], + multiplicities: Sequence[int], + offset: NumVal): + + self.bases = bases + self.multiplicities = multiplicities + self.offset = offset + self.__is_time = all(isinstance(b,TimeType) for b in bases) and isinstance(offset,TimeType) + + #this is not to circumvent float errors in python, but rounding errors from awg-increment commands. + #python float are thereby accurate enough if no awg with a 500 bit resolution is invented. + def __call__(self, resolution: Optional[float]) -> Union[NumVal,TimeType]: + #with resolution = None handle TimeType case? + if resolution is None: + assert self.__is_time + return sum(b*m for b,m in zip(self.bases,self.multiplicities)) + self.offset + #resolution as float value of granularity of base val. + #to avoid conflicts between positive and negative vals from casting half to even, + #use abs val + return sum(np.sign(b) * round(abs(b) / resolution) * m * resolution for b,m in zip(self.bases,self.multiplicities))\ + + np.sign(self.offset) * round(abs(self.offset) / resolution) * resolution + #cast the offset only once? + + def __bool__(self): + return any(bool(b) for b in self.bases) or bool(self.offset) @dataclass(frozen=True) @@ -59,7 +85,7 @@ def from_voltages(cls, voltages: Sequence[float], resolution: float): @classmethod def from_lin_times(cls, times: Sequence[float], resolution: float): return cls.from_domain(times, resolution, DepDomain.TIME_LIN) - + @dataclass class LinSpaceNode: @@ -290,14 +316,14 @@ class LoopLabel: @dataclass class Increment: channel: Optional[GeneralizedChannel] - value: Union[float,TimeType] + value: ResolutionDependentValue dependency_key: DepKey @dataclass class Set: channel: Optional[GeneralizedChannel] - value: Union[float,TimeType] + value: ResolutionDependentValue key: DepKey = dataclasses.field(default_factory=lambda: DepKey((),DepDomain.NODEP)) @@ -326,11 +352,13 @@ class DepState: base: float iterations: Tuple[int, ...] - def required_increment_from(self, previous: 'DepState', factors: Sequence[float]) -> float: + def required_increment_from(self, previous: 'DepState', + factors: Sequence[float]) -> ResolutionDependentValue: assert len(self.iterations) == len(previous.iterations) assert len(self.iterations) == len(factors) - increment = self.base - previous.base + # increment = self.base - previous.base + res_bases, res_mults, offset = [], [], self.base - previous.base for old, new, factor in zip(previous.iterations, self.iterations, factors): # By convention there are only two possible values for each integer here: 0 or the last index # The three possible increments are none, regular and jump to next line @@ -343,13 +371,18 @@ def required_increment_from(self, previous: 'DepState', factors: Sequence[float] assert old == 0 # regular iteration, although the new value will probably be > 1, the resulting increment will be # applied multiple times so only one factor is needed. - increment += factor - + # increment += factor + res_bases.append(factor) + res_mults.append(1) + else: assert new == 0 # we need to jump back. The old value gives us the number of increments to reverse - increment -= factor * old - return increment + # increment -= factor * old + res_bases.append(-factor) + res_mults.append(old) + + return ResolutionDependentValue(res_bases,res_mults,offset) @dataclass @@ -381,7 +414,7 @@ def get_dependency_state(self, dependencies: Mapping[GeneralizedChannel, set]): def set_voltage(self, channel: ChannelID, value: float): key = DepKey((),DepDomain.VOLTAGE) if self.active_dep.get(channel, None) != key or self.plain_voltage.get(channel, None) != value: - self.commands.append(Set(channel, value, key)) + self.commands.append(Set(channel, ResolutionDependentValue((),(),offset=value), key)) self.active_dep[channel] = key self.plain_voltage[channel] = value @@ -431,7 +464,7 @@ def set_indexed_value(self, dep_key: DepKey, channel: GeneralizedChannel, current_dep_state = self.dep_states.setdefault(channel, {}).get(dep_key, None) if current_dep_state is None: assert all(it == 0 for it in self.iterations) - self.commands.append(Set(channel, base, dep_key)) + self.commands.append(Set(channel, ResolutionDependentValue((),(),offset=base), dep_key)) self.active_dep[channel] = dep_key else: From 051e8b9dd84d1c7acab591d4b70f39749100da02 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Sat, 8 Jun 2024 16:26:58 +0200 Subject: [PATCH 18/35] remove outdates resolution handling attempt --- qupulse/hardware/awgs/base.py | 8 ++++---- qupulse/program/linspace.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index 6e683767..045446aa 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -191,7 +191,7 @@ def __init__(self, program: AllowedProgramTypes, voltage_transformations: Tuple[Optional[Callable], ...], sample_rate: TimeType, waveforms: Sequence[Waveform] = None, - voltage_resolution: Optional[float] = None, + # voltage_resolution: Optional[float] = None, program_type: _ProgramType = _ProgramType.Loop): """ @@ -205,7 +205,7 @@ def __init__(self, program: AllowedProgramTypes, sample_rate: waveforms: These waveforms are sampled and stored in _waveforms. If None the waveforms are extracted from loop - voltage_resolution: voltage resolution for LinSpaceProgram, i.e. 2**(-16) for 16 bit AWG + # voltage_resolution: voltage resolution for LinSpaceProgram, i.e. 2**(-16) for 16 bit AWG program_type: type of program from _ProgramType, determined by the ProgramBuilder used. """ assert len(channels) == len(amplitudes) == len(offsets) == len(voltage_transformations) @@ -221,11 +221,11 @@ def __init__(self, program: AllowedProgramTypes, self._program_type = program_type self._program = program - self._voltage_resolution = voltage_resolution + # self._voltage_resolution = voltage_resolution if program_type == _ProgramType.Linspace: #!!! the voltage resolution may not be adequately represented if voltage transformations are not None? - self._transformed_commands = self._transform_linspace_commands(to_increment_commands(program,self._voltage_resolution)) + self._transformed_commands = self._transform_linspace_commands(to_increment_commands(program,)) if waveforms is None: if program_type is _ProgramType.Loop: diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index f9b3918e..aa74ea48 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -39,14 +39,14 @@ def __init__(self, self.bases = bases self.multiplicities = multiplicities self.offset = offset - self.__is_time = all(isinstance(b,TimeType) for b in bases) and isinstance(offset,TimeType) + self.__is_time_or_int = all(isinstance(b,(TimeType,int)) for b in bases) and isinstance(offset,(TimeType,int)) #this is not to circumvent float errors in python, but rounding errors from awg-increment commands. #python float are thereby accurate enough if no awg with a 500 bit resolution is invented. def __call__(self, resolution: Optional[float]) -> Union[NumVal,TimeType]: - #with resolution = None handle TimeType case? + #with resolution = None handle TimeType/int case? if resolution is None: - assert self.__is_time + assert self.__is_time_or_int return sum(b*m for b,m in zip(self.bases,self.multiplicities)) + self.offset #resolution as float value of granularity of base val. #to avoid conflicts between positive and negative vals from casting half to even, @@ -515,11 +515,11 @@ def add_node(self, node: Union[LinSpaceNode, Sequence[LinSpaceNode]]): def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode], - resolution: float = DEFAULT_INCREMENT_RESOLUTION + # resolution: float = DEFAULT_INCREMENT_RESOLUTION ) -> List[Command]: """translate the given linspace node tree to a minimal sequence of set and increment commands as well as loops.""" # if resolution: raise NotImplementedError('wrongly assumed resolution. need to fix') - state = _TranslationState(resolution=resolution if resolution is not None else DEFAULT_INCREMENT_RESOLUTION) + state = _TranslationState() state.add_node(linspace_nodes) return state.commands From e58277ced51b740c157eb062c7c567f8e1f5f374 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Sat, 8 Jun 2024 17:22:06 +0200 Subject: [PATCH 19/35] math methods for resolution class --- qupulse/program/linspace.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index aa74ea48..709d1358 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -58,7 +58,32 @@ def __call__(self, resolution: Optional[float]) -> Union[NumVal,TimeType]: def __bool__(self): return any(bool(b) for b in self.bases) or bool(self.offset) - + def __add__(self, other): + # this should happen in the context of an offset being added to it, not the bases being modified. + if isinstance(other, (float, int, TimeType)): + return ResolutionDependentValue(self.bases,self.multiplicities,self.offset+other) + return NotImplemented + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + return self.__add__(-other) + + def __mul__(self, other): + # this should happen when the amplitude is being scaled + if isinstance(other, (float, int, TimeType)): + return ResolutionDependentValue(self.bases*other,self.multiplicities,self.offset*other) + return NotImplemented + + def __rmul__(self,other): + return self.__mul__(other) + + def __truediv__(self,other): + return self.__mul__(1/other) + + + @dataclass(frozen=True) class DepKey: """The key that identifies how a certain set command depends on iteration indices. The factors are rounded with a From cf74051f4b42132aa97ac58ebb28f14cdb930c05 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Sat, 8 Jun 2024 17:22:14 +0200 Subject: [PATCH 20/35] fix domain check --- qupulse/hardware/awgs/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index 045446aa..5573b654 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -282,14 +282,14 @@ def _transform_linspace_commands(self, command_list: List[Command]) -> List[Comm # play is handled by transforming the sampled waveform continue elif isinstance(command, Increment): - if command.dependency_key is not DepDomain.VOLTAGE: + if command.dependency_key.domain is not DepDomain.VOLTAGE: continue ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: raise RuntimeError("Cannot apply a voltage transformation to a linspace increment command") command.value /= ch_trafo.amplitude elif isinstance(command, LSPSet): - if command.key is not DepDomain.VOLTAGE: + if command.key.domain is not DepDomain.VOLTAGE: continue ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: From 2447e23c3c2012025ca91c8be792f779b1e91a11 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Sat, 8 Jun 2024 17:24:51 +0200 Subject: [PATCH 21/35] fix __mul__ --- qupulse/program/linspace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 709d1358..ecd26d18 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -73,7 +73,7 @@ def __sub__(self, other): def __mul__(self, other): # this should happen when the amplitude is being scaled if isinstance(other, (float, int, TimeType)): - return ResolutionDependentValue(self.bases*other,self.multiplicities,self.offset*other) + return ResolutionDependentValue(tuple(b*other for b in self.bases),self.multiplicities,self.offset*other) return NotImplemented def __rmul__(self,other): From a9f17101d4fab259bd2eab5dfd99ab9be79df1c2 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Tue, 11 Jun 2024 09:02:17 +0200 Subject: [PATCH 22/35] test wf amp sweep --- qupulse/expressions/simple.py | 5 +- qupulse/hardware/awgs/base.py | 6 +- qupulse/program/linspace.py | 180 ++++++++++++++++++++++++++---- qupulse/program/transformation.py | 38 ++++++- 4 files changed, 201 insertions(+), 28 deletions(-) diff --git a/qupulse/expressions/simple.py b/qupulse/expressions/simple.py index e5ab3a1f..b3e33005 100644 --- a/qupulse/expressions/simple.py +++ b/qupulse/expressions/simple.py @@ -110,7 +110,10 @@ def __rmul__(self, other): def __truediv__(self, other): inv = 1 / other return self.__mul__(inv) - + + def __hash__(self): + return hash((self.base,frozenset(sorted(self.offsets.items())))) + @property def free_symbols(self): return () diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index 5573b654..b622368c 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -276,13 +276,15 @@ def _transform_linspace_commands(self, command_list: List[Command]) -> List[Comm # TODO: voltage resolution # trafos_by_channel_idx = list(self._channel_transformations().values()) - + # increment_domains_to_transform = {DepDomain.VOLTAGE, DepDomain.WF_SCALE, DepDomain.WF_OFFSET} + for command in command_list: if isinstance(command, (LoopLabel, LoopJmp, Play, Wait)): # play is handled by transforming the sampled waveform continue elif isinstance(command, Increment): if command.dependency_key.domain is not DepDomain.VOLTAGE: + #for sweeps of wf-scale and wf-offset, the channel amplitudes/offsets are already considered in the wf sampling. continue ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: @@ -290,9 +292,11 @@ def _transform_linspace_commands(self, command_list: List[Command]) -> List[Comm command.value /= ch_trafo.amplitude elif isinstance(command, LSPSet): if command.key.domain is not DepDomain.VOLTAGE: + #for sweeps of wf-scale and wf-offset, the channel amplitudes/offsets are already considered in the wf sampling. continue ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: + # for the case of swept parameters, this is defaulted to identity command.value = float(ch_trafo.voltage_transformation(command.value)) command.value -= ch_trafo.offset command.value /= ch_trafo.amplitude diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index ecd26d18..c19da6ae 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -12,6 +12,9 @@ from qupulse.expressions.simple import SimpleExpression, NumVal from qupulse.program.waveforms import MultiChannelWaveform, TransformingWaveform +from qupulse.program.transformation import ChainedTransformation, ScalingTransformation, OffsetTransformation,\ + IdentityTransformation, ParallelChannelTransformation, Transformation + # this resolution is used to unify increments # the increments themselves remain floats # !!! translated: this is NOT a hardware resolution, @@ -24,6 +27,8 @@ class DepDomain(Enum): TIME_LIN = -1 TIME_LOG = -2 FREQUENCY = -3 + WF_SCALE = -4 + WF_OFFSET = -5 NODEP = None GeneralizedChannel = Union[DepDomain,ChannelID] @@ -149,7 +154,18 @@ def dependencies(self) -> Mapping[GeneralizedChannel, set]: class LinSpaceArbitraryWaveform(LinSpaceNodeChannelSpecific): """This is just a wrapper to pipe arbitrary waveforms through the system.""" waveform: Waveform - # channels: Tuple[ChannelID, ...] + + +@dataclass +class LinSpaceArbitraryWaveformIndexed(LinSpaceNodeChannelSpecific): + """This is just a wrapper to pipe arbitrary waveforms through the system.""" + waveform: Waveform + + scale_bases: Dict[ChannelID, float] + scale_factors: Dict[ChannelID, Optional[Tuple[float, ...]]] + + offset_bases: Dict[ChannelID, float] + offset_factors: Dict[ChannelID, Optional[Tuple[float, ...]]] @dataclass @@ -278,20 +294,55 @@ def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, Hard self._stack[-1].append(set_cmd) def play_arbitrary_waveform(self, waveform: Waveform): + + #recognize voltage trafo sweep syntax from a transforming waveform. other sweepable things may need different approaches. if not isinstance(waveform,TransformingWaveform): - return self._stack[-1].append(LinSpaceArbitraryWaveform(waveform=waveform, - channels=waveform.defined_channels, - # self._voltage_idx_to_name - ) - ) + return self._stack[-1].append(LinSpaceArbitraryWaveform(waveform=waveform,channels=waveform.defined_channels,)) #test for transformations that contain SimpleExpression wf_transformation = waveform.transformation - return self._stack[-1].append(LinSpaceArbitraryWaveform(waveform=waveform, - # self._voltage_idx_to_name - channels=waveform.defined_channels - )) + # chainedTransformation should now have flat hierachy. + collected_trafos, dependent_vals_flag = collect_scaling_and_offset_per_channel(waveform.defined_channels,wf_transformation) + + #fast track + if not dependent_vals_flag: + return self._stack[-1].append(LinSpaceArbitraryWaveform(waveform=waveform,channels=waveform.defined_channels,)) + + ranges = self._get_ranges() + scale_factors, offset_factors = {}, {} + scale_bases, offset_bases = {}, {} + + for ch_name,scale_offset_dict in collected_trafos.items(): + for bases,factors,key in zip((scale_bases, offset_bases),(scale_factors, offset_factors),('s','o')): + value = scale_offset_dict[key] + if isinstance(value, float): + bases[ch_name] = value + factors[ch_name] = None + continue + offsets = value.offsets + base = value.base + incs = [] + for rng_name, rng in ranges.items(): + start = 0. + step = 0. + offset = offsets.get(rng_name, None) + if offset: + start += rng.start * offset + step += rng.step * offset + base += start + incs.append(step) + factors[ch_name] = tuple(incs) + bases[ch_name] = base + + + return self._stack[-1].append(LinSpaceArbitraryWaveformIndexed(waveform=waveform, + channels=waveform.defined_channels, + scale_bases=scale_bases, + scale_factors=scale_factors, + offset_bases=offset_bases, + offset_factors=offset_factors, + )) def measure(self, measurements: Optional[Sequence[MeasurementWindow]]): """Ignores measurements""" @@ -332,6 +383,59 @@ def to_program(self) -> Optional[Sequence[LinSpaceNode]]: return self._root() +def collect_scaling_and_offset_per_channel(channels: Sequence[ChannelID], + transformation: Transformation) \ + -> Tuple[Dict[ChannelID,Dict[str,Union[NumVal,SimpleExpression]]], bool]: + + ch_trafo_dict = {ch: {'s':1.,'o':0.} for ch in channels} + + # allowed_trafos = {IdentityTransformation,} + if not isinstance(transformation,ChainedTransformation): + transformations = (transformation,) + else: + transformations = transformation.transformations + + is_dependent_flag = [] + + for trafo in transformations: + #first elements of list are applied first in trafos. + assert trafo.is_constant_invariant() + if isinstance(trafo,ParallelChannelTransformation): + for ch,val in trafo._channels.items(): + is_dependent_flag.append(trafo.contains_sweepval) + # assert not ch in ch_trafo_dict.keys() + # the waveform is sampled with these values taken into account, no change needed. + # ch_trafo_dict[ch]['o'] = val + # ch_trafo_dict.setdefault(ch,{'s':1.,'o':val}) + elif isinstance(trafo,ScalingTransformation): + is_dependent_flag.append(trafo.contains_sweepval) + for ch,val in trafo._factors.items(): + try: + ch_trafo_dict[ch]['s'] = reduce_non_swept(ch_trafo_dict[ch]['s']*val) + ch_trafo_dict[ch]['o'] = reduce_non_swept(ch_trafo_dict[ch]['o']*val) + except TypeError as e: + print('Attempting scale sweep of other sweep val') + raise e + elif isinstance(trafo,OffsetTransformation): + is_dependent_flag.append(trafo.contains_sweepval) + for ch,val in trafo._offsets.items(): + ch_trafo_dict[ch]['o'] += val + elif isinstance(trafo,IdentityTransformation): + continue + elif isinstance(trafo,ChainedTransformation): + raise RuntimeError() + else: + raise NotImplementedError() + + return ch_trafo_dict, any(is_dependent_flag) + + +def reduce_non_swept(val: Union[SimpleExpression,NumVal]) -> Union[SimpleExpression,NumVal]: + if isinstance(val,SimpleExpression) and all(v==0 for v in val.offsets.values()): + return val.base + return val + + @dataclass class LoopLabel: idx: int @@ -341,14 +445,14 @@ class LoopLabel: @dataclass class Increment: channel: Optional[GeneralizedChannel] - value: ResolutionDependentValue + value: Union[ResolutionDependentValue,Tuple[ResolutionDependentValue]] dependency_key: DepKey @dataclass class Set: channel: Optional[GeneralizedChannel] - value: ResolutionDependentValue + value: Union[ResolutionDependentValue,Tuple[ResolutionDependentValue]] key: DepKey = dataclasses.field(default_factory=lambda: DepKey((),DepDomain.NODEP)) @@ -367,7 +471,13 @@ class LoopJmp: class Play: waveform: Waveform channels: Tuple[ChannelID] + keys: Sequence[DepKey] = None + def __post_init__(self): + if self.keys is None: + self.keys = tuple(dataclasses.field(default_factory=lambda: DepKey((),DepDomain.NODEP)) + for i in range(len(self.channels))) + Command = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play] @@ -419,7 +529,7 @@ class _TranslationState: iterations: List[int] = dataclasses.field(default_factory=list) active_dep: Dict[GeneralizedChannel, DepKey] = dataclasses.field(default_factory=dict) dep_states: Dict[GeneralizedChannel, Dict[DepKey, DepState]] = dataclasses.field(default_factory=dict) - plain_voltage: Dict[ChannelID, float] = dataclasses.field(default_factory=dict) + plain_value: Dict[GeneralizedChannel, Dict[DepDomain,float]] = dataclasses.field(default_factory=dict) resolution: float = dataclasses.field(default_factory=lambda: DEFAULT_INCREMENT_RESOLUTION) resolution_time: float = dataclasses.field(default_factory=lambda: DEFAULT_TIME_RESOLUTION) @@ -437,12 +547,24 @@ def get_dependency_state(self, dependencies: Mapping[GeneralizedChannel, set]): } def set_voltage(self, channel: ChannelID, value: float): - key = DepKey((),DepDomain.VOLTAGE) - if self.active_dep.get(channel, None) != key or self.plain_voltage.get(channel, None) != value: + self.set_non_indexed_value(channel, value, domain=DepDomain.VOLTAGE) + + def set_wf_scale(self, channel: ChannelID, value: float): + self.set_non_indexed_value(channel, value, domain=DepDomain.WF_SCALE) + + def set_wf_offset(self, channel: ChannelID, value: float): + self.set_non_indexed_value(channel, value, domain=DepDomain.WF_OFFSET) + + def set_non_indexed_value(self, channel: GeneralizedChannel, value: float, domain: DepDomain): + key = DepKey((),domain) + # I do not completely get why it would have to be set again if not in active dep. + # if not key != self.active_dep.get(channel, None) or + if self.plain_value.get(channel, {}).get(domain, None) != value: self.commands.append(Set(channel, ResolutionDependentValue((),(),offset=value), key)) self.active_dep[channel] = key - self.plain_voltage[channel] = value - + self.plain_value.setdefault(channel,{}) + self.plain_value[channel][domain] = value + def _add_repetition_node(self, node: LinSpaceRepeat): pre_dep_state = self.get_dependency_state(node.dependencies()) label, jmp = self.new_loop(node.count) @@ -477,7 +599,7 @@ def _set_indexed_voltage(self, channel: ChannelID, base: float, factors: Sequenc def _set_indexed_lin_time(self, base: TimeType, factors: Sequence[TimeType]): key = DepKey.from_lin_times(times=factors, resolution=self.resolution) self.set_indexed_value(key, DepDomain.TIME_LIN, base, factors, domain=DepDomain.TIME_LIN) - + def set_indexed_value(self, dep_key: DepKey, channel: GeneralizedChannel, base: Union[float,TimeType], factors: Sequence[Union[float,TimeType]], domain: DepDomain): @@ -516,7 +638,22 @@ def _add_hold_node(self, node: LinSpaceHold): self.commands.append(Wait(None, self.active_dep[DepDomain.TIME_LIN])) else: self.commands.append(Wait(node.duration_base)) - + + def _add_indexed_play_node(self, node: LinSpaceArbitraryWaveformIndexed): + + for ch in node.channels: + for base,factors,domain in zip((node.scale_bases[ch], node.offset_bases[ch]), + (node.scale_factors[ch], node.offset_factors[ch]), + (DepDomain.WF_SCALE,DepDomain.WF_OFFSET)): + if factors is None: + self.set_non_indexed_value(ch, base, domain) + else: + key = DepKey.from_domain(factors, resolution=self.resolution, domain=domain) + self.set_indexed_value(key, ch, base, factors, key.domain) + + self.commands.append(Play(node.waveform, node.channels, keys=tuple(self.active_dep[ch] for ch in node.channels))) + + def add_node(self, node: Union[LinSpaceNode, Sequence[LinSpaceNode]]): """Translate a (sequence of) linspace node(s) to commands and add it to the internal command list.""" if isinstance(node, Sequence): @@ -531,7 +668,10 @@ def add_node(self, node: Union[LinSpaceNode, Sequence[LinSpaceNode]]): elif isinstance(node, LinSpaceHold): self._add_hold_node(node) - + + elif isinstance(node, LinSpaceArbitraryWaveformIndexed): + self._add_indexed_play_node(node) + elif isinstance(node, LinSpaceArbitraryWaveform): self.commands.append(Play(node.waveform, node.channels)) diff --git a/qupulse/program/transformation.py b/qupulse/program/transformation.py index 1d3c8687..21e43772 100644 --- a/qupulse/program/transformation.py +++ b/qupulse/program/transformation.py @@ -8,9 +8,9 @@ from qupulse.comparable import Comparable from qupulse.utils.types import SingletonABCMeta, frozendict from qupulse.expressions import ExpressionScalar +from qupulse.expressions.simple import SimpleExpression - -_TrafoValue = Union[Real, ExpressionScalar] +_TrafoValue = Union[Real, ExpressionScalar, SimpleExpression] __all__ = ['Transformation', 'IdentityTransformation', 'LinearTransformation', 'ScalingTransformation', @@ -88,7 +88,16 @@ def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) - class ChainedTransformation(Transformation): def __init__(self, *transformations: Transformation): - self._transformations = transformations + #avoid nesting also here in init to ensure always flat hierachy? + parsed = [] + for t in transformations: + if t is IdentityTransformation() or t is None: + pass + elif isinstance(t,ChainedTransformation): + parsed.extend(t.transformations) + else: + parsed.append(t) + self._transformations = tuple(parsed) @property def transformations(self) -> Tuple[Transformation, ...]: @@ -231,7 +240,7 @@ def __init__(self, offsets: Mapping[ChannelID, _TrafoValue]): def __call__(self, time: Union[np.ndarray, float], data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: - offsets = _instantiate_expression_dict(time, self._offsets) + offsets = _instantiate_expression_dict(time, self._offsets, default_sweepval = 0.) return {channel: channel_values + offsets[channel] if channel in offsets else channel_values for channel, channel_values in data.items()} @@ -254,6 +263,10 @@ def is_constant_invariant(self): def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return _get_constant_output_channels(self._offsets, input_channels) + + @property + def contains_sweepval(self) -> bool: + return any(isinstance(o,SimpleExpression) for o in self._offsets.values()) class ScalingTransformation(Transformation): @@ -263,7 +276,7 @@ def __init__(self, factors: Mapping[ChannelID, _TrafoValue]): def __call__(self, time: Union[np.ndarray, float], data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: - factors = _instantiate_expression_dict(time, self._factors) + factors = _instantiate_expression_dict(time, self._factors, default_sweepval = 1.) return {channel: channel_values * factors[channel] if channel in factors else channel_values for channel, channel_values in data.items()} @@ -287,6 +300,10 @@ def is_constant_invariant(self): def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return _get_constant_output_channels(self._factors, input_channels) + @property + def contains_sweepval(self) -> bool: + return any(isinstance(o,SimpleExpression) for o in self._factors.values()) + try: if TYPE_CHECKING: @@ -359,6 +376,10 @@ def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) - output_channels.add(ch) return output_channels + + @property + def contains_sweepval(self) -> bool: + return any(isinstance(o,SimpleExpression) for o in self._channels.values()) def chain_transformations(*transformations: Transformation) -> Transformation: @@ -378,12 +399,17 @@ def chain_transformations(*transformations: Transformation) -> Transformation: return ChainedTransformation(*parsed_transformations) -def _instantiate_expression_dict(time, expressions: Mapping[str, _TrafoValue]) -> Mapping[str, Union[Real, np.ndarray]]: +def _instantiate_expression_dict(time, expressions: Mapping[str, _TrafoValue], + default_sweepval: float) -> Mapping[str, Union[Real, np.ndarray]]: scope = {'t': time} modified_expressions = {} for name, value in expressions.items(): if hasattr(value, 'evaluate_in_scope'): modified_expressions[name] = value.evaluate_in_scope(scope) + if isinstance(value, SimpleExpression): + # it is assumed that swept parameters will be handled by the ProgramBuilder accordingly + # such that here only an "identity" trafo is to be applied and the trafos are set in the program internally. + modified_expressions[name] = default_sweepval if modified_expressions: return {**expressions, **modified_expressions} else: From d9dad64ec98f39193daf57ea455ff8cb1b258f8f Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Tue, 11 Jun 2024 09:10:27 +0200 Subject: [PATCH 23/35] bugfix --- qupulse/program/linspace.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index c19da6ae..928439bf 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -474,9 +474,7 @@ class Play: keys: Sequence[DepKey] = None def __post_init__(self): if self.keys is None: - self.keys = tuple(dataclasses.field(default_factory=lambda: DepKey((),DepDomain.NODEP)) - for i in range(len(self.channels))) - + self.keys = tuple(DepKey((),DepDomain.NODEP) for i in range(len(self.channels))) Command = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play] From 5f22a52b8a9ce03426d66aaac5297e8c00c245a9 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Tue, 11 Jun 2024 09:35:27 +0200 Subject: [PATCH 24/35] not sure if correct (depstate comparison) --- qupulse/program/linspace.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 928439bf..963d9875 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -616,7 +616,8 @@ def set_indexed_value(self, dep_key: DepKey, channel: GeneralizedChannel, inc = new_dep_state.required_increment_from(previous=current_dep_state, factors=factors) # we insert all inc here (also inc == 0) because it signals to activate this amplitude register - if inc or self.active_dep.get(channel, None) != dep_key: + #not really sure if correct, but if dep states are the same, dont emit increment call. + if (inc or self.active_dep.get(channel, None) != dep_key) and new_dep_state != current_dep_state: self.commands.append(Increment(channel, inc, dep_key)) self.active_dep[channel] = dep_key self.dep_states[channel][dep_key] = new_dep_state From b04cbcf89b69db23c7721e98af35ce70865c2074 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Wed, 12 Jun 2024 18:04:52 +0200 Subject: [PATCH 25/35] fix some of the depkey confusion --- qupulse/program/linspace.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 963d9875..9f7a34e4 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -459,7 +459,7 @@ class Set: @dataclass class Wait: duration: Optional[TimeType] - key: DepKey = dataclasses.field(default_factory=lambda: DepKey((),DepDomain.NODEP)) + key_by_domain: Dict[DepDomain,DepKey] = dataclasses.field(default_factory=lambda: {}) @dataclass @@ -471,10 +471,10 @@ class LoopJmp: class Play: waveform: Waveform channels: Tuple[ChannelID] - keys: Sequence[DepKey] = None + keys_by_ch_by_domain: Dict[DepDomain,Dict[ChannelID,DepKey]] = None def __post_init__(self): - if self.keys is None: - self.keys = tuple(DepKey((),DepDomain.NODEP) for i in range(len(self.channels))) + if self.keys_by_ch_by_domain is None: + self.keys_by_ch_by_domain = {ch: {} for ch in self.channels} Command = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play] @@ -525,7 +525,7 @@ class _TranslationState: label_num: int = dataclasses.field(default=0) commands: List[Command] = dataclasses.field(default_factory=list) iterations: List[int] = dataclasses.field(default_factory=list) - active_dep: Dict[GeneralizedChannel, DepKey] = dataclasses.field(default_factory=dict) + active_dep: Dict[GeneralizedChannel, Dict[DepDomain, DepKey]] = dataclasses.field(default_factory=dict) dep_states: Dict[GeneralizedChannel, Dict[DepKey, DepState]] = dataclasses.field(default_factory=dict) plain_value: Dict[GeneralizedChannel, Dict[DepDomain,float]] = dataclasses.field(default_factory=dict) resolution: float = dataclasses.field(default_factory=lambda: DEFAULT_INCREMENT_RESOLUTION) @@ -559,7 +559,8 @@ def set_non_indexed_value(self, channel: GeneralizedChannel, value: float, domai # if not key != self.active_dep.get(channel, None) or if self.plain_value.get(channel, {}).get(domain, None) != value: self.commands.append(Set(channel, ResolutionDependentValue((),(),offset=value), key)) - self.active_dep[channel] = key + # there has to be no active dep when the value is not indexed? + # self.active_dep.setdefault(channel,{})[DepDomain.NODEP] = key self.plain_value.setdefault(channel,{}) self.plain_value[channel][domain] = value @@ -610,16 +611,16 @@ def set_indexed_value(self, dep_key: DepKey, channel: GeneralizedChannel, if current_dep_state is None: assert all(it == 0 for it in self.iterations) self.commands.append(Set(channel, ResolutionDependentValue((),(),offset=base), dep_key)) - self.active_dep[channel] = dep_key + self.active_dep.setdefault(channel,{})[dep_key.domain] = dep_key else: inc = new_dep_state.required_increment_from(previous=current_dep_state, factors=factors) # we insert all inc here (also inc == 0) because it signals to activate this amplitude register #not really sure if correct, but if dep states are the same, dont emit increment call. - if (inc or self.active_dep.get(channel, None) != dep_key) and new_dep_state != current_dep_state: + if (inc or self.active_dep.get(channel, {}).get(dep_key.domain) != dep_key) and new_dep_state != current_dep_state: self.commands.append(Increment(channel, inc, dep_key)) - self.active_dep[channel] = dep_key + self.active_dep.setdefault(channel,{})[dep_key.domain] = dep_key self.dep_states[channel][dep_key] = new_dep_state def _add_hold_node(self, node: LinSpaceHold): @@ -634,7 +635,7 @@ def _add_hold_node(self, node: LinSpaceHold): if node.duration_factors: self._set_indexed_lin_time(node.duration_base,node.duration_factors) # raise NotImplementedError("TODO") - self.commands.append(Wait(None, self.active_dep[DepDomain.TIME_LIN])) + self.commands.append(Wait(None, {DepDomain.TIME_LIN: self.active_dep[DepDomain.TIME_LIN][DepDomain.TIME_LIN]})) else: self.commands.append(Wait(node.duration_base)) @@ -645,12 +646,17 @@ def _add_indexed_play_node(self, node: LinSpaceArbitraryWaveformIndexed): (node.scale_factors[ch], node.offset_factors[ch]), (DepDomain.WF_SCALE,DepDomain.WF_OFFSET)): if factors is None: - self.set_non_indexed_value(ch, base, domain) + continue + # assume here that the waveform will have the correct settings the TransformingWaveform, + # where no SimpleExpression is replaced now. + # will yield the correct trafo already without having to make adjustments + # self.set_non_indexed_value(ch, base, domain) else: key = DepKey.from_domain(factors, resolution=self.resolution, domain=domain) self.set_indexed_value(key, ch, base, factors, key.domain) - self.commands.append(Play(node.waveform, node.channels, keys=tuple(self.active_dep[ch] for ch in node.channels))) + self.commands.append(Play(node.waveform, node.channels, + keys_by_ch_by_domain={c: self.active_dep.get(c,{}) for c in node.channels})) def add_node(self, node: Union[LinSpaceNode, Sequence[LinSpaceNode]]): From e4c2366797ef8d3a2edc3eeaf68c54f3642f6c71 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Thu, 13 Jun 2024 14:46:57 +0200 Subject: [PATCH 26/35] more flexible repetition in sequence structure --- qupulse/program/linspace.py | 114 ++++++++++++++++++++++++++++-------- 1 file changed, 91 insertions(+), 23 deletions(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 9f7a34e4..085008a0 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import Mapping, Optional, Sequence, ContextManager, Iterable, Tuple, Union, Dict, List, Iterator, Generic from enum import Enum +from itertools import dropwhile from qupulse import ChannelID, MeasurementWindow from qupulse.parameter_scope import Scope, MappedScope, FrozenDict @@ -122,6 +123,8 @@ class LinSpaceNode: """AST node for a program that supports linear spacing of set points as well as nested sequencing and repetitions""" def dependencies(self) -> Mapping[GeneralizedChannel, set]: + # doing this as a set _should_ get rid of non-active deps that are one level above? + #!!! raise NotImplementedError @dataclass @@ -144,17 +147,25 @@ class LinSpaceHold(LinSpaceNodeChannelSpecific): duration_base: TimeType duration_factors: Optional[Tuple[TimeType, ...]] - def dependencies(self) -> Mapping[GeneralizedChannel, set]: - return {idx: {factors} - for idx, factors in self.factors.items() + def dependencies(self) -> Mapping[DepDomain, Mapping[ChannelID, set]]: + return {dom: {ch: {factors}} + for dom, ch_to_factors in self._dep_by_domain().items() + for ch, factors in ch_to_factors.items() if factors} - + + def _dep_by_domain(self) -> Mapping[DepDomain, set]: + return {DepDomain.VOLTAGE: self.factors, + DepDomain.TIME_LIN: {DepDomain.TIME_LIN:self.duration_factors},} + @dataclass class LinSpaceArbitraryWaveform(LinSpaceNodeChannelSpecific): """This is just a wrapper to pipe arbitrary waveforms through the system.""" waveform: Waveform - + + def dependencies(self): + return {} + @dataclass class LinSpaceArbitraryWaveformIndexed(LinSpaceNodeChannelSpecific): @@ -166,7 +177,17 @@ class LinSpaceArbitraryWaveformIndexed(LinSpaceNodeChannelSpecific): offset_bases: Dict[ChannelID, float] offset_factors: Dict[ChannelID, Optional[Tuple[float, ...]]] - + + def dependencies(self) -> Mapping[DepDomain, Mapping[ChannelID, set]]: + return {dom: {ch: {factors}} + for dom, ch_to_factors in self._dep_by_domain().items() + for ch, factors in ch_to_factors.items() + if factors} + + def _dep_by_domain(self) -> Mapping[DepDomain, set]: + return {DepDomain.WF_SCALE: self.scale_factors, + DepDomain.WF_OFFSET: self.offset_factors,} + @dataclass class LinSpaceRepeat(LinSpaceNode): @@ -177,8 +198,9 @@ class LinSpaceRepeat(LinSpaceNode): def dependencies(self): dependencies = {} for node in self.body: - for idx, deps in node.dependencies().items(): - dependencies.setdefault(idx, set()).update(deps) + for dom, ch_to_deps in node.dependencies().items(): + for ch, deps in ch_to_deps.items(): + dependencies.setdefault(dom,{}).setdefault(ch, set()).update(deps) return dependencies @@ -193,11 +215,12 @@ class LinSpaceIter(LinSpaceNode): def dependencies(self): dependencies = {} for node in self.body: - for idx, deps in node.dependencies().items(): - # remove the last elemt in index because this iteration sets it -> no external dependency - shortened = {dep[:-1] for dep in deps} - if shortened != {()}: - dependencies.setdefault(idx, set()).update(shortened) + for dom, ch_to_deps in node.dependencies().items(): + for ch, deps in ch_to_deps.items(): + # remove the last elemt in index because this iteration sets it -> no external dependency + shortened = {dep[:-1] for dep in deps} + if shortened != {()}: + dependencies.setdefault(dom,{}).setdefault(ch, set()).update(shortened) return dependencies @@ -537,13 +560,52 @@ def new_loop(self, count: int): self.label_num += 1 return label, jmp - def get_dependency_state(self, dependencies: Mapping[GeneralizedChannel, set]): - return { - self.dep_states.get(ch, {}).get(DepKey.from_domain(dep, self.resolution), None) - for ch, deps in dependencies.items() - for dep in deps - } - + def get_dependency_state(self, dependencies: Mapping[DepDomain, Mapping[GeneralizedChannel, set]]): + dom_to_ch_to_depstates = {} + + for dom, ch_to_deps in dependencies.items(): + dom_to_ch_to_depstates.setdefault(dom,{}) + for ch, deps in ch_to_deps.items(): + dom_to_ch_to_depstates[dom].setdefault(ch,set()) + for dep in deps: + dom_to_ch_to_depstates[dom][ch].add(self.dep_states.get(ch, {}).get( + DepKey.from_domain(dep, self.resolution, dom),None)) + + return dom_to_ch_to_depstates + # return { + # dom: self.dep_states.get(ch, {}).get(DepKey.from_domain(dep, self.resolution, dom), + # None) + # for dom, ch_to_deps in dependencies.items() + # for ch, deps in ch_to_deps.items() + # for dep in deps + # } + + def compare_ignoring_post_trailing_zeros(self, + pre_state: Mapping[DepDomain, Mapping[GeneralizedChannel, set]], + post_state: Mapping[DepDomain, Mapping[GeneralizedChannel, set]]) -> bool: + + def reduced_or_none(dep_state: DepState) -> Union[DepState,None]: + new_iterations = tuple(dropwhile(lambda x: x == 0, reversed(dep_state.iterations)))[::-1] + return DepState(dep_state.base, new_iterations) if len(new_iterations)>0 else None + + has_changed = False + dom_keys = set(pre_state.keys()).union(post_state.keys()) + for dom_key in dom_keys: + pre_state_dom, post_state_dom = pre_state.get(dom_key,{}), post_state.get(dom_key,{}) + ch_keys = set(pre_state_dom.keys()).union(post_state_dom.keys()) + for ch_key in ch_keys: + pre_state_dom_ch, post_state_dom_ch = pre_state_dom.get(ch_key,set()), post_state_dom.get(ch_key,set()) + # reduce the depStates to the ones which do not just contain zeros + reduced_pre_set = set(reduced_or_none(dep_state) for dep_state in pre_state_dom_ch + if dep_state is not None) - {None} + reduced_post_set = set(reduced_or_none(dep_state) for dep_state in post_state_dom_ch + if dep_state is not None) - {None} + + if not reduced_post_set <= reduced_pre_set: + has_changed == True + + return has_changed + def set_voltage(self, channel: ChannelID, value: float): self.set_non_indexed_value(channel, value, domain=DepDomain.VOLTAGE) @@ -562,7 +624,7 @@ def set_non_indexed_value(self, channel: GeneralizedChannel, value: float, domai # there has to be no active dep when the value is not indexed? # self.active_dep.setdefault(channel,{})[DepDomain.NODEP] = key self.plain_value.setdefault(channel,{}) - self.plain_value[channel][domain] = value + self.plain_value[channel][domain] = value def _add_repetition_node(self, node: LinSpaceRepeat): pre_dep_state = self.get_dependency_state(node.dependencies()) @@ -571,14 +633,20 @@ def _add_repetition_node(self, node: LinSpaceRepeat): self.commands.append(label) self.add_node(node.body) post_dep_state = self.get_dependency_state(node.dependencies()) - if pre_dep_state != post_dep_state: + # the last index in the iterations may not be initialized in pre_dep_state if the outer loop only sets an index + # after this loop is in the sequence of the current level, + # meaning that an trailing 0 at the end of iterations of each depState in the post_dep_state + # should be ignored when comparing. + # zeros also should only mean a "Set" command, which is not harmful if executed multiple times. + # if pre_dep_state != post_dep_state: + if self.compare_ignoring_post_trailing_zeros(pre_dep_state,post_dep_state): # hackedy self.commands.pop(initial_position) self.commands.append(label) label.count -= 1 self.add_node(node.body) self.commands.append(jmp) - + def _add_iteration_node(self, node: LinSpaceIter): self.iterations.append(0) self.add_node(node.body) From f25022642775eef5f220746704d2526107754ff1 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Thu, 13 Jun 2024 16:28:50 +0200 Subject: [PATCH 27/35] dependency_key -> key for consistency --- qupulse/program/linspace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 085008a0..435ec15d 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -469,7 +469,7 @@ class LoopLabel: class Increment: channel: Optional[GeneralizedChannel] value: Union[ResolutionDependentValue,Tuple[ResolutionDependentValue]] - dependency_key: DepKey + key: DepKey @dataclass From e84ecf85b3504febdb6617685b03d68eeb01b256 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Fri, 14 Jun 2024 00:33:35 +0200 Subject: [PATCH 28/35] always emit incr/set before wait --- qupulse/hardware/awgs/base.py | 2 +- qupulse/program/linspace.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index b622368c..7e773e9d 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -283,7 +283,7 @@ def _transform_linspace_commands(self, command_list: List[Command]) -> List[Comm # play is handled by transforming the sampled waveform continue elif isinstance(command, Increment): - if command.dependency_key.domain is not DepDomain.VOLTAGE: + if command.key.domain is not DepDomain.VOLTAGE: #for sweeps of wf-scale and wf-offset, the channel amplitudes/offsets are already considered in the wf sampling. continue ch_trafo = self._channel_transformations()[command.channel] diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 435ec15d..d7372886 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -312,7 +312,7 @@ def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, Hard bases=bases, factors=factors, duration_base=duration_base, - duration_factors=duration_factors) + duration_factors=tuple(duration_factors)) self._stack[-1].append(set_cmd) @@ -607,7 +607,7 @@ def reduced_or_none(dep_state: DepState) -> Union[DepState,None]: return has_changed def set_voltage(self, channel: ChannelID, value: float): - self.set_non_indexed_value(channel, value, domain=DepDomain.VOLTAGE) + self.set_non_indexed_value(channel, value, domain=DepDomain.VOLTAGE, always_emit_set=True) def set_wf_scale(self, channel: ChannelID, value: float): self.set_non_indexed_value(channel, value, domain=DepDomain.WF_SCALE) @@ -615,11 +615,11 @@ def set_wf_scale(self, channel: ChannelID, value: float): def set_wf_offset(self, channel: ChannelID, value: float): self.set_non_indexed_value(channel, value, domain=DepDomain.WF_OFFSET) - def set_non_indexed_value(self, channel: GeneralizedChannel, value: float, domain: DepDomain): + def set_non_indexed_value(self, channel: GeneralizedChannel, value: float, domain: DepDomain, always_emit_set: bool=False): key = DepKey((),domain) # I do not completely get why it would have to be set again if not in active dep. # if not key != self.active_dep.get(channel, None) or - if self.plain_value.get(channel, {}).get(domain, None) != value: + if self.plain_value.get(channel, {}).get(domain, None) != value or always_emit_set: self.commands.append(Set(channel, ResolutionDependentValue((),(),offset=value), key)) # there has to be no active dep when the value is not indexed? # self.active_dep.setdefault(channel,{})[DepDomain.NODEP] = key @@ -661,7 +661,7 @@ def _add_iteration_node(self, node: LinSpaceIter): def _set_indexed_voltage(self, channel: ChannelID, base: float, factors: Sequence[float]): key = DepKey.from_voltages(voltages=factors, resolution=self.resolution) - self.set_indexed_value(key, channel, base, factors, domain=DepDomain.VOLTAGE) + self.set_indexed_value(key, channel, base, factors, domain=DepDomain.VOLTAGE, always_emit_incr=True) def _set_indexed_lin_time(self, base: TimeType, factors: Sequence[TimeType]): key = DepKey.from_lin_times(times=factors, resolution=self.resolution) @@ -669,7 +669,7 @@ def _set_indexed_lin_time(self, base: TimeType, factors: Sequence[TimeType]): def set_indexed_value(self, dep_key: DepKey, channel: GeneralizedChannel, base: Union[float,TimeType], factors: Sequence[Union[float,TimeType]], - domain: DepDomain): + domain: DepDomain, always_emit_incr: bool = False): new_dep_state = DepState( base, iterations=tuple(self.iterations) @@ -685,8 +685,12 @@ def set_indexed_value(self, dep_key: DepKey, channel: GeneralizedChannel, inc = new_dep_state.required_increment_from(previous=current_dep_state, factors=factors) # we insert all inc here (also inc == 0) because it signals to activate this amplitude register - #not really sure if correct, but if dep states are the same, dont emit increment call. - if (inc or self.active_dep.get(channel, {}).get(dep_key.domain) != dep_key) and new_dep_state != current_dep_state: + # -> since this is not necessary for other domains, make it stricter and bypass if necessary for voltage. + if ((inc or self.active_dep.get(channel, {}).get(dep_key.domain) != dep_key) and new_dep_state != current_dep_state)\ + or always_emit_incr: + # if always_emit_incr and new_dep_state == current_dep_state, inc should be zero. + if always_emit_incr and new_dep_state == current_dep_state: + assert inc==0. self.commands.append(Increment(channel, inc, dep_key)) self.active_dep.setdefault(channel,{})[dep_key.domain] = dep_key self.dep_states[channel][dep_key] = new_dep_state From 73513ebd1b83aa135136f1e364424207a840a805 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Sun, 16 Jun 2024 01:34:58 +0200 Subject: [PATCH 29/35] dirty stepped play node in LinSpaceBuilder --- qupulse/expressions/simple.py | 9 +- qupulse/hardware/awgs/base.py | 16 +- qupulse/program/__init__.py | 7 +- qupulse/program/linspace.py | 371 ++++++++++++++++++++------ qupulse/program/loop.py | 3 +- qupulse/program/waveforms.py | 28 ++ qupulse/pulses/loop_pulse_template.py | 5 +- qupulse/pulses/pulse_template.py | 17 +- 8 files changed, 368 insertions(+), 88 deletions(-) diff --git a/qupulse/expressions/simple.py b/qupulse/expressions/simple.py index b3e33005..10d1e94f 100644 --- a/qupulse/expressions/simple.py +++ b/qupulse/expressions/simple.py @@ -130,4 +130,11 @@ def evaluate_in_scope_(self, *args, **kwargs): return self -_lambdify_modules.append({'SimpleExpression': SimpleExpression}) +#alibi class to allow instance check? +@dataclass +class SimpleExpressionStepped(SimpleExpression): + step_nesting_level: int + rng: range + + +_lambdify_modules.append({'SimpleExpression': SimpleExpression, 'SimpleExpressionStepped': SimpleExpressionStepped}) diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index 7e773e9d..794decf1 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -19,7 +19,7 @@ from qupulse.program.linspace import LinSpaceNode, LinSpaceArbitraryWaveform, to_increment_commands, Command, \ Increment, Set as LSPSet, LoopLabel, LoopJmp, Wait, Play, DEFAULT_INCREMENT_RESOLUTION, DepDomain from qupulse.program.loop import Loop -from qupulse.program.waveforms import Waveform +from qupulse.program.waveforms import Waveform, WaveformCollection from qupulse.comparable import Comparable from qupulse.utils.types import TimeType @@ -234,8 +234,18 @@ def __init__(self, program: AllowedProgramTypes, elif program_type is _ProgramType.Linspace: #not so clean #TODO: also marker handling not optimal - waveforms = OrderedDict((command.waveform, None) - for command in self._transformed_commands if isinstance(command,Play)).keys() + waveforms_d = OrderedDict() + for command in self._transformed_commands: + if not isinstance(command,Play): + continue + if isinstance(command.waveform,Waveform): + waveforms_d[command.waveform] = None + elif isinstance(command.waveform,WaveformCollection): + for w in command.waveform.flatten(): + waveforms_d[w] = None + else: + raise NotImplementedError() + waveforms = waveforms_d.keys() else: raise NotImplementedError() diff --git a/qupulse/program/__init__.py b/qupulse/program/__init__.py index 82b28148..41c2afe7 100644 --- a/qupulse/program/__init__.py +++ b/qupulse/program/__init__.py @@ -66,9 +66,14 @@ def new_subprogram(self, global_transformation: 'Transformation' = None) -> Cont it is not empty.""" def with_iteration(self, index_name: str, rng: range, + pt_obj: 'ForLoopPT', #hack this in for now. + # can be placed more suitably, like in pulsemetadata later on, but need some working thing now. measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']: pass - + + def evaluate_nested_stepping(self, scope: Scope, parameter_names: set[str]) -> bool: + return False + def to_program(self) -> Optional[Program]: """Further addition of new elements might fail after finalizing the program.""" diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index d7372886..e4ec7e7e 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -3,15 +3,20 @@ import dataclasses import numpy as np from dataclasses import dataclass -from typing import Mapping, Optional, Sequence, ContextManager, Iterable, Tuple, Union, Dict, List, Iterator, Generic +from typing import Mapping, Optional, Sequence, ContextManager, Iterable, Tuple, Union, Dict, List, Iterator, Generic,\ + Set as TypingSet, Callable from enum import Enum -from itertools import dropwhile +from itertools import dropwhile, count +from numbers import Real, Number + from qupulse import ChannelID, MeasurementWindow from qupulse.parameter_scope import Scope, MappedScope, FrozenDict +# from qupulse.pulses.pulse_template import PulseTemplate +# from qupulse.pulses import ForLoopPT from qupulse.program import ProgramBuilder, HardwareTime, HardwareVoltage, Waveform, RepetitionCount, TimeType -from qupulse.expressions.simple import SimpleExpression, NumVal -from qupulse.program.waveforms import MultiChannelWaveform, TransformingWaveform +from qupulse.expressions.simple import SimpleExpression, NumVal, SimpleExpressionStepped +from qupulse.program.waveforms import MultiChannelWaveform, TransformingWaveform, WaveformCollection from qupulse.program.transformation import ChainedTransformation, ScalingTransformation, OffsetTransformation,\ IdentityTransformation, ParallelChannelTransformation, Transformation @@ -30,9 +35,50 @@ class DepDomain(Enum): FREQUENCY = -3 WF_SCALE = -4 WF_OFFSET = -5 + STEP_INDEX = -6 NODEP = None -GeneralizedChannel = Union[DepDomain,ChannelID] + +class InstanceCounterMeta(type): + def __init__(cls, name, bases, dct): + super().__init__(name, bases, dct) + cls._instance_tracker = {} + + def __call__(cls, *args, **kwargs): + normalized_args = cls._normalize_args(*args, **kwargs) + # Create a key based on the arguments + key = tuple(sorted(normalized_args.items())) + cls._instance_tracker.setdefault(key,count(start=0)) + instance = super().__call__(*args, **kwargs) + instance._channel_num = next(cls._instance_tracker[key]) + return instance + + def _normalize_args(cls, *args, **kwargs): + # Get the parameter names from the __init__ method + param_names = cls.__init__.__code__.co_varnames[1:cls.__init__.__code__.co_argcount] + # Create a dictionary with default values + normalized_args = dict(zip(param_names, args)) + # Update with any kwargs + normalized_args.update(kwargs) + return normalized_args + +@dataclass +class StepRegister(metaclass=InstanceCounterMeta): + #set this as name of sweepval var + register_name: str + register_nesting: int + #should be increased by metaclass every time the class is instantiated with the same arguments + _channel_num: int = dataclasses.field(default_factory=lambda: None) + + @property + def reg_var_name(self): + return self.register_name+'_'+str(self.register_num)+'_'+str(self._channel_num) + + def __hash__(self): + return hash((self.register_name,self.register_nesting,self._channel_num)) + + +GeneralizedChannel = Union[DepDomain,ChannelID,StepRegister] class ResolutionDependentValue(Generic[NumVal]): @@ -146,16 +192,17 @@ class LinSpaceHold(LinSpaceNodeChannelSpecific): duration_base: TimeType duration_factors: Optional[Tuple[TimeType, ...]] - + def dependencies(self) -> Mapping[DepDomain, Mapping[ChannelID, set]]: return {dom: {ch: {factors}} for dom, ch_to_factors in self._dep_by_domain().items() for ch, factors in ch_to_factors.items() if factors} - def _dep_by_domain(self) -> Mapping[DepDomain, set]: + def _dep_by_domain(self) -> Mapping[DepDomain, Mapping[GeneralizedChannel, set]]: return {DepDomain.VOLTAGE: self.factors, - DepDomain.TIME_LIN: {DepDomain.TIME_LIN:self.duration_factors},} + DepDomain.TIME_LIN: {DepDomain.TIME_LIN:self.duration_factors}, + } @dataclass @@ -170,7 +217,7 @@ def dependencies(self): @dataclass class LinSpaceArbitraryWaveformIndexed(LinSpaceNodeChannelSpecific): """This is just a wrapper to pipe arbitrary waveforms through the system.""" - waveform: Waveform + waveform: Union[Waveform,WaveformCollection] scale_bases: Dict[ChannelID, float] scale_factors: Dict[ChannelID, Optional[Tuple[float, ...]]] @@ -178,16 +225,27 @@ class LinSpaceArbitraryWaveformIndexed(LinSpaceNodeChannelSpecific): offset_bases: Dict[ChannelID, float] offset_factors: Dict[ChannelID, Optional[Tuple[float, ...]]] - def dependencies(self) -> Mapping[DepDomain, Mapping[ChannelID, set]]: + index_factors: Optional[Dict[StepRegister,Tuple[int, ...]]] = dataclasses.field(default_factory=lambda: None) + + def __post_init__(self): + #somewhat assert the integrity in this case. + if isinstance(self.waveform,WaveformCollection): + assert self.index_factors is not None + + def dependencies(self) -> Mapping[DepDomain, Mapping[GeneralizedChannel, set]]: return {dom: {ch: {factors}} for dom, ch_to_factors in self._dep_by_domain().items() for ch, factors in ch_to_factors.items() if factors} - def _dep_by_domain(self) -> Mapping[DepDomain, set]: + def _dep_by_domain(self) -> Mapping[DepDomain, Mapping[GeneralizedChannel, set]]: return {DepDomain.WF_SCALE: self.scale_factors, - DepDomain.WF_OFFSET: self.offset_factors,} + DepDomain.WF_OFFSET: self.offset_factors, + DepDomain.STEP_INDEX: self.index_factors} + @property + def step_channels(self) -> Optional[Tuple[StepRegister]]: + return tuple(self.index_factors.keys()) if self.index_factors else None @dataclass class LinSpaceRepeat(LinSpaceNode): @@ -211,13 +269,15 @@ class LinSpaceIter(LinSpaceNode): Offsets and spacing are stored in the hold node.""" body: Tuple[LinSpaceNode, ...] length: int - + + to_be_stepped: bool + def dependencies(self): dependencies = {} for node in self.body: for dom, ch_to_deps in node.dependencies().items(): for ch, deps in ch_to_deps.items(): - # remove the last elemt in index because this iteration sets it -> no external dependency + # remove the last element in index because this iteration sets it -> no external dependency shortened = {dep[:-1] for dep in deps} if shortened != {()}: dependencies.setdefault(dom,{}).setdefault(ch, set()).update(shortened) @@ -235,6 +295,8 @@ class LinSpaceBuilder(ProgramBuilder): def __init__(self, # channels: Tuple[ChannelID, ...] + to_stepping_repeat: TypingSet[Union[str,'ForLoopPT']] = set() + # identifier, loop_index or ForLoopPT which is to be stepped. ): super().__init__() # self._name_to_idx = {name: idx for idx, name in enumerate(channels)} @@ -242,6 +304,7 @@ def __init__(self, self._stack = [[]] self._ranges = [] + self._to_stepping_repeat = to_stepping_repeat def _root(self): return self._stack[0] @@ -249,11 +312,17 @@ def _root(self): def _get_rng(self, idx_name: str) -> range: return self._get_ranges()[idx_name] - def inner_scope(self, scope: Scope) -> Scope: + def inner_scope(self, scope: Scope, pt_obj: 'ForLoopPT') -> Scope: """This function is necessary to inject program builder specific parameter implementations into the build process.""" if self._ranges: - name, _ = self._ranges[-1] + name, rng = self._ranges[-1] + if pt_obj in self._to_stepping_repeat or pt_obj.identifier in self._to_stepping_repeat \ + or pt_obj.loop_index in self._to_stepping_repeat: + # the nesting level should be simply the amount of this type in the scope. + nest = len(tuple(v for v in scope.values() if isinstance(v,SimpleExpressionStepped))) + return scope.overwrite({name:SimpleExpressionStepped( + base=0,offsets={name: 1},step_nesting_level=nest+1,rng=rng)}) return scope.overwrite({name: SimpleExpression(base=0, offsets={name: 1})}) else: return scope @@ -312,60 +381,126 @@ def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, Hard bases=bases, factors=factors, duration_base=duration_base, - duration_factors=tuple(duration_factors)) + duration_factors=tuple(duration_factors), + ) self._stack[-1].append(set_cmd) - def play_arbitrary_waveform(self, waveform: Waveform): + def play_arbitrary_waveform(self, waveform: Union[Waveform,WaveformCollection], + stepped_var_list: Optional[List[Tuple[str,SimpleExpressionStepped]]] = None): - #recognize voltage trafo sweep syntax from a transforming waveform. other sweepable things may need different approaches. - if not isinstance(waveform,TransformingWaveform): + # recognize voltage trafo sweep syntax from a transforming waveform. + # other sweepable things may need different approaches. + if not isinstance(waveform,(TransformingWaveform,WaveformCollection)): + assert stepped_var_list is None return self._stack[-1].append(LinSpaceArbitraryWaveform(waveform=waveform,channels=waveform.defined_channels,)) - #test for transformations that contain SimpleExpression - wf_transformation = waveform.transformation + + #should be sufficient to test the first wf, as all should have the same trafo + waveform_propertyextractor = waveform + while isinstance(waveform_propertyextractor,WaveformCollection): + waveform_propertyextractor = waveform.waveform_collection[0] - # chainedTransformation should now have flat hierachy. - collected_trafos, dependent_vals_flag = collect_scaling_and_offset_per_channel(waveform.defined_channels,wf_transformation) + if isinstance(waveform_propertyextractor,TransformingWaveform): + #test for transformations that contain SimpleExpression + wf_transformation = waveform_propertyextractor.transformation + + # chainedTransformation should now have flat hierachy. + collected_trafos, dependent_trafo_vals_flag = collect_scaling_and_offset_per_channel( + waveform_propertyextractor.defined_channels,wf_transformation) + else: + dependent_trafo_vals_flag = False #fast track - if not dependent_vals_flag: + if not dependent_trafo_vals_flag and not isinstance(waveform,WaveformCollection): return self._stack[-1].append(LinSpaceArbitraryWaveform(waveform=waveform,channels=waveform.defined_channels,)) - + ranges = self._get_ranges() - scale_factors, offset_factors = {}, {} - scale_bases, offset_bases = {}, {} + ranges_list = list(ranges) + index_factors = {} - for ch_name,scale_offset_dict in collected_trafos.items(): - for bases,factors,key in zip((scale_bases, offset_bases),(scale_factors, offset_factors),('s','o')): - value = scale_offset_dict[key] - if isinstance(value, float): - bases[ch_name] = value - factors[ch_name] = None - continue + if stepped_var_list: + # the index ordering shall be with the last index changing fastest. + # (assuming the WaveformColleciton will be flattened) + # this means increments on last shall be 1, next lower 1*len(fastest), + # next 1*len(fastest)*len(second_fastest),... -> product(higher_reg_range_lens) + # total_reg_len = len(stepped_var_list) + reg_lens = tuple(len(v.rng) for s,v in stepped_var_list) + total_rng_len = np.cumprod(reg_lens)[-1] + reg_incr_values = list(np.cumprod(reg_lens[::-1]))[::-1][1:] + [1,] + + assert isinstance(waveform,WaveformCollection) + + for reg_num,(var_name,value) in enumerate(stepped_var_list): + # this should be given anyway: + assert isinstance(value, SimpleExpressionStepped) + + """ + # by definition, every var_name should be relevant for the waveform/ + # has been included in the nested WaveformCollection. + # so, each time this code is called, a new waveform node containing this is called, + # and one can/must increase the offset by the + + # assert value.base += total_rng_len + """ + + assert value.base == 0 + offsets = value.offsets - base = value.base - incs = [] - for rng_name, rng in ranges.items(): - start = 0. - step = 0. - offset = offsets.get(rng_name, None) - if offset: - start += rng.start * offset - step += rng.step * offset - base += start - incs.append(step) - factors[ch_name] = tuple(incs) - bases[ch_name] = base + #there can never be more than one key in this + # (nowhere is an evaluation of arithmetics betwen steppings intended) + assert len(offsets)==1 + assert all(v==1 for v in offsets.values()) + assert set(offsets.keys())=={var_name,} + + # this makes the search through ranges pointless; have tuple of zeros + # except for one inc at the position of the stepvar in the ranges dict + + incs = [0 for v in ranges_list] + incs[ranges_list.index(var_name)] = reg_incr_values[reg_num] + + #needs to be new "channel" each time? should be handled by metaclass + reg_channel = StepRegister(var_name,reg_num) + index_factors[reg_channel] = tuple(incs) + # bases[reg_channel] = value.base + scale_factors, offset_factors = {}, {} + scale_bases, offset_bases = {}, {} - return self._stack[-1].append(LinSpaceArbitraryWaveformIndexed(waveform=waveform, - channels=waveform.defined_channels, - scale_bases=scale_bases, - scale_factors=scale_factors, - offset_bases=offset_bases, - offset_factors=offset_factors, - )) + if dependent_trafo_vals_flag: + for ch_name,scale_offset_dict in collected_trafos.items(): + for bases,factors,key in zip((scale_bases, offset_bases),(scale_factors, offset_factors),('s','o')): + value = scale_offset_dict[key] + if isinstance(value, float): + bases[ch_name] = value + factors[ch_name] = None + continue + offsets = value.offsets + base = value.base + incs = [] + for rng_name, rng in ranges.items(): + start = 0. + step = 0. + offset = offsets.get(rng_name, None) + if offset: + start += rng.start * offset + step += rng.step * offset + base += start + incs.append(step) + factors[ch_name] = tuple(incs) + bases[ch_name] = base + + # assert ba + + return self._stack[-1].append(LinSpaceArbitraryWaveformIndexed( + waveform=waveform, + channels=waveform_propertyextractor.defined_channels.union(set(index_factors.keys())), + scale_bases=scale_bases, + scale_factors=scale_factors, + offset_bases=offset_bases, + offset_factors=offset_factors, + index_factors=index_factors, + )) def measure(self, measurements: Optional[Sequence[MeasurementWindow]]): """Ignores measurements""" @@ -390,6 +525,7 @@ def new_subprogram(self, global_transformation: 'Transformation' = None) -> Cont raise NotImplementedError('Not implemented yet (postponed)') def with_iteration(self, index_name: str, rng: range, + pt_obj: 'ForLoopPT', measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']: if len(rng) == 0: return @@ -399,8 +535,71 @@ def with_iteration(self, index_name: str, rng: range, cmds = self._stack.pop() self._ranges.pop() if cmds: - self._stack[-1].append(LinSpaceIter(body=tuple(cmds), length=len(rng))) + stepped = False + if pt_obj in self._to_stepping_repeat or pt_obj.identifier in self._to_stepping_repeat \ + or pt_obj.loop_index in self._to_stepping_repeat: + stepped = True + self._stack[-1].append(LinSpaceIter(body=tuple(cmds), length=len(rng), to_be_stepped=stepped)) + + + def evaluate_nested_stepping(self, scope: Scope, parameter_names: set[str]) -> bool: + + stepped_vals = {k:v for k,v in scope.items() if isinstance(v,SimpleExpressionStepped)} + #when overlap, then the PT is part of the stepped progression + if stepped_vals.keys() & parameter_names: + return True + return False + + def dispatch_to_stepped_wf_or_hold(self, + build_func: Callable[[Mapping[str, Real],Dict[ChannelID, Optional[ChannelID]]],Optional[Waveform]], + build_parameters: Scope, + parameter_names: set[str], + channel_mapping: Dict[ChannelID, Optional[ChannelID]], + #measurements tbd + global_transformation: Optional["Transformation"]) -> None: + + stepped_vals = {k:v for k,v in build_parameters.items() + if isinstance(v,SimpleExpressionStepped) and k in parameter_names} + sorted_steps = list(sorted(stepped_vals.items(), key=lambda item: item[1].step_nesting_level)) + def build_nested_wf_colls(remaining_ranges: List[Tuple], fixed_elements: List[Tuple] = []): + + if len(remaining_ranges) == 0: + inner_scope = build_parameters.overwrite(dict(fixed_elements)) + #by now, no SimpleExpressionStepped should remain here. + assert not any(isinstance(v,SimpleExpressionStepped) for v in inner_scope.values()) + waveform = build_func(inner_scope,channel_mapping=channel_mapping) + if global_transformation: + waveform = TransformingWaveform.from_transformation(waveform, global_transformation) + #this case should not happen, should have been caught beforehand: + # or maybe not, if e.g. amp is zero for some reason + # assert waveform.constant_value_dict() is None + return waveform + else: + return WaveformCollection( + tuple(build_nested_wf_colls(remaining_ranges[1:], + fixed_elements+[(remaining_ranges[0][0],remaining_ranges[0][1].value({remaining_ranges[0][0]:it})),]) + for it in remaining_ranges[0][1].rng)) + + # not completely convinced this works as intended. + # doesn't this - also in pulse_template program creation - lead to complications with ParallelConstantChannelTrafo? + # dirty, quick workaround - if this doesnt work, assume it is also not constant: + try: + potential_waveform = build_func(build_parameters,channel_mapping=channel_mapping) + if global_transformation: + potential_waveform = TransformingWaveform.from_transformation(potential_waveform, global_transformation) + constant_values = potential_waveform.constant_value_dict() + except: + constant_values = None + + if constant_values is None: + wf_coll = build_nested_wf_colls(sorted_steps) + self.play_arbitrary_waveform(wf_coll,sorted_steps) + else: + # in the other case, all dependencies _should_ be on amp and length, which is covered by hold appropriately + # and doesn't need to be stepped? + self.hold_voltage(potential_waveform.duration, constant_values) + def to_program(self) -> Optional[Sequence[LinSpaceNode]]: if self._root(): return self._root() @@ -492,12 +691,14 @@ class LoopJmp: @dataclass class Play: - waveform: Waveform - channels: Tuple[ChannelID] - keys_by_ch_by_domain: Dict[DepDomain,Dict[ChannelID,DepKey]] = None + waveform: Union[Waveform,WaveformCollection] + play_channels: Tuple[ChannelID] + step_channels: Optional[Tuple[StepRegister]] = None + #actually did the name + keys_by_domain_by_ch: Dict[ChannelID,Dict[DepDomain,DepKey]] = None def __post_init__(self): - if self.keys_by_ch_by_domain is None: - self.keys_by_ch_by_domain = {ch: {} for ch in self.channels} + if self.keys_by_domain_by_ch is None: + self.keys_by_domain_by_ch = {ch: {} for ch in self.channels} Command = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play] @@ -615,7 +816,8 @@ def set_wf_scale(self, channel: ChannelID, value: float): def set_wf_offset(self, channel: ChannelID, value: float): self.set_non_indexed_value(channel, value, domain=DepDomain.WF_OFFSET) - def set_non_indexed_value(self, channel: GeneralizedChannel, value: float, domain: DepDomain, always_emit_set: bool=False): + def set_non_indexed_value(self, channel: GeneralizedChannel, value: float, + domain: DepDomain, always_emit_set: bool=False): key = DepKey((),domain) # I do not completely get why it would have to be set again if not in active dep. # if not key != self.active_dep.get(channel, None) or @@ -648,6 +850,7 @@ def _add_repetition_node(self, node: LinSpaceRepeat): self.commands.append(jmp) def _add_iteration_node(self, node: LinSpaceIter): + self.iterations.append(0) self.add_node(node.body) @@ -686,7 +889,8 @@ def set_indexed_value(self, dep_key: DepKey, channel: GeneralizedChannel, # we insert all inc here (also inc == 0) because it signals to activate this amplitude register # -> since this is not necessary for other domains, make it stricter and bypass if necessary for voltage. - if ((inc or self.active_dep.get(channel, {}).get(dep_key.domain) != dep_key) and new_dep_state != current_dep_state)\ + if ((inc or self.active_dep.get(channel, {}).get(dep_key.domain) != dep_key) + and new_dep_state != current_dep_state)\ or always_emit_incr: # if always_emit_incr and new_dep_state == current_dep_state, inc should be zero. if always_emit_incr and new_dep_state == current_dep_state: @@ -713,22 +917,31 @@ def _add_hold_node(self, node: LinSpaceHold): def _add_indexed_play_node(self, node: LinSpaceArbitraryWaveformIndexed): - for ch in node.channels: - for base,factors,domain in zip((node.scale_bases[ch], node.offset_bases[ch]), - (node.scale_factors[ch], node.offset_factors[ch]), - (DepDomain.WF_SCALE,DepDomain.WF_OFFSET)): - if factors is None: - continue - # assume here that the waveform will have the correct settings the TransformingWaveform, - # where no SimpleExpression is replaced now. - # will yield the correct trafo already without having to make adjustments - # self.set_non_indexed_value(ch, base, domain) - else: - key = DepKey.from_domain(factors, resolution=self.resolution, domain=domain) - self.set_indexed_value(key, ch, base, factors, key.domain) - - self.commands.append(Play(node.waveform, node.channels, - keys_by_ch_by_domain={c: self.active_dep.get(c,{}) for c in node.channels})) + #assume this as criterion: + if len(node.scale_bases) and len(node.offset_bases): + for ch in node.play_channels: + for base,factors,domain in zip((node.scale_bases[ch], node.offset_bases[ch]), + (node.scale_factors[ch], node.offset_factors[ch]), + (DepDomain.WF_SCALE,DepDomain.WF_OFFSET)): + if factors is None: + continue + # assume here that the waveform will have the correct settings the TransformingWaveform, + # where no SimpleExpression is replaced now. + # will yield the correct trafo already without having to make adjustments + # self.set_non_indexed_value(ch, base, domain) + else: + key = DepKey.from_domain(factors, resolution=self.resolution, domain=domain) + self.set_indexed_value(key, ch, base, factors, key.domain) + + for st_ch, st_factors in node.index_factors.items(): + #this should not happen: + assert st_factors is not None + key = DepKey.from_domain(st_factors, resolution=self.resolution, domain=DepDomain.STEP_INDEX) + self.set_indexed_value(key, st_ch, 0, st_factors, key.domain) + + + self.commands.append(Play(node.waveform, node.channels, step_channels=node.step_channels, + keys_by_domain_by_ch={c: self.active_dep.get(c,{}) for c in node.channels})) def add_node(self, node: Union[LinSpaceNode, Sequence[LinSpaceNode]]): diff --git a/qupulse/program/loop.py b/qupulse/program/loop.py index 0f035653..9e8af1f3 100644 --- a/qupulse/program/loop.py +++ b/qupulse/program/loop.py @@ -771,7 +771,7 @@ def __init__(self): self._stack: List[StackFrame] = [StackFrame(self._root, None)] - def inner_scope(self, scope: Scope) -> Scope: + def inner_scope(self, scope: Scope, pt_obj: 'ForLoopPT') -> Scope: local_vars = self._stack[-1].iterating if local_vars is None: return scope @@ -806,6 +806,7 @@ def with_repetition(self, repetition_count: RepetitionCount, self._try_append(repetition_loop, measurements) def with_iteration(self, index_name: str, rng: range, + pt_obj: 'ForLoopPT', measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']: with self.with_sequence(measurements): top_frame = self._stack[-1] diff --git a/qupulse/program/waveforms.py b/qupulse/program/waveforms.py index 0d044953..2ba64a6f 100644 --- a/qupulse/program/waveforms.py +++ b/qupulse/program/waveforms.py @@ -1291,3 +1291,31 @@ def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Any: def reversed(self) -> 'Waveform': return self._inner + + + +class WaveformCollection(): + + def __init__(self, waveform_collection: Tuple[Union[Waveform,"WaveformCollection"]]): + + self._waveform_collection = tuple(waveform_collection) + + @property + def waveform_collection(self): + return self._waveform_collection + + @property + def nesting_level(self): + #assume it is balanced for now. + if isinstance(self.waveform_collection[0],type(self)): + return self.waveform_collection[0].nesting_level+1 + return 0 + + def flatten(self) -> Tuple[Waveform]: + def flatten_tuple(nested_tuple): + for item in nested_tuple: + if isinstance(item, tuple): + yield from flatten_tuple(item) + else: + yield item + return flatten_tuple(self.waveform_collection) diff --git a/qupulse/pulses/loop_pulse_template.py b/qupulse/pulses/loop_pulse_template.py index 0f458c68..e439840f 100644 --- a/qupulse/pulses/loop_pulse_template.py +++ b/qupulse/pulses/loop_pulse_template.py @@ -159,8 +159,9 @@ def _internal_create_program(self, *, measurements = self.get_measurement_windows(scope, measurement_mapping) for iteration_program_builder in program_builder.with_iteration(loop_index_name, loop_range, - measurements=measurements): - self.body._create_program(scope=iteration_program_builder.inner_scope(scope), + measurements=measurements, + pt_obj=self): + self.body._create_program(scope=iteration_program_builder.inner_scope(scope,pt_obj=self), measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, global_transformation=global_transformation, diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index 97dd5cda..53075f5a 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -28,6 +28,9 @@ from qupulse.parameter_scope import Scope, DictScope from qupulse.program import ProgramBuilder, default_program_builder, Program +from qupulse.program.linspace import LinSpaceBuilder + +from qupulse.expressions.simple import SimpleExpressionStepped __all__ = ["PulseTemplate", "AtomicPulseTemplate", "DoubleParameterNameException", "MappingTuple", "UnknownVolatileParameter"] @@ -513,7 +516,19 @@ def _internal_create_program(self, *, ### current behavior (same as previously): only adds EXEC Loop and measurements if a waveform exists. ### measurements are directly added to parent_loop (to reflect behavior of Sequencer + MultiChannelProgram) assert not scope.get_volatile_parameters().keys() & self.parameter_names, "AtomicPT cannot be volatile" - + + + # "hackedy": + if program_builder.evaluate_nested_stepping(scope,self.parameter_names): + program_builder.dispatch_to_stepped_wf_or_hold(build_func=self.build_waveform, + build_parameters=scope, + parameter_names=self.parameter_names, + channel_mapping=channel_mapping, + #measurements + global_transformation=global_transformation + ) + return + waveform = self.build_waveform(parameters=scope, channel_mapping=channel_mapping) if waveform: From 468a46c72718bc0b69233add3a3307a7895ffd52 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Sun, 16 Jun 2024 13:23:21 +0200 Subject: [PATCH 30/35] further bugfixes --- qupulse/program/linspace.py | 18 ++++++++++-------- qupulse/program/waveforms.py | 6 +++--- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index e4ec7e7e..608310d1 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -80,6 +80,8 @@ def __hash__(self): GeneralizedChannel = Union[DepDomain,ChannelID,StepRegister] +# is there any way to cast the numpy cumprod to int? +int_type = Union[np.int32,int] class ResolutionDependentValue(Generic[NumVal]): @@ -91,7 +93,7 @@ def __init__(self, self.bases = bases self.multiplicities = multiplicities self.offset = offset - self.__is_time_or_int = all(isinstance(b,(TimeType,int)) for b in bases) and isinstance(offset,(TimeType,int)) + self.__is_time_or_int = all(isinstance(b,(TimeType,int_type)) for b in bases) and isinstance(offset,(TimeType,int_type)) #this is not to circumvent float errors in python, but rounding errors from awg-increment commands. #python float are thereby accurate enough if no awg with a 500 bit resolution is invented. @@ -245,7 +247,7 @@ def _dep_by_domain(self) -> Mapping[DepDomain, Mapping[GeneralizedChannel, set]] @property def step_channels(self) -> Optional[Tuple[StepRegister]]: - return tuple(self.index_factors.keys()) if self.index_factors else None + return tuple(self.index_factors.keys()) if self.index_factors else () @dataclass class LinSpaceRepeat(LinSpaceNode): @@ -399,7 +401,7 @@ def play_arbitrary_waveform(self, waveform: Union[Waveform,WaveformCollection], #should be sufficient to test the first wf, as all should have the same trafo waveform_propertyextractor = waveform while isinstance(waveform_propertyextractor,WaveformCollection): - waveform_propertyextractor = waveform.waveform_collection[0] + waveform_propertyextractor = waveform_propertyextractor.waveform_collection[0] if isinstance(waveform_propertyextractor,TransformingWaveform): #test for transformations that contain SimpleExpression @@ -566,8 +568,8 @@ def build_nested_wf_colls(remaining_ranges: List[Tuple], fixed_elements: List[Tu if len(remaining_ranges) == 0: inner_scope = build_parameters.overwrite(dict(fixed_elements)) - #by now, no SimpleExpressionStepped should remain here. - assert not any(isinstance(v,SimpleExpressionStepped) for v in inner_scope.values()) + #by now, no SimpleExpressionStepped should remain here that is relevant for the current loop. + assert not any(isinstance(v,SimpleExpressionStepped) for k,v in inner_scope.items() if k in parameter_names) waveform = build_func(inner_scope,channel_mapping=channel_mapping) if global_transformation: waveform = TransformingWaveform.from_transformation(waveform, global_transformation) @@ -693,12 +695,12 @@ class LoopJmp: class Play: waveform: Union[Waveform,WaveformCollection] play_channels: Tuple[ChannelID] - step_channels: Optional[Tuple[StepRegister]] = None + step_channels: Tuple[StepRegister] = () #actually did the name keys_by_domain_by_ch: Dict[ChannelID,Dict[DepDomain,DepKey]] = None def __post_init__(self): if self.keys_by_domain_by_ch is None: - self.keys_by_domain_by_ch = {ch: {} for ch in self.channels} + self.keys_by_domain_by_ch = {ch: {} for ch in self.play_channels+self.step_channels} Command = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play] @@ -963,7 +965,7 @@ def add_node(self, node: Union[LinSpaceNode, Sequence[LinSpaceNode]]): self._add_indexed_play_node(node) elif isinstance(node, LinSpaceArbitraryWaveform): - self.commands.append(Play(node.waveform, node.channels)) + self.commands.append(Play(node.waveform, node.play_channels)) else: raise TypeError("The node type is not handled", type(node), node) diff --git a/qupulse/program/waveforms.py b/qupulse/program/waveforms.py index 2ba64a6f..01b4cfc6 100644 --- a/qupulse/program/waveforms.py +++ b/qupulse/program/waveforms.py @@ -1314,8 +1314,8 @@ def nesting_level(self): def flatten(self) -> Tuple[Waveform]: def flatten_tuple(nested_tuple): for item in nested_tuple: - if isinstance(item, tuple): - yield from flatten_tuple(item) + if isinstance(item, type(self)): + yield from flatten_tuple(item.waveform_collection) else: yield item - return flatten_tuple(self.waveform_collection) + return tuple(flatten_tuple(self.waveform_collection)) From a7eacf4d7caf4487c2f13466f6713c1bc2c9c1ca Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Sun, 16 Jun 2024 23:56:00 +0200 Subject: [PATCH 31/35] further bug patching --- qupulse/program/linspace.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 608310d1..ee7d3bd0 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -314,18 +314,22 @@ def _root(self): def _get_rng(self, idx_name: str) -> range: return self._get_ranges()[idx_name] - def inner_scope(self, scope: Scope, pt_obj: 'ForLoopPT') -> Scope: + def inner_scope(self, scope: Scope, pt_obj: Optional['ForLoopPT']=None) -> Scope: """This function is necessary to inject program builder specific parameter implementations into the build process.""" if self._ranges: name, rng = self._ranges[-1] - if pt_obj in self._to_stepping_repeat or pt_obj.identifier in self._to_stepping_repeat \ - or pt_obj.loop_index in self._to_stepping_repeat: + if pt_obj and (pt_obj in self._to_stepping_repeat or pt_obj.identifier in self._to_stepping_repeat \ + or pt_obj.loop_index in self._to_stepping_repeat): # the nesting level should be simply the amount of this type in the scope. nest = len(tuple(v for v in scope.values() if isinstance(v,SimpleExpressionStepped))) return scope.overwrite({name:SimpleExpressionStepped( base=0,offsets={name: 1},step_nesting_level=nest+1,rng=rng)}) - return scope.overwrite({name: SimpleExpression(base=0, offsets={name: 1})}) + else: + if isinstance(scope.get(name,None),SimpleExpressionStepped): + return scope + else: + return scope.overwrite({name: SimpleExpression(base=0, offsets={name: 1})}) else: return scope @@ -383,7 +387,7 @@ def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, Hard bases=bases, factors=factors, duration_base=duration_base, - duration_factors=tuple(duration_factors), + duration_factors=tuple(duration_factors) if duration_factors else None, ) self._stack[-1].append(set_cmd) From 550a31cbeaa169ef7fef74be9d40abbceb82c4a9 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Mon, 17 Jun 2024 14:26:04 +0200 Subject: [PATCH 32/35] hash Commands --- qupulse/program/linspace.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index ee7d3bd0..923cae8b 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -86,8 +86,8 @@ def __hash__(self): class ResolutionDependentValue(Generic[NumVal]): def __init__(self, - bases: Sequence[NumVal], - multiplicities: Sequence[int], + bases: Tuple[NumVal], + multiplicities: Tuple[int], offset: NumVal): self.bases = bases @@ -675,6 +675,9 @@ class Increment: channel: Optional[GeneralizedChannel] value: Union[ResolutionDependentValue,Tuple[ResolutionDependentValue]] key: DepKey + + def __hash__(self): + return hash((self.channel,self.value,self.key)) @dataclass @@ -683,12 +686,16 @@ class Set: value: Union[ResolutionDependentValue,Tuple[ResolutionDependentValue]] key: DepKey = dataclasses.field(default_factory=lambda: DepKey((),DepDomain.NODEP)) + def __hash__(self): + return hash((self.channel,self.value,self.key)) @dataclass class Wait: duration: Optional[TimeType] key_by_domain: Dict[DepDomain,DepKey] = dataclasses.field(default_factory=lambda: {}) + def __hash__(self): + return hash((self.duration,frozenset(self.key_by_domain.items()))) @dataclass class LoopJmp: @@ -702,10 +709,15 @@ class Play: step_channels: Tuple[StepRegister] = () #actually did the name keys_by_domain_by_ch: Dict[ChannelID,Dict[DepDomain,DepKey]] = None + def __post_init__(self): if self.keys_by_domain_by_ch is None: self.keys_by_domain_by_ch = {ch: {} for ch in self.play_channels+self.step_channels} + def __hash__(self): + return hash((self.waveform,self.play_channels,self.step_channels, + frozenset((k,frozenset(d.items())) for k,d in self.keys_by_domain_by_ch.items()))) + Command = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play] From fc8aee4d8443ee4c616098a0752dd78655bd21e8 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Fri, 5 Jul 2024 23:39:41 +0200 Subject: [PATCH 33/35] only modify commands that affect the current awg --- qupulse/hardware/awgs/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index 794decf1..a4418b64 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -293,15 +293,18 @@ def _transform_linspace_commands(self, command_list: List[Command]) -> List[Comm # play is handled by transforming the sampled waveform continue elif isinstance(command, Increment): - if command.key.domain is not DepDomain.VOLTAGE: + if command.key.domain is not DepDomain.VOLTAGE or \ + command.channel not in self._channels: #for sweeps of wf-scale and wf-offset, the channel amplitudes/offsets are already considered in the wf sampling. continue + ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: raise RuntimeError("Cannot apply a voltage transformation to a linspace increment command") command.value /= ch_trafo.amplitude elif isinstance(command, LSPSet): - if command.key.domain is not DepDomain.VOLTAGE: + if command.key.domain is not DepDomain.VOLTAGE or \ + command.channel not in self._channels: #for sweeps of wf-scale and wf-offset, the channel amplitudes/offsets are already considered in the wf sampling. continue ch_trafo = self._channel_transformations()[command.channel] From b3176f0abfad0f2e23f5fa030acd5a442fbad275 Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Wed, 10 Jul 2024 09:53:26 +0200 Subject: [PATCH 34/35] re-commit P.S.' initial changes --- qupulse/expressions/simple.py | 2 +- qupulse/pulses/repetition_pulse_template.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/qupulse/expressions/simple.py b/qupulse/expressions/simple.py index 10d1e94f..367c86ad 100644 --- a/qupulse/expressions/simple.py +++ b/qupulse/expressions/simple.py @@ -93,7 +93,7 @@ def __sub__(self, other): return self.__add__(-other) def __rsub__(self, other): - (-self).__add__(other) + return (-self).__add__(other) def __neg__(self): return SimpleExpression(-self.base, {name: -value for name, value in self.offsets.items()}) diff --git a/qupulse/pulses/repetition_pulse_template.py b/qupulse/pulses/repetition_pulse_template.py index ead19c6d..9657051a 100644 --- a/qupulse/pulses/repetition_pulse_template.py +++ b/qupulse/pulses/repetition_pulse_template.py @@ -135,7 +135,7 @@ def _internal_create_program(self, *, for repetition_program_builder in program_builder.with_repetition(repetition_definition, measurements=measurements): - self.body._create_program(scope=repetition_program_builder.inner_scope(scope), + self.body._create_program(scope=repetition_program_builder.inner_scope(scope, pt_obj=self), measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, global_transformation=global_transformation, From def9369d3612727fc430e84ce04c1ca63527ca2b Mon Sep 17 00:00:00 2001 From: Nomos11 <82180697+Nomos11@users.noreply.github.com> Date: Wed, 10 Jul 2024 10:05:03 +0200 Subject: [PATCH 35/35] re-commit P.S.' bugfixes --- qupulse/hardware/awgs/base.py | 5 +++-- qupulse/hardware/setup.py | 4 +++- qupulse/program/linspace.py | 7 +++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index a4418b64..8cd5a8ad 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -300,7 +300,8 @@ def _transform_linspace_commands(self, command_list: List[Command]) -> List[Comm ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: - raise RuntimeError("Cannot apply a voltage transformation to a linspace increment command") + if ch_trafo.voltage_transformation(1.0) != 1.0: + raise RuntimeError("Cannot apply a voltage transformation to a linspace increment command") command.value /= ch_trafo.amplitude elif isinstance(command, LSPSet): if command.key.domain is not DepDomain.VOLTAGE or \ @@ -310,7 +311,7 @@ def _transform_linspace_commands(self, command_list: List[Command]) -> List[Comm ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: # for the case of swept parameters, this is defaulted to identity - command.value = float(ch_trafo.voltage_transformation(command.value)) + command.value = ch_trafo.voltage_transformation(command.value) command.value -= ch_trafo.offset command.value /= ch_trafo.amplitude else: diff --git a/qupulse/hardware/setup.py b/qupulse/hardware/setup.py index e976a674..3a300ef0 100644 --- a/qupulse/hardware/setup.py +++ b/qupulse/hardware/setup.py @@ -94,6 +94,7 @@ def register_program(self, name: str, program: Loop, run_callback=lambda: None, update: bool = False, + channels = None, measurements: Mapping[str, Tuple[np.ndarray, np.ndarray]] = None) -> None: """Register a program under a given name at the hardware setup. The program will be uploaded to the participating AWGs and DACs. The run callback is used for triggering the program after arming. @@ -109,7 +110,8 @@ def register_program(self, name: str, if not callable(run_callback): raise TypeError('The provided run_callback is not callable') - channels = next(program.get_depth_first_iterator()).waveform.defined_channels + if channels is None: + channels = next(program.get_depth_first_iterator()).waveform.defined_channels if channels - set(self._channel_map.keys()): raise KeyError('The following channels are unknown to the HardwareSetup: {}'.format( channels - set(self._channel_map.keys()))) diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 923cae8b..2842dfa4 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -135,6 +135,9 @@ def __rmul__(self,other): def __truediv__(self,other): return self.__mul__(1/other) + + def __float__(self): + return float(self(resolution=None)) @@ -347,8 +350,8 @@ def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, Hard duration_factors = None for ch_name,value in voltages.items(): - if isinstance(value, float): - bases[ch_name] = value + if isinstance(value, (float, int)): + bases[ch_name] = float(value) factors[ch_name] = None continue offsets = value.offsets