diff --git a/qopt/__init__.py b/qopt/__init__.py index 48e362e..42f4f8b 100644 --- a/qopt/__init__.py +++ b/qopt/__init__.py @@ -82,3 +82,12 @@ __version__ = '1.3' __license__ = 'GNU GPLv3+' __author__ = 'Julian Teske, Forschungszentrum Juelich' + + +try: + from jax.config import config + config.update("jax_enable_x64", True) + #TODO: add new objects here/ import other stuff? + # __all__ += [] +except ImportError: + pass \ No newline at end of file diff --git a/qopt/amplitude_functions.py b/qopt/amplitude_functions.py index 832a94c..f0b378a 100644 --- a/qopt/amplitude_functions.py +++ b/qopt/amplitude_functions.py @@ -64,10 +64,11 @@ """ from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Optional import numpy as np +from typing import Union class AmplitudeFunction(ABC): """Abstract Base class of the amplitude function. """ @@ -218,3 +219,125 @@ def derivative_by_chain_rule(self, deriv_by_ctrl_amps: np.ndarray, # return: shape (time, func, par) return np.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps) + + +############################################################################### + +try: + import jax.numpy as jnp + from jax import jit,vmap,jacfwd + _HAS_JAX = True +except ImportError: + from unittest import mock + jit, vmap, jacfwd = mock.Mock(), mock.Mock(), mock.Mock() + jnp = mock.Mock() + _HAS_JAX = False + + +class IdentityAmpFuncJAX(AmplitudeFunction): + """See docstring of class without JAX. + Designed to return jax-numpy-arrays. + """ + + def __init__(self): + if not _HAS_JAX: + raise ImportError("JAX not available") + + def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: + """See base class. """ + return jnp.asarray(x) + + def derivative_by_chain_rule( + self, + deriv_by_ctrl_amps: Union[np.ndarray,jnp.ndarray], + x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: + """See base class. """ + return jnp.asarray(deriv_by_ctrl_amps) + + +class UnaryAnalyticAmpFuncJAX(AmplitudeFunction): + """See docstring of class without JAX. + Designed to return jax-numpy-arrays. + Functions need to be compatible with jit. + (Includes that functions need to be pure + (i.e. output solely depends on input)). + """ + + def __init__(self, + value_function: Callable[[float, ], float], + derivative_function: [Callable[[float, ], float]]): + if not _HAS_JAX: + raise ImportError("JAX not available") + self.value_function = jit(jnp.vectorize(value_function)) + self.derivative_function = jit(jnp.vectorize(derivative_function)) + + def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: + """See base class. """ + return jnp.asarray(self.value_function(x)) + + def derivative_by_chain_rule( + self, + deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], x): + """See base class. """ + du_by_dx = self.derivative_function(x) + # du_by_dx shape: (n_time, n_ctrl) + # deriv_by_ctrl_amps shape: (n_time, n_func, n_ctrl) + # deriv_by_opt_par shape: (n_time, n_func, n_ctrl + # since the function is unary we have n_ctrl = n_amps + return jnp.einsum('ij,ikj->ikj', du_by_dx, deriv_by_ctrl_amps) + + +class CustomAmpFuncJAX(AmplitudeFunction): + """See docstring of class without JAX. + Designed to return jax-numpy-arrays. + Functions need to be compatible with jit. + (Includes that functions need to be pure + (i.e. output solely depends on input)). + If derivative_function=None, autodiff is used. + t_to_vectorize: if value_function/derivative_function not yet + vectorized for num_t + """ + + def __init__( + self, + value_function: Callable[[Union[np.ndarray, jnp.ndarray],], + Union[np.ndarray, jnp.ndarray]], + derivative_function: Callable[[Union[np.ndarray, jnp.ndarray],], + Union[np.ndarray, jnp.ndarray]], + t_to_vectorize: bool = False + ): + if not _HAS_JAX: + raise ImportError("JAX not available") + if t_to_vectorize == True: + self.value_function = jit(vmap(value_function),in_axes=(0,)) + else: + self.value_function = jit(value_function) + if derivative_function is not None: + if t_to_vectorize == True: + self.derivative_function = jit(vmap(derivative_function),in_axes=(0,)) + else: + self.derivative_function = jit(derivative_function) + else: + if t_to_vectorize == True: + def der_wrapper(x): + return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(x)),in_axes=(0,))(x),1,2) + else: + def der_wrapper(x): + return jnp.swapaxes(vmap(jacfwd(lambda x: value_function(jnp.expand_dims(x,axis=0))[0,:]),in_axes=(0,))(x),1,2) + self.derivative_function = jit(der_wrapper) + + def __call__(self, x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: + """See base class. """ + return jnp.asarray(self.value_function(x)) + + def derivative_by_chain_rule( + self, + deriv_by_ctrl_amps: Union[np.ndarray, jnp.ndarray], + x: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: + """See base class. """ + du_by_dx = self.derivative_function(x) + # du_by_dx: shape (time, par, ctrl) + # deriv_by_ctrl_amps: shape (time, func, ctrl) + # return: shape (time, func, par) + + return jnp.einsum('imj,ikj->ikm', du_by_dx, deriv_by_ctrl_amps) diff --git a/qopt/cost_functions.py b/qopt/cost_functions.py index 0bd8bff..cdb8c36 100644 --- a/qopt/cost_functions.py +++ b/qopt/cost_functions.py @@ -104,7 +104,9 @@ from qopt.util import needs_refactoring, deprecated from qopt.matrix import ket_vectorize_density_matrix, \ convert_ket_vectorized_density_matrix_to_square, \ - convert_unitary_to_super_operator + convert_unitary_to_super_operator, DenseOperator + +from functools import partial class CostFunction(ABC): @@ -122,7 +124,6 @@ class CostFunction(ABC): storing the data. """ - def __init__(self, solver: solver_algorithms.Solver, label: Optional[List[str]] = None): self.solver = solver @@ -830,6 +831,51 @@ def derivative_entanglement_fidelity_with_du( return derivative_fidelity +def derivative_entanglement_fidelity_with_dfreq( + target: matrix.OperatorMatrix, + target_der: matrix.OperatorMatrix, + forward_propagators: List[matrix.OperatorMatrix], + computational_states: Optional[List[int]] = None, + map_to_closest_unitary: bool = False +) -> np.ndarray: + + target_unitary_dag = target.dag(do_copy=True) + if computational_states: + trace = np.conj( + ((forward_propagators[-1].truncate_to_subspace( + computational_states, + map_to_closest_unitary=map_to_closest_unitary) + * target_unitary_dag).tr()) + ) + else: + trace = np.conj(((forward_propagators[-1] * target_unitary_dag).tr())) + num_ctrls = 1 + num_time_steps = 1 + d = target.shape[0] + + derivative_fidelity = np.zeros(shape=(num_time_steps, num_ctrls), + dtype=float) + + target_unitary_dag = target_der.dag(do_copy=True) + + ctrl=0 + t=-1 + # here we need to take the real part. + if computational_states: + derivative_fidelity[t, ctrl] = 2 / d / d * np.real( + trace * (forward_propagators[t].truncate_to_subspace( + subspace_indices=computational_states, + map_to_closest_unitary=map_to_closest_unitary + ) + * target_unitary_dag).tr()) + else: + derivative_fidelity[t, ctrl] = 2 / d / d * np.real( + trace * (forward_propagators[t] + * target_unitary_dag).tr()) + + return derivative_fidelity + + def entanglement_fidelity_super_operator( target: Union[np.ndarray, matrix.OperatorMatrix], propagator: Union[np.ndarray, matrix.OperatorMatrix], @@ -986,6 +1032,30 @@ def deriv_entanglement_fid_sup_op_with_du( return derivative_fidelity +def deriv_entanglement_fid_sup_op_with_dfreq( + target: matrix.OperatorMatrix, + target_der: matrix.OperatorMatrix, + forward_propagators: List[matrix.OperatorMatrix], + computational_states: Optional[List[int]] = None, + map_to_closest_unitary: bool = False +): + + num_ctrls = 1 + num_time_steps = 1 + + derivative_fidelity = np.zeros(shape=(num_time_steps, num_ctrls), + dtype=float) + ctrl=0 + t=-1 + # here we need to take the real part. + derivative_fidelity[t, ctrl] = \ + entanglement_fidelity_super_operator( + target=target_der, + propagator=forward_propagators[t], + computational_states=computational_states) + return derivative_fidelity + + class StateInfidelity(CostFunction): """Quantum state infidelity. @@ -1038,6 +1108,65 @@ def grad(self) -> np.ndarray: return -1. * np.real(derivative_fid) +class StateInfidelity2(CostFunction): + """Quantum state infidelity. + + TODO: + * support super operator formalism + * handle leakage states? + """ + + def __init__(self, + solver: solver_algorithms.Solver, + target: matrix.OperatorMatrix, + initial: matrix.OperatorMatrix, + label: Optional[List[str]] = None, + computational_states: Optional[List[int]] = None, + rescale_propagated_state: bool = False + ): + if label is None: + label = ['State Infidelity', ] + super().__init__(solver=solver, label=label) + # assure target is a bra vector + + if target.shape[0] > target.shape[1]: + self.target = target.dag() + else: + self.target = target + + #1D + self.initial = initial + + self.computational_states = computational_states + self.rescale_propagated_state = rescale_propagated_state + + def costs(self) -> np.float64: + """See base class. """ + + final = DenseOperator((self.solver.forward_propagators[-1]*self.initial).data[:,np.newaxis]) + + infid = 1. - state_fidelity( + target=self.target, + propagated_state=final, + computational_states=self.computational_states, + rescale_propagated_state=self.rescale_propagated_state + ) + return infid + + def grad(self) -> np.ndarray: + """See base class. """ + derivative_fid = derivative_state_fidelity( + forward_propagators=[DenseOperator((p*self.initial).data[:,np.newaxis]) for p in self.solver.forward_propagators], + target=self.target, + reversed_propagators=self.solver.reversed_propagators, + propagator_derivatives=self.solver.frechet_deriv_propagators, + computational_states=self.computational_states, + rescale_propagated_state=self.rescale_propagated_state + ) + return -1. * np.real(derivative_fid) + + + class StateInfidelitySubspace(CostFunction): """ Quantum state infidelity on a subspace. @@ -1247,7 +1376,8 @@ def __init__(self, super_operator_formalism: bool = False, label: Optional[List[str]] = None, computational_states: Optional[List[int]] = None, - map_to_closest_unitary: bool = False + map_to_closest_unitary: bool = False, + total_ang_time = None, ): if label is None: if fidelity_measure == 'entanglement': @@ -1267,8 +1397,17 @@ def __init__(self, 'currently supported.') self.super_operator = super_operator_formalism - - def costs(self) -> float: + + + + if total_ang_time is None: + self.total_ang_time = 0 + elif total_ang_time <0: + self.total_ang_time = sum(solver.transferred_time)-0.5*solver.transferred_time[-1] + else: + self.total_ang_time = total_ang_time + + def costs_original(self) -> float: """Calculates the costs by the selected fidelity measure. """ final = self.solver.forward_propagators[-1] @@ -1290,9 +1429,12 @@ def costs(self) -> float: 'implemented in this version.') return np.real(infid) - def grad(self) -> np.ndarray: + def grad_original(self) -> np.ndarray: """Calculates the derivatives of the selected fidelity measure with respect to the control amplitudes. """ + + + if self.fidelity_measure == 'entanglement' and self.super_operator: derivative_fid = deriv_entanglement_fid_sup_op_with_du( forward_propagators=self.solver.forward_propagators, @@ -1315,6 +1457,90 @@ def grad(self) -> np.ndarray: 'version.') return -1 * np.real(derivative_fid) + + + def costs(self,freq=0) -> float: + """Calculates the costs by the selected fidelity measure. """ + final = self.solver.forward_propagators[-1] + + + r = DenseOperator(np.array([[np.exp(1j*freq/2*self.total_ang_time),0],[0,np.exp(-1j*freq/2*self.total_ang_time)]])) + + if self.fidelity_measure == 'entanglement' and self.super_operator: + infid = 1 - entanglement_fidelity_super_operator( + propagator=final, + target=r.dag()*self.target, + computational_states=self.computational_states, + ) + elif self.fidelity_measure == 'entanglement': + infid = 1 - entanglement_fidelity( + propagator=final, + target=r.dag()*self.target, + computational_states=self.computational_states, + map_to_closest_unitary=self.map_to_closest_unitary + ) + else: + raise NotImplementedError('Only the entanglement fidelity is ' + 'implemented in this version.') + return np.real(infid) + + def grad(self,freq=0) -> np.ndarray: + """Calculates the derivatives of the selected fidelity measure with + respect to the control amplitudes. """ + + + r = DenseOperator(np.array([[np.exp(1j*freq/2*self.total_ang_time),0],[0,np.exp(-1j*freq/2*self.total_ang_time)]])) + + if self.fidelity_measure == 'entanglement' and self.super_operator: + derivative_fid = deriv_entanglement_fid_sup_op_with_du( + forward_propagators=self.solver.forward_propagators, + target=r.dag()*self.target, + reversed_propagators=self.solver.reversed_propagators, + unitary_derivatives=self.solver.frechet_deriv_propagators, + computational_states=self.computational_states, + ) + elif self.fidelity_measure == 'entanglement': + derivative_fid = derivative_entanglement_fidelity_with_du( + forward_propagators=self.solver.forward_propagators, + target=r.dag()*self.target, + reversed_propagators=self.solver.reversed_propagators, + propagator_derivatives=self.solver.frechet_deriv_propagators, + computational_states=self.computational_states, + ) + else: + raise NotImplementedError('Only the average and entanglement' + 'fidelity is implemented in this ' + 'version.') + return -1 * np.real(derivative_fid) + + def der_freq_test(self,freq): + + + r_der = 1j*self.total_ang_time/2*DenseOperator(np.array([[np.exp(1j*freq/2*self.total_ang_time),0],[0,-np.exp(-1j*freq/2*self.total_ang_time)]])) + r = DenseOperator(np.array([[np.exp(1j*freq/2*self.total_ang_time),0],[0,np.exp(-1j*freq/2*self.total_ang_time)]])) + + if self.fidelity_measure == 'entanglement' and self.super_operator: + derivative_fid = deriv_entanglement_fid_sup_op_with_dfreq( + forward_propagators=self.solver.forward_propagators, + target_der = r_der.dag()*self.target, + target=r.dag()*self.target, + computational_states=self.computational_states, + ) + elif self.fidelity_measure == 'entanglement': + derivative_fid = derivative_entanglement_fidelity_with_dfreq( + forward_propagators=self.solver.forward_propagators, + target_der = r_der.dag()*self.target, + target=r.dag()*self.target, + computational_states=self.computational_states, + ) + + else: + raise NotImplementedError('Only the average and entanglement' + 'fidelity is implemented in this ' + 'version.') + return -1 * np.real(derivative_fid) + + class OperationNoiseInfidelity(CostFunction): """ @@ -1365,7 +1591,9 @@ def __init__(self, fidelity_measure: str = 'entanglement', computational_states: Optional[List[int]] = None, map_to_closest_unitary: bool = False, - neglect_systematic_errors: bool = True): + neglect_systematic_errors: bool = True, + total_ang_time = None, + ): if label is None: label = ['Operator Noise Infidelity'] super().__init__(solver=solver, label=label) @@ -1381,6 +1609,16 @@ def __init__(self, print('The systematic errors must be neglected if no target is ' 'set!') self.neglect_systematic_errors = True + + + + if total_ang_time is None: + self.total_ang_time = 0 + elif total_ang_time <0: + self.total_ang_time = sum(solver.transferred_time)-0.5*solver.transferred_time[-1] + else: + self.total_ang_time = total_ang_time + def _to_comp_space(self, dynamic_target: matrix.OperatorMatrix) -> matrix.OperatorMatrix: @@ -1393,18 +1631,28 @@ def _to_comp_space(self, else: return dynamic_target - def _effective_target(self) -> matrix.OperatorMatrix: + def _effective_target(self,freq=0) -> matrix.OperatorMatrix: if self.neglect_systematic_errors: return self._to_comp_space(self.solver.forward_propagators[-1]) else: - return self.target - - def costs(self): + + r = DenseOperator(np.array([[np.exp(1j*freq/2*self.total_ang_time),0],[0,np.exp(-1j*freq/2*self.total_ang_time)]])) + return r.dag()*self.target + + def _effective_target_der(self,freq=0) -> matrix.OperatorMatrix: + if self.neglect_systematic_errors: + return 0*self.target + else: + + r = 1j*self.total_ang_time/2*DenseOperator(np.array([[np.exp(1j*freq/2*self.total_ang_time),0],[0,-np.exp(-1j*freq/2*self.total_ang_time)]])) + return r.dag()*self.target + + def costs(self,freq=0): """See base class. """ n_traces = self.solver.noise_trace_generator.n_traces infidelities = np.zeros((n_traces,)) - target = self._effective_target() + target = self._effective_target(freq=freq) if self.fidelity_measure == 'entanglement': for i in range(n_traces): @@ -1421,9 +1669,9 @@ def costs(self): return np.mean(np.real(infidelities)) - def grad(self): + def grad(self,freq=0): """See base class. """ - target = self._effective_target() + target = self._effective_target(freq) n_traces = self.solver.noise_trace_generator.n_traces num_t = len(self.solver.transferred_time) @@ -1452,7 +1700,28 @@ def grad(self): ) derivative[:, :, i] = np.real(temp) return np.mean(-derivative, axis=2) + + def der_freq_test(self,freq): + + + + target_der = self._effective_target_der(freq) + target = self._effective_target(freq) + n_traces = self.solver.noise_trace_generator.n_traces + num_t=1 + num_ctrl = 1 + derivative = np.zeros((num_t, num_ctrl, n_traces, )) + for i in range(n_traces): + temp = derivative_entanglement_fidelity_with_dfreq( + target_der=target_der, + target=target, + forward_propagators=self.solver.forward_propagators_noise[i], + computational_states=self.computational_states + ) + derivative[:, :, i] = np.real(temp) + return np.mean(-derivative, axis=2) + class LiouvilleMonteCarloEntanglementInfidelity(CostFunction): """ @@ -1864,81 +2133,1150 @@ def costs(self): # the result should always be positive within numerical accuracy return leakage.data.real[0] + def grad(self): """See base class. """ - raise NotImplementedError('The derivative of the cost function ' - 'LeakageLiouville has not been implemented' - 'yet.') - -@deprecated -def derivative_entanglement_fidelity( - control_hamiltonians: List[matrix.OperatorMatrix], - forward_propagators: List[matrix.OperatorMatrix], - reversed_propagators: List[matrix.OperatorMatrix], - delta_t: List[float], - target_unitary: matrix.OperatorMatrix) -> np.ndarray: - """ - Derivative of the entanglement fidelity using the grape approximation. - - dU / du = -i delta_t H_ctrl U + num_ctrls = len(self.solver.frechet_deriv_propagators) + num_time_steps = len(self.solver.frechet_deriv_propagators[0]) - Parameters - ---------- - control_hamiltonians: List[ControlMatrix], len: num_ctrl - The control hamiltonians of the simulation. + derivative_leakage = np.zeros(shape=(num_time_steps, num_ctrls), + dtype=np.float64) - forward_propagators: List[ControlMatrix], len: num_t +1 - The forward propagators calculated in the systems simulation. + for ctrl in range(num_ctrls): + for t in range(num_time_steps): + derivative_leakage[t, ctrl] = (1 / self.dim_comp) * ( + self.projector_leakage_bra + * self.solver.reversed_propagators[::-1][t + 1] \ + * self.solver.frechet_deriv_propagators[ctrl][t] + * self.solver.forward_propagators[t] + * self.projector_comp_ket + ).data.real[0] + + return derivative_leakage + + +############################################################################### + +try: + import jax.numpy as jnp + from jax import jit, vmap + import jax + _HAS_JAX = True +except ImportError: + from unittest import mock + jit = mock.Mock() + jnp = mock.Mock() + vmap = mock.Mock() + jax = mock.Mock() + _HAS_JAX = False + + +@jit +def _closest_unitary_jnp(matrix: jnp.ndarray) -> jnp.ndarray: + """Return the closest unitary to the matrix.""" + + left_singular_vec, __, right_singular_vec_h = jnp.linalg.svd( + matrix) + return left_singular_vec.dot(right_singular_vec_h) + + +@partial(jit,static_argnums=(1,)) +def _truncate_to_subspace_jnp_unmapped(arr: jnp.ndarray, + subspace_indices: Optional[tuple], + ) -> jnp.ndarray: + """Return the truncated jnp array""" + # subspace_indices = jnp.asarray(subspace_indices) + if subspace_indices is None: + return arr + elif arr.shape[0] == arr.shape[1]: + # square matrix + subspace_indices = jnp.asarray(subspace_indices) + out = arr[jnp.ix_(subspace_indices, subspace_indices)] + + elif arr.shape[0] == 1: + # bra-vector + subspace_indices = jnp.asarray(subspace_indices) + out = arr[jnp.ix_(jnp.array([0]), subspace_indices)] + + elif arr.shape[0] == 1: + # ket-vector + subspace_indices = jnp.asarray(subspace_indices) + out = arr[jnp.ix_(subspace_indices, jnp.array([0]))] - reversed_propagators: List[ControlMatrix] - The reversed propagators calculated in the systems simulation. + else: + subspace_indices = jnp.asarray(subspace_indices) + out = arr[jnp.ix_(subspace_indices)] - delta_t: List[float], len: num_t - The durations of the time steps. + return out - target_unitary: ControlMatrix - The target unitary evolution. - Returns - ------- - derivative_fidelity: np.ndarray, shape: (num_t, num_ctrl) - The derivatives of the entanglement fidelity. +@partial(jit,static_argnums=(1,)) +def _truncate_to_subspace_jnp_mapped(arr: jnp.ndarray, + subspace_indices: Optional[tuple], + ) -> jnp.ndarray: + """Return the truncated jnp array mapped to the closest unitary (matrix) / + renormalized (vector) + """ + # subspace_indices = jnp.asarray(subspace_indices) + if subspace_indices is None: + return arr + elif arr.shape[0] == arr.shape[1]: + # square matrix + subspace_indices = jnp.asarray(subspace_indices) + out = arr[jnp.ix_(subspace_indices, subspace_indices)] + out = _closest_unitary_jnp(out) + elif arr.shape[0] == 1: + # bra-vector + subspace_indices = jnp.asarray(subspace_indices) + out = arr[jnp.ix_(jnp.array([0]), subspace_indices)] + out *= 1 / jnp.linalg.norm(out,'fro') + elif arr.shape[0] == 1: + # ket-vector + subspace_indices = jnp.asarray(subspace_indices) + out = arr[jnp.ix_(subspace_indices, jnp.array([0]))] + out *= 1 / jnp.linalg.norm(out,'fro') + else: + subspace_indices = jnp.asarray(subspace_indices) + out = arr[jnp.ix_(subspace_indices)] + return out + +@partial(jit,static_argnums=(1,2)) +def _truncate_to_subspace_jnp(arr,subspace_indices,map_to_closest_unitary): + """Return the truncated jnp array, either mapped to the + closest unitary (matrix) / renormalized (vector) or not + """ + if map_to_closest_unitary==True: + return _truncate_to_subspace_jnp_mapped(arr,subspace_indices) + else: + return _truncate_to_subspace_jnp_unmapped(arr,subspace_indices) + + +@partial(jit,static_argnums=(2,3)) +def _entanglement_fidelity_jnp( + target: jnp.ndarray, + propagator: jnp.ndarray, + computational_states: Optional[tuple] = None, + map_to_closest_unitary: bool = False +) -> jnp.float64: + """Return the entanglement fidelity of target and propagator""" + d = target.shape[0] + if computational_states is None: + trace = (jnp.conj(target).T @ propagator).trace() + else: + trace = (jnp.conj(target).T @ _truncate_to_subspace_jnp(propagator, + computational_states, + map_to_closest_unitary)).trace() + return (jnp.abs(trace) ** 2) / d / d + + +@partial(jit,static_argnums=(2,3)) +def _entanglement_fidelity_super_operator_jnp( + target: jnp.ndarray, + propagator: jnp.ndarray, + dim_prop: int, + computational_states: Optional[tuple] = None, +) -> jnp.float64: + """Return the entanglement fidelity of target and propagator in super- + operator formalism """ - target_unitary_dag = target_unitary.dag(do_copy=True) - trace = np.conj(((forward_propagators[-1] * target_unitary_dag).tr())) - num_ctrls = len(control_hamiltonians) - num_time_steps = len(delta_t) - d = target_unitary.shape[0] - derivative_fidelity = np.zeros(shape=(num_time_steps, num_ctrls), - dtype=complex) + dim_comp = target.shape[0] - for ctrl in range(num_ctrls): - for t in range(num_time_steps): - # we take the imaginary part because we took a factor of i out - derivative_fidelity[t, ctrl] = 2 / d / d * delta_t * np.imag( - trace * (reversed_propagators[::-1][t + 1] - * control_hamiltonians[ctrl] - * forward_propagators[t + 1] - * target_unitary_dag).tr()) - return derivative_fidelity + if computational_states is None: + target_super_operator_inv = \ + jnp.kron(target.T, jnp.conj(target.T)) + trace = (target_super_operator_inv @ propagator).trace().real + else: + # Here we assume that the full Hilbertspace is the outer sum of a + # computational and a leakage space. + # Thus the dimension of the propagator is (d_comp + d_leak) ** 2 + d_leakage = dim_prop - dim_comp -@needs_refactoring -def averge_gate_fidelity(unitary: matrix.OperatorMatrix, - target_unitary: matrix.OperatorMatrix): - """ - Average gate fidelity. + # We fill zeros to the target on the leakage space. We will project + # onto the computational space anyway. - Parameters - ---------- - unitary: ControlMatrix - The evolution matrix of the system. + target_inv = jnp.conj(target.T) + target_inv_full_space = jnp.zeros((d_leakage + dim_comp, + d_leakage + dim_comp),dtype=complex) + + clist = jnp.array(computational_states) - target_unitary: ControlMatrix - The target unitary to which the evolution is compared. + for i, row in enumerate(computational_states): + for k, column in enumerate(computational_states): + target_inv_full_space = target_inv_full_space.at[row, column].set(target_inv[i, k]) + + # Then convert the target unitary into Liouville space. + + target_super_operator_inv = jnp.kron(jnp.conj(target_inv_full_space), + target_inv_full_space) + + # We start the projector with a zero matrix of dimension + # (d_comp + d_leak). + projector_comp_state = 0 * jnp.identity(target_inv_full_space.shape[0]) + + # for state in computational_states: + projector_comp_state = projector_comp_state.at[clist, + clist].set(1) + + # Then convert the projector into liouville space. + projector_comp_state=jnp.kron(jnp.conj(projector_comp_state), + projector_comp_state) + + trace = ( + projector_comp_state @ target_super_operator_inv @ propagator + ).trace().real + return trace / dim_comp / dim_comp + + +@partial(jit,static_argnums=(4,5)) +def _derivative_entanglement_fidelity_with_du_jnp( + target: jnp.ndarray, + forward_propagators_jnp: jnp.ndarray, + propagator_derivatives_jnp: jnp.ndarray, + reversed_propagators_jnp: jnp.ndarray, + computational_states: Optional[tuple] = None, + map_to_closest_unitary: bool = False +) -> jnp.ndarray: + """Return the derivative of the entanglement fidelity of target and + propagator + """ + target_unitary_dag = jnp.conj(target).T + if computational_states is not None: + trace = jnp.conj( + ((_truncate_to_subspace_jnp(forward_propagators_jnp[-1], + computational_states, + map_to_closest_unitary=map_to_closest_unitary) + @ target_unitary_dag).trace()) + ) + else: + trace = jnp.conj(((forward_propagators_jnp[-1]@ + target_unitary_dag).trace())) + d = target.shape[0] + + # here we need to take the real part. + if computational_states: + derivative_fidelity = 2/d/d * jnp.real(trace*_der_fid_comp_states( + propagator_derivatives_jnp, + reversed_propagators_jnp[::-1][1:], + forward_propagators_jnp[:-1],computational_states, + map_to_closest_unitary,target_unitary_dag)).T + + else: + derivative_fidelity = 2/d/d * jnp.real(trace*_der_fid( + propagator_derivatives_jnp, + reversed_propagators_jnp[::-1][1:], + forward_propagators_jnp[:-1],target_unitary_dag)).T + + return derivative_fidelity + + +def _der_fid_comp_states_loop(prop_der,rev_prop_rev,fwd_prop,comp_states, + map_to_closest_unitary,target_unitary_dag): + """Internal loop of derivative of entanglement fidelity w/ truncation""" + return (_truncate_to_subspace_jnp( + rev_prop_rev @ prop_der @ fwd_prop, + subspace_indices=comp_states, + map_to_closest_unitary=map_to_closest_unitary) + @ target_unitary_dag).trace() + + +#(to be used with additional .T for previously used shape) +@partial(jit,static_argnums=(3,4)) +def _der_fid_comp_states(prop_der,rev_prop_rev,fwd_prop,comp_states, + map_to_closest_unitary,target_unitary_dag): + """Derivative of entanglement fidelity w/ truncation, n_ctrl&n_timesteps on + first two axes + """ + return vmap(vmap(_der_fid_comp_states_loop,in_axes=(0,0,0,None,None,None)), + in_axes=(0,None,None,None,None,None))( + prop_der,rev_prop_rev,fwd_prop,comp_states, + map_to_closest_unitary,target_unitary_dag) + +def _der_fid_loop(prop_der,rev_prop_rev,fwd_prop,target_unitary_dag): + """Internal loop of derivative of entanglement fidelity w/o truncation""" + return (rev_prop_rev @ prop_der @ fwd_prop @ target_unitary_dag).trace() + +#(to be used with additional .T for previous shape) +@jit +def _der_fid(prop_der,rev_prop_rev,fwd_prop,target_unitary_dag): + """Derivative of entanglement fidelity w/o truncation""" + return vmap(vmap(_der_fid_loop,in_axes=(0,0,0,None)), + in_axes=(0,None,None,None))( + prop_der,rev_prop_rev,fwd_prop,target_unitary_dag) + + +@partial(jit,static_argnums=(4,5)) +def _deriv_entanglement_fid_sup_op_with_du_jnp( + target: jnp.ndarray, + forward_propagators: jnp.ndarray, + unitary_derivatives: jnp.ndarray, + reversed_propagators: jnp.ndarray, + dim_prop: int, + computational_states: Optional[tuple] = None +): + """Return the derivative of the entanglement fidelity of target and + propagator in super-operator formalism + """ + + derivative_fidelity = _der_entanglement_fidelity_super_operator_jnp( + target, + reversed_propagators[::-1][1:] @ unitary_derivatives @ + forward_propagators[:-1], + dim_prop, + computational_states).T + + return derivative_fidelity + + +#(to be used with additional .T for previous shape) +@partial(jit,static_argnums=(2,3)) +def _der_entanglement_fidelity_super_operator_jnp(target,propagators,dim_prop, + computational_states): + """Unnecessarily nested function for the derivative of the + entanglement fidelity of target and propagator in super-operator formalism + """ + return vmap(vmap(_entanglement_fidelity_super_operator_jnp, + in_axes=(None,0,None,None)),in_axes=(None,0,None,None))( + target,propagators,dim_prop,computational_states) + + +class StateInfidelityJAX(CostFunction): + """See docstring of class w/o JAX. Requires solver with JAX""" + def __init__(self, + solver: solver_algorithms.SolverJAX, + target: matrix.OperatorMatrix, + label: Optional[List[str]] = None, + computational_states: Optional[List[int]] = None, + rescale_propagated_state: bool = False + ): + if not _HAS_JAX: + raise ImportError("JAX not available") + if label is None: + label = ['State Infidelity', ] + super().__init__(solver=solver, label=label) + # assure target is a bra vector + + if target.shape[0] > target.shape[1]: + self.target = target.dag() + else: + self.target = target + + self._target_jnp = jnp.array(target.data) + if computational_states is None: + self.computational_states = None + else: + self.computational_states = tuple(computational_states) + self.rescale_propagated_state = rescale_propagated_state + + def costs(self) -> jnp.float64: + """See base class. """ + final = self.solver.forward_propagators_jnp[-1] + infid = 1. - _state_fidelity_jnp( + target=self._target_jnp, + propagated_state=final, + computational_states=self.computational_states, + rescale_propagated_state=self.rescale_propagated_state + ) + return jnp.real(infid) + + def grad(self) -> jnp.ndarray: + """See base class. """ + derivative_fid = _derivative_state_fidelity_jnp( + forward_propagators=self.solver.forward_propagators_jnp, + target=self._target_jnp, + reversed_propagators=self.solver.reversed_propagators_jnp, + propagator_derivatives=self.solver.frechet_deriv_propagators_jnp, + computational_states=self.computational_states, + rescale_propagated_state=self.rescale_propagated_state + ) + return -1. * jnp.real(derivative_fid) + + +@partial(jit,static_argnums=(2,3)) +def _state_fidelity_jnp( + target: jnp.ndarray, + propagated_state: jnp.ndarray, + computational_states: Optional[tuple] = None, + rescale_propagated_state: bool = False +) -> jnp.float64: + """Quantum state fidelity of target and propagated_state""" + + if computational_states is not None: + scalar_prod = jnp.dot( + target, + _truncate_to_subspace_jnp( + propagated_state, + computational_states, + map_to_closest_unitary=rescale_propagated_state + )) + else: + scalar_prod = jnp.dot(target, propagated_state) + + if scalar_prod.shape != (1, 1): + raise ValueError('The scalar product is not a scalar. This means that' + 'either the target is not a bra vector or the the ' + 'propagated state not a ket, or both!') + scalar_prod = scalar_prod[0, 0] + return jnp.abs(scalar_prod)**2 + + +@partial(jit,static_argnums=(4,5)) +def _derivative_state_fidelity_jnp( + target: jnp.ndarray, + forward_propagators: jnp.ndarray, + propagator_derivatives: jnp.ndarray, + reversed_propagators: jnp.ndarray, + computational_states: Optional[tuple] = None, + rescale_propagated_state: bool = False +) -> jnp.ndarray: + """Derivative of the state fidelity.""" + + if computational_states is not None: + scalar_prod = jnp.dot( + target, + _truncate_to_subspace_jnp( + forward_propagators[-1],subspace_indices=computational_states, + map_to_closest_unitary=rescale_propagated_state + )) + else: + scalar_prod = jnp.dot(target,forward_propagators[-1]) + + scalar_prod = jnp.conj(scalar_prod) + + if computational_states: + derivative_fidelity = 2 * jnp.real(scalar_prod*_der_fid_comp_states( + propagator_derivatives, + reversed_propagators[::-1][1:], + forward_propagators[:-1],computational_states, + rescale_propagated_state,target)).T + + else: + derivative_fidelity = 2 * jnp.real(scalar_prod*_der_fid( + propagator_derivatives, + reversed_propagators[::-1][1:], + forward_propagators[:-1],target)).T + + return derivative_fidelity + + +def _der_state_fid_comp_states_loop(prop_der,rev_prop_rev,fwd_prop,comp_states, + map_to_closest_unitary,target): + """Internal loop of derivative of state fidelity w/ truncation""" + return (target@_truncate_to_subspace_jnp( + rev_prop_rev@prop_der@fwd_prop, + subspace_indices=comp_states, + map_to_closest_unitary=map_to_closest_unitary))[0,0] + +#(to be used with additional .T for previous shape) +@partial(jit,static_argnums=(3,4)) +def _der_state_fid_comp_states(prop_der,rev_prop_rev,fwd_prop,comp_states, + map_to_closest_unitary,target): + """Derivative of state fidelity w/ truncation, + n_ctrl&n_time_steps on first two axes + """ + return vmap(vmap( + _der_state_fid_comp_states_loop,in_axes=(0,0,0,None,None,None)), + in_axes=(0,None,None,None,None,None))( + prop_der,rev_prop_rev,fwd_prop, + comp_states,map_to_closest_unitary,target) + +def _der_state_fid_loop(prop_der,rev_prop_rev,fwd_prop,target): + """Internal loop of derivative of state fidelity w/o truncation""" + + return (target @ rev_prop_rev @ prop_der @ fwd_prop)[0,0] + +#(to be used with additional .T for previous shape) +@jit +def _der_state_fid(prop_der,rev_prop_rev,fwd_prop,target): + """Derivative of state fidelity w/o truncation, + n_ctrl&n_time_steps on first two axes + """ + return vmap(vmap( + _der_state_fid_loop,in_axes=(0,0,0,None)),in_axes=(0,None,None,None))( + prop_der,rev_prop_rev,fwd_prop,target) + + +class StateInfidelitySubspaceJAX(CostFunction): + """See docstring of class w/o JAX. Requires solver with JAX""" + def __init__(self, + solver: solver_algorithms.SolverJAX, + target: matrix.OperatorMatrix, + dims: List[int], + remove: List[int], + label: Optional[List[str]] = None + ): + if not _HAS_JAX: + raise ImportError("JAX not available") + if label is None: + label = ['State Infidelity', ] + super().__init__(solver=solver, label=label) + # assure target is a bra vector + + if target.shape[0] > target.shape[1]: + self.target = target.dag() + else: + self.target = target + + self._target_jnp = jnp.asarray(self.target.data) + self.dims = tuple(dims) + self.remove = tuple(remove) + + def costs(self) -> jnp.float64: + """See base class. """ + final = self.solver.forward_propagators_jnp[-1] + infid = 1. - _state_fidelity_subspace_jnp( + target=self._target_jnp, + propagated_state=final, + dims=self.dims, + remove=self.remove + ) + return infid + + def grad(self) -> jnp.ndarray: + """See base class. """ + derivative_fid = _derivative_state_fidelity_subspace_jnp( + forward_propagators=self.solver.forward_propagators_jnp, + target=self._target_jnp, + reversed_propagators=self.solver.reversed_propagators_jnp, + propagator_derivatives=self.solver.frechet_deriv_propagators_jnp, + dims=self.dims, + remove=self.remove + ) + return -1. * derivative_fid + + +# @partial(jit,static_argnums=(2,3)) +def _state_fidelity_subspace_jnp( + target: jnp.ndarray, + propagated_state: jnp.ndarray, + dims: tuple, + remove: tuple +) -> jnp.float64: + r"""Derivative of the state fidelity on a subspace. + The unused subspace is traced out. + TODO: DID NOT include changes of last master commit -> WONT work with + vectorized density matrices. not as benefitial to have if statements in jax + functions; better create new func for it + """ + + rho = _ptrace_jnp(propagated_state,dims,remove) + + scalar_prod = target @ rho @ jnp.conj(target).T + + if scalar_prod.shape != (1, 1): + raise ValueError('The scalar product is not a scalar. This means that' + 'either the target is not a bra vector or the the ' + 'propagated state not a ket, or both!') + scalar_prod = scalar_prod[0, 0] + scalar_prod_real = scalar_prod.real + assert jnp.abs(scalar_prod - scalar_prod_real) < 1e-5 + return scalar_prod_real + + + +def _ptrace_jnp(mat: jnp.ndarray, + dims: tuple, + remove: tuple) -> jnp.ndarray: + """Partial trace of the matrix""" + + if mat.shape[1] == 1: + mat = (mat @ jnp.conj(mat).T) + + n_dim = len(dims) # number of subspaces + dims = jnp.asarray(dims, dtype=int) + + remove = jnp.sort(jnp.asarray(remove)) + + # indices of subspace that are kept + keep = jnp.array(jnp.where(jnp.arange(n_dim)!=remove)) + + keep=keep[0] + + dims_rm = dims[remove] + dims_keep = dims[keep] + dims = dims + + # 1. Reshape: Split matrix into subspaces + # 2. Transpose: Change subspace/index ordering such that the subspaces + # over which is traced correspond to the first axes + # 3. Reshape: Merge each, subspaces to be removed (A) and to be kept + # (B), common spaces/axes. + # The trace of the merged spaces (A \otimes B) can then be + # calculated as Tr_A(mat) using np.trace for input with + # more than two axes effectively resulting in + # pmat[j,k] = Sum_i mat[i,i,j,k] for all j,k = 0..prod(dims_keep) + pmat = jnp.trace(mat.reshape(jnp.hstack((dims,dims))) + .transpose(jnp.hstack((remove,n_dim + remove, + keep,n_dim +keep))) + .reshape(jnp.hstack((jnp.prod(dims_rm), + jnp.prod(dims_rm), + jnp.prod(dims_keep), + jnp.prod(dims_keep)))) + ) + + return pmat + + +def _derivative_state_fidelity_subspace_jnp( + target: jnp.ndarray, + forward_propagators: jnp.ndarray, + propagator_derivatives: jnp.ndarray, + reversed_propagators: jnp.ndarray, + dims: tuple, + remove: tuple +) -> jnp.ndarray: + """Derivative of the state fidelity on a subspace. + The unused subspace is traced out. + """ + + num_ctrls = len(propagator_derivatives) + num_time_steps = len(propagator_derivatives[0]) + + derivative_fidelity = np.zeros(shape=(num_time_steps, num_ctrls), + dtype=float) + + derivative_fidelity = 2 * jnp.real(_der_state_sub_fid_comp_states( + propagator_derivatives, + reversed_propagators[::-1][1:], + forward_propagators[:-1],dims, + remove,target)).T + + return derivative_fidelity + + +def _der_state_sub_fid_comp_states_loop(prop_der,rev_prop_rev,fwd_prop, + dims,remove,target): + """Internal loop of derivative of state fidelity on subspace""" + return (target @ _ptrace_jnp( + rev_prop_rev@prop_der@fwd_prop@ jnp.conj(fwd_prop[-1]).T,dims,remove)@ + jnp.conj(target).T)[0,0] + +#(to be used with additional .T for previous shape) +# @partial(jit,static_argnums=(3,4)) +def _der_state_sub_fid_comp_states(prop_der,rev_prop_rev,fwd_prop, + dims,remove,target): + """Derivative of state fidelity on subspace, n_ctrl&n_timesteps on first + two axes + """ + return vmap(vmap( + _der_state_sub_fid_comp_states_loop,in_axes=(0,0,0,None,None,None)), + in_axes=(0,None,None,None,None,None))( + prop_der,rev_prop_rev,fwd_prop,dims,remove,target) + + +class StateNoiseInfidelityJAX(CostFunction): + """See docstring of class w/o JAX. Requires solver with JAX""" + + def __init__(self, + solver: solver_algorithms.SchroedingerSMonteCarloJAX, + target: matrix.OperatorMatrix, + label: Optional[List[str]] = None, + computational_states: Optional[List[int]] = None, + rescale_propagated_state: bool = False, + neglect_systematic_errors: bool = True + ): + if not _HAS_JAX: + raise ImportError("JAX not available") + if label is None: + label = ['State Infidelity', ] + super().__init__(solver=solver, label=label) + self.solver = solver + + # assure target is a bra vector + if target.shape[0] > target.shape[1]: + self.target = target.dag() + else: + self.target = target + + self._target_jnp = jnp.array(target.data) + if computational_states is None: + self.computational_states = None + else: + self.computational_states = tuple(computational_states) + self.rescale_propagated_state = rescale_propagated_state + + self.neglect_systematic_errors = neglect_systematic_errors + if target is None and not neglect_systematic_errors: + print('The systematic errors must be neglected if no target is ' + 'set!') + self.neglect_systematic_errors = True + + def costs(self) -> jnp.float64: + """See base class. """ + n_traces = self.solver.noise_trace_generator.n_traces + infidelities = np.zeros((n_traces,)) + + if self.neglect_systematic_errors: + if self.computational_states is None: + target = self.solver.forward_propagators_jnp[-1] + else: + target = _truncate_to_subspace_jnp( + self.solver.forward_propagators_jnp[-1], + self.computational_states, + map_to_closest_unitary=self.rescale_propagated_state + ) + target = jnp.conj(target).T + else: + target = self._target_jnp + + # for i in range(n_traces): + final = self.solver.forward_propagators_noise_jnp[:,-1] + infidelities = 1. - jit(vmap( + _state_fidelity_jnp, + in_axes=(None,0,None,None)),static_argnums=(2,))( + target, + final, + self.computational_states, + self.rescale_propagated_state + ) + + return jnp.mean(jnp.real(infidelities)) + + def grad(self) -> jnp.ndarray: + """See base class. """ + raise NotImplementedError + + +class OperationInfidelityJAX(CostFunction): + """See docstring of class w/o JAX. Requires solver with JAX""" + + def __init__(self, + solver: solver_algorithms.SolverJAX, + target: matrix.OperatorMatrix, + fidelity_measure: str = 'entanglement', + super_operator_formalism: bool = False, + label: Optional[List[str]] = None, + computational_states: Optional[List[int]] = None, + map_to_closest_unitary: bool = False + ): + if not _HAS_JAX: + raise ImportError("JAX not available") + if label is None: + if fidelity_measure == 'entanglement': + label = ['Entanglement Infidelity', ] + else: + label = ['Operator Infidelity', ] + + super().__init__(solver=solver, label=label) + self.target = target + self._target_jnp = jnp.array(target.data) + if computational_states is None: + self.computational_states = None + else: + self.computational_states = tuple(computational_states) + self.map_to_closest_unitary = map_to_closest_unitary + + if fidelity_measure == 'entanglement': + self.fidelity_measure = fidelity_measure + else: + raise NotImplementedError('Only the entanglement fidelity is ' + 'currently supported.') + + self.super_operator = super_operator_formalism + + def costs(self) -> float: + """Calculates the costs by the selected fidelity measure. """ + final = self.solver.forward_propagators_jnp[-1] + + if self.fidelity_measure == 'entanglement' and self.super_operator: + infid = 1 - _entanglement_fidelity_super_operator_jnp( + self._target_jnp, + final, + jnp.sqrt(final.shape[0]).astype(int), + self.computational_states, + ) + elif self.fidelity_measure == 'entanglement': + infid = 1 - _entanglement_fidelity_jnp( + self._target_jnp, + final, + self.computational_states, + self.map_to_closest_unitary + ) + else: + raise NotImplementedError('Only the entanglement fidelity is ' + 'implemented in this version.') + return jnp.real(infid) + + + def grad(self) -> jnp.ndarray: + """Calculates the derivatives of the selected fidelity measure with + respect to the control amplitudes. """ + if self.fidelity_measure == 'entanglement' and self.super_operator: + derivative_fid = _deriv_entanglement_fid_sup_op_with_du_jnp( + self._target_jnp, + self.solver.forward_propagators_jnp, + self.solver.frechet_deriv_propagators_jnp, + self.solver.reversed_propagators_jnp, + jnp.sqrt(self.solver.forward_propagators_jnp.shape[1]).astype(int), + self.computational_states, + ) + elif self.fidelity_measure == 'entanglement': + derivative_fid = _derivative_entanglement_fidelity_with_du_jnp( + self._target_jnp, + self.solver.forward_propagators_jnp, + self.solver.frechet_deriv_propagators_jnp, + self.solver.reversed_propagators_jnp, + self.computational_states, + self.map_to_closest_unitary + ) + else: + raise NotImplementedError('Only the average and entanglement' + 'fidelity is implemented in this ' + 'version.') + return -1 * jnp.real(derivative_fid) + + +class OperationNoiseInfidelityJAX(CostFunction): + """See docstring of class w/o JAX. Requires solver with JAX""" + + def __init__(self, + solver: solver_algorithms.SchroedingerSMonteCarloJAX, + target: Optional[matrix.OperatorMatrix], + label: Optional[List[str]] = None, + fidelity_measure: str = 'entanglement', + computational_states: Optional[List[int]] = None, + map_to_closest_unitary: bool = False, + neglect_systematic_errors: bool = True): + if not _HAS_JAX: + raise ImportError("JAX not available") + if label is None: + label = ['Operator Noise Infidelity'] + super().__init__(solver=solver, label=label) + self.solver = solver + self.target = target + + self._target_jnp = jnp.array(target.data) + if computational_states is None: + self.computational_states = None + else: + self.computational_states = tuple(computational_states) + self.map_to_closest_unitary = map_to_closest_unitary + self.fidelity_measure = fidelity_measure + + self.neglect_systematic_errors = neglect_systematic_errors + if target is None and not neglect_systematic_errors: + print('The systematic errors must be neglected if no target is ' + 'set!') + self.neglect_systematic_errors = True + + def _to_comp_space(self, dynamic_target: jnp.ndarray) -> jnp.ndarray: + """Map an operator to the computational space""" + if self.computational_states is not None: + return _truncate_to_subspace_jnp(dynamic_target, + subspace_indices=self.computational_states, + map_to_closest_unitary=self.map_to_closest_unitary, + ) + else: + return dynamic_target + + def _effective_target(self) -> jnp.ndarray: + if self.neglect_systematic_errors: + return self._to_comp_space(self.solver.forward_propagators_jnp[-1]) + else: + return self._target_jnp + + def costs(self): + """See base class. """ + n_traces = self.solver.noise_trace_generator.n_traces + infidelities = np.zeros((n_traces,)) + + target = self._effective_target() + + if self.fidelity_measure == 'entanglement': + # for i in range(n_traces): + final = self.solver.forward_propagators_noise_jnp[:,-1] + + infidelities = 1 - jit(vmap( + _entanglement_fidelity_jnp, + in_axes=(None,0,None,None)),static_argnums=(2,))( + target,final, + self.computational_states, + self.map_to_closest_unitary + ) + else: + raise NotImplementedError('Only the entanglement fidelity is ' + 'currently implemented in this class.') + + return jnp.mean(jnp.real(infidelities)) + + def grad(self): + """See base class. """ + target = self._effective_target() + + temp = _derivative_entanglement_fidelity_with_du_noise_jnp( + target, + self.solver.forward_propagators_noise_jnp, + self.solver.frechet_deriv_propagators_noise_jnp, + self.solver.reversed_propagators_noise_jnp, + self.computational_states, + self.map_to_closest_unitary + ) + + if self.neglect_systematic_errors: + temp_target = vmap(self._to_comp_space,in_axes=(0,))( + self.solver.forward_propagators_noise_jnp[:,-1]) + + temp += _derivative_entanglement_fidelity_with_du_noise_sys_jnp( + temp_target, + self.solver.forward_propagators_jnp, + self.solver.frechet_deriv_propagators_jnp, + self.solver.reversed_propagators_jnp, + self.computational_states, + self.map_to_closest_unitary + ) + + return jnp.mean(-jnp.real(temp), axis=0) + + +@partial(jit,static_argnums=(4,5)) +def _derivative_entanglement_fidelity_with_du_noise_jnp( + target,fwd_props,prop_der,reversed_props,comp_states,map_to_closest): + """Return derivative of entanglement fidelity with vmap over traces""" + return vmap(_derivative_entanglement_fidelity_with_du_jnp, + in_axes=(None,0,0,0,None,None))( + target,fwd_props,prop_der,reversed_props, + comp_states,map_to_closest) + + +@partial(jit,static_argnums=(4,5)) +def _derivative_entanglement_fidelity_with_du_noise_sys_jnp( + target,fwd_props,prop_der,reversed_props,comp_states,map_to_closest): + """Return additional product rule part of derivative of entanglement + fidelity if systematic errors neglected""" + return vmap(_derivative_entanglement_fidelity_with_du_jnp, + in_axes=(0,None,None,None,None,None))( + target,fwd_props,prop_der,reversed_props, + comp_states,map_to_closest) + + +class LeakageErrorJAX(CostFunction): + """See docstring of class w/o JAX. Requires solver with JAX""" + + def __init__(self, solver: solver_algorithms.SolverJAX, + computational_states: List[int], + label: Optional[List[str]] = None): + if not _HAS_JAX: + raise ImportError("JAX not available") + if label is None: + label = ["Leakage Error", ] + super().__init__(solver=solver, label=label) + if computational_states is None: + self.computational_states = None + else: + self.computational_states = tuple(computational_states) + + def costs(self): + """See base class. """ + final_prop = self.solver.forward_propagators_jnp[-1] + clipped_prop = _truncate_to_subspace_jnp(final_prop, + self.computational_states,map_to_closest_unitary=False) + temp = jnp.conj(clipped_prop).T @ clipped_prop + + # the result should always be positive within numerical accuracy + return max(0, 1 - temp.trace().real / clipped_prop.shape[0]) + + def grad(self): + """See base class. """ + final = self.solver.forward_propagators_jnp[-1] + final_leak_dag = _truncate_to_subspace_jnp(jnp.conj(final).T, + self.computational_states,map_to_closest_unitary=False) + d = final_leak_dag.shape[0] + + derivative_fidelity = -2./d*jnp.real( + _der_leak_comp_states( + self.solver.frechet_deriv_propagators_jnp, + self.solver.reversed_propagators_jnp[::-1][1:], + self.solver.forward_propagators_jnp[:-1], + self.computational_states, + final_leak_dag).T) + + return derivative_fidelity + + +def _der_leak_comp_states_loop(prop_der,rev_prop_rev,fwd_prop,comp_states, + final_leak_dag): + """Internal loop of derivative of leakage""" + return (_truncate_to_subspace_jnp( + rev_prop_rev @ prop_der @ fwd_prop,subspace_indices=comp_states, + map_to_closest_unitary=False) @ final_leak_dag).trace() + +#(to be used with additional .T for previous shape) +@partial(jit,static_argnums=3) +def _der_leak_comp_states(prop_der,rev_prop_rev,fwd_prop,comp_states, + final_leak_dag): + """Derivative of leakage, n_ctrl&n_timesteps on first two axes""" + return vmap(vmap(_der_leak_comp_states_loop,in_axes=(0,0,0,None,None)), + in_axes=(0,None,None,None,None))( + prop_der,rev_prop_rev,fwd_prop,comp_states,final_leak_dag) + + +class IncoherentLeakageErrorJAX(CostFunction): + """See docstring of class w/o JAX. Requires solver with JAX""" + + def __init__(self, solver: solver_algorithms.SchroedingerSMonteCarloJAX, + computational_states: List[int], + label: Optional[List[str]] = None): + if not _HAS_JAX: + raise ImportError("JAX not available") + if label is None: + label = ["Incoherent Leakage Error", ] + super().__init__(solver=solver, label=label) + self.solver = solver + if computational_states is None: + self.computational_states = None + else: + self.computational_states = tuple(computational_states) + + def costs(self): + """See base class. """ + final_props = self.solver.forward_propagators_noise_jnp[:,-1] + + clipped_props = vmap(_truncate_to_subspace_jnp,in_axes=(0,None,None))( + final_props,self.computational_states,False) + + result = 1-jnp.real( + jnp.trace(jnp.transpose(jnp.conj(clipped_props),axes=(0,2,1))@ + clipped_props,axis1=1,axis2=2))/len( + self.computational_states) + + return jnp.mean(result) + + def grad(self): + """See base class. """ + raise NotImplementedError('Derivatives only implemented for the ' + 'coherent leakage.') + + +class LeakageLiouvilleJAX(CostFunction): + """See docstring of class w/o JAX. Requires solver with JAX""" + + def __init__(self, solver: solver_algorithms.SolverJAX, + computational_states: List[int], + label: Optional[List[str]] = None, + verbose: int = 0): + if not _HAS_JAX: + raise ImportError("JAX not available") + if label is None: + label = ["Leakage Error Lindblad", ] + super().__init__(solver=solver, label=label) + + self.computational_states = tuple(computational_states) + dim = self.solver.h_ctrl[0].shape[0] + self.dim_comp = len(self.computational_states) + self.verbose = verbose + # operator_class = type(self.solver.h_ctrl[0]) + + # create projectors + projector_comp = np.diag(np.ones([dim, ], dtype=complex)) + projector_leakage = np.diag(np.ones([dim, ], dtype=complex)) + + for state in computational_states: + projector_leakage[state, state] = 0 + projector_comp -= projector_leakage + + # vectorize projectors + self.projector_leakage_bra = jnp.asarray(ket_vectorize_density_matrix( + projector_leakage).transpose()) + + self.projector_comp_ket = jnp.asarray( + ket_vectorize_density_matrix(projector_comp)) + + + def costs(self): + """See base class. """ + leakage = (1 / self.dim_comp) * ( + self.projector_leakage_bra + @ self.solver.forward_propagators_jnp[-1] + @ self.projector_comp_ket + ) + + if self.verbose > 0: + print('leakage:') + print(leakage[0, 0]) + + # the result should always be positive within numerical accuracy + return leakage.real[0] + + def grad(self): + """See base class. """ + raise NotImplementedError('The derivative of the cost function ' + 'LeakageLiouville has not been implemented' + 'yet.') + + + + +@deprecated +def derivative_entanglement_fidelity( + control_hamiltonians: List[matrix.OperatorMatrix], + forward_propagators: List[matrix.OperatorMatrix], + reversed_propagators: List[matrix.OperatorMatrix], + delta_t: List[float], + target_unitary: matrix.OperatorMatrix) -> np.ndarray: + """ + Derivative of the entanglement fidelity using the grape approximation. + + dU / du = -i delta_t H_ctrl U + + Parameters + ---------- + control_hamiltonians: List[ControlMatrix], len: num_ctrl + The control hamiltonians of the simulation. + + forward_propagators: List[ControlMatrix], len: num_t +1 + The forward propagators calculated in the systems simulation. + + reversed_propagators: List[ControlMatrix] + The reversed propagators calculated in the systems simulation. + + delta_t: List[float], len: num_t + The durations of the time steps. + + target_unitary: ControlMatrix + The target unitary evolution. + + Returns + ------- + derivative_fidelity: np.ndarray, shape: (num_t, num_ctrl) + The derivatives of the entanglement fidelity. + + """ + target_unitary_dag = target_unitary.dag(do_copy=True) + trace = np.conj(((forward_propagators[-1] * target_unitary_dag).tr())) + num_ctrls = len(control_hamiltonians) + num_time_steps = len(delta_t) + d = target_unitary.shape[0] + + derivative_fidelity = np.zeros(shape=(num_time_steps, num_ctrls), + dtype=complex) + + for ctrl in range(num_ctrls): + for t in range(num_time_steps): + # we take the imaginary part because we took a factor of i out + derivative_fidelity[t, ctrl] = 2 / d / d * delta_t * np.imag( + trace * (reversed_propagators[::-1][t + 1] + * control_hamiltonians[ctrl] + * forward_propagators[t + 1] + * target_unitary_dag).tr()) + return derivative_fidelity + + +@needs_refactoring +def averge_gate_fidelity(unitary: matrix.OperatorMatrix, + target_unitary: matrix.OperatorMatrix): + """ + Average gate fidelity. + + Parameters + ---------- + unitary: ControlMatrix + The evolution matrix of the system. + + target_unitary: ControlMatrix + The target unitary to which the evolution is compared. Returns @@ -1998,8 +3336,7 @@ def default_set_orthorgonal(dim: int) -> List[matrix.OperatorMatrix]: @deprecated def derivative_average_gate_fidelity(control_hamiltonians, propagators, - propagators_past, delta_t, - target_unitary): + propagators_past, delta_t, target_unitary): """ The derivative of the average gate fidelity. """ @@ -2018,13 +3355,13 @@ def derivative_average_gate_fidelity(control_hamiltonians, propagators, dtype=complex) for ctrl in range(num_ctrls): for t in range(num_time_steps): - bkwd_prop_target = propagators_future[t + 1].dag() * target_unitary + bkwd_prop_target = propagators_future[t+1].dag() * target_unitary temp = 0 for ort in orthogonal_operators: lambda_ = bkwd_prop_target * ort.dag(do_copy=True) lambda_ *= bkwd_prop_target.dag() - rho = propagators_past[t + 1] * ort - rho *= propagators_past[t + 1].dag() + rho = propagators_past[t+1] * ort + rho *= propagators_past[t+1].dag() # everything rewritten to operate in place temp_mat2 = control_hamiltonians[t, ctrl] * rho temp_mat2 -= rho * control_hamiltonians[t, ctrl] @@ -2033,9 +3370,7 @@ def derivative_average_gate_fidelity(control_hamiltonians, propagators, temp_mat *= delta_t temp_mat *= temp_mat2 temp += temp_mat.tr() - # temp += (lambda_ * -1j * delta_t * ( - # control_hamiltonians[t, ctrl] * rho - # - rho * control_hamiltonians[t, ctrl])).tr() + derivative_fidelity[t, ctrl] = temp / (dim ** 2 * (dim + 1)) return derivative_fidelity @@ -2074,3 +3409,545 @@ def derivative_average_gate_fid_with_du(propagators, propagators_past, temp += lambda_.tr() derivative_fidelity[t, ctrl] = temp / (dim ** 2 * (dim + 1)) return derivative_fidelity + + +############################################################################### + +class OperationInfidelityJAXSpecial(OperationInfidelityJAX): + """ + """ + def __init__(self, + solver: solver_algorithms.Solver, + target: matrix.OperatorMatrix, + rot_frame_ang_freq: float, + fidelity_measure: str = 'entanglement', + super_operator_formalism: bool = False, + label: Optional[List[str]] = None, + computational_states: Optional[List[int]] = None, + map_to_closest_unitary: bool = False + ): + + super().__init__(solver=solver, + target=target, + fidelity_measure=fidelity_measure, + super_operator_formalism=super_operator_formalism, + label=label, + computational_states=computational_states, + map_to_closest_unitary=map_to_closest_unitary) + + + self.end_time = sum(solver.transferred_time)-0.5*solver.transferred_time[-1] + self.freq = rot_frame_ang_freq + + def rot_op_4(self,time): + return jnp.array([[np.exp(-1j*2*self.freq/2*time),0,0,0], + [0,np.exp(0*self.freq/2*time),0,0], + [0,0,np.exp(0*self.freq/2*time),0], + [0,0,0,np.exp(1j*2*self.freq/2*time)]]) + + def rot_op_4_der_t(self,time): + return 1j*2*self.freq/2*jnp.array([[-np.exp(-1j*2*self.freq/2*time),0,0,0], + [0,np.exp(0*self.freq/2*time),0,0], + [0,0,np.exp(0*self.freq/2*time),0], + [0,0,0,np.exp(1j*2*self.freq/2*time)]]) + + def costs(self,time_fact) -> float: + """Calculates the costs by the selected fidelity measure. """ + final = self.solver.forward_propagators_jnp[-1] + + if self.fidelity_measure == 'entanglement' and self.super_operator: + # raise NotImplementedError + infid = 1 - _entanglement_fidelity_super_operator_jnp( + self._target_jnp, + final, + jnp.sqrt(final.shape[0]).astype(int), + self.computational_states, + + ) + elif self.fidelity_measure == 'entanglement': + infid = 1 - _entanglement_fidelity_jnp( + self.rot_op_4(time_fact*self.end_time)@self._target_jnp, + final, + self.computational_states, + self.map_to_closest_unitary + ) + else: + raise NotImplementedError('Only the entanglement fidelity is ' + 'implemented in this version.') + return jnp.real(infid) + + + def grad(self, time_fact) -> jnp.ndarray: + """Calculates the derivatives of the selected fidelity measure with + respect to the control amplitudes. """ + if self.fidelity_measure == 'entanglement' and self.super_operator: + raise NotImplementedError + derivative_fid = _deriv_entanglement_fid_sup_op_with_du_jnp( + self._target_jnp, + self.solver.forward_propagators_jnp, + self.solver.frechet_deriv_propagators_jnp, + self.solver.reversed_propagators_jnp, + jnp.sqrt(self.solver.forward_propagators_jnp.shape[1]).astype(int), + self.computational_states, + ) + elif self.fidelity_measure == 'entanglement': + # raise NotImplementedError + derivative_fid = _derivative_entanglement_fidelity_with_du_jnp( + self.rot_op_4(time_fact*self.end_time)@self._target_jnp, + self.solver.forward_propagators_jnp, + self.solver.frechet_deriv_propagators_jnp, + self.solver.reversed_propagators_jnp, + self.computational_states, + self.map_to_closest_unitary + ) + else: + raise NotImplementedError('Only the average and entanglement' + 'fidelity is implemented in this ' + 'version.') + return -1 * jnp.real(derivative_fid) + + def der_time_fact(self,time_fact): + + + if self.fidelity_measure == 'entanglement' and self.super_operator: + raise NotImplementedError + + elif self.fidelity_measure == 'entanglement': + derivative_fid = _derivative_entanglement_fidelity_with_dtf_jnp( + self.rot_op_4(time_fact*self.end_time)@self._target_jnp, + self.end_time*self.rot_op_4_der_t(time_fact*self.end_time)@self._target_jnp, + self.solver.forward_propagators_jnp, + self.computational_states, + self.map_to_closest_unitary + ) + + else: + raise NotImplementedError('Only the average and entanglement' + 'fidelity is implemented in this ' + 'version.') + return -1 * np.real(derivative_fid) + + +@partial(jit,static_argnums=(3,4)) +def _derivative_entanglement_fidelity_with_dtf_jnp( + target: jnp.ndarray, + target_der: jnp.ndarray, + forward_propagators_jnp: jnp.ndarray, + computational_states: Optional[tuple] = None, + map_to_closest_unitary: bool = False +) -> jnp.ndarray: + """ + + """ + target_unitary_dag = jnp.conj(target).T + if computational_states is not None: + trace = jnp.conj( + ((_truncate_to_subspace_jnp(forward_propagators_jnp[-1], + computational_states, + map_to_closest_unitary=map_to_closest_unitary) + @ target_unitary_dag).trace()) + ) + else: + trace = jnp.conj(((forward_propagators_jnp[-1] @ target_unitary_dag).trace())) + # num_ctrls,num_time_steps = propagator_derivatives_jnp.shape[:2] + d = target.shape[0] + + # here we need to take the real part. + if computational_states: + derivative_fidelity = 2/d/d * jnp.real(trace*( + jnp.conj(target_der).T @ _truncate_to_subspace_jnp(forward_propagators_jnp[-1], + computational_states, + map_to_closest_unitary)).trace()) + + else: + derivative_fidelity = 2/d/d * jnp.real(trace*( + jnp.conj(target_der).T @ forward_propagators_jnp[-1]).trace()) + + return derivative_fidelity + + +class OperationInfidelityJAXSpecial2(OperationInfidelityJAX): + """ + """ + def __init__(self, + solver: solver_algorithms.Solver, + target: matrix.OperatorMatrix, + # rot_frame_ang_freq: float, + fidelity_measure: str = 'entanglement', + super_operator_formalism: bool = False, + label: Optional[List[str]] = None, + computational_states: Optional[List[int]] = None, + map_to_closest_unitary: bool = False + ): + + super().__init__(solver=solver, + target=target, + fidelity_measure=fidelity_measure, + super_operator_formalism=super_operator_formalism, + label=label, + computational_states=computational_states, + map_to_closest_unitary=map_to_closest_unitary) + + + # self.end_time = sum(solver.transferred_time)-0.5*solver.transferred_time[-1] + # self.freq = rot_frame_ang_freq + + # def rot_op_4(self,time): + # return jnp.array([[np.exp(-1j*2*self.freq/2*time),0,0,0], + # [0,np.exp(0*self.freq/2*time),0,0], + # [0,0,np.exp(0*self.freq/2*time),0], + # [0,0,0,np.exp(1j*2*self.freq/2*time)]]) + + # def rot_op_4_der_t(self,time): + # return 1j*2*self.freq/2*jnp.array([[-np.exp(-1j*2*self.freq/2*time),0,0,0], + # [0,np.exp(0*self.freq/2*time),0,0], + # [0,0,np.exp(0*self.freq/2*time),0], + # [0,0,0,np.exp(1j*2*self.freq/2*time)]]) + + def costs(self) -> float: + """Calculates the costs by the selected fidelity measure. """ + final = self.solver.forward_propagators_jnp[-1] + + if self.fidelity_measure == 'entanglement' and self.super_operator: + raise NotImplementedError + infid = 1 - _entanglement_fidelity_super_op_jnp_zphase( + self._target_jnp, + final, + jnp.sqrt(final.shape[0]).astype(int), + self.computational_states, + ) + elif self.fidelity_measure == 'entanglement': + infid = 1 - _entanglement_fidelity_jnp_zphase( + self._target_jnp, + final, + self.computational_states, + self.map_to_closest_unitary + ) + else: + raise NotImplementedError('Only the entanglement fidelity is ' + 'implemented in this version.') + return jnp.real(infid) + + + def grad(self) -> jnp.ndarray: + raise NotImplementedError + + +@jit +def _rot_op_p(ph_arr): + return jnp.diagflat(jnp.exp(1j*(ph_arr[0].real*jnp.array([1,1,-1,-1])+ph_arr[1].real*jnp.array([1,-1,1,-1])))) + +@partial(jit,static_argnums=(3,4)) +def _entanglement_infidelity_jnp_zphase_wrapper(ph_arr,target,prop,comp_states,to_closest): + return 1-_entanglement_fidelity_jnp(_rot_op_p(ph_arr)@target,prop,comp_states,to_closest) + +import jax.scipy.optimize as jsco + +@partial(jit,static_argnums=(2,3)) +def _entanglement_fidelity_jnp_zphase(target,prop,comp_states,to_closest): + res = jsco.minimize(_entanglement_infidelity_jnp_zphase_wrapper, + x0=jnp.array([0.,0.],dtype=jnp.float64),args=(target,prop,comp_states,to_closest), + method="BFGS") + return 1-res.fun + +@partial(jit,static_argnums=(2,3)) +def _entanglement_fidelity_jnp_zphase_returnopt(target,prop,comp_states,to_closest): + res = jsco.minimize(_entanglement_infidelity_jnp_zphase_wrapper, + x0=jnp.array([0.,0.],dtype=jnp.float64),args=(target,prop,comp_states,to_closest), + method="BFGS") + return 1-res.fun, res.x + +@partial(jit,static_argnums=(3,4,5)) +def _entanglement_infidelity_super_op_jnp_zphase_wrapper(ph_arr,target,prop,dim_prop,comp_states): + return 1-_entanglement_fidelity_super_operator_jnp(_rot_op_p(ph_arr)@target,prop,dim_prop,comp_states) + +@partial(jit,static_argnums=(2,3,4)) +def _entanglement_fidelity_super_op_jnp_zphase(target,prop,dim_prop,comp_states): + res = jsco.minimize(_entanglement_infidelity_super_op_jnp_zphase_wrapper, + x0=jnp.array([0.,0.],dtype=jnp.float64),args=(target,prop,dim_prop,comp_states), + method="BFGS") + return 1-res.fun + + +class TwoQubitEquivalenceClass(CostFunction): + """ + + """ + def __init__(self, + solver: solver_algorithms.Solver, + local_invariants: np.ndarray, + super_operator_formalism: bool = False, + label: Optional[List[str]] = None, + computational_states: Optional[List[int]] = None, + map_to_closest_unitary: bool = False + ): + if label is None: + label = ['Two Qubit Equivalence Class', ] + + super().__init__(solver=solver, label=label) + self.target_g = local_invariants + self._target_g_jnp = jnp.array(self.target_g) + self._target_g_c_jnp = jnp.array([self.target_g[0]+1j*self.target_g[1],self.target_g[2]]) + if computational_states is None: + self.computational_states = None + else: + self.computational_states = tuple(computational_states) + self.map_to_closest_unitary = map_to_closest_unitary + + # if fidelity_measure == 'entanglement': + # self.fidelity_measure = fidelity_measure + # else: + # raise NotImplementedError('Only the entanglement fidelity is ' + # 'currently supported.') + + self.super_operator = super_operator_formalism + + self._q_mat = 1/2**0.5*jnp.array([[1,0,0,1j], + [0,1j,1,0], + [0,1j,-1,0], + [1,0,0,-1j]]) + self._qq = jnp.conj(self._q_mat)@jnp.conj(self._q_mat).T + + + def costs(self) -> float: + """Calculates the costs by the selected fidelity measure. """ + final = self.solver.forward_propagators_jnp[-1] + + if self.computational_states is not None: + final = _truncate_to_subspace_jnp(final,self.computational_states,self.map_to_closest_unitary) + + m = _calc_m(final,self._q_mat) + g_arr_c = _calc_g_c(m,final) + l_sq_abs = jnp.sum(jnp.abs(g_arr_c-self._target_g_c_jnp)**2)**0.5 + return l_sq_abs + + + def grad(self) -> jnp.ndarray: + """Calculates the derivatives of the selected fidelity measure with + respect to the control amplitudes. """ + + final = self.solver.forward_propagators_jnp[-1] + + rev_prop_rev = self.solver.reversed_propagators_jnp[::-1][1:] + prop_der = self.solver.frechet_deriv_propagators_jnp + fwd_props = self.solver.forward_propagators_jnp[:-1] + + if self.computational_states is not None: + final = _truncate_to_subspace_jnp(final,self.computational_states,self.map_to_closest_unitary) + rpr_pd_fp = _truncate_to_subspace_jnp_dvmap(rev_prop_rev@prop_der@fwd_props,self.computational_states,self.map_to_closest_unitary) + + m = _calc_m(final,self._q_mat) + g_arr_c = _calc_g_c(m,final) + l_sq_abs = jnp.sum(jnp.abs(g_arr_c-self._target_g_c_jnp)**2)**0.5 + + derivative_lsq = _dlsq_du_c(m,self._q_mat,self._qq, + rpr_pd_fp, + final, + self._target_g_c_jnp, + g_arr_c, + l_sq_abs).T + + # should be shape: (num_t, num_ctrl) + return jnp.real(derivative_lsq) + +@jit +def _calc_m(arr,q): + ub = (jnp.conj(q).T)@arr@q + return (ub.T)@ub + +@jit +def _g_to_s_d(g_arr): + z_arr = jnp.roots([1,-g_arr[2],(4*(g_arr[0]**2+g_arr[1]**2)**0.5-1),(g_arr[2]-4*g_arr[0])]) + return jnp.pi-jnp.arccos(z_arr[0])-jnp.arccos(z_arr[2]), g_arr[2]*(g_arr[0]**2+g_arr[1]**2)**0.5-g_arr[0] + +@jit +def _calc_g_c(m,u): + g1 = 1/16 * jnp.trace(m)**2 + g3 = 1/4 * (jnp.trace(m)**2-jnp.trace(m@m)) + return jnp.asarray([g1,g3]) * jnp.linalg.det(jnp.conj(u).T) + +@jit +def _dm_dukj(q,qq,rpr_pd_fp,final): + return q.T@(rpr_pd_fp).T@qq@final@q+\ + q.T@final.T@qq@rpr_pd_fp@q + +@jit +def _ddetU_dukj(U,dUdukj): + return jnp.linalg.det(U)*jnp.trace(jnp.linalg.inv(U)@dUdukj) + +@jit +def _dg12_dukj(m,q,qq,rpr_pd_fp,final): + return 1/16*(2*m.trace()*_dm_dukj(q,qq,rpr_pd_fp,final).trace()*jnp.linalg.det(jnp.conj(final).T) + +m.trace()**2*_ddetU_dukj(jnp.conj(final).T,jnp.conj(rpr_pd_fp).T)) + +@jit +def _dg3_dukj(m,q,qq,rpr_pd_fp,final): + return 0.25*(2*(m.trace()*_dm_dukj(q,qq,rpr_pd_fp,final).trace()- + (m@_dm_dukj(q,qq,rpr_pd_fp,final)).trace())*jnp.linalg.det(jnp.conj(final).T) + +(m.trace()**2-(m@m).trace())*_ddetU_dukj(jnp.conj(final).T,jnp.conj(rpr_pd_fp).T)) + +@jit +def _dlsq_dukj_c(m,q,qq,rpr_pd_fp,final,g0_arr_c,g_arr_c,l_sq_abs): + dg12 = _dg12_dukj(m,q,qq,rpr_pd_fp,final) + dg3 = _dg3_dukj(m,q,qq,rpr_pd_fp,final) + return 1/l_sq_abs*jnp.sum(jnp.real((g_arr_c-g0_arr_c)*jnp.conj(jnp.array([dg12,dg3])))) + + +#(to be used with additional .T for previous shape) +@jit +def _dlsq_du_c(m,q,qq,rpr_pd_fp,final,g0_arr_c,g_arr_c,l_sq_abs): + return vmap(vmap(_dlsq_dukj_c,in_axes=(None,None,None,0,None,None,None,None)), + in_axes=(None,None,None,0,None,None,None,None))( + m,q,qq,rpr_pd_fp,final,g0_arr_c,g_arr_c,l_sq_abs) + +@partial(jit,static_argnums=(1,2)) +def _truncate_to_subspace_jnp_vmap(arr,subspace_indices,map_to_closest_unitary): + return vmap(_truncate_to_subspace_jnp,in_axes=(0,None,None))(arr,subspace_indices,map_to_closest_unitary) + +@partial(jit,static_argnums=(1,2)) +def _truncate_to_subspace_jnp_dvmap(arr,subspace_indices,map_to_closest_unitary): + return vmap(_truncate_to_subspace_jnp_vmap,in_axes=(0,None,None))(arr,subspace_indices,map_to_closest_unitary) + + + +############################################################################### + +class OperationInfidelityJAXzphase1Q(OperationInfidelityJAX): + """ + """ + def __init__(self, + solver: solver_algorithms.Solver, + target: matrix.OperatorMatrix, + # rot_frame_ang_freq: float, + fidelity_measure: str = 'entanglement', + super_operator_formalism: bool = False, + label: Optional[List[str]] = None, + computational_states: Optional[List[int]] = None, + map_to_closest_unitary: bool = False, + basis_change_op = None + ): + + super().__init__(solver=solver, + target=target, + fidelity_measure=fidelity_measure, + super_operator_formalism=super_operator_formalism, + label=label, + computational_states=computational_states, + map_to_closest_unitary=map_to_closest_unitary) + + self.basis_change_op = basis_change_op + + def costs(self) -> float: + """Calculates the costs by the selected fidelity measure. """ + if self.basis_change_op is not None: + final = self.basis_change_op @ self.solver.forward_propagators_jnp[-1] + else: + final = self.solver.forward_propagators_jnp[-1] + + if self.fidelity_measure == 'entanglement' and self.super_operator: + raise NotImplementedError + # infid = 1 - _entanglement_fidelity_super_op_jnp_zphase_1q( + # self._target_jnp, + # final, + # jnp.sqrt(final.shape[0]).astype(int), + # self.computational_states, + # ) + elif self.fidelity_measure == 'entanglement': + infid = 1 - _entanglement_fidelity_jnp_zphase_1q( + self._target_jnp, + final, + self.computational_states, + self.map_to_closest_unitary + ) + else: + raise NotImplementedError('Only the entanglement fidelity is ' + 'implemented in this version.') + return jnp.real(infid) + + + def grad(self) -> jnp.ndarray: + raise NotImplementedError + + +@jit +def _rot_op_p_1q(ph_arr): + return jnp.diagflat(jnp.exp(1j*(ph_arr[0].real*jnp.array([1,-1])))) + +@partial(jit,static_argnums=(3,4)) +def _entanglement_infidelity_jnp_zphase_wrapper_1q(ph_arr,target,prop,comp_states,to_closest): + return 1-_entanglement_fidelity_jnp(_rot_op_p_1q(ph_arr)@target,prop,comp_states,to_closest) + +@partial(jit,static_argnums=(2,3)) +def _entanglement_fidelity_jnp_zphase_1q(target,prop,comp_states,to_closest): + res = jsco.minimize(_entanglement_infidelity_jnp_zphase_wrapper_1q, + x0=jnp.array([0.,],dtype=jnp.float64),args=(target,prop,comp_states,to_closest), + method="BFGS") + return 1-res.fun + + + +class LeakageErrorBaseChangeJAX(CostFunction): + """See docstring of class w/o JAX. Requires solver with JAX""" + + def __init__(self, solver: solver_algorithms.SolverJAX, + computational_states: List[int], + label: Optional[List[str]] = None, + basis_change_op = None + ): + + if not _HAS_JAX: + raise ImportError("JAX not available") + if label is None: + label = ["Leakage Error", ] + super().__init__(solver=solver, label=label) + if computational_states is None: + self.computational_states = None + else: + self.computational_states = tuple(computational_states) + + self.basis_change_op = basis_change_op + + def costs(self): + """See base class. """ + if self.basis_change_op is not None: + final_prop = self.basis_change_op @ self.solver.forward_propagators_jnp[-1] + else: + final_prop = self.solver.forward_propagators_jnp[-1] + + clipped_prop = _truncate_to_subspace_jnp(final_prop, + self.computational_states,map_to_closest_unitary=False) + temp = jnp.conj(clipped_prop).T @ clipped_prop + + # the result should always be positive within numerical accuracy + return max(0, 1 - temp.trace().real / clipped_prop.shape[0]) + + def grad(self): + """See base class. """ + if self.basis_change_op is not None: + final = self.basis_change_op @ self.solver.forward_propagators_jnp[-1] + else: + final = self.solver.forward_propagators_jnp[-1] + + final_leak_dag = _truncate_to_subspace_jnp(jnp.conj(final).T, + self.computational_states,map_to_closest_unitary=False) + d = final_leak_dag.shape[0] + + if self.basis_change_op is not None: + derivative_fidelity = -2./d*jnp.real( + _der_leak_comp_states( + self.basis_change_op @ self.solver.frechet_deriv_propagators_jnp, + self.basis_change_op @ self.solver.reversed_propagators_jnp[::-1][1:], + self.basis_change_op @ self.solver.forward_propagators_jnp[:-1], + self.computational_states, + final_leak_dag).T) + + else: + derivative_fidelity = -2./d*jnp.real( + _der_leak_comp_states( + self.solver.frechet_deriv_propagators_jnp, + self.solver.reversed_propagators_jnp[::-1][1:], + self.solver.forward_propagators_jnp[:-1], + self.computational_states, + final_leak_dag).T) + + return derivative_fidelity diff --git a/qopt/matrix.py b/qopt/matrix.py index c77968d..88bb5b9 100644 --- a/qopt/matrix.py +++ b/qopt/matrix.py @@ -1451,3 +1451,668 @@ def closest_unitary(matrix: OperatorMatrix): left_singular_vec, __, right_singular_vec_h = scipy.linalg.svd( matrix.data) return type(matrix)(left_singular_vec.dot(right_singular_vec_h)) + +############################################################################### + +try: + import jax.numpy as jnp + from jax import jit, vmap + import jax + _HAS_JAX = True +except ImportError: + from unittest import mock + jit = mock.Mock() + jnp = mock.Mock() + vmap = mock.Mock() + jax = mock.Mock() + _HAS_JAX = False + + +class DenseOperatorJAX(OperatorMatrix): + """See docstring of class w/o JAX. Works with jnp arrays""" + + __slots__ = ("data",) + + def __init__( + self, + obj: Union[Qobj, np.ndarray, jnp.ndarray, + sp.csr_matrix, 'DenseOperator']) \ + -> None: + if not _HAS_JAX: + raise ImportError("JAX not available") + super().__init__() + self.data = None + if isinstance(obj,jnp.ndarray): + self.data = obj.astype(jnp.complex128) + elif type(obj) is DenseOperatorJAX: + self.data = obj.data + elif type(obj) is DenseOperator: + self.data = obj.data.astype(jnp.complex128) + elif type(obj) is np.ndarray: + self.data = obj.astype(np.complex128) + elif type(obj) is Qobj: + self.data = jnp.array(obj.data.todense(),dtype=jnp.complex128) + elif type(obj) is sp.csr_matrix: + self.data = obj.toarray() + self.data = jnp.array(self.data,dtype=jnp.complex128) + else: + raise ValueError("Data of this type can not be broadcasted into a " + "dense control matrix. Type: " + str(type(obj))) + + def copy(self): + """See base class. """ + copy_ = DenseOperatorJAX(jnp.array(self.data,copy=True)) + # numpy copies are deep + return copy_ + + def __imul__( + self, + other: Union['DenseOperatorJAX', 'DenseOperator', complex, float, + int, np.generic, jnp.ndarray] + ) -> 'DenseOperatorJAX': + """See base class. """ + + if type(other) == DenseOperatorJAX or type(other) == DenseOperator: + jnp.matmul(self.data, other.data, out=self.data) + elif isinstance(other,jnp.ndarray) or isinstance(other,np.ndarray): + jnp.matmul(self.data, other, out=self.data) + elif type(other) in VALID_SCALARS: + self.data *= other + else: + raise NotImplementedError(str(type(other))) + return self + + def __mul__( + self, + other: Union['DenseOperatorJAX', 'DenseOperator', complex, float, + int, np.generic, jnp.ndarray] + ) -> 'DenseOperatorJAX': + """See base class. """ + + if type(other) in VALID_SCALARS: + out = self.copy() + out *= other + if type(other) == DenseOperatorJAX or type(other) == DenseOperator: + out = DenseOperatorJAX(jnp.matmul(self.data, other.data)) + elif type(other) == np.ndarray: + out = DenseOperatorJAX(jnp.matmul(self.data, jnp.array(other))) + elif isinstance(other,jnp.ndarray): + if other.shape==(): + out = DenseOperatorJAX(self.data*other) + else: + out = DenseOperatorJAX(jnp.matmul(self.data, jnp.array(other))) + else: + raise NotImplementedError(str(type(other))) + return out + + def __rmul__( + self, + other: Union['DenseOperatorJAX', 'DenseOperator', complex, float, + int, np.generic, jnp.ndarray] + ) -> 'DenseOperatorJAX': + """See base class. """ + + if isinstance(other,jnp.ndarray) or isinstance(other,np.ndarray): + out = DenseOperatorJAX(jnp.matmul(other, self.data)) + elif type(other) in VALID_SCALARS: + out = self.copy() + out *= other + else: + raise NotImplementedError(str(type(other))) + return out + + def __iadd__(self, other: 'DenseOperatorJAX') -> 'DenseOperatorJAX': + """See base class. """ + if type(other) is DenseOperatorJAX: + self.data += other.data + elif isinstance(other,jnp.ndarray) or isinstance(other,np.ndarray): + self.data += other + elif type(other) in VALID_SCALARS: + self.data += other + else: + raise NotImplementedError(str(type(other))) + return self + + def __isub__(self, other: 'DenseOperatorJAX') -> 'DenseOperatorJAX': + """See base class. """ + + if type(other) is DenseOperatorJAX: + self.data -= other.data + elif isinstance(other,jnp.ndarray) or isinstance(other,np.ndarray): + self.data -= other + elif type(other) in VALID_SCALARS: + self.data -= other + else: + raise NotImplementedError(str(type(other))) + return self + + def __truediv__(self, other: 'DenseOperatorJAX') -> 'DenseOperatorJAX': + if isinstance(other, (np.ndarray,jnp.ndarray, *VALID_SCALARS)): + return DenseOperatorJAX(self.data / other) + raise NotImplementedError(str(type(other))) + + def __itruediv__(self, other: 'DenseOperatorJAX') -> 'DenseOperatorJAX': + if isinstance(other, (np.ndarray,jnp.ndarray, *VALID_SCALARS)): + self.data /= other + return self + raise NotImplementedError(str(type(other))) + + def __getitem__(self, index: tuple) -> jnp.complex128: + """See base class. """ + return self.data[index] + + def __setitem__(self, key, value) -> None: + """See base class. """ + self.data = self.data.at[key].set(value) + + def __repr__(self): + """Representation as numpy array. """ + return 'DenseOperatorJAX with data: \n' + self.data.__repr__() + + def dag(self, do_copy: bool = True) -> Optional['DenseOperatorJAX']: + """See base class. """ + if do_copy: + cp = self.copy() + #was additional statement with "out" before, not in jnp? + + cp.data = jnp.conj(cp.data).T + return cp + else: + self.data = jnp.conj(self.data).T + return self + + def conj(self, do_copy: bool = True) -> Optional['DenseOperatorJAX']: + """See base class. """ + if do_copy: + copy = self.copy() + copy.data = jnp.conj(copy.data) + return copy + else: + self.data = jnp.conj(self.data) + return self + + def transpose(self, do_copy: bool = True) -> Optional['DenseOperatorJAX']: + """See base class. """ + if do_copy: + out = self.copy() + else: + out = self + out.data = out.data.transpose() + return out + + def flatten(self) -> jnp.ndarray: + """See base class. """ + return self.data.flatten() + + def norm(self, ord: Union[str, None, int] = 'fro') -> jnp.float64: + """ + Calulates the norm of the matrix. + + Uses the implementation of numpy.linalg.norm. + + Parameters + ---------- + ord: string + Defines the norm which is calculated. Defaults to the Frobenius norm + 'fro'. + + Returns + ------- + norm: float + Norm of the Matrix. + + """ + return jnp.linalg.norm(self.data, ord=ord) + + def tr(self) -> complex: + """See base class. """ + return self.data.trace() + + def ptrace(self, + dims: Sequence[int], + remove: Sequence[int], + do_copy: bool = True) -> 'DenseOperatorJAX': + """ + Partial trace of the matrix. + + If the matrix describes a ket, the corresponding density matrix is + calculated and used for the partial trace. + + This implementation closely follows that of QuTip's qobj._ptrace_dense. + Parameters + ---------- + dims : list of int + Dimensions of the subspaces making up the total space on which + the matrix operates. The product of elements in 'dims' must be + equal to the matrix' dimension. + remove : list of int + The selected subspaces as indices over which the partial trace is + formed. The given indices correspond to the ordering of + subspaces specified in the 'dim' argument. + do_copy : bool, optional + If false, the operation is executed inplace. Otherwise returns + a new instance. Defaults to True. + + Returns + ------- + pmat : OperatorMatrix + The partially traced OperatorMatrix. + + Raises + ------ + AssertionError: + If matrix dimension does not match specified dimensions. + + Examples + -------- + ghz_ket = DenseOperator(np.array([[1,0,0,0,0,0,0,1]]).T) / np.sqrt(2) + ghz_rho = ghz_ket * ghz_ket.dag() + ghz_rho.ptrace(dims=[2,2,2], remove=[0,2]) + DenseOperator with data: + array([[0.5+0.j, 0. +0.j], + [0. +0.j, 0.5+0.j]]) + """ + + if self.shape[1] == 1: + mat = (self * self.dag()).data + else: + mat = self.data + if mat.shape[0] != jnp.prod(dims): + raise AssertionError("Specified dimensions do not match " + "matrix dimension.") + n_dim = len(dims) # number of subspaces + dims = jnp.asarray(dims, dtype=int) + + remove = list(jnp.sort(remove)) + # indices of subspace that are kept + keep = list(set(np.arange(n_dim)) - set(remove)) + + dims_rm = (dims[remove]).tolist() + dims_keep = (dims[keep]).tolist() + dims = list(dims) + + # 1. Reshape: Split matrix into subspaces + # 2. Transpose: Change subspace/index ordering such that the subspaces + # over which is traced correspond to the first axes + # 3. Reshape: Merge each, subspaces to be removed (A) and to be kept + # (B), common spaces/axes. + # The trace of the merged spaces (A \otimes B) can then be + # calculated as Tr_A(mat) using np.trace for input with + # more than two axes effectively resulting in + # pmat[j,k] = Sum_i mat[i,i,j,k] for all j,k = 0..prod(dims_keep) + pmat = jnp.trace(mat.reshape(dims + dims) + .transpose(remove + [n_dim + q for q in remove] + + keep + [n_dim + q for q in keep]) + .reshape([jnp.prod(dims_rm), + jnp.prod(dims_rm), + jnp.prod(dims_keep), + jnp.prod(dims_keep)]) + ) + + if do_copy: + return DenseOperatorJAX(pmat) + else: + self.data = pmat + return self + + def kron(self, other: 'DenseOperatorJAX') -> 'DenseOperatorJAX': + """See base class. """ + if type(other) == DenseOperatorJAX: + out = jnp.kron(self.data, other.data) + elif isinstance(other,jnp.ndarray) or isinstance(other,np.ndarray): + out = jnp.kron(self.data, other) + else: + raise ValueError('The kronecker product of dense control matrices' + 'is not defined for: ' + str(type(other))) + return DenseOperatorJAX(out) + + def _exp_diagonalize(self, tau: complex = 1, + is_skew_hermitian: bool = False) -> 'DenseOperatorJAX': + """ Calculates the matrix exponential by spectral decomposition. + + Refactored version of _spectral_decomp. + + Parameters + ---------- + tau : complex + The matrix is multiplied by tau. + + is_skew_hermitian : bool + If True, the matrix is expected to be skew hermitian. + + Returns + ------- + exp: DenseOperator + Dense operator matrix containing the matrix exponential. + + """ + if is_skew_hermitian: + eig_val, eig_vec = jnp.linalg.eigh(-1j * self.data) + eig_val = 1j * eig_val + else: + eig_val, eig_vec = jnp.linalg.eig(self.data) + + # apply the exponential function to the eigenvalues and invert the + # diagonalization transformation + exp = jnp.einsum('ij,j,kj->ik', eig_vec, jnp.exp(tau * eig_val), + eig_vec.conj()) + + return DenseOperatorJAX(exp) + + def _dexp_diagonalization(self, + direction: 'DenseOperatorJAX', tau: complex = 1, + is_skew_hermitian: bool = False, + compute_expm: bool = False): + """ Calculates the matrix exponential by spectral decomposition. + + Refactored version of _spectral_decomp. + + Parameters + ---------- + direction: DenseOperator + Direction in which the frechet derivative is calculated. Must be of + the same shape as self. + + tau : complex + The matrix is multiplied by tau. + + is_skew_hermitian : bool + If True, the matrix is expected to be skew hermitian. + + compute_expm : bool + If True, the matrix exponential is calculated as well. + + Returns + ------- + exp: DenseOperator + The matrix exponential. Only returned if compute_expm is set to + True. + + dexp: DenseOperator + Frechet derivative of the matrix exponential. + + """ + if is_skew_hermitian: + eig_val, eig_vec = jnp.linalg.eigh(-1j * self.data) + eig_val = 1j * eig_val + else: + eig_val, eig_vec = jnp.linalg.eig(self.data) + + eig_vec_dag = eig_vec.conj().T + + eig_val_cols = eig_val * jnp.ones(self.shape) + eig_val_diffs = eig_val_cols - eig_val_cols.T + + # avoid devision by zero + eig_val_diffs += jnp.eye(self.data.shape[0]) + + omega = (jnp.exp(eig_val_diffs * tau) - 1.) / eig_val_diffs + + # override the false diagonal elements. + np.fill_diagonal(omega, tau) + + direction_transformed = eig_vec @ direction.data @ eig_vec_dag + dk_dalpha = direction_transformed * omega + + exp = jnp.einsum('ij,j,jk->ik', eig_vec, jnp.exp(tau * eig_val), + eig_vec_dag) + # einsum might be less accurate than the @ operator + dv_dalpha = eig_vec_dag @ dk_dalpha @ eig_vec + du_dalpha = exp @ dv_dalpha + + if compute_expm: + return exp, du_dalpha + else: + return du_dalpha + + def spectral_decomposition(self, hermitian: bool = False): + """See base class. """ + if hermitian is False: + eig_val, eig_vec = jax.scipy.linalg.eig(self.data) + else: + eig_val, eig_vec = jax.scipy.linalg.eigh(self.data) + + return eig_val, eig_vec + + def exp(self, tau: complex = 1, + method: str = "spectral", + is_skew_hermitian: bool = False) -> 'DenseOperatorJAX': + """ + Matrix exponential. + + Parameters + ---------- + tau: complex + The matrix is multiplied by tau before calculating the exponential. + + method: string + Numerical method used for the calculation of the matrix + exponential. + Currently the following are implemented: + - 'approx', 'Frechet': use the scipy linalg matrix exponential + - 'first_order': First order taylor approximation + - 'second_order': Second order taylor approximation + - 'third_order': Third order taylor approximation + - 'spectral': Use the self implemented spectral decomposition + + is_skew_hermitian: bool + Only important for the method 'spectral'. If set to True then the + matrix is assumed to be skew hermitian in the spectral + decomposition. + + Returns + ------- + prop: DenseOperator + The matrix exponential. + + Raises + ------ + NotImplementedError: + If the method given as parameter is not implemented. + + """ + + if method == "spectral": + prop = self._exp_diagonalize(tau=tau, + is_skew_hermitian=is_skew_hermitian) + + elif method in ["approx", "Frechet"]: + prop = jax.scipy.linalg.expm(self.data * tau) + + elif method == "first_order": + prop = jnp.eye(self.data.shape[0]) + self.data * tau + + elif method == "second_order": + prop = jnp.eye(self.data.shape[0]) + self.data * tau + prop += self.data @ self.data * (tau * tau * 0.5) + + elif method == "third_order": + b = self.data * tau + prop = jnp.eye(self.data.shape[0]) + b + bb = b @ b * 0.5 + prop += bb + prop += bb @ b * 0.3333333333333333333 + else: + raise ValueError("Unknown or not specified method for the " + "calculation of the matrix exponential:" + + str(method)) + return DenseOperatorJAX(prop) + + def prop(self, tau: complex = 1) -> 'DenseOperatorJAX': + """See base class. """ + return DenseOperatorJAX(self.exp(tau)) + + def dexp(self, + direction: 'DenseOperatorJAX', + tau: complex = 1, + compute_expm: bool = False, + method: str = "spectral", + is_skew_hermitian: bool = False, + epsilon: float = 1e-10, + ) \ + -> Union['DenseOperatorJAX', Tuple['DenseOperatorJAX']]: + """ + Frechet derivative of the matrix exponential. + + Parameters + ---------- + direction: DenseOperator + Direction in which the frechet derivative is calculated. Must be of + the same shape as self. + + tau: complex + The matrix is multiplied by tau before calculating the exponential. + + compute_expm: bool + If true, then the matrix exponential is calculated and returned as + well. + + method: string + Numerical method used for the calculation of the matrix + exponential. + Currently the following are implemented: + - 'Frechet': Uses the scipy linalg matrix exponential for + simultaniously calculation of the frechet derivative expm_frechet + - 'approx': Approximates the Derivative by finite differences. + - 'first_order': First order taylor approximation + - 'second_order': Second order taylor approximation + - 'third_order': Third order taylor approximation + - 'spectral': Use the self implemented spectral decomposition + + is_skew_hermitian: bool + Only required, for the method 'spectral'. If set to True, then the + matrix is assumed to be skew hermitian in the spectral + decomposition. + + epsilon: float + Width of the finite difference. Only relevant for the method + 'approx'. + + Returns + ------- + prop: DenseOperator + The matrix exponential. Only returned if compute_expm is True! + prop_grad: DenseOperator + The frechet derivative d exp(Ax + B)/dx at x=0 where A is the + direction and B is the matrix stored in self. + + Raises + ------ + NotImplementedError: + If the method given as parameter is not implemented. + + """ + prop = None + + if type(direction) != DenseOperatorJAX: + direction = DenseOperatorJAX(direction) + + if method == "Frechet": + a = self.data * tau + e = direction.data * tau + if compute_expm: + prop, prop_grad = jax.scipy.linalg.expm_frechet( + a, e, compute_expm=True) + prop_grad = DenseOperatorJAX(prop_grad) + prop = DenseOperatorJAX(prop) + + else: + prop_grad = jax.scipy.linalg.expm_frechet( + a, e, compute_expm=False) + prop_grad = DenseOperatorJAX(prop_grad) + + + elif method == "spectral": + if compute_expm: + prop, prop_grad = self._dexp_diagonalization( + direction=direction, tau=tau, + is_skew_hermitian=is_skew_hermitian, + compute_expm=compute_expm + ) + else: + prop_grad = self._dexp_diagonalization( + direction=direction, tau=tau, + is_skew_hermitian=is_skew_hermitian, + compute_expm=compute_expm + ) + + elif method == "approx": + d_m = (self.data + epsilon * direction.data) * tau + dprop = jax.scipy.linalg.expm(d_m) + prop = self.exp(tau) + prop_grad = (dprop - prop) * (1 / epsilon) + + elif method == "first_order": + if compute_expm: + prop = self.exp(tau) + prop_grad = direction.data * tau + + elif method == "second_order": + if compute_expm: + prop = self.exp(tau) + prop_grad = direction.data * tau + prop_grad += (self.data @ direction.data + + direction.data @ self.data) * (tau * tau * 0.5) + + elif method == "third_order": + if compute_expm: + prop = self.exp(tau) + prop_grad = direction.data * tau + prop_grad += (self.data @ direction.data + + direction.data @ self.data) * tau * tau * 0.5 + prop_grad += ( + self.data @ self.data @ direction.data + + direction.data @ self.data @ self.data + + self.data @ direction.data @ self.data + ) * (tau * tau * tau * 0.16666666666666666) + else: + raise NotImplementedError( + 'The specified method ' + method + "is not implemented!") + if compute_expm: + if type(prop) != DenseOperatorJAX: + prop = DenseOperatorJAX(prop) + if type(prop_grad) != DenseOperatorJAX: + prop_grad = DenseOperatorJAX(prop_grad) + if compute_expm: + return prop, prop_grad + else: + return prop_grad + + def identity_like(self) -> 'DenseOperatorJAX': + """See base class. """ + assert self.shape[0] == self.shape[1] + return DenseOperatorJAX(jnp.eye(self.shape[0], dtype=complex)) + + def truncate_to_subspace( + self, subspace_indices: Optional[Sequence[int]], + map_to_closest_unitary: bool = False + ) -> 'DenseOperatorJAX': + """See base class. """ + if subspace_indices is None: + return self + elif self.shape[0] == self.shape[1]: + # square matrix + out = type(self)( + self.data[jnp.ix_(jnp.array(subspace_indices), + jnp.array(subspace_indices))]) + if map_to_closest_unitary: + out = closest_unitary(out) + elif self.shape[0] == 1: + # bra-vector + out = type(self)(self.data[jnp.ix_(jnp.array([0]), + jnp.array(subspace_indices))]) + if map_to_closest_unitary: + out *= 1 / out.norm('fro') + elif self.shape[0] == 1: + # ket-vector + out = type(self)(self.data[jnp.ix_(jnp.array(subspace_indices), + jnp.array([0]))]) + if map_to_closest_unitary: + out *= 1 / out.norm('fro') + else: + out = type(self)(self.data[jnp.ix_(jnp.array(subspace_indices))]) + + return out + + + diff --git a/qopt/noise.py b/qopt/noise.py index bc81527..cbe77f3 100644 --- a/qopt/noise.py +++ b/qopt/noise.py @@ -77,6 +77,8 @@ from qopt.util import deprecated +import random +from functools import partial def bell_curve_1dim(x: Union[np.ndarray, float], stdx: float) -> Union[np.ndarray, float]: @@ -691,3 +693,370 @@ def plot_periodogram(self, n_average: int, scaling: str = 'density', np.mean(spectral_density_or_spectrum, axis=0)[1:-1] - self.noise_spectral_density(sample_frequencies)[1:-1]) return deviation_norm + + +############################################################################### + +try: + import jax.numpy as jnp + from jax import jit, vmap + import jax + _HAS_JAX = True +except ImportError: + from unittest import mock + jit = mock.Mock() + jnp = mock.Mock() + vmap = mock.Mock() + jax = mock.Mock() + _HAS_JAX = False + + +@jit +def _inverse_cumulative_gaussian_distribution_function_jnp( + z: Union[float, np.array, jnp.ndarray], std: float, mean: float): + """ + Calculates the inverse cumulative function for the gaussian distribution. + + Parameters + ---------- + z: Union[float, np.array, jnp.array] + Function value. + + std: float + Standard deviation of the bell curve. + + mean: float + Mean value of the gaussian distribution. Defaults to 0. + + Returns + ------- + selected_x: list of float + Noise samples. + + """ + return std * jnp.sqrt(2) * jax.scipy.special.erfinv(2 * z - 1) + mean + + +@partial(jit,static_argnums=1) +def _sample_1dim_gaussian_distribution_jnp(std: float, n_samples: int, mean: float = 0)\ + -> jnp.ndarray: + """ + Returns 'n_samples' samples from the one dimensional bell curve. + + The samples are chosen such, that the integral over the bell curve between + two adjacent samples is always the same. The samples reproduce the correct + standard deviation only in the limit n_samples -> inf due to the + discreteness of the approximation. The error is to good approximation + 1/n_samples. + + Parameters + ---------- + std: float + Standard deviation of the bell curve. + + n_samples: int + Number of samples returned. + + mean: float + Mean value of the gaussian distribution. Defaults to 0. + + Returns + ------- + selected_x: numpy array of shape:(n_samples, ) + Noise samples. + + """ + z = jnp.linspace(start=0, stop=1, num=n_samples, endpoint=False) + z += 1 / (2 * n_samples) + # we distribute the total probability of 1 into n_samples equal parts. + # The z-values are in the center of each part. + + x = _inverse_cumulative_gaussian_distribution_function_jnp( + z=jnp.expand_dims(z,0), std=jnp.expand_dims(std,1), mean=mean + ) + # We use the inverse cumulative gaussian distribution to find the values x. + # The integral over the Gaussian distribution between x[i] and x[i+1] + # now always equals 1/n_samples. + return x + + +class NTGQuasiStaticJAX(NoiseTraceGenerator): + """See docstring of class w/o JAX. + + Additional parameter: seed: int, optional: seed for jax.random.PRNGKey + """ + + + def __init__(self, standard_deviation: List[float], + n_samples_per_trace: int, + n_traces: int = 1, + noise_samples: Optional[np.ndarray] = None, + always_redraw_samples: bool = True, + correct_std_for_discrete_sampling: bool = True, + sampling_mode: str = 'uncorrelated_deterministic', + seed: Optional[int] = None): + if not _HAS_JAX: + raise ImportError("JAX not available") + n_noise_operators = len(standard_deviation) + super().__init__(noise_samples=noise_samples, + n_samples_per_trace=n_samples_per_trace, + n_traces=n_traces, + n_noise_operators=n_noise_operators, + always_redraw_samples=always_redraw_samples) + self.standard_deviation = jnp.asarray(standard_deviation) + + self.sampling_mode = sampling_mode + self.seed = seed if seed is not None else random.randint(0,2**32-1) + self.rnd_key_first = jax.random.PRNGKey(self.seed) + self.rnd_key_arr = [self.rnd_key_first] + + if correct_std_for_discrete_sampling: + if self.n_traces == 1: + raise RuntimeWarning('Standard deviation cannot be estimated' + 'for a single trace!') + elif self.sampling_mode == 'uncorrelated_deterministic': + + + n_std_dev = len(self.standard_deviation) + _noise_samples = _sample_1dim_gaussian_distribution_jnp( + self.standard_deviation, self._n_traces) + _noise_samples = jnp.broadcast_to( + jnp.expand_dims(jnp.tile(_noise_samples,n_std_dev)* + jnp.repeat(jnp.eye(n_std_dev),self._n_traces,axis=1),2), + (n_std_dev,self._n_traces*n_std_dev,self.n_samples_per_trace)) + + actual_std = jnp.std(_noise_samples,axis=(1,2)) + if jnp.any(actual_std < 1e-20): + raise RuntimeError('The standard deviation was ' + 'estimated close to 0!') + self.standard_deviation *= \ + self.standard_deviation / actual_std + + @property + def n_traces(self) -> int: + """Number of traces. + + The number of requested traces must be multiplied with the number of + standard deviations because if standard deviation is sampled + separately. + + """ + if self._n_traces: + if self.sampling_mode == 'uncorrelated_deterministic': + return self._n_traces * len(self.standard_deviation) + elif self.sampling_mode == 'monte_carlo': + return self._n_traces + else: + raise ValueError('Unsupported sampling mode!') + else: + return self.noise_samples.shape[1] + + def _sample_noise(self) -> None: + """ + Draws quasi static noise samples from a normal distribution. + + Each noise contribution (corresponding to one noise operator) is + sampled separately. For each standard deviation n_traces traces are + calculated. + + """ + if self.sampling_mode == 'uncorrelated_deterministic': + + n_std_dev = len(self.standard_deviation) + _noise_samples = _sample_1dim_gaussian_distribution_jnp( + self.standard_deviation, self._n_traces) + self._noise_samples = jnp.broadcast_to( + jnp.expand_dims(jnp.tile(_noise_samples,n_std_dev)* + jnp.repeat(jnp.eye(n_std_dev),self._n_traces,axis=1),2), + (n_std_dev,self._n_traces*n_std_dev,self.n_samples_per_trace)) + + elif self.sampling_mode == 'monte_carlo': + + self._noise_samples = jnp.einsum( + 'i,ijk->ijk', + self.standard_deviation, + jax.random.normal( + key=self.rnd_key_arr[-1], + shape=(len(self.standard_deviation),self.n_traces,1)) + ) + self._noise_samples = jnp.repeat( + self._noise_samples, self.n_samples_per_trace, axis=2) + + self.rnd_key_arr.append( + jax.random.split(self.rnd_key_arr[-1],num=2)[1]) + + else: + raise ValueError('Unsupported sampling mode!') + + +def _fast_colored_noise_jnp(spectral_density: Callable, dt: float, + n_samples: int, output_shape: tuple, key, + r_power_of_two=False + ) -> jnp.ndarray: + """See docstring of function without _jnp""" + f_max = 1 / dt + f_nyquist = f_max / 2 + s0 = 1 / f_nyquist + if r_power_of_two: + actual_n_samples = int(2 ** jnp.ceil(jnp.log2(n_samples))) + else: + actual_n_samples = int(n_samples) + + delta_white = jax.random.normal(key,(*output_shape, actual_n_samples)) + delta_white_ft = jnp.fft.rfft(delta_white, axis=-1) + # Only positive frequencies since FFT is real and therefore symmetric + f = jnp.linspace(0, f_nyquist, actual_n_samples // 2 + 1) + f = spectral_density(f[1:]) + f = jnp.pad(f,((1, 0),)) + delta_colored = jnp.fft.irfft(delta_white_ft * jnp.sqrt(f / s0), + n=actual_n_samples, axis=-1) + # the ifft takes r//2 + 1 inputs to generate r outputs + + return delta_colored + + +class NTGColoredNoiseJAX(NoiseTraceGenerator): + """See docstring of class w/o JAX. + + Additional parameter: seed: int, optional: seed for jax.random.PRNGKey + """ + + def __init__(self, + n_samples_per_trace: int, + noise_spectral_density: Callable, + dt: float, + n_traces: int = 1, + n_noise_operators: int = 1, + always_redraw_samples: bool = True, + low_frequency_extension_ratio: int = 1, + seed: Optional[int] = None): + if not _HAS_JAX: + raise ImportError("JAX not available") + super().__init__(n_traces=n_traces, + n_samples_per_trace=n_samples_per_trace, + noise_samples=None, + n_noise_operators=n_noise_operators, + always_redraw_samples=always_redraw_samples) + self.noise_spectral_density = noise_spectral_density + self.dt = dt + if low_frequency_extension_ratio < 1: + raise ValueError("The low frequency extension ratio must be " + "greater or equal to 1.") + self.low_frequency_extension_ratio = low_frequency_extension_ratio + if hasattr(dt, "__len__"): + raise ValueError('dt is supposed to be a scalar value!') + + self.seed = seed if seed is not None else random.randint(0,2**32-1) + self.rnd_key_first = jax.random.PRNGKey(self.seed) + self.rnd_key_arr = [self.rnd_key_first] + + def _sample_noise(self, **kwargs) -> None: + """Samples noise from an arbitrary colored spectrum. """ + if self._n_noise_operators is None: + raise ValueError('Please specify the number of noise operators!') + if self._n_traces is None: + raise ValueError('Please specify the number of noise traces!') + if self._n_samples_per_trace is None: + raise ValueError('Please specify the number of noise samples per' + 'trace!') + + + noise_samples = _fast_colored_noise_jnp( + spectral_density=self.noise_spectral_density, + n_samples= + self.n_samples_per_trace * self.low_frequency_extension_ratio, + output_shape=(self.n_noise_operators, self.n_traces), + r_power_of_two=False, + dt=self.dt, + key=self.rnd_key_arr[-1]) + self._noise_samples = noise_samples[:, :, :self.n_samples_per_trace] + + self.rnd_key_arr.append( + jax.random.split(self.rnd_key_arr[-1],num=2)[1]) + + def plot_periodogram(self, n_average: int, scaling: str = 'density', + log_plot: Optional[str] = None, draw_plot=True): + """Creates noise samples and plots the corresponding periodogram. + + Parameters + ---------- + n_average: int + Number of Periodograms which are averaged. + + scaling: {'density', 'spectrum'}, optional + If 'density' then the power spectral density in units of V**2/Hz is + plotted. + If 'spectral' then the power spectrum in units of V**2 is plotted. + Defaults to 'density'. + + log_plot: {None, 'semilogy', 'semilogx', 'loglog'}, optional + If None, then the plot is not plotted logarithmically. If + 'semilogy' only the y-axis is plotted logarithmically, if + 'semilogx' only the x-axis is plotted logarithmically, if 'loglog' + both axis are plotted logarithmically. Defaults to None. + + draw_plot: bool, optional + If true, then the periodogram is plotted. Defaults to True. + + Returns + ------- + deviation_norm: float + The vector norm of the deviation between the actual power spectral + density and the power spectral densitry found in the periodogram. + + """ + + noise_samples = fast_colored_noise( + spectral_density=self.noise_spectral_density, + n_samples=self.n_samples_per_trace, + output_shape=(n_average,), + r_power_of_two=False, + dt=self.dt + ) + + sample_frequencies, spectral_density_or_spectrum = signal.periodogram( + x=noise_samples, + fs=1 / self.dt, + return_onesided=True, + scaling=scaling, + axis=-1 + ) + + if scaling == 'density': + y_label = 'Power Spectral Density (V**2/Hz)' + elif scaling == 'spectrum': + y_label = 'Power Spectrum (V**2)' + else: + raise ValueError('Unexpected scaling argument.') + + if draw_plot: + plt.figure() + + if log_plot is None: + plot_function = plt.plot + elif log_plot == 'semilogy': + plot_function = plt.semilogy + elif log_plot == 'semilogx': + plot_function = plt.semilogx + elif log_plot == 'loglog': + plot_function = plt.loglog + else: + raise ValueError('Unexpected plotting mode') + + plot_function(sample_frequencies, + np.mean(spectral_density_or_spectrum, axis=0), + label='Periodogram') + plot_function(sample_frequencies, + self.noise_spectral_density(sample_frequencies), + label='Spectral Noise Density') + + plt.ylabel(y_label) + plt.xlabel('Frequency (Hz)') + plt.legend(['Periodogram', 'Spectral Noise Density']) + plt.show() + + deviation_norm = np.linalg.norm( + np.mean(spectral_density_or_spectrum, axis=0)[1:-1] - + self.noise_spectral_density(sample_frequencies)[1:-1]) + return deviation_norm + diff --git a/qopt/optimize.py b/qopt/optimize.py index c191a1a..dbfe309 100644 --- a/qopt/optimize.py +++ b/qopt/optimize.py @@ -125,7 +125,7 @@ class Optimizer(ABC): use_jacobian_function: bool, optional If set to true, then the jacobians are calculated analytically. - Defaults to True. + Defaults to False. store_optimizer: bool, optional If True, then the optimizer stores itself in the result class. @@ -266,6 +266,86 @@ def cost_jacobian_wrapper(self, optimization_parameters): self._n_jac_fkt_eval += 1 return jacobian + def cost_func_wrapper_global(self, optimization_parameters): + """Wraps the cost function given by the simulator class. + + The relevant information for the analysis is saved. + + Parameters + ---------- + optimization_parameters: np.array + Raw optimization parameters in a linear array. + + Returns + ------- + costs: np.array, shape (n_fun) + Cost values. + + """ + if (time.time() - self._opt_start_time) \ + > self.termination_conditions['max_wall_time']: + raise WallTimeExceeded + + costs = self.system_simulator.wrapped_cost_functions_test( + optimization_parameters.reshape(self.pulse_shape[::-1]).T) + + if self.save_intermediary_steps: + self.optim_iter_summary.iter_num += 1 + self.optim_iter_summary.costs.append(costs) + self.optim_iter_summary.parameters.append( + optimization_parameters.reshape(self.pulse_shape[::-1]).T + ) + if np.linalg.norm(costs) < np.linalg.norm(self._min_costs): + self._min_costs = costs + self._min_costs_par = optimization_parameters.reshape( + self.pulse_shape[::-1]).T + + # apply the cost function weights after saving the values. + if self.cost_func_weights is not None: + costs *= self.cost_func_weights + + self._n_cost_fkt_eval += 1 + return costs + + def cost_jacobian_wrapper_global(self, optimization_parameters, scale_ind=[]): + """Wraps the cost Jacobian function given by the simulator class. + + The relevant information for the analysis is saved. + + Parameters + ---------- + optimization_parameters: np.array + Raw optimization parameters in a linear array. + + Returns + ------- + jacobian: np.array, shape (num_func, num_t * num_amp) + Jacobian of the cost functions. + + """ + jacobian = self.system_simulator.wrapped_jac_function_test( + optimization_parameters.reshape(self.pulse_shape[::-1]).T) + + if self.save_intermediary_steps: + self.optim_iter_summary.gradients.append(jacobian) + + jacobian[:,:,scale_ind] = jacobian[:,:,scale_ind]/(1+self._n_jac_fkt_eval) + + # jacobian shape (num_t, num_f, num_ctrl) -> (num_f, num_t * num_ctrl) + jacobian = jacobian.transpose([1, 2, 0]) + jacobian = jacobian.reshape( + (jacobian.shape[0], jacobian.shape[1] * jacobian.shape[2])) + + # apply the cost function weights after saving the values. + if self.cost_func_weights is not None: + jacobian = np.einsum('ab, a -> ab', jacobian, + self.cost_func_weights) + + self._n_jac_fkt_eval += 1 + return jacobian + + ########################### + @abstractmethod def run_optimization(self, initial_control_amplitudes: np.ndarray, verbose) \ @@ -302,7 +382,11 @@ def prepare_optimization(self, self._min_costs_par = None self._n_cost_fkt_eval = 0 self._n_jac_fkt_eval = 0 - self.pulse_shape = initial_optimization_parameters.shape + try: + self.pulse_shape = initial_optimization_parameters.shape + except: + self.pulse_shape = len(initial_optimization_parameters) + if self.save_intermediary_steps: self.optim_iter_summary = \ optimization_data.OptimizationSummary( @@ -459,6 +543,111 @@ def run_optimization(self, initial_control_amplitudes: np.array, return optim_result +class LeastSquaresOptimizerGlobal(Optimizer): + """ + Uses the scipy least squares method for optimization. + + Parameters + ---------- + system_simulator: `Simulator` + The systems simulator. + + termination_cond: dict + Termination conditions. + + save_intermediary_steps: bool, optional + If False, only the simulation result is stored. Defaults to False. + + method: str, optional + The optimization method used. Currently implemented are: + - 'trf': A trust region optimization algorithm. This is the default. + + bounds: array or list of boundaries, optional + The boundary conditions for the pulse optimizations. If none are given + then the pulse is assumed to take any real value. + + """ + + def __init__( + self, + n_time_steps_ctrl: int, + system_simulator: Optional[simulator.Simulator] = None, + termination_cond: Optional[Dict] = None, + save_intermediary_steps: bool = True, + method: str = 'trf', + bounds: Union[np.ndarray, List, None] = None, + use_jacobian_function=True, + cost_func_weights: Optional[Sequence[float]] = None, + store_optimizer: bool = False, + scale_down_grad_ind = []): + super().__init__(system_simulator=system_simulator, + termination_cond=termination_cond, + save_intermediary_steps=save_intermediary_steps, + cost_func_weights=cost_func_weights, + use_jacobian_function=use_jacobian_function, + store_optimizer=store_optimizer) + self.method = method + self.bounds = bounds + self.n_time_steps_ctrl = n_time_steps_ctrl + + self.scale_down_grad_ind = scale_down_grad_ind + + def cost_jacobian_wrapper_test(self,optimization_parameters): + + return super().cost_jacobian_wrapper_global(optimization_parameters,self.scale_down_grad_ind) + + def run_optimization(self, initial_control_amplitudes: np.array, + verbose: int = 0) -> optimization_data.OptimizationResult: + """See base class. + """ + super().prepare_optimization( + initial_optimization_parameters=initial_control_amplitudes) + + if self.use_jacobian_function: + jac = self.cost_jacobian_wrapper_test + else: + jac = '2-point' + + try: + result = scipy.optimize.least_squares( + fun=super().cost_func_wrapper_global, + x0=initial_control_amplitudes.T.flatten(), + jac=jac, + bounds=self.bounds, + method=self.method, + ftol=self.termination_conditions["min_cost_gain"], + xtol=self.termination_conditions["min_amplitude_change"], + gtol=self.termination_conditions["min_gradient_norm"], + max_nfev=self.termination_conditions["max_iterations"], + verbose=verbose, + x_scale="jac" + ) + + if self.system_simulator.stats is not None: + self.system_simulator.stats.end_t_opt = time.time() + + if self.store_optimizer: + storage_opt = self + else: + storage_opt = None + + optim_result = optimization_data.OptimizationResult( + final_cost=result.fun, + indices=self.system_simulator.cost_indices, + final_parameters=[result.x]*self.n_time_steps_ctrl, + final_grad_norm=np.linalg.norm(result.grad), + num_iter=result.nfev, + termination_reason=result.message, + status=result.status, + optimizer=storage_opt, + optim_summary=self.optim_iter_summary, + optimization_stats=self.system_simulator.stats + ) + except WallTimeExceeded: + optim_result = self.write_state_to_result() + + return optim_result + class ScalarMinimizingOptimizer(Optimizer): """ Interfaces to the minimize functions of the optimization package in @@ -879,3 +1068,624 @@ def prepare_optimization(self, super().prepare_optimization( initial_optimization_parameters=initial_optimization_parameters) self.annealer.state = initial_optimization_parameters + + +class SimulatedAnnealingScipy(Optimizer): + """ + This class uses simulated annealing for discrete optimization. + + Parameters + ---------- + temperature: float + Initial temperature for the annealing algorithm. + + step_size: int + Initial stepsize. + + interval: int + Number of optimization iterations before the step size is reduced. + + bounds: array of boundaries, shape: (2, num_t, num_ctrl) + The boundary conditions for the pulse optimizations. bounds[0] should + be the lower bounds, and bounds[1] the upper ones. + + """ + + def __init__( + self, + system_simulator: Optional[simulator.Simulator] = None, + termination_cond: Optional[Dict] = None, + save_intermediary_steps: bool = False, + store_optimizer: bool = False, + temperature: float = 1., + step_size: int = 1, + interval: int = 50, + bounds: Optional[np.ndarray] = None + ): + super().__init__( + system_simulator=system_simulator, + termination_cond=termination_cond, + save_intermediary_steps=save_intermediary_steps, + store_optimizer=store_optimizer + ) + self.temperature = temperature + self.step_size = step_size + self.interval = interval + self.bounds = bounds + + def run_optimization(self, initial_control_amplitudes: np.ndarray, + verbose: bool = False): + """See base class. """ + + super().prepare_optimization( + initial_optimization_parameters=initial_control_amplitudes) + + if self.store_optimizer: + storage_opt = self + else: + storage_opt = None + + try: + result = scipy.optimize.basinhopping( + func=self.cost_func_wrapper, + x0=initial_control_amplitudes.T.flatten(), + niter=self.termination_conditions["max_iterations"], + T=self.temperature, + stepsize=self.step_size, + take_step=self._take_step, + callback=None, + interval=self.interval, + disp=verbose + ) + + if self.system_simulator.stats is not None: + self.system_simulator.stats.end_t_opt = time.time() + + optim_result = optimization_data.OptimizationResult( + final_cost=result.fun, + indices=self.system_simulator.cost_indices, + final_parameters=result.x.reshape(self.pulse_shape[::-1]).T, + num_iter=result.nfev, + termination_reason=result.message, + status=result.status, + optimizer=storage_opt, + optim_summary=self.optim_iter_summary, + optimization_stats=self.system_simulator.stats + ) + + except WallTimeExceeded: + if self.system_simulator.stats is not None: + self.system_simulator.stats.end_t_opt = time.time() + + optim_result = optimization_data.OptimizationResult( + final_cost=self._min_costs, + indices=self.system_simulator.cost_indices, + final_parameters=self._min_costs_par, + num_iter=self._n_cost_fkt_eval, + termination_reason='Maximum Wall Time Exceeded', + status=5, + optimizer=storage_opt, + optim_summary=self.optim_iter_summary, + optimization_stats=self.system_simulator.stats + ) + + return optim_result + + def _take_step(self, current_pulse: np.ndarray) -> np.ndarray: + """ + This function applies a random discrete variation to the pulse. + + Parameters + ---------- + current_pulse: array of int + The pulse before the application of the take step function. + + Returns + ------- + new_pulse: array of int + The pulse initial pulse plus a random variation. + + """ + pulse = current_pulse.reshape(self.pulse_shape[::-1]).T + + if type(self.step_size) != int: + raise ValueError("The step size must be integer! But it is: " + + str(self.step_size)) + + if self.step_size == 0: + raise ValueError("The step size has been set to 0.") + + random_step = np.random.randint( + low=-1 * self.step_size, + high=self.step_size + 1, + size=pulse.shape + ) + + new_pulse = pulse + random_step + + # if a limit is exceeded, set the value to the limit + lower_limit_exceeded = new_pulse < self.bounds[0] + upper_limit_exceeded = new_pulse > self.bounds[1] + + new_pulse[lower_limit_exceeded] = self.bounds[0][lower_limit_exceeded] + new_pulse[upper_limit_exceeded] = self.bounds[1][upper_limit_exceeded] + + return new_pulse.T.flatten() + + +############################################################################### + +try: + import jax.numpy as jnp + from jax import jit, vmap + import jax + _HAS_JAX = True +except ImportError: + from unittest import mock + jit = mock.Mock() + jnp = mock.Mock() + vmap = mock.Mock() + jax = mock.Mock() + _HAS_JAX = False + + +class OptimizerJAX(ABC): + """See docstring of class w/o JAX. Requires simulator with JAX""" + + def __init__( + self, + system_simulator: Optional[simulator.SimulatorJAX] = None, + termination_cond: Optional[Dict] = None, + save_intermediary_steps: bool = True, + cost_func_weights: Optional[Sequence[float]] = None, + use_jacobian_function=True, + store_optimizer: bool = False + ): + if not _HAS_JAX: + raise ImportError("JAX not available") + self.system_simulator = system_simulator + self.use_jacobian_function = use_jacobian_function + self.termination_conditions = default_termination_conditions + if termination_cond is not None: + self.termination_conditions.update(**termination_cond) + + self.optim_iter_summary = None + self.pulse_shape = () + + self._opt_start_time = 0 + self._min_costs = jnp.inf + self._min_costs_par = None + self._n_cost_fkt_eval = 0 + self._n_jac_fkt_eval = 0 + + # flags: + self.save_intermediary_steps = save_intermediary_steps + self.store_optimizer = store_optimizer + + self.cost_func_weights = cost_func_weights + + if self.cost_func_weights is not None: + self.cost_func_weights = jnp.asarray( + self.cost_func_weights).flatten() + if len(self.cost_func_weights) == 0: + self.cost_func_weights = None + elif not len(self.system_simulator.cost_funcs) == len( + self.cost_func_weights): + raise ValueError('A cost function weight must be specified for' + 'each cost function or for none at all.') + + def cost_func_wrapper(self, optimization_parameters): + """Wraps the cost function given by the simulator class. + + The relevant information for the analysis is saved. + + Parameters + ---------- + optimization_parameters: Union[np.array, jnp.ndarray] + Raw optimization parameters in a linear array. + + Returns + ------- + costs: jnp.array, shape (n_fun) + Cost values. + + """ + if (time.time() - self._opt_start_time) \ + > self.termination_conditions['max_wall_time']: + raise WallTimeExceeded + + costs = self.system_simulator.wrapped_cost_functions( + optimization_parameters.reshape(self.pulse_shape[::-1]).T) + + if self.save_intermediary_steps: + self.optim_iter_summary.iter_num += 1 + self.optim_iter_summary.costs.append(costs) + self.optim_iter_summary.parameters.append( + optimization_parameters.reshape(self.pulse_shape[::-1]).T + ) + if jnp.linalg.norm(costs) < jnp.linalg.norm(self._min_costs): + self._min_costs = costs + self._min_costs_par = optimization_parameters.reshape( + self.pulse_shape[::-1]).T + + # apply the cost function weights after saving the values. + if self.cost_func_weights is not None: + costs *= self.cost_func_weights + + self._n_cost_fkt_eval += 1 + return costs + + def cost_jacobian_wrapper(self, optimization_parameters): + """Wraps the cost Jacobian function given by the simulator class. + + The relevant information for the analysis is saved. + + Parameters + ---------- + optimization_parameters: Union[np.array, jnp.ndarray] + Raw optimization parameters in a linear array. + + Returns + ------- + jacobian: jnp.array, shape (num_func, num_t * num_amp) + Jacobian of the cost functions. + + """ + jacobian = self.system_simulator.wrapped_jac_function( + optimization_parameters.reshape(self.pulse_shape[::-1]).T) + + if self.save_intermediary_steps: + self.optim_iter_summary.gradients.append(jacobian) + + # jacobian shape (num_t, num_f, num_ctrl) -> (num_f, num_t * num_ctrl) + jacobian = jacobian.transpose([1, 2, 0]) + jacobian = jacobian.reshape( + (jacobian.shape[0], jacobian.shape[1] * jacobian.shape[2])) + + # apply the cost function weights after saving the values. + if self.cost_func_weights is not None: + jacobian = jnp.einsum('ab, a -> ab', jacobian, + self.cost_func_weights) + + self._n_jac_fkt_eval += 1 + return jacobian + + @abstractmethod + def run_optimization( + self, + initial_control_amplitudes: Union[np.ndarray,jnp.ndarray], + verbose) \ + -> optimization_data.OptimizationResult: + """Runs the optimization of the control amplitudes. + + Parameters + ---------- + initial_control_amplitudes : array + shape (num_t, num_ctrl) + verbose + Verbosity of the run. Depends on which optimizer is used. + + Returns + ------- + optimization_result : `OptimizationResult` + The resulting data of the simulation. + + """ + pass + + def prepare_optimization( + self, + initial_optimization_parameters: Union[np.ndarray,jnp.ndarray]): + """Prepare for the next optimization. + + Parameters + ---------- + initial_optimization_parameters : array + shape (num_t, num_ctrl) + + Data stored in this class might be overwritten. + """ + self._min_costs = jnp.inf + self._min_costs_par = None + self._n_cost_fkt_eval = 0 + self._n_jac_fkt_eval = 0 + self.pulse_shape = initial_optimization_parameters.shape + if self.save_intermediary_steps: + self.optim_iter_summary = \ + optimization_data.OptimizationSummary( + indices=self.system_simulator.cost_indices + ) + self._opt_start_time = time.time() + if self.system_simulator.stats is not None: + # If the system simulator wants to write down statistics, then + # initialise a fresh instance + self.system_simulator.stats = \ + performance_statistics.PerformanceStatistics() + self.system_simulator.stats.start_t_opt = float( + self._opt_start_time) + self.system_simulator.stats.indices = \ + self.system_simulator.cost_indices + + def write_state_to_result(self): + """ Writes the current state into an instance of 'OptimizationResult'. + + Intended for saving progress when terminating the optimization in an + unexpected way. + + Returns + ------- + result: optimization_data.OptimizationResult + The current result of the optimization. + + """ + if self.system_simulator.stats is not None: + self.system_simulator.stats.end_t_opt = time.time() + + if self.use_jacobian_function: + jac_norm = jnp.linalg.norm( + self.cost_jacobian_wrapper(self._min_costs_par)) + else: + jac_norm = 0 + + if self.store_optimizer: + storage_opt = self + else: + storage_opt = None + + optim_result = optimization_data.OptimizationResult( + final_cost=self._min_costs, + indices=self.system_simulator.cost_indices, + final_parameters=self._min_costs_par, + final_grad_norm=jac_norm, + num_iter=self._n_cost_fkt_eval, + termination_reason='Maximum Wall Time Exceeded', + status=5, + optimizer=storage_opt, + optim_summary=self.optim_iter_summary, + optimization_stats=self.system_simulator.stats + ) + return optim_result + + +#only changes are np.array() on jax arrays in the end to be picklable, +#jax.scipy.optimize not usable in qopt workflow (?) +class LeastSquaresOptimizerJAX(OptimizerJAX): + """See docstring of class w/o JAX.""" + + def __init__( + self, + system_simulator: Optional[simulator.SimulatorJAX] = None, + termination_cond: Optional[Dict] = None, + save_intermediary_steps: bool = True, + method: str = 'trf', + bounds: Union[np.ndarray, jnp.array, List, None] = None, + use_jacobian_function=True, + cost_func_weights: Optional[Sequence[float]] = None, + store_optimizer: bool = False, + x_scale = 1.): + super().__init__(system_simulator=system_simulator, + termination_cond=termination_cond, + save_intermediary_steps=save_intermediary_steps, + cost_func_weights=cost_func_weights, + use_jacobian_function=use_jacobian_function, + store_optimizer=store_optimizer) + self.method = method + self.bounds = bounds + self.x_scale = x_scale + + def run_optimization(self, + initial_control_amplitudes: Union[np.array,jnp.array], + verbose: int = 0 + ) -> optimization_data.OptimizationResult: + """See base class. """ + super().prepare_optimization( + initial_optimization_parameters=initial_control_amplitudes) + + if self.use_jacobian_function: + jac = super().cost_jacobian_wrapper + else: + jac = '2-point' + + try: + result = scipy.optimize.least_squares( + fun=super().cost_func_wrapper, + x0=initial_control_amplitudes.T.flatten(), + jac=jac, + bounds=self.bounds, + method=self.method, + ftol=self.termination_conditions["min_cost_gain"], + xtol=self.termination_conditions["min_amplitude_change"], + gtol=self.termination_conditions["min_gradient_norm"], + max_nfev=self.termination_conditions["max_iterations"], + verbose=verbose, + x_scale=self.x_scale + ) + + if self.system_simulator.stats is not None: + self.system_simulator.stats.end_t_opt = time.time() + + if self.store_optimizer: + storage_opt = self + else: + storage_opt = None + + optim_result = optimization_data.OptimizationResult( + final_cost=np.array(result.fun), + indices=self.system_simulator.cost_indices, + final_parameters=np.array(result.x.reshape( + self.pulse_shape[::-1]).T), + final_grad_norm=np.linalg.norm(np.array(result.grad)), + num_iter=result.nfev, + termination_reason=result.message, + status=result.status, + optimizer=storage_opt, + optim_summary=self.optim_iter_summary, + optimization_stats=self.system_simulator.stats + ) + except WallTimeExceeded: + optim_result = self.write_state_to_result() + + return optim_result + + +class ScalarMinimizingOptimizerJAX(OptimizerJAX): + """See docstring of class w/o JAX.""" + + def __init__( + self, + system_simulator: Optional[simulator.SimulatorJAX] = None, + termination_cond: Optional[Dict] = None, + save_intermediary_steps: bool = True, + method: str = 'L-BFGS-B', + bounds: Union[np.ndarray, List, None] = None, + use_jacobian_function=True, + cost_func_weights: Optional[Sequence[float]] = None, + store_optimizer: bool = False, + ): + super().__init__(system_simulator=system_simulator, + termination_cond=termination_cond, + save_intermediary_steps=save_intermediary_steps, + cost_func_weights=cost_func_weights, + use_jacobian_function=use_jacobian_function, + store_optimizer=store_optimizer) + self.method = method + self.bounds = bounds + + + def cost_func_wrapper(self, optimization_parameters): + """ Evalutes the cost function. + + The total cost function is defined as the sum of cost functions. + + """ + costs = super().cost_func_wrapper(optimization_parameters) + scalar_costs = jnp.sum(costs) + #need to convert devicearray to float (?) + return float(scalar_costs) + + def cost_jacobian_wrapper(self, optimization_parameters): + """ The Jacobian reduced to the gradient. + + The gradient is calculated by summation over the Jacobian along the + function axis, because the total cost function is defined as the sum + of cost functions. + + Returns + ------- + gradient: numpy array, shape (num_t * num_amp) + The gradient of the costs in the 2 norm. + + """ + jac = super().cost_jacobian_wrapper(optimization_parameters) + grad = (jnp.sum(jac, axis=0)) + return np.array(grad,copy=True) + + def run_optimization(self, + initial_control_amplitudes: Union[np.array,jnp.array], + verbose: bool = False + ) -> optimization_data.OptimizationResult: + super().prepare_optimization( + initial_optimization_parameters=initial_control_amplitudes) + + if self.use_jacobian_function: + jac = self.cost_jacobian_wrapper + else: + jac = None + + if self.method == 'L-BFGS-B': + try: + result = scipy.optimize.minimize( + fun=self.cost_func_wrapper, + x0=initial_control_amplitudes.T.flatten(), + jac=jac, + bounds=self.bounds, + method=self.method, + options={ + 'ftol': self.termination_conditions["min_cost_gain"], + 'gtol': self.termination_conditions["min_gradient_norm"], + 'maxiter': self.termination_conditions["max_iterations"], + 'disp': verbose + } + ) + + if self.store_optimizer: + storage_opt = self + else: + storage_opt = None + + optim_result = optimization_data.OptimizationResult( + final_cost=np.array(result.fun), + indices=self.system_simulator.cost_indices, + final_parameters=np.array(result.x.reshape( + self.pulse_shape[::-1]).T), + final_grad_norm=np.linalg.norm(np.array(result.jac)), + num_iter=result.nfev, + termination_reason=result.status, + status=result.status, + optimizer=storage_opt, + optim_summary=self.optim_iter_summary, + optimization_stats=self.system_simulator.stats + ) + except WallTimeExceeded: + optim_result = self.write_state_to_result() + + elif self.method == 'Nelder-Mead': + try: + result = scipy.optimize.minimize( + fun=self.cost_func_wrapper, + x0=initial_control_amplitudes.T.flatten(), + bounds=self.bounds, + method=self.method, + options={ + 'maxiter': self.termination_conditions[ + "max_iterations"]}, + ) + + if self.store_optimizer: + storage_opt = self + else: + storage_opt = None + + optim_result = optimization_data.OptimizationResult( + final_cost=np.array(result.fun), + indices=self.system_simulator.cost_indices, + final_parameters=np.array(result.x.reshape( + self.pulse_shape[::-1]).T), + num_iter=result.nfev, + termination_reason=result.message, + status=result.status, + optimizer=storage_opt, + optim_summary=self.optim_iter_summary, + optimization_stats=self.system_simulator.stats + ) + except WallTimeExceeded: + optim_result = self.write_state_to_result() + + else: + try: + result = scipy.optimize.minimize( + fun=self.cost_func_wrapper, + x0=initial_control_amplitudes.T.flatten(), + bounds=self.bounds, + method=self.method + ) + + optim_result = optimization_data.OptimizationResult( + final_cost=np.array(result.fun), + indices=self.system_simulator.cost_indices, + final_parameters=np.array(result.x.reshape( + self.pulse_shape[::-1]).T), + num_iter=result.nfev, + termination_reason=result.message, + status=result.status, + optimizer=self, + optim_summary=self.optim_iter_summary, + optimization_stats=self.system_simulator.stats + ) + except WallTimeExceeded: + optim_result = self.write_state_to_result() + + if self.system_simulator.stats is not None: + self.system_simulator.stats.end_t_opt = time.time() + + return optim_result diff --git a/qopt/plotting.py b/qopt/plotting.py index ff39a5e..2435bdd 100644 --- a/qopt/plotting.py +++ b/qopt/plotting.py @@ -112,7 +112,7 @@ def plot_bloch_vector_evolution( states = [ qt.Qobj((prop * initial_state).data) for prop in forward_propagators ] - a = np.empty((3, len(states))) + a = np.empty((3, len(states)),dtype=complex) # for numerical integrity x, y, z = qt.sigmax(), qt.sigmay(), qt.sigmaz() for i, state in enumerate(states): a[:, i] = [qt.expect(x, state), diff --git a/qopt/simulator.py b/qopt/simulator.py index f248460..1331675 100644 --- a/qopt/simulator.py +++ b/qopt/simulator.py @@ -59,7 +59,6 @@ from qopt.util import needs_refactoring - class Simulator(object): """ The Dynamics class provides the interface for the Optimizer class. @@ -292,6 +291,155 @@ def wrapped_jac_function(self, pulse=None): return total_jac + def wrapped_cost_functions_test(self, pulse=None): + """ + Wraps the cost functions of the fidelity computer. + + This function coordinates the complete simulation including the + application of the transfer function, the execution of the time + slot computer and the evaluation of the actual cost functions. + + Parameters + ---------- + pulse: numpy array optional + If no pulse is specified the cost function is evaluated for the + attribute pulse. + + Returns + ------- + costs: numpy array, shape (n_fun) + Array of costs (i.e. infidelities). + + costs_indices: list of str + Names of the costs. + + """ + if pulse is None: + pulse = self.pulse + + for solver in self.solvers: + solver.set_optimization_parameters(pulse) + + costs = [] + + if self.stats: + self.stats.cost_func_eval_times.append([]) + for i, cost_func in enumerate(self.cost_funcs): + t_start = time.time() + #second argument is frequency [[amp,freq,phase],...,] + if type(cost_func).__name__ == "OperationInfidelity" or type(cost_func).__name__ == "OperationNoiseInfidelity" : + cost = cost_func.costs(pulse[0][1]) + elif type(cost_func).__name__ == "LeakageError" or type(cost_func).__name__ == "LeakageLiouville" or type(cost_func).__name__ == "StateInfidelity2": + cost = cost_func.costs() + else: + raise RuntimeError + t_end = time.time() + self.stats.cost_func_eval_times[-1].append(t_end - t_start) + + # reimplement the block below + costs.append(np.asarray(cost).flatten()) + + """ + I do not understand this block anymore. The cost can be an + array or a scalar, but the scalar can not be reshaped. + if hasattr(cost, "__len__"): + costs.append(cost) + else: + costs.append(cost.reshape(1)) + """ + costs = np.concatenate(costs, axis=0) + else: + for i, cost_func in enumerate(self.cost_funcs): + + if cost_func.__name__ == "OperationInfidelity" or cost_func.__name__ == "OperationNoiseInfidelity" : + cost = cost_func.costs(pulse[0][1]) + elif cost_func.__name__ == "LeakageError" or type(cost_func).__name__ == "LeakageLiouville" or type(cost_func).__name__ == "StateInfidelity2": + cost = cost_func.costs() + else: + raise RuntimeError + + costs.append(np.asarray(cost).flatten()) + """ + if hasattr(cost, "__len__"): + costs.append(cost) + else: + costs.append(cost.reshape(1)) + """ + costs = np.concatenate(costs, axis=0) + + return np.asarray(costs) + + def wrapped_jac_function_test(self, pulse=None): + """ + Wraps the gradient calculation functions of the fidelity computer. + + Parameters + ---------- + pulse: numpy array, optional + shape: (num_t, num_ctrl) If no pulse is specified the cost function + is evaluated for the attribute pulse. + + Returns + ------- + jac: numpy array + Array of gradients of shape (num_t, num_func, num_amp). + """ + + if self.numeric_jacobian: + return self.numeric_gradient(pulse=pulse) + + if pulse is None: + pulse = self.pulse + + for solver in self.solvers: + solver.set_optimization_parameters(pulse) + + jacobians = [] + + record_evaluation_times = bool(self.stats) + + if record_evaluation_times: + self.stats.grad_func_eval_times.append([]) + + for i, cost_func in enumerate(self.cost_funcs): + if record_evaluation_times: + t_start = time.time() + + if type(cost_func).__name__ == "OperationInfidelity" or type(cost_func).__name__ == "OperationNoiseInfidelity" : + jac_u = cost_func.grad(pulse[0][1]) + elif type(cost_func).__name__ == "LeakageError" or type(cost_func).__name__ == "LeakageLiouville" or type(cost_func).__name__ == "StateInfidelity2": + jac_u = cost_func.grad() + else: + raise RuntimeError + + # if the cost function is scalar, an extra dimension is inserted + if len(jac_u.shape) == 2: + jac_u = np.expand_dims(jac_u, axis=1) + + # apply the chain rule to the derivatives + jac_x = cost_func.solver.amplitude_function.derivative_by_chain_rule( + jac_u, cost_func.solver.transfer_function(pulse)) + jac_x_transferred = \ + cost_func.solver.transfer_function.gradient_chain_rule( + jac_x + ) + + if type(cost_func).__name__ == "OperationInfidelity" or type(cost_func).__name__ == "OperationNoiseInfidelity" : + jac_x_transferred[0,0,1] += cost_func.der_freq_test(pulse[0][1])[0,0] + elif type(cost_func).__name__ != "LeakageError" and type(cost_func).__name__ != "LeakageLiouville" and type(cost_func).__name__ != "StateInfidelity2": + raise RuntimeWarning + + jacobians.append(jac_x_transferred) + if record_evaluation_times: + t_end = time.time() + self.stats.grad_func_eval_times[-1].append(t_end - t_start) + + # two dimensional form as required by scipy solvers + total_jac = np.concatenate(jacobians, axis=1) + + return total_jac + + def compare_numeric_to_analytic_gradient( self, pulse: Optional[np.ndarray] = None, delta_eps: float = 1e-8, @@ -369,14 +517,14 @@ def numeric_gradient( central_costs = self.wrapped_cost_functions(pulse=test_pulse) - n_times, n_operators = test_pulse.shape + n_times, n_operators = np.asarray(test_pulse).shape n_cost_funcs = len(central_costs) gradients = np.zeros((n_times, n_cost_funcs, n_operators)) for n_time in range(n_times): for n_operator in range(n_operators): - delta = np.zeros_like(test_pulse, dtype=float) + delta = np.zeros_like(test_pulse) delta[n_time, n_operator] = delta_eps fwd_val = self.wrapped_cost_functions(test_pulse + delta) if symmetric: @@ -388,3 +536,320 @@ def numeric_gradient( (fwd_val - central_costs) / delta_eps return gradients + + +############################################################################### + +try: + import jax.numpy as jnp + _HAS_JAX = True +except ImportError: + from unittest import mock + jnp = mock.Mock() + _HAS_JAX = False + +class SimulatorJAX(Simulator): + """See docstring of class w/o JAX. Requires solver with JAX""" + + def __init__( + self, + solvers: Optional[Sequence[solver_algorithms.SolverJAX]], + cost_funcs: Optional[Sequence[cost_functions.CostFunction]], + optimization_parameters=None, + num_ctrl=None, + times=None, + num_times=None, + record_performance_statistics: bool = True, + numeric_jacobian: bool = False + ): + if not _HAS_JAX: + raise ImportError("JAX not available") + super().__init__(solvers,cost_funcs,optimization_parameters,num_ctrl, + times,num_times,record_performance_statistics, + numeric_jacobian) + + def wrapped_cost_functions(self, pulse=None): + """ + Wraps the cost functions of the fidelity computer. + + This function coordinates the complete simulation including the + application of the transfer function, the execution of the time + slot computer and the evaluation of the actual cost functions. + + Parameters + ---------- + pulse: (j)np array optional + If no pulse is specified the cost function is evaluated for the + attribute pulse. + + Returns + ------- + costs: jnp array, shape (n_fun) + Array of costs (i.e. infidelities). + + costs_indices: list of str + Names of the costs. + + """ + if pulse is None: + pulse = self.pulse + + for solver in self.solvers: + solver.set_optimization_parameters(pulse) + + costs = [] + + if self.stats: + self.stats.cost_func_eval_times.append([]) + for i, cost_func in enumerate(self.cost_funcs): + t_start = time.time() + cost = cost_func.costs() + t_end = time.time() + self.stats.cost_func_eval_times[-1].append(t_end - t_start) + + # reimplement the block below + costs.append(jnp.asarray(cost).flatten()) + + """ + I do not understand this block anymore. The cost can be an + array or a scalar, but the scalar can not be reshaped. + if hasattr(cost, "__len__"): + costs.append(cost) + else: + costs.append(cost.reshape(1)) + """ + costs = jnp.concatenate(costs, axis=0) + else: + for i, cost_func in enumerate(self.cost_funcs): + cost = cost_func.costs() + + costs.append(jnp.asarray(cost).flatten()) + """ + if hasattr(cost, "__len__"): + costs.append(cost) + else: + costs.append(cost.reshape(1)) + """ + costs = jnp.concatenate(costs, axis=0) + + return jnp.asarray(costs) + + def wrapped_jac_function(self, pulse=None): + """ + Wraps the gradient calculation functions of the fidelity computer. + + Parameters + ---------- + pulse: (j)np array, optional + shape: (num_t, num_ctrl) If no pulse is specified the cost function + is evaluated for the attribute pulse. + + Returns + ------- + jac: jnp array + Array of gradients of shape (num_t, num_func, num_amp). + """ + + if self.numeric_jacobian: + return self.numeric_gradient(pulse=pulse) + + if pulse is None: + pulse = self.pulse + + for solver in self.solvers: + solver.set_optimization_parameters(pulse) + + jacobians = [] + + record_evaluation_times = bool(self.stats) + + if record_evaluation_times: + self.stats.grad_func_eval_times.append([]) + + for i, cost_func in enumerate(self.cost_funcs): + if record_evaluation_times: + t_start = time.time() + jac_u = cost_func.grad() + + # if the cost function is scalar, an extra dimension is inserted + if len(jac_u.shape) == 2: + jac_u = jnp.expand_dims(jac_u, axis=1) + + # apply the chain rule to the derivatives + jac_x = cost_func.solver.amplitude_function.derivative_by_chain_rule( + jac_u, cost_func.solver.transfer_function(pulse)) + jac_x_transferred = \ + cost_func.solver.transfer_function.gradient_chain_rule( + jac_x + ) + jacobians.append(jac_x_transferred) + if record_evaluation_times: + t_end = time.time() + self.stats.grad_func_eval_times[-1].append(t_end - t_start) + + # two dimensional form as required by scipy solvers + total_jac = jnp.concatenate(jacobians, axis=1) + + return total_jac + +############################################################################### + +class SimulatorJAXSpecial(SimulatorJAX): + """ + + + """ + + def __init__( + self, + solvers: Optional[Sequence[solver_algorithms.Solver]], + cost_funcs: Optional[Sequence[cost_functions.CostFunction]], + optimization_parameters=None, + num_ctrl=None, + times=None, + num_times=None, + record_performance_statistics: bool = True, + numeric_jacobian: bool = False + ): + super().__init__(solvers,cost_funcs,optimization_parameters,num_ctrl,times,num_times,record_performance_statistics,numeric_jacobian) + + def wrapped_cost_functions(self, pulse=None): + """ + Wraps the cost functions of the fidelity computer. + + This function coordinates the complete simulation including the + application of the transfer function, the execution of the time + slot computer and the evaluation of the actual cost functions. + + Parameters + ---------- + pulse: numpy array optional + If no pulse is specified the cost function is evaluated for the + attribute pulse. + + Returns + ------- + costs: numpy array, shape (n_fun) + Array of costs (i.e. infidelities). + + costs_indices: list of str + Names of the costs. + + """ + if pulse is None: + pulse = self.pulse + + for solver in self.solvers: + solver.set_optimization_parameters(pulse) + + costs = [] + + if self.stats: + self.stats.cost_func_eval_times.append([]) + for i, cost_func in enumerate(self.cost_funcs): + t_start = time.time() + if type(cost_func).__name__ == "OperationInfidelityJAXSpecial" or type(cost_func).__name__ == "OperationNoiseInfidelityJAXSpecial" : + cost = cost_func.costs(pulse[0][-1]) + else: + raise RuntimeError + + t_end = time.time() + self.stats.cost_func_eval_times[-1].append(t_end - t_start) + + # reimplement the block below + costs.append(jnp.asarray(cost).flatten()) + + """ + I do not understand this block anymore. The cost can be an + array or a scalar, but the scalar can not be reshaped. + if hasattr(cost, "__len__"): + costs.append(cost) + else: + costs.append(cost.reshape(1)) + """ + costs = jnp.concatenate(costs, axis=0) + else: + for i, cost_func in enumerate(self.cost_funcs): + if type(cost_func).__name__ == "OperationInfidelityJAXSpecial" or type(cost_func).__name__ == "OperationNoiseInfidelityJAXSpecial" : + cost = cost_func.costs(pulse[0][-1]) + else: + raise RuntimeError + + costs.append(jnp.asarray(cost).flatten()) + """ + if hasattr(cost, "__len__"): + costs.append(cost) + else: + costs.append(cost.reshape(1)) + """ + costs = jnp.concatenate(costs, axis=0) + + return jnp.asarray(costs) + + def wrapped_jac_function(self, pulse=None): + """ + Wraps the gradient calculation functions of the fidelity computer. + + Parameters + ---------- + pulse: numpy array, optional + shape: (num_t, num_ctrl) If no pulse is specified the cost function + is evaluated for the attribute pulse. + + Returns + ------- + jac: numpy array + Array of gradients of shape (num_t, num_func, num_amp). + """ + + if self.numeric_jacobian: + return self.numeric_gradient(pulse=pulse) + + if pulse is None: + pulse = self.pulse + + for solver in self.solvers: + solver.set_optimization_parameters(pulse) + + jacobians = [] + + record_evaluation_times = bool(self.stats) + + if record_evaluation_times: + self.stats.grad_func_eval_times.append([]) + + for i, cost_func in enumerate(self.cost_funcs): + if record_evaluation_times: + t_start = time.time() + + if type(cost_func).__name__ == "OperationInfidelityJAXSpecial": + jac_u = cost_func.grad(pulse[0][-1]) + else: + raise RuntimeError + + + # if the cost function is scalar, an extra dimension is inserted + if len(jac_u.shape) == 2: + jac_u = jnp.expand_dims(jac_u, axis=1) + + # apply the chain rule to the derivatives + jac_x = cost_func.solver.amplitude_function.derivative_by_chain_rule( + jac_u, cost_func.solver.transfer_function(pulse)) + jac_x_transferred = \ + cost_func.solver.transfer_function.gradient_chain_rule( + jac_x + ) + + if type(cost_func).__name__ == "OperationInfidelityJAXSpecial": + ### jac_x_transferred=jac_x_transferred.at[0,0,-1].set(jac_x_transferred.at[0,0,-1] + cost_func.der_time_fact(pulse[0][-1])) + jac_x_transferred[0,0,-1] += cost_func.der_time_fact(pulse[0][-1]) + + jacobians.append(jac_x_transferred) + if record_evaluation_times: + t_end = time.time() + self.stats.grad_func_eval_times[-1].append(t_end - t_start) + + # two dimensional form as required by scipy solvers + total_jac = jnp.concatenate(jacobians, axis=1) + + return total_jac diff --git a/qopt/solver_algorithms.py b/qopt/solver_algorithms.py index 5ce5f68..7f3eaf0 100644 --- a/qopt/solver_algorithms.py +++ b/qopt/solver_algorithms.py @@ -78,6 +78,7 @@ from qopt.amplitude_functions import AmplitudeFunction, IdentityAmpFunc from qopt.util import needs_refactoring +from jax import grad, jit class Solver(ABC): r""" @@ -494,7 +495,8 @@ def forward_propagators(self) -> List[q_mat.OperatorMatrix]: """ if self._fwd_prop is None: - self._compute_forward_propagation() + # self._compute_forward_propagation() + jit(self._compute_forward_propagation)() return self._fwd_prop @property @@ -510,7 +512,8 @@ def frechet_deriv_propagators(self) -> List[List[q_mat.OperatorMatrix]]: """ if self._derivative_prop is None: - self._compute_propagation_derivatives() + # self._compute_propagation_derivatives() + jit(self._compute_propagation_derivatives)() return self._derivative_prop @property @@ -870,7 +873,7 @@ def __init__(self, self.frechet_deriv_approx_method = frechet_deriv_approx_method self._dyn_gen = None - + def set_optimization_parameters(self, y: np.array) -> None: """See base class. """ if not np.array_equal(self._opt_pars, y): @@ -1210,7 +1213,8 @@ def forward_propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: """ if self._fwd_prop_noise is None: - self._compute_forward_propagation() + # self._compute_forward_propagation() + jit(self._compute_forward_propagation)() return self._fwd_prop_noise @property @@ -1228,7 +1232,8 @@ def frechet_deriv_propagators_noise(self) \ """ if self._derivative_prop_noise is None: - self._compute_propagation_derivatives() + # self._compute_propagation_derivatives() + jit(self._compute_propagation_derivatives)() return self._derivative_prop_noise @property @@ -1525,7 +1530,7 @@ def noise_amplitude_function(noise_samples: np.array, Parameters ---------- - noise_samples: np.array, shape() + noise_samples: np.array Noise samples calculated by the noise trace generator. transferred_parameters: np.array @@ -1535,11 +1540,7 @@ def noise_amplitude_function(noise_samples: np.array, Control amplitudes. """ - # noise_amplitudes = np.zeros_like(noise_samples, dtype=complex) - noise_amplitudes = np.zeros( - (noise_samples.shape[0], noise_samples.shape[1], - control_amplitudes.shape[1]), dtype=complex) - + noise_amplitudes = np.zeros((noise_samples.shape[0],noise_samples.shape[1],control_amplitudes.shape[1]), dtype=complex) # complex values were requested. for trace_num in range(noise_samples.shape[1]): noise_amplitudes[:, trace_num, :] = self.amplitude_function( @@ -1872,6 +1873,7 @@ def reset_cached_propagators(self): self._diss_sup_op = None self._diss_sup_op_deriv = None + def _calc_diss_sup_op(self) -> List[q_mat.OperatorMatrix]: r""" Calculates the dissipative super operator as described in the class @@ -2195,3 +2197,1377 @@ def _compute_propagation(self): self._fwd.append(self._prop[t] * self._fwd[t]) self.prop_calculated = True + + +############################################################################### + +try: + import jax.numpy as jnp + from jax import jit, vmap + import jax + _HAS_JAX = True +except ImportError: + from unittest import mock + jit = mock.Mock() + jnp = mock.Mock() + vmap = mock.Mock() + jax = mock.Mock() + _HAS_JAX = False + +def _compute_propagation_expm_both_loop(transferred_time,dyn_gen, + derivative_directions): + """Internal loop of exponentiation of propagator and derivative""" + return jax.scipy.linalg.expm_frechet( + dyn_gen*transferred_time, + derivative_directions*transferred_time, + compute_expm=True) + +#from profiling with simple optimization example +#(could be different for complex problems): +#here all the runtime is (probably) in, but no faster way seems available +@jit +def _compute_propagation_expm_both(transferred_time,dyn_gen, + derivative_directions): + """Exponentiation of propagator and derivative, n_ctrl&n_timesteps on + first two axes + """ + return vmap(vmap(_compute_propagation_expm_both_loop,in_axes=(0,0,None)), + in_axes=(None,None,0))( + transferred_time,dyn_gen,derivative_directions) + +@jit +def _compute_propagation_expm_both_lind(transferred_time,dyn_gen, + derivative_directions): + """Exponentiation of propagator and derivative in super-operator formalism, + n_ctrl&n_timesteps on first two axes + """ + return vmap(vmap(_compute_propagation_expm_both_loop,in_axes=(0,0,0)), + in_axes=(None,None,1))( + transferred_time,dyn_gen,derivative_directions) + +@jit +def _compute_propagation_expm_both_noise(transferred_time,dyn_gen_noise, + derivative_directions): + """Exponentiation of propagator and derivative for Monte-Carlo, + n_traces on first axis + """ + return vmap(_compute_propagation_expm_both,in_axes=(None,0,None))( + transferred_time,dyn_gen_noise,derivative_directions) + +def _compute_propagation_expm_loop(transferred_time,dyn_gen): + """Internal loop of exponentiation of propagator""" + return jax.scipy.linalg.expm(dyn_gen*transferred_time) + +#if no derivatives runtime also here +@jit +def _compute_propagation_expm(transferred_time,dyn_gen): + """Exponentiation of propagator, n_ctrl&n_timesteps on first two axes""" + return vmap(_compute_propagation_expm_loop,in_axes=(0,0))( + transferred_time,dyn_gen) + +@jit +def _compute_propagation_expm_noise(transferred_time,dyn_gen_noise): + """Exponentiation of propagator for Monte-Carlo, n_traces on first axis""" + return vmap(_compute_propagation_expm,in_axes=(None,0))( + transferred_time,dyn_gen_noise) + +def _cumprod_loop(res,el): + """Internal loop of cumulative product of propagators""" + res = jnp.dot(el,res) + return res,res + +@jit +def _cumprod(init,prop): + """Cumulative product of propagators of single timesteps""" + _, cum_prod = jax.lax.scan(_cumprod_loop,init,prop) + return cum_prod + +@jit +def _cumprod_noise(init,prop_noise): + """Cumulative product of propagators of single timesteps for Monte-Carlo""" + return vmap(_cumprod,in_axes=(None,0))(init,prop_noise) + +def _cumprod_reversed_loop(res,el): + """Internal loop of reversed cumulative product of propagators""" + res = jnp.dot(res,el) + return res,res + +@jit +def _cumprod_reversed(init,prop): + """Reversed cumulative product of propagators of single timesteps""" + _, cum_prod = jax.lax.scan(_cumprod_reversed_loop,init,prop) + return cum_prod + +@jit +def _cumprod_reversed_noise(init,prop_noise): + """Reversed cumulative product of propagators of single timesteps for MC""" + return vmap(_cumprod_reversed,in_axes=(None,0))(init,prop_noise) + + +class SolverJAX(Solver): + """See docstring of class w/o JAX.""" + def __init__( + self, + h_ctrl: List[q_mat.OperatorMatrix], + h_drift: List[q_mat.OperatorMatrix], + tau: np.array, + initial_state: q_mat.OperatorMatrix = None, + opt_pars: Optional[Union[jnp.ndarray,np.ndarray]] = None, + ctrl_amps: Optional[Union[jnp.ndarray,np.ndarray]] = None, + filter_function_h_n: Union[ + Callable, List[List], None] = None, + filter_function_basis: Optional[basis.Basis] = None, + filter_function_n_coeffs_deriv: Optional[ + Callable[[np.ndarray], np.ndarray]] = None, + exponential_method: Optional[str] = None, + is_skew_hermitian: bool = True, + transfer_function: Optional[TransferFunction] = None, + amplitude_function: Optional[AmplitudeFunction] = None, + paranoia_level: int = 2 + ): + if not _HAS_JAX: + raise ImportError("JAX not available") + super().__init__( + h_ctrl, + h_drift, + tau, + initial_state, + opt_pars, + ctrl_amps, + filter_function_h_n, + filter_function_basis, + filter_function_n_coeffs_deriv, + exponential_method, + is_skew_hermitian, + transfer_function, + amplitude_function, + paranoia_level) + + if type(h_drift) in [matrix.DenseOperator, matrix.SparseOperator, + matrix.DenseOperatorJAX]: + self._h_drift_jnp = jnp.expand_dims(h_drift.data,0) + self.h_drift = [h_drift, ] * self.transfer_function.num_x + elif len(h_drift) == 1: + self._h_drift_jnp = jnp.expand_dims(h_drift[0].data,0) + self.h_drift = h_drift * self.transfer_function.num_x + else: + self._h_drift_jnp = jnp.array([h.data for h in h_drift]) + self.h_drift = h_drift + + self._h_ctrl_jnp = jnp.array([h.data for h in h_ctrl]) + self._transferred_time_jnp = jnp.array(self.transferred_time) + if initial_state is None: + dim = self.h_ctrl[0].shape[0] + self.initial_state = matrix.DenseOperatorJAX(jnp.eye(dim)) + else: + self.initial_state = matrix.DenseOperatorJAX(initial_state) + self._initial_state_jnp = self.initial_state.data + + self._prop_jnp = None + self._reversed_prop_jnp = None + self._fwd_prop_jnp = None + self._derivative_prop_jnp = None + + + def set_optimization_parameters(self, y: Union[jnp.ndarray,np.ndarray] + ) -> None: + """ + Set the control amplitudes. + + All computation flags are set to false. + + The new control amplitudes u are calculated: + u: np.array, shape (num_t, num_ctrl) + + Parameters + ---------- + y: Union[jnp.ndarray,np.ndarray], shape (num_x, num_ctrl) + Raw optimization parameters. + + """ + + if jnp.array_equal(self._opt_pars, y): + return + else: + #previously with copy (?) + self._opt_pars = y + + if self.transfer_function is not None: + self.transferred_parameters = self.transfer_function(y) + else: + #previously with copy (?) + self.transferred_parameters = y + + if self.amplitude_function is not None: + u = self.amplitude_function( + self.transferred_parameters) + else: + u = self.transferred_parameters + + if len(u.shape) != 2: + raise ValueError('The new control amplitudes must have two ' + 'dimensions! ' + '(time, control operator)') + + if u.shape[0] != len(self.transferred_time): + raise ValueError('The new control amplitudes do not have the ' + 'correct number of entries on the time axis!'+ + str(u.shape[0])+" "+str(len(self.transferred_time))) + + if u.shape[1] != len(self.h_ctrl): + raise ValueError('The new control amplitudes do not have the ' + 'correnct number of entries on the control axis!') + + self._ctrl_amps = u + self.reset_cached_propagators() + + def reset_cached_propagators(self): + """ Resets all cached propagators. """ + + self._prop = None #perhaps nonexistent? + self._fwd_prop = None + self._derivative_prop = None + self._reversed_prop = None + self.pulse_sequence = None + + self._prop_jnp = None + self._reversed_prop_jnp = None + self._fwd_prop_jnp = None + self._derivative_prop_jnp = None + + @property + def forward_propagators_jnp(self) -> jnp.ndarray: + + if self._fwd_prop_jnp is None: + self._compute_forward_propagation_jnp() + return self._fwd_prop_jnp + + @property + def reversed_propagators_jnp(self) -> jnp.ndarray: + + if self._reversed_prop_jnp is None: + self._compute_reversed_propagation_jnp() + return self._reversed_prop_jnp + + @property + def frechet_deriv_propagators_jnp(self) -> jnp.ndarray: + + if self._derivative_prop_jnp is None: + self._compute_propagation_derivatives_jnp() + return self._derivative_prop_jnp + + @abstractmethod + def _compute_propagation(self) -> None: + if self._prop_jnp is None: + self._compute_propagation_jnp() + + self._prop = [matrix.DenseOperatorJAX(p) for p in self._prop_jnp] + + def _compute_forward_propagation(self) -> None: + """Computes the forward propagators. """ + + + self._fwd_prop = [matrix.DenseOperatorJAX(p) + for p in self.forward_propagators_jnp] + + def _compute_reversed_propagation(self) -> None: + """Compute the reversed propagation. """ + + self._reversed_prop = [matrix.DenseOperatorJAX(p) + for p in self.reversed_propagators_jnp] + + @abstractmethod + def _compute_propagation_derivatives(self) -> None: + + if self._derivative_prop_jnp is None: + self._compute_propagation_derivatives_jnp() + + self._derivative_prop = [[matrix.DenseOperatorJAX(p) for p in der_t] + for der_t in self._derivative_prop_jnp] + + def _compute_forward_propagation_jnp(self) -> None: + + if self._prop_jnp is None: + self._compute_propagation_jnp() + + self._fwd_prop_jnp = jnp.append( + jnp.expand_dims(self._initial_state_jnp.copy(),0), + _cumprod(self._initial_state_jnp.copy(),self._prop_jnp),axis=0) + + def _compute_reversed_propagation_jnp(self) -> None: + + if self._prop_jnp is None: + self._compute_propagation_jnp() + + _initial_state_rev_jnp = jnp.eye(self._prop_jnp[0].shape[0]) * (1+0j) + + self._reversed_prop_jnp = jnp.append( + jnp.expand_dims(_initial_state_rev_jnp.copy(),0), + _cumprod_reversed(_initial_state_rev_jnp.copy(), + self._prop_jnp[::-1]),axis=0) + + @abstractmethod + def _compute_propagation_jnp(self) -> None: + """ + Computes the propagators. Must set self._prop! + + Raises + ------ + ValueError + If the control amplitudes are not set. + + """ + if self._ctrl_amps is None: + raise ValueError("The control amplitudes must be set to calculate " + "the propagation!") + + @abstractmethod + def _compute_propagation_derivatives_jnp(self) -> None: + """Compute the derivatives of the propagators by the control + amplitudes. + """ + pass + + def _calc_error(self): + + if self._dyn_gen is None: + self._dyn_gen = self._compute_dyn_gen() + + return (self._transferred_time_jnp[0])**2/2*jnp.linalg.norm( + self._dyn_gen[1:]@self._dyn_gen[:-1] + -self._dyn_gen[:-1]@self._dyn_gen[1:],axis=(1,2)) + + +class SchroedingerSolverJAX(SolverJAX): + """See docstring of class w/o JAX.""" + + def __init__(self, + h_drift: List[q_mat.OperatorMatrix], + h_ctrl: List[q_mat.OperatorMatrix], + tau: Union[jnp.array,np.array], + initial_state: q_mat.OperatorMatrix = None, + ctrl_amps: Optional[Union[jnp.array,np.array]] = None, + calculate_propagator_derivatives: bool = True, + filter_function_h_n: Union[ + Callable, List[List], None] = None, + filter_function_basis: Optional[basis.Basis] = None, + filter_function_n_coeffs_deriv: Optional[ + Callable[[np.ndarray], Union[jnp.array,np.array]]] = None, + exponential_method: Optional[str] = None, + frechet_deriv_approx_method: Optional[str] = None, + is_skew_hermitian: bool = True, + transfer_function: Optional[TransferFunction] = None, + amplitude_function: Optional[AmplitudeFunction] = None): + super().__init__( + h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, + tau=tau, ctrl_amps=ctrl_amps, + filter_function_h_n=filter_function_h_n, + filter_function_basis=filter_function_basis, + filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, + exponential_method=exponential_method, + is_skew_hermitian=is_skew_hermitian, + transfer_function=transfer_function, + amplitude_function=amplitude_function + ) + + + if self.exponential_method != "Frechet": + print("Other than Frechet ignored") + + self.id_text = 'ALL' + self.cache_text = 'Save' + self.calculate_propagator_derivatives = \ + calculate_propagator_derivatives + self.frechet_deriv_approx_method = frechet_deriv_approx_method + + self._dyn_gen = None + + + def set_optimization_parameters(self, y: Union[np.ndarray,jnp.ndarray]) -> None: + """See base class. """ + if not jnp.array_equal(self._opt_pars, y): + self.reset_cached_propagators() + super().set_optimization_parameters(y) + + def reset_cached_propagators(self): + """See base class. """ + self._dyn_gen = None + super().reset_cached_propagators() + + def _compute_dyn_gen(self) -> jnp.ndarray: + """ + Computes the dynamics generators. + + Returns + ------- + dyn_gen: List[ControlMatrix], len num_t + This is basically the total Hamiltonian. + + """ + + self._dyn_gen = -1j*(self._h_drift_jnp+jnp.einsum("ij,jkl->ikl", + self._ctrl_amps, + self._h_ctrl_jnp)) + #internally now only jax tensors? + return self._dyn_gen + + def _compute_derivative_directions( + self) -> jnp.ndarray: + """ + The directions of the frechet derivatives are the control operators. + + No deep copy is required because the result is not used for in-place + operations. + + """ + # The list is multiplied (copied by reference) because the elements + # will not be manipulated in place. (only as copy) + return -1j*jnp.expand_dims(self._h_ctrl_jnp,0) + + def _compute_propagation(self) -> None: + super()._compute_propagation() + + def _compute_propagation_derivatives(self) -> None: + super()._compute_propagation_derivatives() + + def _compute_propagation_jnp( + self, calculate_propagator_derivatives: Optional[bool] = None) \ + -> None: + """See base class. """ + super()._compute_propagation_jnp() + + if self._dyn_gen is None: + self._dyn_gen = self._compute_dyn_gen() + + if calculate_propagator_derivatives is None: + calculate_propagator_derivatives = \ + self.calculate_propagator_derivatives + + if calculate_propagator_derivatives: + derivative_directions = self._compute_derivative_directions() + + #TODO: behavior is not exactly reproduced as now + # derivative_directions[0] is taken; however only relevant for + #LindbladSolver (in special cases) (?) + self._prop_jnp,self._derivative_prop_jnp = \ + _compute_propagation_expm_both( + self._transferred_time_jnp, + self._dyn_gen,derivative_directions[0]) + self._prop_jnp = self._prop_jnp[0,:,:,:] + + else: + self._prop_jnp = _compute_propagation_expm( + self._transferred_time_jnp,self._dyn_gen) + + + def _compute_propagation_derivatives_jnp(self) -> None: + """ + Computes the frechet derivatives of the propagators. + + The derivatives are not returned but cached. Since the function is only + called when no derivatives are cached, the approximation is + prioritised. + """ + if not self.frechet_deriv_approx_method: + self._compute_propagation_jnp( + calculate_propagator_derivatives=True) + elif self.frechet_deriv_approx_method == 'grape': + if self._prop_jnp is None: + self._compute_propagation_jnp( + calculate_propagator_derivatives=False) + + self._derivative_prop_jnp = jnp.swapaxes( + jnp.expand_dims(self._transferred_time_jnp,(1,2,3))* + self._compute_derivative_directions()@ + jnp.expand_dims(self._prop_jnp,axis=1),0,1) + + else: + raise ValueError('Unknown gradient derivative approximation ' + 'method:' + + str(self.frechet_deriv_approx_method)) + + +class SchroedingerSMonteCarloJAX(SchroedingerSolverJAX): + """See docstring of class w/o JAX.""" + def __init__( + self, h_drift: List[q_mat.OperatorMatrix], + h_ctrl: List[q_mat.OperatorMatrix], + tau: Union[jnp.array,np.array], + h_noise: List[q_mat.OperatorMatrix], + noise_trace_generator: + Optional[noise.NoiseTraceGenerator], + initial_state: q_mat.OperatorMatrix = None, + ctrl_amps: Optional[Union[jnp.array,np.array]] = None, + calculate_propagator_derivatives: bool = False, + processes: Optional[int] = 1, + filter_function_h_n: Union[ + Callable, List[List], None] = None, + filter_function_basis: Optional[basis.Basis] = None, + filter_function_n_coeffs_deriv: Optional[ + Callable[[np.ndarray], np.ndarray]] = None, + exponential_method: Optional[str] = None, + frechet_deriv_approx_method: Optional[str] = None, + is_skew_hermitian: bool = True, + transfer_function: Optional[TransferFunction] = None, + amplitude_function: Optional[AmplitudeFunction] = None, + noise_amplitude_function: Optional[Callable[ + [np.array, np.array, np.array, + np.array], np.array]] = None + ): + + super().__init__( + h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, + tau=tau, ctrl_amps=ctrl_amps, + filter_function_h_n=filter_function_h_n, + filter_function_basis=filter_function_basis, + filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, + exponential_method=exponential_method, + calculate_propagator_derivatives=calculate_propagator_derivatives, + frechet_deriv_approx_method=frechet_deriv_approx_method, + is_skew_hermitian=is_skew_hermitian, + transfer_function=transfer_function, + amplitude_function=amplitude_function) + + self.h_noise = h_noise + self._h_noise_jnp = jnp.array([h.data for h in h_noise]) + self.noise_trace_generator = noise_trace_generator + self.noise_amplitude_function = noise_amplitude_function + self.processes = processes + + self._dyn_gen_noise = None + self._prop_noise = None + self._derivative_prop_noise = None + self._fwd_prop_noise = None + self._reversed_prop_noise = None + + self._prop_noise_jnp = None + self._derivative_prop_noise_jnp = None + self._fwd_prop_noise_jnp = None + self._reversed_prop_noise_jnp = None + + def set_optimization_parameters(self, + y: Union[np.ndarray,jnp.ndarray] + ) -> None: + """See base class. """ + if not jnp.array_equal(self._opt_pars, y): + self.reset_cached_propagators() + super().set_optimization_parameters(y) + + def reset_cached_propagators(self): + """See base class. """ + super().reset_cached_propagators() + self._dyn_gen_noise = None + self._prop_noise = None + self._prop_noise_jnp = None + self._derivative_prop_noise = None + self._derivative_prop_noise_jnp = None + self._fwd_prop_noise = None + self._reversed_prop_noise = None + self._fwd_prop_noise_jnp = None + self._reversed_prop_noise_jnp = None + + + @property + def propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: + """ + Returns the propagators of the system for each noise trace and + calculates them if necessary. + + Returns + ------- + propagators_noise: List[List[ControlMatrix]], + shape [[] * num_t] * num_noise_traces + Propagators of the system for each noise trace. + + """ + if self._prop_noise is None: + self._compute_propagation() + return self._prop_noise + + @property + def propagators_noise_jnp(self) -> jnp.ndarray: + """See docstring of function without _jnp. Now as jnp-array.""" + if self._prop_noise_jnp is None: + self._compute_propagation_jnp() + return self._prop_noise_jnp + + @property + def forward_propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: + """ + Returns the forward propagation of the initial state for every time + slice and every noise trace and calculate it if necessary. If the + initial state is the identity matrix, then the cumulative propagators + are given. The element forward_propagators[k][i] propagates a state by + the first i time steps under the kth noise trace, if the initial state + is the identity matrix. + + Returns + ------- + forward_propagation:List[List[ControlMatrix]], + shape [[] * (num_t + 1)] * num_noise_traces + Propagation of the initial state of the system. fwd[0] gives the + initial state itself. + + """ + if self._fwd_prop_noise is None: + self._compute_forward_propagation() + return self._fwd_prop_noise + + @property + def forward_propagators_noise_jnp(self) -> jnp.ndarray: + """See docstring of function without _jnp. Now as jnp-array.""" + if self._fwd_prop_noise_jnp is None: + self._compute_forward_propagation_jnp() + return self._fwd_prop_noise_jnp + + @property + def frechet_deriv_propagators_noise(self) \ + -> List[List[List[q_mat.OperatorMatrix]]]: + """ + Returns the frechet derivatives of the propagators with respect to the + control amplitudes for each noise trace. + + Returns + ------- + derivative_prop_noise: List[List[List[ControlMatrix]]], + shape [[[] * num_t] * num_ctrl] * num_noise_traces + Frechet derivatives of the propagators by the control amplitudes. + + """ + if self._derivative_prop_noise is None: + self._compute_propagation_derivatives() + return self._derivative_prop_noise + + @property + def frechet_deriv_propagators_noise_jnp(self) -> jnp.ndarray: + """See docstring of function without _jnp. Now as jnp-array.""" + if self._derivative_prop_noise_jnp is None: + self._compute_propagation_derivatives_jnp() + return self._derivative_prop_noise_jnp + + @property + def reversed_propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: + """ + Returns the reversed propagation of the initial state for every noise + trace and calculate it if necessary. If the initial state is the + identity matrix, then the reversed cumulative propagators are given. + The element forward_propagators[k][i] propagates a state by the first i + time steps under the kth noise trace, if the initial state is the + identity matrix. + + Returns + ------- + reversed_propagation_noise: List[List[ControlMatrix]], + shape [[] * (num_t + 1)] * num_noise_traces + Propagation of the initial state of the system. reversed[k][0] + gives the initial state itself. + + """ + if self._reversed_prop_noise is None: + self._compute_reversed_propagation() + return self._reversed_prop_noise + + @property + def reversed_propagators_noise_jnp(self) -> jnp.ndarray: + """See docstring of function without _jnp. Now as jnp-array.""" + if self._reversed_prop_noise_jnp is None: + self._compute_reversed_propagation_jnp() + return self._reversed_prop_noise_jnp + + def _compute_dyn_gen_noise(self) -> jnp.ndarray: + """ + Computes the dynamics generators for the perturbed and unperturbed + Schroedinger equation. + + Returns + ------- + dyn_gen_noise: List[List[q_mat.ControlMatrix]], + shape [[] * num_t] * num_noise_traces + Dynamics generators for each noise trace. + + """ + # compute the generators of the unperturbed dynamics + self._dyn_gen = super()._compute_dyn_gen() + + # compute the generators for the noise traces. + # n_noise_traces = self.noise_trace_generator.n_traces + + noise_samples = jnp.array(self.noise_trace_generator.noise_samples) + # we transpose, so we iterate over the time last + noise_samples = jnp.transpose(noise_samples, (2, 1, 0)) + + if self.noise_amplitude_function: + noise_samples = self.noise_amplitude_function( + noise_samples=noise_samples, + optimization_parameters=self._opt_pars, + transferred_parameters=self.transferred_parameters, + control_amplitudes=self._ctrl_amps + ) + + # i: n_samples_per_trace, j: n_traces, k: n_noise_ops, + # l: first dim of ham, m: second dim of ham + self._dyn_gen_noise = jnp.expand_dims(self._dyn_gen,axis=0) \ + - 1j*(jnp.einsum("ijk,klm->jilm",noise_samples,self._h_noise_jnp)) + + # -> (n_traces,n_samples_per_trace || t ?,d,d) + return self._dyn_gen_noise + + def _compute_propagation(self) -> None: + """ + Computes the propagators for the perturbed Schroedinger equation and + the derivatives on demand. + + Parameters + ---------- + calculate_propagator_derivatives: bool, optional + Calculate the derivatives of the propagators with respect to the + control amplitudes if true. + + """ + super()._compute_propagation() + self._prop_noise = [[matrix.DenseOperatorJAX(p) for p in trace] + for trace in self.propagators_noise_jnp] + + def _compute_propagation_derivatives(self) -> None: + """Computes propagator derivatives.""" + super()._compute_propagation_derivatives() + if self._derivative_prop_noise_jnp is None: + self._compute_propagation_derivatives_jnp() + + self._derivative_prop_noise_jnp = \ + [[[matrix.DenseOperatorJAX(p) for p in ctrl] for ctrl in der_t] + for der_t in self._derivative_prop_noise_jnp] + + + def _compute_propagation_jnp( + self, calculate_propagator_derivatives: Optional[bool] = None + ) -> None: + """See docstring of function without _jnp. Now as jnp-array.""" + + if self._dyn_gen_noise is None: + self._dyn_gen_noise = self._compute_dyn_gen_noise() + + if calculate_propagator_derivatives is None: + calculate_propagator_derivatives = \ + self.calculate_propagator_derivatives + + # parallelization of following code probably unnecessary + if calculate_propagator_derivatives: + + derivative_directions = self._compute_derivative_directions() + + # call the parent method for the noiseless propagators + super()._compute_propagation_jnp( + calculate_propagator_derivatives=calculate_propagator_derivatives) + + if self.processes == 1: + if calculate_propagator_derivatives: + + self._prop_noise_jnp, self._derivative_prop_noise_jnp = \ + _compute_propagation_expm_both_noise( + self._transferred_time_jnp, + self._dyn_gen_noise, + derivative_directions[0]) + self._prop_noise_jnp = self._prop_noise_jnp[:,0,:,:,:] + else: + + self._prop_noise_jnp = _compute_propagation_expm_noise( + self._transferred_time_jnp,self._dyn_gen_noise) + + elif (type(self.processes) == int and self.processes > 0) \ + or self.processes is None: + + raise NotImplementedError("No pool-multiprocess with jax calc, \ + (TODO) perhaps add with pmap (?)") + + + else: + raise ValueError('Invalid number of processes for parallel ' + 'computation!') + + def _compute_forward_propagation_jnp(self) -> None: + """Computes the forward propagators. """ + super()._compute_forward_propagation_jnp() + if self._prop_noise_jnp is None: + self._compute_propagation_jnp() + + cum_prop_noise = _cumprod_noise(self._initial_state_jnp.copy(), + self._prop_noise_jnp) + sh = cum_prop_noise.shape + + self._fwd_prop_noise_jnp = jnp.append(jnp.broadcast_to( + self._initial_state_jnp.copy(),(sh[0],1,*sh[2:])), + cum_prop_noise,axis=1) + + def _compute_forward_propagation(self) -> None: + """Computes the forward propagators. """ + super()._compute_forward_propagation() + + self._fwd_prop_noise = [[matrix.DenseOperatorJAX(p) for p in trace] + for trace in self.forward_propagators_noise_jnp] + + def _compute_reversed_propagation_jnp(self) -> None: + """Compute the reversed propagation. For the perturbed and unperturbed + Schroedinger equation. """ + super()._compute_reversed_propagation_jnp() + if self._prop_noise_jnp is None: + self._compute_propagation_jnp() + + _initial_state_rev_jnp = jnp.eye(self._prop_jnp[0].shape[0]) * (1+0j) + + cum_prop_reversed_noise = _cumprod_reversed_noise( + _initial_state_rev_jnp,self._prop_noise_jnp[::-1]) + + sh = cum_prop_reversed_noise.shape + + self._reversed_prop_noise_jnp = jnp.append( + jnp.broadcast_to(_initial_state_rev_jnp,(sh[0],1,*sh[2:])), + cum_prop_reversed_noise,axis=1) + + + def _compute_reversed_propagation(self) -> None: + """Compute the reversed propagation. For the perturbed and unperturbed + Schroedinger equation. """ + super()._compute_reversed_propagation() + + self._reversed_prop_noise = \ + [[matrix.DenseOperatorJAX(p) for p in trace] + for trace in self.reversed_propagators_noise_jnp] + + + def _compute_propagation_derivatives_jnp(self) -> None: + """ + Computes the frechet derivatives of the propagators. + + The derivatives are not returned but cached. Since the function is only + called when no derivatives are cached, the approximation is + prioritised. + """ + if not self.frechet_deriv_approx_method: + self._compute_propagation_jnp(calculate_propagator_derivatives=True) + + elif self.frechet_deriv_approx_method == 'grape': + super()._compute_propagation_derivatives_jnp() + + if self._prop_noise_jnp is None: + self._compute_propagation_jnp( + calculate_propagator_derivatives=False) + + derivative_directions = self._compute_derivative_directions() + + #broadcasting explicitly + self._derivative_prop_noise_jnp = \ + jnp.swapaxes( + jnp.expand_dims(self._transferred_time_jnp,(0,2,3,4))* + jnp.expand_dims(derivative_directions,0)@ + jnp.expand_dims(self._prop_noise_jnp,axis=2),1,2) + + else: + raise ValueError('Unknown gradient derivative approximation ' + 'method:' + + str(self.frechet_deriv_approx_method)) + + +class SchroedingerSMCControlNoiseJAX(SchroedingerSMonteCarloJAX): + """See docstring of class w/o JAX.""" + + def __init__( + self, + h_drift: List[q_mat.OperatorMatrix], + h_ctrl: List[q_mat.OperatorMatrix], + tau: Union[jnp.array,np.array], + noise_trace_generator: + Optional[noise.NoiseTraceGenerator], + initial_state: q_mat.OperatorMatrix = None, + ctrl_amps: Optional[np.array] = None, + calculate_propagator_derivatives: bool = False, + processes: Optional[int] = 1, + filter_function_h_n: Union[ + Callable, List[List], None] = None, + filter_function_basis: Optional[basis.Basis] = None, + filter_function_n_coeffs_deriv: Optional[ + Callable[[np.ndarray], np.ndarray]] = None, + exponential_method: Optional[str] = None, + frechet_deriv_approx_method: Optional[str] = None, + is_skew_hermitian: bool = True, + transfer_function: Optional[TransferFunction] = None, + amplitude_function: Optional[AmplitudeFunction] = None): + + def noise_amplitude_function( + noise_samples: Union[np.array,jnp.array], + transferred_parameters: Union[np.array,jnp.array], + control_amplitudes: Union[np.array,jnp.array], + **_): + """Calculates the noise amplitudes. + + Takes into account the actual optimization parameters and random + variations. + + Parameters + ---------- + noise_samples: np.array, shape() + Noise samples calculated by the noise trace generator. + + transferred_parameters: np.array + Transferred optimization parameters. + + control_amplitudes: np.array + Control amplitudes. + + """ + noise_amplitudes = jnp.zeros( + (noise_samples.shape[0], noise_samples.shape[1], + control_amplitudes.shape[1]), dtype=complex) + + + for trace_num in range(noise_samples.shape[1]): + #jnp cannot be updated in place + #->copy every time; inefficient in for loop? + noise_amplitudes = noise_amplitudes.at[:,trace_num,:].set(self.amplitude_function( + transferred_parameters + noise_samples[:, trace_num, :]) \ + - control_amplitudes) + return noise_amplitudes + + super().__init__( + h_drift=h_drift, + h_ctrl=h_ctrl, + initial_state=initial_state, + tau=tau, + h_noise=h_ctrl, + noise_trace_generator=noise_trace_generator, + ctrl_amps=ctrl_amps, + calculate_propagator_derivatives=calculate_propagator_derivatives, + processes=processes, + filter_function_h_n=filter_function_h_n, + filter_function_basis=filter_function_basis, + filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, + exponential_method=exponential_method, + frechet_deriv_approx_method=frechet_deriv_approx_method, + is_skew_hermitian=is_skew_hermitian, + transfer_function=transfer_function, + amplitude_function=amplitude_function,) + + +class LindbladSolverJAX(SchroedingerSolverJAX): + """See docstring of class w/o JAX.""" + + def __init__( + self, + h_drift: List[q_mat.OperatorMatrix], + h_ctrl: List[q_mat.OperatorMatrix], + tau: np.array, + initial_state: q_mat.OperatorMatrix = None, + ctrl_amps: Optional[np.array] = None, + calculate_unitary_derivatives: bool = False, + filter_function_h_n: Union[ + Callable, List[List], None] = None, + filter_function_basis: Optional[basis.Basis] = None, + filter_function_n_coeffs_deriv: Optional[ + Callable[[np.ndarray], np.ndarray]] = None, + exponential_method: Optional[str] = None, + frechet_deriv_approx_method: Optional[str] = None, + initial_diss_super_op: List[q_mat.OperatorMatrix] = None, + lindblad_operators: List[q_mat.OperatorMatrix] = None, + prefactor_function: Callable[[np.array,np.array],np.array] = None, + prefactor_derivative_function: + Callable[[np.array, np.array], np.array] = None, + super_operator_function: + Callable[[np.array, np.array], List[q_mat.OperatorMatrix]] = None, + super_operator_derivative_function: + Callable[[np.array, np.array], + List[List[q_mat.OperatorMatrix]]] = None, + is_skew_hermitian: bool = False, + transfer_function: Optional[TransferFunction] = None, + amplitude_function: Optional[AmplitudeFunction] = None) \ + -> None: + + if initial_state is None: + dim = h_ctrl[0].shape[0] + initial_state = type(h_ctrl[0])(np.eye(dim ** 2)) + + self._diss_sup_op_jnp = None + self._diss_sup_op_deriv_jnp = None + + # we do not throw away any operators or functions, just in case + self._initial_diss_super_op = initial_diss_super_op + self._lindblad_operators = lindblad_operators + + self._prefactor_function = prefactor_function + self._prefactor_deriv_function = prefactor_derivative_function + self._sup_op_func = super_operator_function + self._sup_op_deriv_func = super_operator_derivative_function + self._is_hermitian = is_skew_hermitian + + super().__init__( + h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, + tau=tau, ctrl_amps=ctrl_amps, + calculate_propagator_derivatives=calculate_unitary_derivatives, + filter_function_h_n=filter_function_h_n, + filter_function_basis=filter_function_basis, + filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, + exponential_method=exponential_method, + frechet_deriv_approx_method=frechet_deriv_approx_method, + is_skew_hermitian=is_skew_hermitian, + transfer_function=transfer_function, + amplitude_function=amplitude_function) + + def set_optimization_parameters(self, y: Union[jnp.array,np.array] + ) -> None: + """See base class. """ + if not np.array_equal(self._opt_pars, y): + super().set_optimization_parameters(y) + self.reset_cached_propagators() + + def reset_cached_propagators(self): + """ See base class. """ + super().reset_cached_propagators() + if self._prefactor_function is not None \ + or self._sup_op_func is not None: + self._diss_sup_op_jnp = None + self._diss_sup_op_deriv_jnp = None + + def _calc_diss_sup_op_jnp(self) -> jnp.ndarray: + r""" + Calculates the dissipative super operator as described in the class + doc string. + + Returns + ------- + diss_sup_op: jnp.ndarray, len num_t + Dissipation super operator; Where num_t is the number of timesteps + """ + if self._sup_op_func is None: + # use Lindblad operators + if self._lindblad_operators is None: + # use dissipation_sup_op + const_diss_sup_op = self._initial_diss_super_op + else: + # Calculate the time constant dissipation super operators + # without time dependence + const_diss_sup_op = [] + identity = self._lindblad_operators[0].identity_like() + + for lindblad in self._lindblad_operators: + const_diss_sup_op.append( + (lindblad.conj(do_copy=True)).kron(lindblad)) + const_diss_sup_op[-1] -= .5 * identity.kron( + lindblad.dag(do_copy=True) * lindblad) + const_diss_sup_op[-1] -= .5 * ( + lindblad.transpose(do_copy=True) + * lindblad.conj(do_copy=True)).kron(identity) + + # Add the time dependence + if self._prefactor_function is not None: + self._diss_sup_op = [] + prefactors = self._prefactor_function( + copy.deepcopy(self._ctrl_amps), + copy.deepcopy(self.transferred_parameters)) + for factor_at_time_t in prefactors: + self._diss_sup_op.append( + const_diss_sup_op[0] * factor_at_time_t[0]) + for sup_op, factor \ + in zip(const_diss_sup_op[1:], + factor_at_time_t[1:]): + self._diss_sup_op[-1] += sup_op * factor + else: + self._diss_sup_op = [const_diss_sup_op[0], ] + for sup_op in const_diss_sup_op[1:]: + self._diss_sup_op[0] += sup_op + self._diss_sup_op *= len(self.transferred_time) + else: + self._diss_sup_op = self._sup_op_func( + copy.deepcopy(self._ctrl_amps), + copy.deepcopy(self.transferred_parameters)) + + if isinstance(self._diss_sup_op,jnp.ndarray) or isinstance(self._diss_sup_op,np.ndarray): + self._diss_sup_op_jnp = self._diss_sup_op + else: + self._diss_sup_op_jnp = jnp.array([l.data for l in self._diss_sup_op]) + del self._diss_sup_op + #would be complicated to rewrite as jnp cause many in-place assignments? + #not the most efficient, but ok if not insane amounts of lindblad ops? + return self._diss_sup_op_jnp + + def _calc_diss_sup_op_deriv_jnp(self) \ + -> Optional[jnp.ndarray]: + r""" + Calculates the derivatives of the dissipation super operator with + respect to the control amplitudes. + + If the dissipation super operator is given as constant (1.) or as + lindblad operators (2.) they are assumed not to depend on the control + parameters and only the derivative of the prefactor is to be taken into + account. In order to do so, a function handle containing the + derivatives must be given. This function receives the control + amplitudes as num_t x num_ctrl numpy array and returns the derivatives + as num_t x num_l x num_ctrl array. + + If the dissipation super operator is given as function handle (3.), + then the derivatives must also be given as function handle receiving + the control amplitudes and returning a nested list of super operators + as control matrices. + + If the requested derivative functions are not provided (None), then + the dissipation super operator is considered constant in the control + amplitudes and the function returns None. + + Returns + ------- + diss_sup_op_deriv: jnp.array + The derivatives of the dissipation super operator with respect to + the control variables. + + """ + if self._sup_op_deriv_func is not None: + self._diss_sup_op_deriv = \ + self._sup_op_deriv_func( + copy.deepcopy(self._ctrl_amps), + copy.deepcopy(self.transferred_parameters)) + + if isinstance(self._diss_sup_op_deriv,jnp.ndarray) or isinstance(self._diss_sup_op_deriv,np.ndarray): + self._diss_sup_op_deriv_jnp = self._diss_sup_op_deriv + else: + self._diss_sup_op_deriv_jnp = \ + jnp.array([[l.data for l in lm] + for lm in self._diss_sup_op_deriv]) + del self._diss_sup_op_deriv + return self._diss_sup_op_deriv_jnp + + elif self._prefactor_deriv_function is not None: + if self._lindblad_operators is None: + # use dissipation_sup_op + const_diss_sup_op = self._initial_diss_super_op + else: + # Calculate the time constant dissipation super operators + # without time dependence + const_diss_sup_op = [] + identity = self._lindblad_operators[0].identity_like() + + for lindblad in self._lindblad_operators: + const_diss_sup_op.append( + (lindblad.conj(do_copy=True)).kron(lindblad)) + const_diss_sup_op[-1] -= .5 * identity.kron( + lindblad.dag(do_copy=True) * lindblad) + const_diss_sup_op[-1] -= .5 * ( + lindblad.transpose(do_copy=True) + * lindblad.conj(do_copy=True)).kron(identity) + + prefactor_derivatives = \ + self._prefactor_deriv_function( + copy.deepcopy(self._ctrl_amps), + copy.deepcopy(self.transferred_parameters)) + + # Todo: Assert that the prefactor returns the right dimension + + # prefactor_derivatives: shape (num_t, num_ctrl, num_l) + diss_sup_op_deriv = [] + for factor_per_ctrl_lind in prefactor_derivatives: + # create new sub list for eacht time step + diss_sup_op_deriv.append([]) + for factor_per_lind in factor_per_ctrl_lind: + # add the first term for each control direction + diss_sup_op_deriv[-1].append( + const_diss_sup_op[0] * factor_per_lind[0]) + for diss_sup_op, factor in zip( + const_diss_sup_op[1:], factor_per_lind[1:]): + # add the remaining terms + diss_sup_op_deriv[-1][-1] += diss_sup_op * factor + + if isinstance(diss_sup_op_deriv,jnp.ndarray) or isinstance(diss_sup_op_deriv,np.ndarray): + self._diss_sup_op_deriv_jnp = diss_sup_op_deriv + else: + self._diss_sup_op_deriv_jnp = \ + jnp.array([[l.data for l in lm] for lm in diss_sup_op_deriv]) + + return self._diss_sup_op_deriv_jnp + + else: + return None + + def _compute_derivative_directions( + self) -> jnp.ndarray: + r""" + Computes the derivative directions of the total dynamics generator. + + Returns + ------- + deriv_directions: jnp.array + """ + + identity_times_i = -1j*jnp.identity(self._h_ctrl_jnp[0].shape[0]) + h_ctrl_sup_op_jnp = jnp.kron(identity_times_i,self._h_ctrl_jnp) \ + -jnp.kron(jnp.transpose(self._h_ctrl_jnp,(0,2,1)),identity_times_i) + + # add derivative of the dissipation part + if self._diss_sup_op_deriv_jnp is None: + self._diss_sup_op_deriv_jnp = self._calc_diss_sup_op_deriv_jnp() + if self._diss_sup_op_deriv_jnp is not None: + dh_by_ctrl = self._diss_sup_op_deriv_jnp + h_ctrl_sup_op_jnp + else: + dh_by_ctrl = jnp.broadcast_to(h_ctrl_sup_op_jnp, + self._transferred_time_jnp.shape \ + +h_ctrl_sup_op_jnp.shape) + + return dh_by_ctrl + + def _parse_dissipative_super_operator(self) -> None: + r""" + check the dissipative super operator for dimensional consistency + (maybe even physical properties) + - not implemented yet - + """ + pass + + def _compute_dyn_gen(self) -> jnp.ndarray: + r""" + Computes the dynamics generator for the Lindblad master equation. + + The Hamiltonian is translated into the master equation formalism as + + .. math:: + + \mathcal{H} = I \otimes H - H^\ast \otimes I + + Then the dissipation super operator is added. + + Returns + ------- + dyn_gen: jnp.array, len num_t + Dynamics generators for the master equation. + + Raises + ------ + ValueError: + The computation is only defined for the use of dense control + matrices. + + """ + self._dyn_gen = super()._compute_dyn_gen() + + if self._diss_sup_op_jnp is None: + self._diss_sup_op_jnp = self._calc_diss_sup_op_jnp() + + identity_operator = jnp.identity(self._dyn_gen[0].shape[0]) + sup_op_dyn_gen = [] + + assert(len(self._dyn_gen) == len(self._diss_sup_op_jnp)) + + sup_op_dyn_gen = jnp.kron(identity_operator,self._dyn_gen) \ + + jnp.kron(jnp.conj(self._dyn_gen),identity_operator) \ + + self._diss_sup_op_jnp + + self._dyn_gen = sup_op_dyn_gen + return sup_op_dyn_gen + + def _compute_propagation_jnp( + self, calculate_propagator_derivatives: Optional[bool] = None) \ + -> None: + """See base class. """ + super(SchroedingerSolverJAX,self)._compute_propagation_jnp() + + if self._dyn_gen is None: + self._dyn_gen = self._compute_dyn_gen() + + if calculate_propagator_derivatives is None: + calculate_propagator_derivatives = \ + self.calculate_propagator_derivatives + + if calculate_propagator_derivatives: + derivative_directions = self._compute_derivative_directions() + + #previously with derivative_directions[0] due to being + #time-constant in normal SchroedingerSolver; however in + #LindbladSolver is maybe not(?) + self._prop_jnp, self._derivative_prop_jnp = \ + _compute_propagation_expm_both_lind(self._transferred_time_jnp, + self._dyn_gen, + derivative_directions) + self._prop_jnp = self._prop_jnp[0,:,:,:] + + else: + self._prop_jnp = _compute_propagation_expm( + self._transferred_time_jnp, + self._dyn_gen) + + +class LindbladSControlNoiseJAX(LindbladSolverJAX): + """See docstring of class w/o JAX.""" + + @needs_refactoring + def __init__(self, h_drift, h_ctrl, initial_state, tau, + ctrl_amps, transfer_function=None, + calculate_unitary_derivatives=True, filter_function_h_n=None, + exponential_method=None, lindblad_operators=None, + constant_lindblad_operators=False, noise_psd=1): + super().__init__( + h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, + tau=tau, ctrl_amps=ctrl_amps, + calculate_unitary_derivatives=calculate_unitary_derivatives, + filter_function_h_n=filter_function_h_n, + exponential_method=exponential_method) + + if lindblad_operators is None: + self.lindblad_super_operator = None + else: + d = lindblad_operators[0].shape[0] + self.lindblad_super_operator = np.zeros( + (len(lindblad_operators), d**2, d**2)) + for i, l in enumerate(lindblad_operators): + self.lindblad_super_operator[i, :, :] += np.kron(np.conj(l), l) + self.lindblad_super_operator[i, :, :] += -.5 * np.kron( + np.eye(d), l.T.conj() @ l) + self.lindblad_super_operator[i, :, :] += -.5 * np.kron( + l.T @ l.conj(), np.eye(d)) + + self.transfer_function = transfer_function + # if no transfer function is given it might be consider to be identity + # its not necessarily required + + self.constant_lindblad_operators = constant_lindblad_operators + self.noise_psd = noise_psd + self.incoherent_dyn_gen = None + + def _compute_propagation(self): + """Computes propagators.""" + # Compute and cache all dyn_gen (basically the total hamiltonian) + self._dyn_gen = self._h_drift_jnp + self._dyn_gen += jnp.sum(self._ctrl_amps * self._h_ctrl_jnp, axis=1) + + # super operator calculation + # this is the special case for charge noise on the control parameters + # the required filter function contains + if not self.constant_lindblad_operators or \ + self.incoherent_dyn_gen is None: + transfer_matrix = self.transfer_function.transfer_matrix + self.incoherent_dyn_gen = jnp.einsum('ijk,klm,k->ilm', + transfer_matrix, + self.lindblad_super_operator, + self.noise_psd) + dim = self._dyn_gen[0].shape[0] + identity_operator = jnp.identity(dim) + + self._dyn_gen = -1j*jnp.kron(identity_operator,self._dyn_gen) \ + -jnp.kron(self._dyn_gen,identity_operator) + self._dyn_gen += self.incoherent_dyn_gen + + # calculation of the propagators + # for t in range(len(self.num_t)): + if self.calculate_propagator_derivatives: + derivative_directions = jnp.kron( + identity_operator,self._h_ctrl_jnp) \ + -jnp.kron(self._h_ctrl_jnp,identity_operator) + self._prop_jnp, _derivative_prop_jnp = \ + _compute_propagation_expm_both_lind(self._transferred_time_jnp, + self._dyn_gen, + derivative_directions) + self._prop_jnp = self._prop_jnp[0,:,:,:] + #why this convention now? + self._dU = jnp.swapaxes(_derivative_prop_jnp,0,1) + + else: + self._prop_jnp = _compute_propagation_expm( + self._transferred_time_jnp,self._dyn_gen) + + self.prop_calculated = True + + \ No newline at end of file diff --git a/qopt/solver_algorithms_copy_original.py b/qopt/solver_algorithms_copy_original.py new file mode 100644 index 0000000..f138c77 --- /dev/null +++ b/qopt/solver_algorithms_copy_original.py @@ -0,0 +1,2105 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# qopt +# Copyright (C) 2020 Julian Teske, Forschungszentrum Juelich +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# Contact email: j.teske@fz-juelich.de +# ============================================================================= +""" Implements the algorithms to solve differential equations like +Schroedinger's equation or a master equation. + +The `Solver` class is the central piece of the actual simulation. It calculates +propagators from the differential equations describing the quantum dynamics. +The abstract base class inherits among other things an interface to the +`PulseSequence` class of the filter_functions package. + +The `Solver` classes can have an amplitude and a transfer function as attribute +and automate their use. The Monte Carlo solvers also hold an instance of a +noise trace generator. + +If requested, also derivatives of the propagators by the control amplitudes are +calculated or approximated. + +Classes +------- +:class:`Solver` + Abstract base class of the time slot computers. + +:class:`SchroedingerSolver` + Solver for the the unperturbed Schroedinger equation. + +:class:`SchroedingerSMonteCarlo` + Solver for the Schroedinger equation under the influence of noise. + +:class:`SchroedingerSMCControlNoise` + Solver for the Schroedinger equation under the influence of noise affecting + the control terms. + +:class:`LindbladSolver` + Solves the master equation in Lindblad form. + +Notes +----- +The implementation was inspired by the optimal control package of QuTiP [1]_ +(Quantum Toolbox in Python) + +References +---------- +.. [1] J. R. Johansson, P. D. Nation, and F. Nori: "QuTiP 2: A Python framework + for the dynamics of open quantum systems.", Comp. Phys. Comm. 184, 1234 + (2013) [DOI: 10.1016/j.cpc.2012.11.019]. + +""" + +import numpy as np +import copy +from typing import Optional, List, Callable, Union +from abc import ABC, abstractmethod +from multiprocessing import Pool + +from filter_functions import pulse_sequence, plotting, basis, numeric + +from qopt import noise, matrix, matrix as q_mat +from qopt.transfer_function import TransferFunction, IdentityTF +from qopt.amplitude_functions import AmplitudeFunction, IdentityAmpFunc +from qopt.util import needs_refactoring + + +class Solver(ABC): + r""" + Abstract base class for Solvers. + + Parameters + ---------- + h_ctrl: List[ControlMatrix], len num_ctrl + Control operators in the Hamiltonian as nested list of + shape n_t, num_ctrl. + + h_drift: List[ControlMatrix], len num_t or 1 + Drift operators in the Hamiltonian. You can either give a single element + or one for each transferred time step. + + initial_state : ControlMatrix + Initial state of the system as state vector. Can also be set to the + identity matrix. Then the forward propagation gives the total + propagator of the system. + + tau: array of float, shape (num_t, ) + Durations of the time slices. + + opt_pars: np.array, shape (num_y, num_par), optional + Raw optimization parameters. + + ctrl_amps: np.array, shape (num_t, num_ctrl), optional + The initial control amplitudes. + + filter_function_h_n: List[List[np.array]] or List[List[Qobj]] or callable + Nested list of noise Operators. Used in the filter function + formalism. _filter_function_h_n should look something like this: + + >>> H = [[n_oper1, n_coeff1, n_oper_identifier1], + >>> [n_oper2, n_coeff2, n_oper_identifier2], ...] + + The operators may be given either as NumPy arrays or QuTiP Qobjs + and each coefficient array should have the same number of elements + as *dt*, and should be given in units of :math:`\hbar`. If not every + sublist (read operator) was given a identifier, they are automatically + filled up with 'A_i' where i is the position of the operator. + Alternatively the create_ff_h_n may be a function handle creating + such an object when called with the optimization parameters. + + filter_function_basis: Basis, shape (d**2, d, d), optional + The operator basis in which to calculate. If a Generalized Gell-Mann + basis (see :meth:`~basis.Basis.ggm`) is chosen, some calculations will + be faster for large dimensions due to a simpler basis expansion. + However, when extending the pulse sequence to larger qubit registers, + cached filter functions cannot be retained since the GGM basis does not + factor into tensor products. In this case a Pauli basis is preferable. + + filter_function_n_coeffs_deriv: Callable numpy array to numpy array + This function calculates the derivatives of the noise susceptibility in + the filter function formalism. It receives the optimization parameters + as array of shape (num_opt, num_t) and returns the derivatives as array + of shape (num_noise_op, n_ctrl, num_t). + + exponential_method: string, optional + Method used by the ControlMatrix class for the calculation of the + matrix exponential. The default is 'Frechet'. See also the Docstring of + the file 'qopt.matrix'. + + is_skew_hermitian: bool + Only important for the exponential_method 'spectral'. If set to true, + the dynamical generator is assumed to be skew hermitian during the + spectral decomposition. + + transfer_function: TransferFunction + The transfer function for reshaping the optimization parameters. + + amplitude_function: AmplitudeFunction + The amplitude function connecting the transferred optimization + parameters to the control amplitudes. + + paranoia_level: int + The paranoia_level determines how many checks are conducted. + 0 No tests + 1 Some tests + 2 Exhaustive tests, dimension checks + + Attributes + ---------- + h_ctrl : List[ControlMatrix], len num_ctrl + Control operators in the Hamiltonian as list of length num_ctrl. + + h_drift : List[ControlMatrix], len num_t + Drift operators in the Hamiltonian. + + initial_state : ControlMatrix + Initial state of the system as state vector. Can also be set to the + identity matrix. Then the forward propagation gives the total + propagator of the system. + + transferred_time: List[float] + Durations of the time slices. + + filter_function_h_n: List[List[np.array]] or List[List[Qobj]] + Nested list of noise Operators. Used in the filter function + formalism. + + filter_function_basis: Basis + The filter function pulse sequence will be expressed in this basis. + See documentation of the filter function package. + + exponential_method: string, optional + Method used by the ControlMatrix class for the calculation of the + matrix exponential. The default is 'Frechet'. See also the Docstring of + the file 'qopt.matrix'. + + transfer_function: TransferFunction + The transfer function for reshaping the optimization parameters. + + amplitude_function: AmplitudeFunction + The amplitude function connecting the transferred optimization + parameters to the control amplitudes. + + _prop: List[ControlMatrix], len num_t + Propagators of the system. + + _fwd_prop: List[ControlMatrix], len num_t + 1 + Ordered product of the propagators. They describe the forward + propagation of the systems state. + + _reversed_prop: List[ControlMatrix], len num_t + 1 + Ordered product of propagators in reversed order. + + _derivative_prop: List[List[ControlMatrix]], shape [[] * num_t] * num_ctrl + Frechet derivatives of the propagators by the control amplitudes. + + Methods + ------- + propagators: List[ControlMatrix], len num_t + Returns the propagators of the system. + + forward_propagators: List[ControlMatrix], len num_t + 1 + Returns the forward propagation of the initial state. The element + forward_propagators[i] propagates a state by the first i time steps, if + the initial state is the identity matrix. + + frechet_deriv_propagators: List[List[ControlMatrix]], + shape [[] * num_t] * num_ctrl + Returns the frechet derivatives of the propagators by the control + amplitudes. + + reversed_propagators: List[ControlMatrix], len num_t + 1 + Returns the reversed propagation of the initial state. The element + reversed_propagators[i] propagates a state by the last i time steps, if + the initial state is the identity matrix. + + _compute_propagation: abstract method + Computes the propagators. + + _compute_forward_propagation + Compute the forward propagation of the initial state / system. + + _compute_reversed_propagation + Compute the reversed propagation of the initial state / system. + + _compute_propagation_derivatives: abstract method + Compute the derivatives of the propagators by the control amplitudes. + + create_pulse_sequence(new_amps): PulseSequence + Creates a pulse sequence instance corresponding to the current control + amplitudes. + + `Todo` + * Write parser + * setter for new hamiltonians + * make hamiltonians private + * also for the initial state + * extend constant drift hamiltonian + * Implement the drift operator with an amplitude. Right now, + * the operator is already multiplied with the amplitude, which is + * not coherent with the pulse sequence interface. Alternatively + * amplitude=1? + * transferred_time should be taken from the transfer function + * Use own plotting for the plotting + * Consequent try catches for the computation of the matrix exponential + + """ + + def __init__( + self, + h_ctrl: List[q_mat.OperatorMatrix], + h_drift: List[q_mat.OperatorMatrix], + tau: np.array, + initial_state: q_mat.OperatorMatrix = None, + opt_pars: Optional[np.array] = None, + ctrl_amps: Optional[np.array] = None, + filter_function_h_n: Union[ + Callable, List[List], None] = None, + filter_function_basis: Optional[basis.Basis] = None, + filter_function_n_coeffs_deriv: Optional[ + Callable[[np.ndarray], np.ndarray]] = None, + exponential_method: Optional[str] = None, + is_skew_hermitian: bool = True, + transfer_function: Optional[TransferFunction] = None, + amplitude_function: Optional[AmplitudeFunction] = None, + paranoia_level: int = 2 + ): + + self.h_ctrl = h_ctrl + self._ctrl_amps = ctrl_amps + self._opt_pars = opt_pars + + if initial_state is None: + dim = self.h_ctrl[0].shape[0] + self.initial_state = type(self.h_ctrl[0])(np.eye(dim)) + else: + self.initial_state = initial_state + + if exponential_method is None: + self.exponential_method = 'Frechet' + else: + self.exponential_method = exponential_method + + self._prop = None + self._fwd_prop = None + self._reversed_prop = None + self._derivative_prop = None + + self.pulse_sequence = None + + if filter_function_h_n is None: + self._filter_function_h_n = [] + else: + self._filter_function_h_n = filter_function_h_n + self.filter_function_basis = filter_function_basis + self.filter_function_n_coeffs_deriv = filter_function_n_coeffs_deriv + + self._is_skew_hermitian = is_skew_hermitian + + if transfer_function is None: + self.transfer_function = IdentityTF(num_ctrls=len(h_ctrl)) + else: + self.transfer_function = transfer_function + + self.transferred_time = None + self.set_times(tau=tau) + + if type(h_drift) in [matrix.DenseOperator, matrix.SparseOperator]: + self.h_drift = [h_drift, ] * self.transfer_function.num_x + elif len(h_drift) == 1: + self.h_drift = h_drift * self.transfer_function.num_x + else: + self.h_drift = h_drift + + if amplitude_function is None: + self.amplitude_function = IdentityAmpFunc() + else: + self.amplitude_function = amplitude_function + + self.transferred_parameters = None + + self.consistency_checks(paranoia_level=paranoia_level) + + def set_times(self, tau): + """ Set time values by passing them to the transfer function. + + Parameters + ---------- + tau: array of float, shape (num_t, ) + Durations of the time slices. + + """ + self.transfer_function.set_times(tau) + self.transferred_time = self.transfer_function.x_times + self.reset_cached_propagators() + + def set_optimization_parameters(self, y: np.array) -> None: + """ + Set the control amplitudes. + + All computation flags are set to false. + + The new control amplitudes u are calculated: + u: np.array, shape (num_t, num_ctrl) + + Parameters + ---------- + y: np.array, shape (num_x, num_ctrl) + Raw optimization parameters. + + """ + + if np.array_equal(self._opt_pars, y): + return + else: + self._opt_pars = np.copy(y) + + if self.transfer_function is not None: + self.transferred_parameters = self.transfer_function(y) + else: + self.transferred_parameters = np.copy(y) + + if self.amplitude_function is not None: + u = self.amplitude_function( + self.transferred_parameters) + else: + u = self.transferred_parameters + + if len(u.shape) != 2: + raise ValueError('The new control amplitudes must have two ' + 'dimensions! ' + '(time, control operator)') + + if u.shape[0] != len(self.transferred_time): + raise ValueError('The new control amplitudes do not have the ' + 'correct number of entries on the time axis!') + + if u.shape[1] != len(self.h_ctrl): + raise ValueError('The new control amplitudes do not have the ' + 'correnct number of entries on the control axis!') + + self._ctrl_amps = u + self.reset_cached_propagators() + + def reset_cached_propagators(self): + """ Resets all cached propagators. """ + self._prop = None + self._fwd_prop = None + self._derivative_prop = None + self._reversed_prop = None + self.pulse_sequence = None + + def consistency_checks(self, paranoia_level: int): + """Checks attributes for inner consistency. + + Parameters + ---------- + paranoia_level: int + The paranoia_level determines how many checks are conducted. + 0: No tests + 1: Some tests + 2: Exhaustive tests, dimension checks + + """ + if paranoia_level == 0: + return + + elif paranoia_level >= 1: + # check whether the hamiltonian is correct for the number of time + # steps + if isinstance(self.transferred_time, List): + self.transferred_time = np.asarray(self.transferred_time) + if len(self.transferred_time.shape) > 1: + raise ValueError("Tau must be a one dimensional numpy array or" + "a list.") + n_time_steps = self.transferred_time.shape[0] + + if len(self.h_drift) == 1: + self.h_drift = self.h_drift * n_time_steps + + if not (n_time_steps == len(self.h_drift) + or len(self.h_drift) == 0): + raise ValueError("The drift hamiltonian must have exactly one " + "entry for each transferred time step or no " + "entry at all or a single entry.") + if paranoia_level >= 2: + # check whether the Hamiltonian has the correct dimensions + dim = self.h_ctrl[0].shape[0] + + for ctrl_matrix in self.h_ctrl: + assert(dim == ctrl_matrix.shape[0]) + assert(dim == ctrl_matrix.shape[1]) + + for drift_matrx in self.h_drift: + assert(dim == drift_matrx.shape[0]) + assert(dim == drift_matrx.shape[1]) + + else: + raise ValueError("The paranoia level must be a positive integer.") + + @property + def propagators(self) -> List[q_mat.OperatorMatrix]: + """ + Returns the propagators of the system and calculates them if necessary. + + Returns + ------- + propagators: List[ControlMatrix], len num_t + Propagators of the system. + + """ + if self._prop is None: + self._compute_propagation() + return self._prop + + @property + def forward_propagators(self) -> List[q_mat.OperatorMatrix]: + """ + Returns the forward propagation of the initial state for every time + slice and calculate it if necessary. If the initial state is the + identity matrix, then the cumulative propagators are given. The element + forward_propagators[i] propagates a state by the first i time steps, if + the initial state is the identity matrix. + + Returns + ------- + forward_propagation: List[ControlMatrix], len num_t + 1 + Propagation of the initial state of the system. fwd[0] gives the + initial state itself. + + """ + if self._fwd_prop is None: + self._compute_forward_propagation() + return self._fwd_prop + + @property + def frechet_deriv_propagators(self) -> List[List[q_mat.OperatorMatrix]]: + """ + Returns the frechet derivatives of the propagators. + + Returns + ------- + derivative_prop: List[List[ControlMatrix]], + shape [[] * num_t] * num_ctrl + Frechet derivatives of the propagators by the control amplitudes + + """ + if self._derivative_prop is None: + self._compute_propagation_derivatives() + return self._derivative_prop + + @property + def reversed_propagators(self) -> List[q_mat.OperatorMatrix]: + """ + Returns the reversed propagation of the initial state for every time + slice and calculate it if necessary. If the initial state is the + identity matrix, then the reversed cumulative propagators are given. + The element forward_propagators[i] propagates a state by the first i + time steps, if the initial state is the identity matrix. + + Returns + ------- + reversed_propagation: List[ControlMatrix], len num_t + 1 + Propagation of the initial state of the system. reversed[0] gives + the initial state itself. + + """ + if self._reversed_prop is None: + self._compute_reversed_propagation() + return self._reversed_prop + + @property + def filter_function_n_coeffs_deriv_vals(self) -> Optional[np.ndarray]: + """ + Calculates the derivatives of the noise susceptibilities from the filter + function formalism. + + Returns + ------- + n_coeffs_deriv: numpy array of shape (num_noise_op, n_ctrl, num_t) + Derivatives of the noise susceptibilities by the control amplitudes. + + """ + if self.filter_function_n_coeffs_deriv is None: + return None + else: + return self.filter_function_n_coeffs_deriv(self._ctrl_amps) + + @property + def create_ff_h_n(self) -> list: + """Creates the noise hamiltonian of the filter function formalism. + + Returns + ------- + create_ff_h_n: nested list + Noise Hamiltonian of the filter function formalism. + + """ + if type(self._filter_function_h_n) == list: + h_n = self._filter_function_h_n + else: + h_n = self._filter_function_h_n(self._ctrl_amps) + + if not h_n: + h_n = [[np.zeros(self.h_ctrl[0].shape), + np.zeros((len(self.transferred_time),))]] + + return h_n + + @abstractmethod + def _compute_propagation(self) -> None: + """ + Computes the propagators. Must set self._prop! + + Raises + ------ + ValueError + If the control amplitudes are not set. + + """ + if self._ctrl_amps is None: + raise ValueError("The control amplitudes must be set to calculate " + "the propagation!") + + def _compute_forward_propagation(self) -> None: + """Computes the forward propagators. """ + if self._prop is None: + self._compute_propagation() + self._fwd_prop = [self.initial_state.copy(), ] + for prop in self._prop: + self._fwd_prop.append(prop * self._fwd_prop[-1]) + + def _compute_reversed_propagation(self) -> None: + """Compute the reversed propagation. """ + if self._prop is None: + self._compute_propagation() + + if type(self.initial_state) == matrix.DenseOperator: + self._reversed_prop = [matrix.DenseOperator( + np.eye(self._prop[0].shape[0])) * (1 + 0j), ] + elif type(self.initial_state) == matrix.SparseOperator: + raise NotImplementedError + # self._reversed_prop = [matrix.SparseOperator( + # np.eye(self._prop[0].shape[0])) * (1 + 0j), ] + else: + raise TypeError("The initial state should be either a dense or " + "sparse control matrix.") + + for prop in self._prop[::-1]: + self._reversed_prop.append(self._reversed_prop[-1] * prop) + + @abstractmethod + def _compute_propagation_derivatives(self) -> None: + """Compute the derivatives of the propagators by the control + amplitudes. + """ + pass + + def _diagonalize_and_propagate_pulse_sequence(self) -> None: + """Manually set eigendecomposition of the PulseSequence. + + Work around incompatibility of drift Hamiltonian + representations.""" + ps = self.pulse_sequence + drift_hamiltonian = np.array([h.data for h in self.h_drift]) + control_hamiltonian = np.einsum('ijk,il->ljk', ps.c_opers, ps.c_coeffs) + ps.eigvals, ps.eigvecs, ps.propagators = numeric.diagonalize( + drift_hamiltonian + control_hamiltonian, ps.dt + ) + ps.total_propagator = ps.propagators[-1] + + def create_pulse_sequence( + self, new_amps: Optional[np.array] = None, + ff_basis: Optional[basis.Basis] = None + ) -> pulse_sequence.PulseSequence: + """ + Create a pulse sequence of the filter function package written by + Tobias Hangleiter. + + See the documentation of the filter function package. + + Parameters + ---------- + new_amps: np.array, shape (num_t, num_ctrl), optional + New control amplitudes can be set before the pulse sequence is + initialized. + + ff_basis: Basis + The pulse sequence will be expanded in this basis. See + documentation of the filter function package. + + Returns + ------- + pulse_sequence: filter_functions.pulse_sequence.PulseSequence + The pulse sequence corresponding to the control model and the + control amplitudes set. + + """ + if new_amps is not None: + self.set_optimization_parameters(new_amps) + else: + if self._ctrl_amps is None: + raise ValueError('No optimization parameters set. ' + 'Please supply new_amps argument') + + if ff_basis is not None: + basis = ff_basis + elif self.filter_function_basis is not None: + basis = self.filter_function_basis + else: + basis = None + + # We have to work around different interfaces for the drift + # operators. Since in qopt the drift can be arbitrary (incl. + # nonlinear coupling), but in filter_functions the form H = + # a(t) A is imposed, we don't tell the PulseSequence object + # about H_drift and set the eigendecomposition after the fact. + if self.pulse_sequence is None: + h_c = list(zip( + self.h_ctrl, + self._ctrl_amps.T, + [f'Control{i}' for i in range(len(self.h_ctrl))] + )) + self.pulse_sequence = pulse_sequence.PulseSequence( + h_c, self.create_ff_h_n, self.transferred_time, basis + ) + else: + # Clean up the caches and update coefficients + self.pulse_sequence.cleanup('all') + self.pulse_sequence.c_coeffs = self._ctrl_amps.T + # Not the most elegant, but necessary for the current + # implementation. + self.pulse_sequence.n_coeffs = pulse_sequence._parse_Hamiltonian( + self.create_ff_h_n, + len(self.transferred_time), 'H_n')[2] + + if basis is not None: + self.pulse_sequence.basis = basis + + self._diagonalize_and_propagate_pulse_sequence() + return self.pulse_sequence + + def plot_bloch_sphere( + self, new_amps=None, return_Bloch: bool = False) -> None: + """ + Uses the pulse sequence to plot the systems evolution on the bloch + sphere. + + Only available for two dimensional systems. + + Parameters + ---------- + new_amps: np.array, shape (num_t, num_ctrl), optional + New control amplitudes can be set before the pulse sequence is + initialized. + + return_Bloch: bool + If True, then qutips Bloch object is returned. + + Returns + ------- + b: Bloch + Qutips Bloch object. Only returned if return_Bloch is set to True. + + """ + # Already takes care of updating and cleaning the PulseSequence object + pulse_sequence = self.create_pulse_sequence(new_amps=new_amps) + return plotting.plot_bloch_vector_evolution(pulse_sequence, + n_samples=500, + return_Bloch=return_Bloch) + + +class SchroedingerSolver(Solver): + """ + This time slot computer solves the unperturbed Schroedinger equation. + + All intermediary propagators are calculated and cached. Takes also input + parameters of the base class. + + Parameters + ---------- + calculate_propagator_derivatives: bool + If true, the derivatives of the propagators by the control amplitudes + are always calculated. Otherwise only on demand. + + frechet_deriv_approx_method: Optional[str] + Method for the approximation of the derivatives of the propagators, if + they are not calculated analytically. Note that this method is never + used if calculate_propagator_derivatives is set to True! + Methods: + None: The derivatives are not approximated by calculated by the control + matrix class. + 'grape': use the approximation given in the original grape paper. + + Attributes + ---------- + _dyn_gen: List[ControlMatrix], len num_t + The generators of the systems dynamics + + calculate_propagator_derivatives: bool + If true, the derivatives of the propagators by the control amplitudes + are always calculated. Otherwise only on demand. + + frechet_deriv_approx_method: Optional[str] + Method for the approximation of the derivatives of the propagators, if + they are not calculated analytically. Note that this method is never + used if calculate_propagator_derivatives is set to True! + Methods: + 'grape': use the approximation given in the original grape paper. + + Methods + ------- + _compute_derivative_directions: List[List[q_mat.ControlMatrix]], + shape [[] * num_ctrl] * num_t + Computes the directions of change with respect to the control + parameters. + + _compute_dyn_gen: List[ControlMatrix], len num_t + Computes the dynamics generators. + + `Todo` + * raise a warning if the approximation method although the gradient + is always calculated. + * raise a warning if the grape approximation is chosen but its + requirement of small time steps is not met. + + """ + + def __init__(self, + h_drift: List[q_mat.OperatorMatrix], + h_ctrl: List[q_mat.OperatorMatrix], + tau: np.array, + initial_state: q_mat.OperatorMatrix = None, + ctrl_amps: Optional[np.array] = None, + calculate_propagator_derivatives: bool = True, + filter_function_h_n: Union[ + Callable, List[List], None] = None, + filter_function_basis: Optional[basis.Basis] = None, + filter_function_n_coeffs_deriv: Optional[ + Callable[[np.ndarray], np.ndarray]] = None, + exponential_method: Optional[str] = None, + frechet_deriv_approx_method: Optional[str] = None, + is_skew_hermitian: bool = True, + transfer_function: Optional[TransferFunction] = None, + amplitude_function: Optional[AmplitudeFunction] = None): + super().__init__( + h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, + tau=tau, ctrl_amps=ctrl_amps, + filter_function_h_n=filter_function_h_n, + filter_function_basis=filter_function_basis, + filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, + exponential_method=exponential_method, + is_skew_hermitian=is_skew_hermitian, + transfer_function=transfer_function, + amplitude_function=amplitude_function + ) + self.id_text = 'ALL' + self.cache_text = 'Save' + self.calculate_propagator_derivatives = \ + calculate_propagator_derivatives + self.frechet_deriv_approx_method = frechet_deriv_approx_method + + self._dyn_gen = None + + def set_optimization_parameters(self, y: np.array) -> None: + """See base class. """ + if not np.array_equal(self._opt_pars, y): + self.reset_cached_propagators() + super().set_optimization_parameters(y) + + def reset_cached_propagators(self): + """See base class. """ + self._dyn_gen = None + super().reset_cached_propagators() + + def _compute_dyn_gen(self) -> List[q_mat.OperatorMatrix]: + """ + Computes the dynamics generators. + + Returns + ------- + dyn_gen: List[ControlMatrix], len num_t + This is basically the total Hamiltonian. + + """ + self._dyn_gen = [-1j * h for h in self.h_drift] + for ctrl, ctrl_op in enumerate(self.h_ctrl): + for dyn_gen, ctrl_amp in \ + zip(self._dyn_gen, self._ctrl_amps[:, ctrl]): + dyn_gen += -1j * ctrl_amp * ctrl_op + return self._dyn_gen + + def _compute_derivative_directions( + self) -> List[List[q_mat.OperatorMatrix]]: + """ + The directions of the frechet derivatives are the control operators. + + No deep copy is required because the result is not used for in-place + operations. + + """ + # The list is multiplied (copied by reference) because the elements + # will not be manipulated in place. (only as copy) + return [[operator * -1j for operator in self.h_ctrl], ] * len(self.transferred_time) + + def _compute_propagation( + self, calculate_propagator_derivatives: Optional[bool] = None) \ + -> None: + """See base class. """ + super()._compute_propagation() + + if self._dyn_gen is None: + self._dyn_gen = self._compute_dyn_gen() + + if calculate_propagator_derivatives is None: + calculate_propagator_derivatives = \ + self.calculate_propagator_derivatives + + # initialize the attributes + self._prop = [None for _ in range(len(self.transferred_time))] + + if calculate_propagator_derivatives: + derivative_directions = self._compute_derivative_directions() + self._derivative_prop = [ + [None for _ in range(len(self.transferred_time))] + for _2 in range(len(self.h_ctrl))] + for t in range(len(self.transferred_time)): + for ctrl in range(len(self.h_ctrl)): + try: + self._prop[t], self._derivative_prop[ctrl][t] \ + = self._dyn_gen[t].dexp( + derivative_directions[t][ctrl], + self.transferred_time[t], + compute_expm=True, method=self.exponential_method, + is_skew_hermitian=self._is_skew_hermitian) + except ValueError: + raise ValueError('The computation has failed with ' + 'a value error. Try another ' + 'exponentiation method.') + else: + for t in range(len(self.transferred_time)): + self._prop[t] = self._dyn_gen[t].exp( + tau=self.transferred_time[t], method=self.exponential_method, + is_skew_hermitian=self._is_skew_hermitian) + + def _compute_propagation_derivatives(self) -> None: + """ + Computes the frechet derivatives of the propagators. + + The derivatives are not returned but cached. Since the function is only + called when no derivatives are cached, the approximation is + prioritised. + """ + if not self.frechet_deriv_approx_method: + self._compute_propagation(calculate_propagator_derivatives=True) + elif self.frechet_deriv_approx_method == 'grape': + if self._prop is None: + self._compute_propagation( + calculate_propagator_derivatives=False) + self._derivative_prop = [[None for _ in range(len(self.h_ctrl))] + for _2 in range(len(self.transferred_time))] + derivative_directions = self._compute_derivative_directions() + for t in range(len(self.transferred_time)): + for ctrl in range(len(self.h_ctrl)): + self._derivative_prop[t][ctrl] = \ + self.transferred_time[t] * derivative_directions[t][ctrl] \ + * self._prop[t] + else: + raise ValueError('Unknown gradient derivative approximation ' + 'method:' + + str(self.frechet_deriv_approx_method)) + + +def _compute_matrix_exponentials(input_dict): + """Computes the propagator of the Schroedinger equation by evaluation of + a matrix exponential. + + Parameters + ---------- + input_dict: dict + Holds the parameters in a single dict, because the function + multiprocessing.Pool.map requires a single input argument. The dict + has the fields time, matrices, method and is_skew_hermitian. See also + _compute_propagator. + + Returns + ------- + exponentials: list of ControlMatrix + A list of the propagators. + + """ + time = input_dict['time'] + matrices = input_dict['matrices'] + method = input_dict['method'] + is_skew_hermitian = input_dict['is_skew_hermitian'] + + exponentials = [None, ] * len(time) + for i, m, t in zip(range(len(matrices)), matrices, time): + exponentials[i] = m.exp( + tau=t, + method=method, + is_skew_hermitian=is_skew_hermitian) + return exponentials + + +class SchroedingerSMonteCarlo(SchroedingerSolver): + r""" + Solves Schroedinger's equation for explicit noise realisations as Monte + Carlo experiment. + + This time slot computer solves the Schroedinger equation explicitly for + concrete noise realizations. The noise traces are generated by an instance + of the Noise Trace Generator Class. Then they can be processed by the + noise amplitude function, before they are multiplied by the noise + hamiltionians. + + Parameters + ---------- + h_noise: List[ControlMatrix], len num_noise_operators + List of noise operators occurring in the Hamiltonian. + + noise_trace_generator: noise.NoiseTraceGenerator + Noise trace generator object. + + processes: int, optional + If an integer is given, then the propagation is calculated in + this number of parallel processes. If 1 then no parallel + computing is applied. If None then cpu_count() is called to use + all cores available. Defaults to 1. + + noise_amplitude_function: Callable[[noise_samples: np.array, + optimization_parameters: np.array, + transferred_parameters: np.array, + control_amplitudes: np.array], np.array] + The noise amplitude function calculated the noisy control amplitudes + corresponding to the noise samples. They recieve 4 keyword arguments + being the noise samples, the optimization parameters, the transferred + optimization parameters and the control amplitudes in this order. + The noise samples are given with the shape (n_samples_per_trace, + n_traces, n_noise_operators), the optimization parameters + (num_x, num_ctrl), the transferred parameters (num_t, num_ctrl) and + the control amplitudes (num_t, num_ctrl). The returned noise amplitudes + should be of the shape (num_t, n_traces, n_noise_operators). + + Attributes + ---------- + h_noise: List[ControlMatrix], len num_noise_operators + List of noise operators occurring in the Hamiltonian. + + noise_trace_generator: noise.NoiseTraceGenerator + Noise trace generator object. + + _dyn_gen_noise: List[List[ControlMatrix]], + shape [[] * num_t] * num_noise_traces + Dynamics generators for the individual noise traces. + + _prop_noise: List[List[ControlMatrix]], + shape [[] * num_t] * num_noise_traces + Propagators for the individual noise traces. + + _fwd_prop_noise: List[List[ControlMatrix]], + shape [[] * (num_t + 1)] * num_noise_traces + Cumulation of the propagators for the individual noise traces. They + describe the forward propagation of the systems state. + + _reversed_prop_noise: List[List[ControlMatrix]], + shape [[] * (num_t + 1)] * num_noise_traces + Cumulation of propagators in reversed order for the individual noise + traces. + + _derivative_prop_noise: List[List[List[ControlMatrix]]], + shape [[[] * num_t] * num_ctrl] * num_noise_traces + Frechet derivatives of the propagators by the control amplitudes for + the individual noise traces. + + Methods + ------- + propagators_noise: List[List[ControlMatrix]], + shape [[] * num_t] * num_noise_traces + Propagators for the individual noise traces. + + forward_propagators_noise: List[List[ControlMatrix]], + shape [[] * (num_t + 1)] * num_noise_traces + Cumulation of the propagators for the individual noise traces. They + describe the forward propagation of the systems state. + + reversed_propagators_noise: List[List[ControlMatrix]], + shape [[] * (num_t + 1)] * num_noise_traces + Cumulation of propagators in reversed order for the individual noise + traces. + + frechet_deriv_propagators_noise: List[List[List[ControlMatrix]]], + shape [[[] * num_t] * num_ctrl] * num_noise_traces + Frechet derivatives of the propagators by the control amplitudes for + the individual noise traces. + + """ + def __init__( + self, h_drift: List[q_mat.OperatorMatrix], + h_ctrl: List[q_mat.OperatorMatrix], + tau: np.array, + h_noise: List[q_mat.OperatorMatrix], + noise_trace_generator: + Optional[noise.NoiseTraceGenerator], + initial_state: q_mat.OperatorMatrix = None, + ctrl_amps: Optional[np.array] = None, + calculate_propagator_derivatives: bool = False, + processes: Optional[int] = 1, + filter_function_h_n: Union[ + Callable, List[List], None] = None, + filter_function_basis: Optional[basis.Basis] = None, + filter_function_n_coeffs_deriv: Optional[ + Callable[[np.ndarray], np.ndarray]] = None, + exponential_method: Optional[str] = None, + frechet_deriv_approx_method: Optional[str] = None, + is_skew_hermitian: bool = True, + transfer_function: Optional[TransferFunction] = None, + amplitude_function: Optional[AmplitudeFunction] = None, + noise_amplitude_function: Optional[Callable[ + [np.array, np.array, np.array, + np.array], np.array]] = None + ): + + super().__init__( + h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, + tau=tau, ctrl_amps=ctrl_amps, + filter_function_h_n=filter_function_h_n, + filter_function_basis=filter_function_basis, + filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, + exponential_method=exponential_method, + calculate_propagator_derivatives=calculate_propagator_derivatives, + frechet_deriv_approx_method=frechet_deriv_approx_method, + is_skew_hermitian=is_skew_hermitian, + transfer_function=transfer_function, + amplitude_function=amplitude_function) + + self.h_noise = h_noise + self.noise_trace_generator = noise_trace_generator + self.noise_amplitude_function = noise_amplitude_function + self.processes = processes + + self._dyn_gen_noise = None + self._prop_noise = None + self._derivative_prop_noise = None + self._fwd_prop_noise = None + self._reversed_prop_noise = None + + def set_optimization_parameters(self, y: np.array) -> None: + """See base class. """ + if not np.array_equal(self._opt_pars, y): + self.reset_cached_propagators() + super().set_optimization_parameters(y) + + def reset_cached_propagators(self): + """See base class. """ + super().reset_cached_propagators() + self._dyn_gen_noise = None + self._prop_noise = None + self._derivative_prop_noise = None + self._fwd_prop_noise = None + self._reversed_prop_noise = None + + + @property + def propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: + """ + Returns the propagators of the system for each noise trace and + calculates them if necessary. + + Returns + ------- + propagators_noise: List[List[ControlMatrix]], + shape [[] * num_t] * num_noise_traces + Propagators of the system for each noise trace. + + """ + if self._prop_noise is None: + self._compute_propagation() + return self._prop_noise + + @property + def forward_propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: + """ + Returns the forward propagation of the initial state for every time + slice and every noise trace and calculate it if necessary. If the + initial state is the identity matrix, then the cumulative propagators + are given. The element forward_propagators[k][i] propagates a state by + the first i time steps under the kth noise trace, if the initial state + is the identity matrix. + + Returns + ------- + forward_propagation:List[List[ControlMatrix]], + shape [[] * (num_t + 1)] * num_noise_traces + Propagation of the initial state of the system. fwd[0] gives the + initial state itself. + + """ + if self._fwd_prop_noise is None: + self._compute_forward_propagation() + return self._fwd_prop_noise + + @property + def frechet_deriv_propagators_noise(self) \ + -> List[List[List[q_mat.OperatorMatrix]]]: + """ + Returns the frechet derivatives of the propagators with respect to the + control amplitudes for each noise trace. + + Returns + ------- + derivative_prop_noise: List[List[List[ControlMatrix]]], + shape [[[] * num_t] * num_ctrl] * num_noise_traces + Frechet derivatives of the propagators by the control amplitudes. + + """ + if self._derivative_prop_noise is None: + self._compute_propagation_derivatives() + return self._derivative_prop_noise + + @property + def reversed_propagators_noise(self) -> List[List[q_mat.OperatorMatrix]]: + """ + Returns the reversed propagation of the initial state for every noise + trace and calculate it if necessary. If the initial state is the + identity matrix, then the reversed cumulative propagators are given. + The element forward_propagators[k][i] propagates a state by the first i + time steps under the kth noise trace, if the initial state is the + identity matrix. + + Returns + ------- + reversed_propagation_noise: List[List[ControlMatrix]], + shape [[] * (num_t + 1)] * num_noise_traces + Propagation of the initial state of the system. reversed[k][0] + gives the initial state itself. + + """ + if self._reversed_prop_noise is None: + self._compute_reversed_propagation() + return self._reversed_prop_noise + + def _compute_dyn_gen_noise(self) -> List[List[q_mat.OperatorMatrix]]: + """ + Computes the dynamics generators for the perturbed and unperturbed + Schroedinger equation. + + Returns + ------- + dyn_gen_noise: List[List[q_mat.ControlMatrix]], + shape [[] * num_t] * num_noise_traces + Dynamics generators for each noise trace. + + """ + # compute the generators of the unperturbed dynamics + self._dyn_gen = super()._compute_dyn_gen() + + # compute the generators for the noise traces. + n_noise_traces = self.noise_trace_generator.n_traces + + noise_samples = self.noise_trace_generator.noise_samples + # we transpose, so we iterate over the time last + noise_samples = np.transpose(noise_samples, (2, 1, 0)) + + if self.noise_amplitude_function: + noise_samples = self.noise_amplitude_function( + noise_samples=noise_samples, + optimization_parameters=self._opt_pars, + transferred_parameters=self.transferred_parameters, + control_amplitudes=self._ctrl_amps + ) + + self._dyn_gen_noise = [[dyn_gen.copy() for dyn_gen in self._dyn_gen] + for _ in range(n_noise_traces)] + + for t, sample_stack in enumerate(noise_samples): + for n_trace, trace in enumerate(sample_stack): + for operator_sample, operator in zip(trace, self.h_noise): + self._dyn_gen_noise[n_trace][t] += \ + (-1j * operator_sample) * operator + return self._dyn_gen_noise + + def _compute_propagation( + self, calculate_propagator_derivatives: Optional[bool] = None + ) -> None: + """ + Computes the propagators for the perturbed Schroedinger equation and + the derivatives on demand. + + Parameters + ---------- + calculate_propagator_derivatives: bool, optional + Calculate the derivatives of the propagators with respect to the + control amplitudes if true. + + """ + + if self._dyn_gen_noise is None: + self._dyn_gen_noise = self._compute_dyn_gen_noise() + + n_noise_traces = self.noise_trace_generator.n_traces + num_t = len(self.transferred_time) + num_ctrl = len(self.h_ctrl) + + self._prop_noise = [[None for _ in range(num_t)] + for _2 in range(n_noise_traces)] + + if calculate_propagator_derivatives is None: + calculate_propagator_derivatives = \ + self.calculate_propagator_derivatives + + # parallelization of following code probably unnecessary + if calculate_propagator_derivatives: + self._derivative_prop_noise = \ + [[[None for _ in range(num_t)] + for _2 in range(num_ctrl)] + for _3 in range(n_noise_traces)] + derivative_directions = self._compute_derivative_directions() + + # call the parent method for the noiseless propagators + super()._compute_propagation( + calculate_propagator_derivatives=calculate_propagator_derivatives) + + if self.processes == 1: + if calculate_propagator_derivatives: + for k in range(n_noise_traces): + for t in range(num_t): + for ctrl in range(len(self.h_ctrl)): + self._prop_noise[k][t], \ + self._derivative_prop_noise[k][ctrl][t] \ + = self._dyn_gen_noise[k][t].dexp( + derivative_directions[t][ctrl], + self.transferred_time[t], + compute_expm=True, + method=self.exponential_method, + is_skew_hermitian=self._is_skew_hermitian) + else: + for k in range(n_noise_traces): + for t in range(num_t): + self._prop_noise[k][t] = self._dyn_gen_noise[k][t].exp( + tau=self.transferred_time[t], + method=self.exponential_method, + is_skew_hermitian=self._is_skew_hermitian) + + elif (type(self.processes) == int and self.processes > 0) \ + or self.processes is None: + + if calculate_propagator_derivatives: + raise NotImplementedError + else: + input_dicts = [] + for k in range(n_noise_traces): + input_dicts.append(dict()) + input_dicts[-1]['time'] = self.transferred_time + input_dicts[-1]['matrices'] = self._dyn_gen_noise[k] + input_dicts[-1]['method'] = self.exponential_method + input_dicts[-1][ + 'is_skew_hermitian'] = self._is_skew_hermitian + + with Pool(processes=self.processes) as pool: + self._prop_noise = pool.map( + _compute_matrix_exponentials, input_dicts) + + else: + raise ValueError('Invalid number of processes for parallel ' + 'computation!') + + def _compute_forward_propagation(self) -> None: + """Computes the forward propagators. """ + super()._compute_forward_propagation() + if self._prop_noise is None: + self._compute_propagation() + + self._fwd_prop_noise = [ + [self.initial_state.copy(), ] + for _ in range(self.noise_trace_generator.n_traces)] + + for fwd_per_trace, prop_per_trace in zip(self._fwd_prop_noise, + self._prop_noise): + for prop in prop_per_trace: + fwd_per_trace.append(prop * fwd_per_trace[-1]) + + def _compute_reversed_propagation(self) -> None: + """Compute the reversed propagation. For the perturbed and unperturbed + Schroedinger equation. """ + super()._compute_reversed_propagation() + if self._prop_noise is None: + self._compute_propagation() + + self._reversed_prop_noise = [ + [self._prop[0].identity_like(), ] + for _ in range(self.noise_trace_generator.n_traces)] + + for rev_per_trace, prop_per_trace in zip(self._reversed_prop_noise, + self._prop_noise): + for prop in prop_per_trace[::-1]: + rev_per_trace.append(rev_per_trace[-1] * prop) + + def _compute_propagation_derivatives(self) -> None: + """ + Computes the frechet derivatives of the propagators. + + The derivatives are not returned but cached. Since the function is only + called when no derivatives are cached, the approximation is + prioritised. + """ + if not self.frechet_deriv_approx_method: + self._compute_propagation(calculate_propagator_derivatives=True) + elif self.frechet_deriv_approx_method == 'grape': + super()._compute_propagation_derivatives() + + if self._prop_noise is None: + self._compute_propagation( + calculate_propagator_derivatives=False) + + n_noise_traces = self.noise_trace_generator.n_traces + num_t = len(self.transferred_time) + num_ctrl = len(self.h_ctrl) + + self._derivative_prop_noise = [ + [[None for _ in range(num_t)] + for _2 in range(num_ctrl)] + for _3 in range(n_noise_traces)] + + derivative_directions = self._compute_derivative_directions() + + for k in range(n_noise_traces): + for t in range(len(self.transferred_time)): + for ctrl in range(num_ctrl): + self._derivative_prop_noise[k][ctrl][t] = \ + self.transferred_time[t] * derivative_directions[t][ctrl] \ + * self._prop_noise[k][t] + else: + raise ValueError('Unknown gradient derivative approximation ' + 'method:' + + str(self.frechet_deriv_approx_method)) + + +class SchroedingerSMCControlNoise(SchroedingerSMonteCarlo): + """ + Convenience class like `SchroedingerSMonteCarlo` but with noise on the + optimization parameters. + + This time slot computer solves the Schroedinger equation explicitly for + concrete control noise realizations. This time slot computer assumes, + that the noise is sampled on the time scale of the already transferred + optimization parameters. The control Hamiltionians are also used as noise + Hamiltionians and the noise amplitude function adds the noise samples to + the unperturbed transferred optimization parameters and applies the + amplitude function of the control amplitudes. + + """ + def __init__( + self, + h_drift: List[q_mat.OperatorMatrix], + h_ctrl: List[q_mat.OperatorMatrix], + tau: np.array, + noise_trace_generator: + Optional[noise.NoiseTraceGenerator], + initial_state: q_mat.OperatorMatrix = None, + ctrl_amps: Optional[np.array] = None, + calculate_propagator_derivatives: bool = False, + processes: Optional[int] = 1, + filter_function_h_n: Union[ + Callable, List[List], None] = None, + filter_function_basis: Optional[basis.Basis] = None, + filter_function_n_coeffs_deriv: Optional[ + Callable[[np.ndarray], np.ndarray]] = None, + exponential_method: Optional[str] = None, + frechet_deriv_approx_method: Optional[str] = None, + is_skew_hermitian: bool = True, + transfer_function: Optional[TransferFunction] = None, + amplitude_function: Optional[AmplitudeFunction] = None): + + def noise_amplitude_function(noise_samples: np.array, + transferred_parameters: np.array, + control_amplitudes: np.array, + **_): + """Calculates the noise amplitudes. + + Takes into account the actual optimization parameters and random + variations. + + Parameters + ---------- + noise_samples: np.array + Noise samples calculated by the noise trace generator. + + transferred_parameters: np.array + Transferred optimization parameters. + + control_amplitudes: np.array + Control amplitudes. + + """ + noise_amplitudes = np.zeros((noise_samples.shape[0],noise_samples.shape[1],control_amplitudes.shape[1]), dtype=complex) + # complex values were requested. + for trace_num in range(noise_samples.shape[1]): + noise_amplitudes[:, trace_num, :] = self.amplitude_function( + transferred_parameters + noise_samples[:, trace_num, :]) \ + - control_amplitudes + return noise_amplitudes + + super().__init__( + h_drift=h_drift, + h_ctrl=h_ctrl, + initial_state=initial_state, + tau=tau, + h_noise=h_ctrl, + noise_trace_generator=noise_trace_generator, + ctrl_amps=ctrl_amps, + calculate_propagator_derivatives=calculate_propagator_derivatives, + processes=processes, + filter_function_h_n=filter_function_h_n, + filter_function_basis=filter_function_basis, + filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, + exponential_method=exponential_method, + frechet_deriv_approx_method=frechet_deriv_approx_method, + is_skew_hermitian=is_skew_hermitian, + transfer_function=transfer_function, + amplitude_function=amplitude_function, + noise_amplitude_function=noise_amplitude_function + ) + + +class LindbladSolver(SchroedingerSolver): + r""" + Solves a master equation for an open quantum system in the Markov + approximation using the Lindblad super operator formalism. + + The master equation to be solved is + + .. math:: + + d \rho / dt = i [\rho, H] + \sum_k (L_k \rho L_k^\dagger + - .5 L_k^\dagger L_k \rho - .5 \rho L_k^\dagger L_k) + + + with the Lindblad operators L_k. The solution is calculated as + + .. math:: + + \rho(t) = exp[(-i \mathcal{H} + \mathcal{G})t] \rho(0) + + with the dissipative super operator + + .. math:: + + \mathcal{G} = \sum_k D(L_k) + + .. math:: + + D(L) = L^\ast \otimes L - .5 I \otimes (L^\dagger L) + - .5 (L^T L^\ast) \otimes I + + The dissipation super operator can be given in three different ways. + + 1. A nested list of dissipation super operators D(L_k) as control + matrices. + 2. A nested list of Lindblad operators L as control matrices. + 3. A function handle receiving the control amplitudes as sole argument and + returning a dissipation super operator as list of control matrices. + + Optionally a prefactor function can be given for 1. and 2. This function + receives the control parameters and returns an array of the shape + num_t x num_l where num_t is the number of time steps in the control and + num_l is the number of Lindblad operators or dissipation super operators. + + If multiple construction arguments are given, the implementation + prioritises the function (3.) over the Lindblad operators (2.) over the + dissipation super operator (1.). + + Parameters + ---------- + initial_diss_super_op: List[ControlMatrix], len num_l + Initial dissipation super operator; num_l is the number of + Lindbladians. Set if you want to use (1.) (See documentation above!). + The control matrices are expected to be of shape (dim, dim) where dim + is the dimension of the system. + + lindblad_operators: List[ControlMatrix], len num_l + Lindblad operators; num_l is the number of Lindbladians. Set if you + want to use (2.) (See documentation above!). The Lindblad operators are + assumend to be of shape (dim, dim) where dim is the dimension of the + system. + + prefactor_function: Callable[[np.array, np.array], np.array] + Receives the control amplitudes u (as numpy array of shape + (num_t, num_ctrl)) and the transferred optimization parameters (as + numpy array of shape (num_t, num_opt)) and returns prefactors as numpy + array of shape (num_t, num_l). The prefactors a_k are used as weights in + the sum of the total dissipation operator. + + .. math:: + + \mathcal{G} = \sum_k a_k * D(L_k) + + If the Lindblad operator is for example given by a complex number b_k + times a constant (in time) matrix C_k. + + .. math:: + + L_k = b_k * C_k + + Then the prefactor is the squared absolute value of this number: + + .. math:: + + a_k = |b_k|^2 + + Set if you want to use method (1.) or (2.). (See class documentation.) + + prefactor_derivative_function: Callable[[np.array, np.array], np.array] + Receives the control amplitudes u (as numpy array of shape + (num_t, num_ctrl)) and the transferred optimization parameters (as + numpy array of shape (num_t, num_opt)) and returns the derivatives of + the prefactors as numpy array of shape (num_t, num_ctrl, num_l). The + derivatives d_k are used as weights in the sum of the derivative of the + total dissipation operator. + + .. math:: + + d \mathcal{G} / d u_k = \sum_k d_k * D(L_k) + + If the Lindblad operator is for example given by a complex number b_k + times a constant (in time) matrix C_k. And this number depends on the + control amplitudes u_k + + .. math:: + + L_k = b_k (u_k) * C_k + + Then the derivative of the prefactor is the derivative of the squared + absolute value of this number: + + .. math:: + + d_k = d |b_k|^2 / d u_k + + Set if you want to use method (1.) or (2.). (See class documentation.) + + super_operator_function: Callable[[np.array, np.array], List[ControlMatrix]] + Receives the control amlitudes u (as numpy array of shape + (num_t, num_ctrl)) and the transferred optimization parameters (as + numpy array of shape (num_t, num_opt)) and returns the total dissipation + operators as list of length num_t. Set if you want to use method (3.). + (See class documentation.) + + super_operator_derivative_function: Callable[[np.array, np.array], + List[List[ControlMatrix]]] + Receives the control amlitudes u (as numpy array of shape + (num_t, num_ctrl)) and the transferred optimization parameters (as + numpy array of shape (num_t, num_opt)) and returns the derivatives of + the total dissipation operators as nested list of + shape [[] * num_ctrl] * num_t. Set if you + want to use method (3.). (See class documentation.) + + is_skew_hermitian: bool + If True, then the total dynamics generator is assumed to be skew + hermitian. + + Attributes + ---------- + _diss_sup_op: List[ControlMatrix], len num_t + Total dissipaton super operator. + + _diss_sup_op_deriv: List[List[ControlMatrix]], + shape [[] * num_ctrl] * num_t + Derivative of the total dissipation operator with respect to the + control amplitudes. + + _initial_diss_super_op: List[ControlMatrix], len num_l + Initial dissipation super operator; num_l is the number of + Lindbladians. + + _lindblad_operatorsList[ControlMatrix], len num_l + Lindblad operators; num_l is the number of Lindbladians. + + _prefactor_function: Callable[[np.array], np.array] + Receives the control amplitudes u (as numpy array of shape + (num_t, num_ctrl)) and returns prefactors as numpy array + of shape (num_t, num_l). The prefactors a_k are used as weights in the + sum of the total dissipation operator. + + .. math:: + + \mathcal{G} = \sum_k a_k * D(L_k) + + If the Lindblad operator is for example given by a complex number b_k + times a constant (in time) matrix C_k. + + .. math:: + + L_k = b_k * C_k + + Then the prefactor is the squared absolute value of this number: + + .. math:: + + a_k = |b_k|^2 + + Set if you want to use method (1.) or (2.). (See class documentation.) + + _prefactor_deriv_function: Callable[[np.array], np.array] + Receives the control amplitudes u (as numpy array of shape + (num_t, num_ctrl)) and returns the derivatives of the + prefactors as numpy array of shape (num_t, num_ctrl, num_l). The + derivatives d_k are used as weights in the sum of the derivative of the + total dissipation operator. + + .. math:: + + d \mathcal{G} / d u_k = \sum_k d_k * D(L_k) + + If the Lindblad operator is for example given by a complex number b_k + times a constant (in time) matrix C_k. And this number depends on the + control amplitudes u_k + + .. math:: + + L_k = b_k (u_k) * C_k + + Then the derivative of the prefactor is the derivative of the squared + absolute value of this number: + + .. math:: + + d_k = d |b_k|^2 / d u_k + + _sup_op_func: Callable[[np.array], List[ControlMatrix]] + Receives the control amplitudes u (as numpy array of shape + (num_t, num_ctrl)) and returns the total dissipation + operators as list of length num_t. + + _sup_op_deriv_func: Callable[[np.array], List[List[ControlMatrix]]] + Receives the control amplitudes u (as numpy array of shape + (num_t, num_ctrl)) and returns the derivatives of the total dissipation + operators as nested list of shape [[] * num_ctrl] * num_t. + + Methods + ------- + _parse_dissipative_super_operator: None + + _calc_diss_sup_op: List[ControlMatrix] + Calculates the total dissipation super operator. + + _calc_diss_sup_op_deriv: Optional[List[List[ControlMatrix]]] + Calculates the derivatives of the total dissipation super operators + with respect to the control amplitudes. + + `Todo` + * Write parser + + """ + + def __init__( + self, + h_drift: List[q_mat.OperatorMatrix], + h_ctrl: List[q_mat.OperatorMatrix], + tau: np.array, + initial_state: q_mat.OperatorMatrix = None, + ctrl_amps: Optional[np.array] = None, + calculate_unitary_derivatives: bool = False, + filter_function_h_n: Union[ + Callable, List[List], None] = None, + filter_function_basis: Optional[basis.Basis] = None, + filter_function_n_coeffs_deriv: Optional[ + Callable[[np.ndarray], np.ndarray]] = None, + exponential_method: Optional[str] = None, + frechet_deriv_approx_method: Optional[str] = None, + initial_diss_super_op: List[q_mat.OperatorMatrix] = None, + lindblad_operators: List[q_mat.OperatorMatrix] = None, + prefactor_function: Callable[[np.array, np.array], np.array] = None, + prefactor_derivative_function: + Callable[[np.array, np.array], np.array] = None, + super_operator_function: + Callable[[np.array, np.array], List[q_mat.OperatorMatrix]] = None, + super_operator_derivative_function: + Callable[[np.array, np.array], + List[List[q_mat.OperatorMatrix]]] = None, + is_skew_hermitian: bool = False, + transfer_function: Optional[TransferFunction] = None, + amplitude_function: Optional[AmplitudeFunction] = None) \ + -> None: + + if initial_state is None: + dim = h_ctrl[0].shape[0] + initial_state = type(h_ctrl[0])(np.eye(dim ** 2)) + + self._diss_sup_op = None + self._diss_sup_op_deriv = None + + # we do not throw away any operators or functions, just in case + self._initial_diss_super_op = initial_diss_super_op + self._lindblad_operators = lindblad_operators + self._prefactor_function = prefactor_function + self._prefactor_deriv_function = prefactor_derivative_function + self._sup_op_func = super_operator_function + self._sup_op_deriv_func = super_operator_derivative_function + self._is_hermitian = is_skew_hermitian + + super().__init__( + h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, + tau=tau, ctrl_amps=ctrl_amps, + calculate_propagator_derivatives=calculate_unitary_derivatives, + filter_function_h_n=filter_function_h_n, + filter_function_basis=filter_function_basis, + filter_function_n_coeffs_deriv=filter_function_n_coeffs_deriv, + exponential_method=exponential_method, + frechet_deriv_approx_method=frechet_deriv_approx_method, + is_skew_hermitian=is_skew_hermitian, + transfer_function=transfer_function, + amplitude_function=amplitude_function) + + def set_optimization_parameters(self, y: np.array) -> None: + """See base class. """ + if not np.array_equal(self._opt_pars, y): + super().set_optimization_parameters(y) + self.reset_cached_propagators() + + def reset_cached_propagators(self): + """ See base class. """ + super().reset_cached_propagators() + if self._prefactor_function is not None \ + or self._sup_op_func is not None: + self._diss_sup_op = None + self._diss_sup_op_deriv = None + + + def _calc_diss_sup_op(self) -> List[q_mat.OperatorMatrix]: + r""" + Calculates the dissipative super operator as described in the class + doc string. + + Returns + ------- + diss_sup_op: List[ControlMatrix], len num_l + Dissipation super operator; Where num_l is the number of Lindblad + terms. + + """ + if self._sup_op_func is None: + # use Lindblad operators + if self._lindblad_operators is None: + # use dissipation_sup_op + const_diss_sup_op = self._initial_diss_super_op + else: + # Calculate the time constant dissipation super operators + # without time dependence + const_diss_sup_op = [] + identity = self._lindblad_operators[0].identity_like() + + for lindblad in self._lindblad_operators: + const_diss_sup_op.append( + (lindblad.conj(do_copy=True)).kron(lindblad)) + const_diss_sup_op[-1] -= .5 * identity.kron( + lindblad.dag(do_copy=True) * lindblad) + const_diss_sup_op[-1] -= .5 * ( + lindblad.transpose(do_copy=True) + * lindblad.conj(do_copy=True)).kron(identity) + + # Add the time dependence + if self._prefactor_function is not None: + self._diss_sup_op = [] + prefactors = self._prefactor_function( + copy.deepcopy(self._ctrl_amps), + copy.deepcopy(self.transferred_parameters)) + for factor_at_time_t in prefactors: + self._diss_sup_op.append( + const_diss_sup_op[0] * factor_at_time_t[0]) + for sup_op, factor \ + in zip(const_diss_sup_op[1:], + factor_at_time_t[1:]): + self._diss_sup_op[-1] += sup_op * factor + else: + self._diss_sup_op = [const_diss_sup_op[0], ] + for sup_op in const_diss_sup_op[1:]: + self._diss_sup_op[0] += sup_op + self._diss_sup_op *= len(self.transferred_time) + else: + self._diss_sup_op = self._sup_op_func( + copy.deepcopy(self._ctrl_amps), + copy.deepcopy(self.transferred_parameters)) + return self._diss_sup_op + + def _calc_diss_sup_op_deriv(self) \ + -> Optional[List[List[q_mat.OperatorMatrix]]]: + r""" + Calculates the derivatives of the dissipation super operator with + respect to the control amplitudes. + + If the dissipation super operator is given as constant (1.) or as + lindblad operators (2.) they are assumed not to depend on the control + parameters and only the derivative of the prefactor is to be taken into + account. In order to do so, a function handle containing the + derivatives must be given. This function receives the control + amplitudes as num_t x num_ctrl numpy array and returns the derivatives + as num_t x num_l x num_ctrl array. + + If the dissipation super operator is given as function handle (3.), + then the derivatives must also be given as function handle receiving + the control amplitudes and returning a nested list of super operators + as control matrices. + + If the requested derivative functions are not provided (None), then + the dissipation super operator is considered constant in the control + amplitudes and the function returns None. + + Returns + ------- + diss_sup_op_deriv: Optional[List[List[q_mat.ControlMatrix]]], + shape [[] * num_ctrl] * num_t + The derivatives of the dissipation super operator with respect to + the control variables. + + """ + + if self._sup_op_deriv_func is not None: + self._diss_sup_op_deriv = \ + self._sup_op_deriv_func( + copy.deepcopy(self._ctrl_amps), + copy.deepcopy(self.transferred_parameters)) + return self._diss_sup_op_deriv + + elif self._prefactor_deriv_function is not None: + if self._lindblad_operators is None: + # use dissipation_sup_op + const_diss_sup_op = self._initial_diss_super_op + else: + # Calculate the time constant dissipation super operators + # without time dependence + const_diss_sup_op = [] + identity = self._lindblad_operators[0].identity_like() + + for lindblad in self._lindblad_operators: + const_diss_sup_op.append( + (lindblad.conj(do_copy=True)).kron(lindblad)) + const_diss_sup_op[-1] -= .5 * identity.kron( + lindblad.dag(do_copy=True) * lindblad) + const_diss_sup_op[-1] -= .5 * ( + lindblad.transpose(do_copy=True) + * lindblad.conj(do_copy=True)).kron(identity) + + prefactor_derivatives = \ + self._prefactor_deriv_function( + copy.deepcopy(self._ctrl_amps), + copy.deepcopy(self.transferred_parameters)) + + # Todo: Assert that the prefactor returns the right dimension + + # prefactor_derivatives: shape (num_t, num_ctrl, num_l) + diss_sup_op_deriv = [] + for factor_per_ctrl_lind in prefactor_derivatives: + # create new sub list for eacht time step + diss_sup_op_deriv.append([]) + for factor_per_lind in factor_per_ctrl_lind: + # add the first term for each control direction + diss_sup_op_deriv[-1].append( + const_diss_sup_op[0] * factor_per_lind[0]) + for diss_sup_op, factor in zip( + const_diss_sup_op[1:], factor_per_lind[1:]): + # add the remaining terms + diss_sup_op_deriv[-1][-1] += diss_sup_op * factor + self._diss_sup_op_deriv = diss_sup_op_deriv + return diss_sup_op_deriv + else: + return None + + def _compute_derivative_directions( + self) -> List[List[q_mat.OperatorMatrix]]: + r""" + Computes the derivative directions of the total dynamics generator. + + Returns + ------- + deriv_directions: List[List[q_mat.ControlMatrix]], + shape [[] * num_ctrl] * num_t + Derivative directions given by + + .. math:: + + -1j * (I \otimes H_k - H_k \otimes I) + d \mathcal{G} / d u_k + + """ + # derivative of the coherent part + identity_times_i = self.h_ctrl[0].identity_like() + identity_times_i *= -1j + h_ctrl_sup_op = [] + for ctrl_op in self.h_ctrl: + h_ctrl_sup_op.append(identity_times_i.kron(ctrl_op)) + h_ctrl_sup_op[-1] -= (ctrl_op.transpose(do_copy=True)).kron( + identity_times_i) + + # add derivative of the dissipation part + if self._diss_sup_op_deriv is None: + self._diss_sup_op_deriv = self._calc_diss_sup_op_deriv() + if self._diss_sup_op_deriv is not None: + dh_by_ctrl = [] + for diss_sup_op_deriv_at_t in self._diss_sup_op_deriv: + dh_by_ctrl.append([]) + for diss_sup_op_deriv, ctrl_sup_op \ + in zip(diss_sup_op_deriv_at_t, h_ctrl_sup_op): + dh_by_ctrl[-1].append(diss_sup_op_deriv + ctrl_sup_op) + else: + dh_by_ctrl = [h_ctrl_sup_op, ] * len(self.transferred_time) + + return dh_by_ctrl + + def _parse_dissipative_super_operator(self) -> None: + r""" + check the dissipative super operator for dimensional consistency + (maybe even physical properties) + - not implemented yet - + """ + pass + + def _compute_dyn_gen(self) -> List[q_mat.OperatorMatrix]: + r""" + Computes the dynamics generator for the Lindblad master equation. + + The Hamiltonian is translated into the master equation formalism as + + .. math:: + + \mathcal{H} = I \otimes H - H^\ast \otimes I + + Then the dissipation super operator is added. + + Returns + ------- + dyn_gen: List[ControlMatrix], len num_t + Dynamics generators for the master equation. + + Raises + ------ + ValueError: + The computation is only defined for the use of dense control + matrices. + + """ + self._dyn_gen = super()._compute_dyn_gen() + + if self._diss_sup_op is None: + self._diss_sup_op = self._calc_diss_sup_op() + + identiy_operator = self._dyn_gen[0].identity_like() + sup_op_dyn_gen = [] + + assert(len(self._dyn_gen) == len(self._diss_sup_op)) + + for dyn_gen, diss_sup_op in zip(self._dyn_gen, self._diss_sup_op): + sup_op_dyn_gen.append(identiy_operator.kron(dyn_gen)) + # the cancelling minus sign accounts for the -i factor, which is + # also conjugated (included in the dyn gen) + sup_op_dyn_gen[-1] += dyn_gen.conj(do_copy=True).kron( + identiy_operator) + sup_op_dyn_gen[-1] += diss_sup_op + + self._dyn_gen = sup_op_dyn_gen + return sup_op_dyn_gen + + +class LindbladSControlNoise(LindbladSolver): + """ + Special case of the Lindblad master equation. It considers white noise on + the control parameters. The same functionality should be implementable + with the parent class, but less convenient. + """ + + @needs_refactoring + def __init__(self, h_drift, h_ctrl, initial_state, tau, + ctrl_amps, transfer_function=None, + calculate_unitary_derivatives=True, filter_function_h_n=None, + exponential_method=None, lindblad_operators=None, + constant_lindblad_operators=False, noise_psd=1): + super().__init__( + h_drift=h_drift, h_ctrl=h_ctrl, initial_state=initial_state, + tau=tau, ctrl_amps=ctrl_amps, + calculate_unitary_derivatives=calculate_unitary_derivatives, + filter_function_h_n=filter_function_h_n, + exponential_method=exponential_method) + + if lindblad_operators is None: + self.lindblad_super_operator = None + else: + d = lindblad_operators[0].shape[0] + self.lindblad_super_operator = np.zeros( + (len(lindblad_operators), d**2, d**2)) + for i, l in enumerate(lindblad_operators): + self.lindblad_super_operator[i, :, :] += np.kron(np.conj(l), l) + self.lindblad_super_operator[i, :, :] += -.5 * np.kron( + np.eye(d), l.T.conj() @ l) + self.lindblad_super_operator[i, :, :] += -.5 * np.kron( + l.T @ l.conj(), np.eye(d)) + + self.transfer_function = transfer_function + # if no transfer function is given it might be consider to be identity + # its not necessarily required + + self.constant_lindblad_operators = constant_lindblad_operators + self.noise_psd = noise_psd + self.incoherent_dyn_gen = None + + def _compute_propagation(self): + """ + + """ + # Compute and cache all dyn_gen (basically the total hamiltonian) + self._dyn_gen = copy.deepcopy(self.h_drift) + self._dyn_gen += np.sum(self._ctrl_amps * self.h_ctrl, axis=1) + + # initialize the attributes + self._prop = [None] * self.num_t + self._dU = np.array(shape=(self.num_t, self.num_ctrl), + dtype=matrix.DenseOperator) + self._fwd = [self.initial_state] + + # super operator calculation + # this is the special case for charge noise on the control parameters + # the required filter function contains + if not self.constant_lindblad_operators or \ + self.incoherent_dyn_gen is None: + transfer_matrix = self.transfer_function.transfer_matrix + self.incoherent_dyn_gen = np.einsum('ijk,klm,k->ilm', + transfer_matrix, + self.lindblad_super_operator, + self.noise_psd) + dim = self._dyn_gen[0].shape[0] + for i, gen in enumerate(self._dyn_gen): + gen = -1j * np.kron( + np.eye(dim), gen.data) - np.kron(gen.data, np.eye(dim)) + gen += self.incoherent_dyn_gen[i, :, :] + gen = matrix.DenseOperator(gen) + + # calculation of the propagators + for t in range(len(self.num_t)): + if self.calculate_propagator_derivatives: + for ctrl in range(self.num_ctrl): + direction = np.kron( + np.eye(dim), self.h_ctrl[t][ctrl]) - np.kron( + self.h_ctrl[t][ctrl], np.eye(dim)) + self._prop[t], self._dU[t, ctrl] = self._dyn_gen[t].dexp( + direction=direction, tau=self.transferred_time[t], + compute_expm=True, method=self.exponential_method) + + else: + self._prop[t] = self._dyn_gen[t].exp( + tau=self.transferred_time[t], method=self.exponential_method) + + self._fwd.append(self._prop[t] * self._fwd[t]) + + self.prop_calculated = True diff --git a/qopt/transfer_function.py b/qopt/transfer_function.py index d1ccc38..ff8a21c 100644 --- a/qopt/transfer_function.py +++ b/qopt/transfer_function.py @@ -116,7 +116,6 @@ from qopt.util import deprecated, needs_refactoring - class TransferFunction(ABC): """ A class for representing transfer functions, between optimization @@ -1714,3 +1713,467 @@ def set_times(self, times): super().set_times(times) # TODO: properly implement 'w' + + +############################################################################### + +try: + import jax.numpy as jnp + from jax import vmap + _HAS_JAX = True +except ImportError: + from unittest import mock + jnp = mock.Mock() + vmap = mock.Mock() + _HAS_JAX = False + +class TransferFunctionJAX(TransferFunction): + """See docstring of class w/o JAX.""" + + def __init__(self, + num_ctrls: int = 1, + bound_type: Optional[Tuple[str, int]] = None, + oversampling: int = 1, + offset: Optional[float] = None + ): + if not _HAS_JAX: + raise ImportError("JAX not available") + super().__init__(num_ctrls,bound_type,oversampling,offset) + + @abstractmethod + def __call__(self, y: Union[np.array,jnp.array]) -> jnp.array: + """Calculate the transferred optimization parameters (x). + + Evaluates the transfer function at the raw optimization parameters (y) + to calculate the transferred optimization parameters (x). + + Parameters + ---------- + y: Union[np.array,jnp.array], shape (num_y, num_par) + Raw optimization variables; num_y is the number of time slices of + the raw optimization parameters and num_par is the number of + distinct raw optimization parameters. + + Returns + ------- + u: jnp.array, shape (num_x, num_par) + Control parameters; num_u is the number of times slices for the + transferred optimization parameters. + + """ + pass + + @property + def num_padding_elements(self) -> (int, int): + """ + Convenience function. Returns the number of elements padded to the + beginning and the end of the control amplitude times. + + Returns + ------- + num_padding_elements: (int, int) + (elements padded to the beginning, elements padded to the end) + + """ + if self.bound_type is None: + return 0, 0 + elif self.bound_type[0] == 'n': + return self.bound_type[1], self.bound_type[1] + elif self.bound_type[0] == 'x': + return self.bound_type[1] * self.oversampling, \ + self.bound_type[1] * self.oversampling + elif self.bound_type[0] == 'right_n': + return 0, self.bound_type[1] + else: + raise ValueError('Unknown bound type ' + str(self.bound_type[0])) + + @abstractmethod + def gradient_chain_rule( + self, deriv_by_transferred_par: Union[np.array,jnp.array] + ) -> jnp.array: + """ + Obtain the derivatives of a quantity a i.e. da/dy by the optimization + variables from the derivatives by the amplitude of the control fields. + + The chain rule applies: df/dy = df/dx * dx/dy. + + Parameters + ---------- + deriv_by_transferred_par: Union[np.array,jnp.array], + shape (num_x, num_f, num_par) + The gradients of num_f functions by num_par optimization parameters + at num_x different time steps. + + Returns + ------- + deriv_by_opt_par: np.array, shape: (num_y, num_f, num_par) + The derivatives by the optimization parameters at num_y time steps. + + """ + pass + + def set_times(self, y_times: Union[np.array,jnp.array]) -> None: + """ + Generate the time_slot duration array 'transferred_time' + (here: x_times). + + The time slices depend on the oversampling of the control variables + and the boundary conditions. The times are for the intended use cases + only set once. + + Parameters + ---------- + y_times: Union[np.ndarray, jnp.ndarray, list], shape (num_y) + The time steps / durations of constant optimization variables. + num_y is the number of time steps for the raw optimization + variables. + + """ + if isinstance(y_times, list): + y_times = jnp.array(y_times) + if not isinstance(y_times, (np.ndarray,jnp.ndarray)): + raise Exception("times must be a list or (j)np.array") + + y_times = jnp.atleast_1d(jnp.squeeze(y_times)) + + if len(y_times.shape) > 1: + raise ValueError('The x_times should not have more than one ' + 'dimension!') + + self._num_y = y_times.size + self._y_times = y_times + + if self.bound_type is None: + self.num_x = self.oversampling * self._num_y + self.x_times = jnp.repeat( + self._y_times, self.oversampling) / self.oversampling + + elif self.bound_type[0] == 'n': + self.num_x = self.oversampling * self._num_y + 2 \ + * self.bound_type[1] + self.x_times = jnp.concatenate(( + self._y_times[0] / self.oversampling + * jnp.ones(self.bound_type[1]), + jnp.repeat( + self._y_times / self.oversampling, self.oversampling), + self._y_times[-1] / self.oversampling + * jnp.ones(self.bound_type[1]))) + + elif self.bound_type[0] == 'x': + self.num_x = self.oversampling * (self._num_y + + 2 * self.bound_type[1]) + self.x_times = jnp.concatenate(( + self._y_times[0] / self.oversampling + * jnp.ones(self.bound_type[1] * self.oversampling), + jnp.repeat(self._y_times / self.oversampling, + self.oversampling), + self._y_times[-1] / self.oversampling + * jnp.ones(self.bound_type[1] * self.oversampling))) + + elif self.bound_type[0] == 'right_n': + self.num_x = self.oversampling * self._num_y + self.bound_type[1] + self.x_times = np.concatenate(( + jnp.repeat(self._y_times / self.oversampling, + self.oversampling), + self._y_times[-1] / self.oversampling + * jnp.ones(self.bound_type[1]))) + + else: + raise ValueError('The boundary type ' + str(self.bound_type[0]) + + ' is not implemented!') + + def set_absolute_times(self, + absolute_y_times: Union[np.array,jnp.array,list] + ) -> None: + """ + Generate the time_slot duration array 'transferred_time' + (here: x_times) + + This time slices depend on the oversampling of the control variables + and the boundary conditions. The differences of the absolute times + give the time steps x_times. + + Parameters + ---------- + absolute_y_times: Union[np.array,jnp.array,list] + Absolute times of the start / end of each time segment for the raw + optimization parameters. + + """ + if isinstance(absolute_y_times, list): + absolute_y_times = jnp.array(absolute_y_times) + if not isinstance(absolute_y_times, Union[np.array,jnp.array]): + raise Exception("times must be a list or (j)np.array") + if not jnp.all(jnp.diff(absolute_y_times) >= 0): + raise Exception("times must be sorted") + + self._absolute_y_times = absolute_y_times + self.set_times(jnp.diff(absolute_y_times)) + + def plot_pulse(self, y: Union[np.array,jnp.array]) -> None: + """ + + Plot the control amplitudes corresponding to the given optimisation + variables. + + Parameters + ---------- + y: array, shape (num_y, num_par) + Raw optimization parameters. + + """ + + x = self(y) + #plotting not good with jnp(?) + x, y = np.array(x), np.array(y) + n_padding_start, n_padding_end = self.num_padding_elements + for y_per_control, x_per_control in zip(y.T, x.T): + plt.figure() + plt.bar(np.cumsum(self.x_times) - .5 * self.x_times[0], + x_per_control, self.x_times[0]) + plt.bar(np.cumsum(self._y_times) - .5 * self._y_times[0] + + np.cumsum(self._y_times)[n_padding_start] + - self._y_times[n_padding_start], + y_per_control, self._y_times[0], + fill=False) + plt.show() + + +class IdentityTFJAX(TransferFunctionJAX): + """See docstring of class w/o JAX.""" + + def __init__(self, num_ctrls=1): + super().__init__( + bound_type=None, + oversampling=1, + num_ctrls=num_ctrls, + offset=0. + ) + self.name = 'Identity' + + def __call__(self, y: Union[np.array,jnp.array]) -> jnp.array: + """See base class. """ + return jnp.asarray(y) + + def gradient_chain_rule( + self, deriv_by_transferred_par: Union[np.array,jnp.array] + ) -> jnp.array: + """See base class. """ + return jnp.asarray(deriv_by_transferred_par) + + +class OversamplingTFJAX(TransferFunctionJAX): + """See docstring of class w/o JAX.""" + + def __init__(self, + num_ctrls: int = 1, + bound_type: Optional[Tuple[str, int]] = None, + oversampling: int = 1 + ): + super().__init__( + num_ctrls=num_ctrls, + bound_type=bound_type, + oversampling=oversampling + ) + + def _calculate_transfer_matrix(self): + """Overrides the base class method. """ + raise NotImplementedError + + def __call__(self, y: Union[np.array,jnp.array]) -> jnp.array: + """Calculate the transferred optimization parameters (x). + + Only the oversampling and boundaries are taken into account. + + Parameters + ---------- + y: Union[np.array,jnp.array], shape (num_y, num_par) + Raw optimization variables; num_y is the number of time slices of + the raw optimization parameters and num_par is the number of + distinct raw optimization parameters. + + Returns + ------- + u: jnp.array, shape (num_x, num_par) + Control parameters; num_u is the number of times slices for the + transferred optimization parameters. + + """ + # oversample pulse by repetition + u = jnp.repeat(y, self.oversampling, axis=0) + + # add the padding elements + padding_start, padding_end = self.num_padding_elements + + u = jnp.concatenate( + (jnp.zeros((padding_start, self.num_ctrls)), + u, + jnp.zeros((padding_end, self.num_ctrls))), axis=0) + + return u + + def gradient_chain_rule( + self, deriv_by_transferred_par: Union[np.array,jnp.array] + ) -> jnp.array: + """ + See base class. + + Processing without transfer matrix. + + Parameters + ---------- + deriv_by_transferred_par: Union[np.array,jnp.array], + shape (num_x, num_f, num_par) + The gradients of num_f functions by num_par optimization parameters + at num_x different time steps. + + Returns + ------- + deriv_by_opt_par: jnp.array, shape: (num_y, num_f, num_par) + The derivatives by the optimization parameters at num_y time steps. + + """ + + shape = deriv_by_transferred_par.shape + assert len(shape) == 3 + assert shape[0] == self.num_x + assert shape[2] == self.num_ctrls + + # delete the padding elements + padding_start, padding_end = self.num_padding_elements + + # deriv_by_ctrl_amps: shape (num_x, num_f, num_par) + if padding_end > 0: + cropped_derivs = deriv_by_transferred_par[ + padding_start:-padding_end, :, :] + else: + cropped_derivs = deriv_by_transferred_par[ + padding_start:, :, :] + + cropped_derivs = jnp.expand_dims(cropped_derivs, axis=1) + cropped_derivs = jnp.reshape( + cropped_derivs, ( + self._num_y, + self.oversampling, + cropped_derivs.shape[2], + cropped_derivs.shape[3] + ) + ) + deriv_by_opt_par = jnp.sum(cropped_derivs, axis=1) + return deriv_by_opt_par + + +#### + +class LinearInterpTFJAX(TransferFunctionJAX): + """See docstring of class w/o JAX.""" + + def __init__(self, + num_ctrls: int = 1, + bound_type: Optional[Tuple[str, int]] = None, + oversampling: int = 1 + ): + super().__init__( + num_ctrls=num_ctrls, + bound_type=bound_type, + oversampling=oversampling + ) + + def _calculate_transfer_matrix(self): + """Overrides the base class method. """ + raise NotImplementedError + + def __call__(self, y: Union[np.array,jnp.array]) -> jnp.array: + """Calculate the transferred optimization parameters (x). + + Only the oversampling and boundaries are taken into account. + + Parameters + ---------- + y: Union[np.array,jnp.array], shape (num_y, num_par) + Raw optimization variables; num_y is the number of time slices of + the raw optimization parameters and num_par is the number of + distinct raw optimization parameters. + + Returns + ------- + u: jnp.array, shape (num_x, num_par) + Control parameters; num_u is the number of times slices for the + transferred optimization parameters. + + """ + # oversample pulse by repetition + # u = jnp.repeat(y, self.oversampling, axis=0) + + x_arr_old, x_arr_new = \ + jnp.linspace(0,y.shape[0],y.shape[0],endpoint=False), \ + jnp.linspace(0,y.shape[0],y.shape[0]*self.oversampling,endpoint=False) + #as coded now has base at beginning of time interval + u = jnp.moveaxis(vmap(jnp.interp,in_axes=(None,None,1))(x_arr_new,x_arr_old,y),0,1) + + # add the padding elements + #TODO: not implemented as not used so far + if self.num_padding_elements[0] != 0 or self.num_padding_elements[1] != 0: + raise NotImplementedError + # padding_start, padding_end = self.num_padding_elements + + # u = jnp.concatenate( + # (jnp.zeros((padding_start, self.num_ctrls)), + # u, + # jnp.zeros((padding_end, self.num_ctrls))), axis=0) + + return u + + def gradient_chain_rule( + self, deriv_by_transferred_par: Union[np.array,jnp.array] + ) -> jnp.array: + """ + See base class. + + Processing without transfer matrix. + + Parameters + ---------- + deriv_by_transferred_par: Union[np.array,jnp.array], + shape (num_x, num_f, num_par) + The gradients of num_f functions by num_par optimization parameters + at num_x different time steps. + + Returns + ------- + deriv_by_opt_par: jnp.array, shape: (num_y, num_f, num_par) + The derivatives by the optimization parameters at num_y time steps. + + """ + + shape = deriv_by_transferred_par.shape + + assert len(shape) == 3 + assert shape[0] == self.num_x + assert shape[2] == self.num_ctrls + # assert self.num_x//self.oversampling > 3 #to avoid complications + # assert self.x//self.oversampling == + + # delete the padding elements + if self.num_padding_elements[0] != 0 or self.num_padding_elements[1] != 0: + raise NotImplementedError + # padding_start, padding_end = self.num_padding_elements + m_arr = jnp.arange(0,self.oversampling)/self.oversampling + len_m = len(m_arr) + + deriv_by_opt_par = np.empty((self.num_x//self.oversampling,shape[1],shape[2])) + + + deriv_by_opt_par[0,:,:] = jnp.sum(deriv_by_transferred_par[0:self.oversampling]*(1-m_arr[:,np.newaxis,np.newaxis]),axis=0) + + deriv_by_opt_par[self.num_x//self.oversampling-1,:,:] = jnp.sum(deriv_by_transferred_par[self.oversampling*(self.num_x//self.oversampling-2):self.oversampling*(self.num_x//self.oversampling -1)]*m_arr[:,np.newaxis,np.newaxis],axis=0) + + + #slow but less memory consumption to avoid y*x shape + for i in range(1,self.num_x//self.oversampling -1): + deriv_by_opt_par[i,:,:] = jnp.sum(deriv_by_transferred_par[self.oversampling*(i-1):self.oversampling*i]*m_arr[:,np.newaxis,np.newaxis],axis=0) +\ + jnp.sum(deriv_by_transferred_par[self.oversampling*i:self.oversampling*(i+1)]*(1-m_arr[:,np.newaxis,np.newaxis]),axis=0) + + + # deriv_by_opt_par = jnp.sum(cropped_derivs, axis=1) + return jnp.asarray(deriv_by_opt_par) \ No newline at end of file