Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MethodPulseTemplate #664

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion qupulse/_program/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numbers import Real
from typing import (
AbstractSet, Any, FrozenSet, Iterable, Mapping, NamedTuple, Sequence, Set,
Tuple, Union, cast, Optional, List, Hashable)
Tuple, Union, cast, Optional, List, Hashable, Callable)
from weakref import WeakValueDictionary, ref

import numpy as np
Expand Down Expand Up @@ -1228,3 +1228,51 @@ def compare_key(self) -> Hashable:

def reversed(self) -> 'Waveform':
return self._inner



class MethodWaveform(Waveform):
"""Waveform obtained from instantiating a FunctionPulseTemplate."""

def __init__(self, pulse_method: Callable,
duration: float,
channel: ChannelID) -> None:
"""Creates a new FunctionWaveform instance.

Args:
pulse_method: The method used to define this waveform.
duration: The duration of the waveform
measurement_windows: A list of measurement windows
channel: The channel this waveform is played on
"""
super().__init__(duration=TimeType.from_float(duration, absolute_error=PULSE_TO_WAVEFORM_ERROR))

self._pulse_method = pulse_method
self._channel_id = channel

@property
def defined_channels(self) -> Set[ChannelID]:
return {self._channel_id}

@property
def compare_key(self) -> Any:
return self._channel_id, self._expression, self._duration

@property
def duration(self) -> TimeType:
return self._duration

def unsafe_sample(self,
channel: ChannelID,
sample_times: np.ndarray,
output_array: Union[np.ndarray, None] = None) -> np.ndarray:
if channel != self._channel_id:
raise IndexError(f'channel {channel} cannot be sampled from {self}')

if output_array is None:
output_array = np.empty(len(sample_times))
output_array[:] = self._pulse_method(sample_times)
return output_array

def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Waveform:
return self
132 changes: 132 additions & 0 deletions qupulse/pulses/method_pulse_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import numbers
from typing import (Any, Callable, Dict, List, Optional, Set, Union)

import numpy as np
from qupulse._program.waveforms import MethodWaveform
from qupulse.expressions import ExpressionScalar
from qupulse.pulses.measurement import MeasurementDeclaration
from qupulse.pulses.parameters import ParameterConstrainer, ParameterConstraint
from qupulse.pulses.pulse_template import (AtomicPulseTemplate, ChannelID)
from qupulse.serialization import PulseRegistryType, Serializer


import functools
@functools.lru_cache(maxsize=1024)
def ExpressionScalarCache(value):
return ExpressionScalar(value)

class MethodPulseTemplate(AtomicPulseTemplate, ParameterConstrainer):
"""Defines a pulse via a method

MethodPulseTemplate.

"""

def __init__(self,
pulse_method: Callable,
duration: ExpressionScalar,
channel: ChannelID = 'default',
identifier: Optional[str] = None,
*,
measurements: Optional[List[MeasurementDeclaration]] = None,
parameter_constraints: Optional[List[Union[str, ParameterConstraint]]] = None,
registry: PulseRegistryType = None) -> None:
"""Creates a new FunctionPulseTemplate object.

Args:
method: The function represented by this MethodPulseTemplate
duration: Duration
channel: The channel this pulse template is defined on.
identifier: A unique identifier for use in serialization.
measurements: A list of measurement declarations forwarded to the
:class:`~qupulse.pulses.measurement.MeasurementDefiner` superclass
parameter_constraints: A list of parameter constraints forwarded to the
:class:`~qupulse.pulses.measurement.ParameterConstrainer` superclass
"""
AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements)
ParameterConstrainer.__init__(self, parameter_constraints=parameter_constraints)

self._pulse_method = pulse_method
self._duration = ExpressionScalarCache(duration)
self.__parameter_names: Set[str] = set()
self.__channel = channel

self._register(registry=registry)

def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]:
raise NotImplementedError(f'expression not available for {self.__class__}')

@property
def pulse_method(self) -> Callable:
return self._pulse_method

@property
def function_parameters(self) -> Set[str]:
return self.__parameter_names

@property
def parameter_names(self) -> Set[str]:
return self.function_parameters | self.measurement_parameters | self.constrained_parameters

@property
def defined_channels(self) -> Set[ChannelID]:
return {self.__channel}

@property
def duration(self) -> ExpressionScalar:
return self._duration

def build_waveform(self,
parameters: Dict[str, numbers.Real],
channel_mapping: Dict[ChannelID, Optional[ChannelID]]) -> Optional['MethodWaveform']:
self.validate_parameter_constraints(parameters=parameters, volatile=set())

channel = channel_mapping[self.__channel]
if channel is None:
return None

duration = self._duration

return MethodWaveform(pulse_method=self.pulse_method,
duration=float(duration),
channel=channel_mapping[self.__channel])

def get_serialization_data(self, serializer: Optional[Serializer] = None) -> Dict[str, Any]:
data = super().get_serialization_data(serializer)

if serializer: # compatibility to old serialization routines, deprecated
raise NotImplementedError

local_data = dict(
duration=self.duration,
method=str(self.pulse_method),
channel=self.__channel,
measurements=self.measurement_declarations,
parameter_constraints=[str(c) for c in self.parameter_constraints]
)

data.update(**local_data)
return data

@classmethod
def deserialize(cls,
serializer: Optional[Serializer] = None,
**kwargs) -> 'MethodPulseTemplate':
raise NotImplementedError()

@property
def integral(self) -> Dict[ChannelID, ExpressionScalar]:
try:
import scipy.integrate
except ImportError:
raise ValueError(f'scipy package is required to perform integral calculations for {self.__class__}')

return {self.__channel: ExpressionScalar(scipy.integrate.quad(self._pulse_method, 0, float(self.duration))[0]
)}



if __name__ == '__main__':
from qupulse.pulses.plotting import plot
px = MethodPulseTemplate(pulse_method=lambda t: np.sin(.2*t), duration=100)
plot(px, sample_rate=10)
56 changes: 56 additions & 0 deletions tests/pulses/method_pulse_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import unittest
import numpy as np

from qupulse.pulses.method_pulse_template import MethodPulseTemplate
from qupulse.expressions import ExpressionScalar
from qupulse.pulses.plotting import render


class MethodPulseTest(unittest.TestCase):
def setUp(self) -> None:
def pulse_method(t): return np.sin(.2*t)
self.fpt = MethodPulseTemplate(pulse_method, duration=100, channel='A')


class MethodPulsePropertyTest(MethodPulseTest):

def test_defined_channels(self) -> None:
self.assertEqual({'A'}, self.fpt.defined_channels)

def test_parameter_names(self):
self.assertEqual(self.fpt.parameter_names, set())

def test_duration(self):
self.assertEqual(self.fpt.duration, 100)

def test_integral(self) -> None:
try:
import scipy.integrate
except ImportError:
return

pulse = MethodPulseTemplate(pulse_method=lambda t: 0*t, duration=30)
self.assertDictEqual(pulse.integral, {'default': 0})
pulse = MethodPulseTemplate(pulse_method=lambda t: 1+0*t, duration=30)
self.assertDictEqual(pulse.integral, {'default': 30})
pulse = MethodPulseTemplate(pulse_method=lambda t: np.sin(t), duration=30)
self.assertDictEqual(pulse.integral, {'default': ExpressionScalar(0.8457485501124153)})

def test_get_serialization_data(self):
s = self.fpt.get_serialization_data()
self.assertEqual(s['channel'], 'A')
self.assertIsInstance(s['method'], str)


class MethodPulseSequencingTest(MethodPulseTest):
def test_build_waveform(self) -> None:
wf = self.fpt.build_waveform({}, channel_mapping={'A': 'B'})
self.assertEqual(wf.defined_channels, {'B'})

def test_sample(self) -> None:
times, values, _ = render(self.fpt.create_program(), sample_rate=2)
np.testing.assert_almost_equal(values['A'], np.sin(.2 * np.arange(0, 100.1, .5)))


if __name__ == '__main__':
unittest.main()