Skip to content

Commit

Permalink
refactor and add frequency training
Browse files Browse the repository at this point in the history
  • Loading branch information
jpmoutinho committed Dec 19, 2023
1 parent 3d34caa commit 5c4175b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 19 deletions.
54 changes: 37 additions & 17 deletions qadence/constructors/feature_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from qadence.blocks import AbstractBlock, KronBlock, chain, kron, tag
from qadence.logger import get_logger
from qadence.operations import PHASE, RX, RY, RZ, H
from qadence.parameters import FeatureParameter, Parameter
from qadence.parameters import FeatureParameter, Parameter, VariationalParameter
from qadence.types import BasisSet, ReuploadScaling, TParameter

logger = get_logger(__name__)
Expand All @@ -36,13 +36,10 @@ def _set_range(fm_type: BasisSet | type[Function] | str) -> tuple[float, float]:
}


def fm_parameter(
def backwards_compatibility(
fm_type: BasisSet | type[Function] | str,
param: Parameter | str = "phi",
feature_range: tuple[float, float] | None = None,
target_range: tuple[float, float] | None = None,
) -> Parameter | Basic:
# Backwards compatibility
reupload_scaling: ReuploadScaling | Callable | str,
) -> tuple:
if fm_type in ("fourier", "chebyshev", "tower"):
logger.warning(
"Selecting `fm_type` as 'fourier', 'chebyshev' or 'tower' is deprecated. "
Expand All @@ -55,7 +52,17 @@ def fm_parameter(
fm_type = BasisSet.CHEBYSHEV
elif fm_type == "tower":
fm_type = BasisSet.CHEBYSHEV
reupload_scaling = ReuploadScaling.TOWER

return fm_type, reupload_scaling


def fm_parameter_scaling(
fm_type: BasisSet | type[Function] | str,
param: Parameter | str = "phi",
feature_range: tuple[float, float] | None = None,
target_range: tuple[float, float] | None = None,
) -> Parameter | Basic:
if isinstance(param, Parameter):
fparam = param
fparam.trainable = False
Expand All @@ -76,21 +83,28 @@ def fm_parameter(
else:
scaled_fparam = scaling * fparam + shift

return scaled_fparam


def fm_parameter_func(fm_type: BasisSet | type[Function] | str) -> type[Function]:
def ident_fn(x: TParameter) -> TParameter:
return x

# Transform feature parameter
if fm_type == BasisSet.FOURIER:
transformed_feature = scaled_fparam
transform_func = ident_fn
elif fm_type == BasisSet.CHEBYSHEV:
transformed_feature = acos(scaled_fparam)
transform_func = acos
elif inspect.isclass(fm_type) and issubclass(fm_type, Function):
transformed_feature = fm_type(scaled_fparam)
transform_func = fm_type
else:
raise NotImplementedError(
f"Feature map type {fm_type} not implemented. Choose an item from the BasisSet "
f"enum: {[bs.name for bs in BasisSet]}, or your own sympy.Function to wrap "
"the given feature parameter with."
)

return transformed_feature
return transform_func # type: ignore [return-value]


def fm_reupload_scaling_fn(
Expand Down Expand Up @@ -126,6 +140,8 @@ def feature_map(
feature_range: tuple[float, float] | None = None,
target_range: tuple[float, float] | None = None,
multiplier: Parameter | TParameter | None = None,
train_freq: bool = False,
freq_prefix: str = "w",
) -> KronBlock:
"""Construct a feature map of a given type.
Expand Down Expand Up @@ -175,14 +191,14 @@ def feature_map(
f"Please provide one from {[rot.__name__ for rot in ROTATIONS]}."
)

transformed_feature = fm_parameter(
# Backwards compatibility
fm_type, reupload_scaling = backwards_compatibility(fm_type, reupload_scaling)

scaled_fparam = fm_parameter_scaling(
fm_type, param, feature_range=feature_range, target_range=target_range
)

# Backwards compatibility
if fm_type == "tower":
logger.warning("Forcing reupload scaling strategy to TOWER")
reupload_scaling = ReuploadScaling.TOWER
transform_func = fm_parameter_func(fm_type)

basis_tag = fm_type.value if isinstance(fm_type, BasisSet) else str(fm_type)
rs_func, rs_tag = fm_reupload_scaling_fn(reupload_scaling)
Expand All @@ -192,8 +208,12 @@ def feature_map(

# Build feature map
op_list = []
fparam = scaled_fparam
for i, qubit in enumerate(support):
op_list.append(op(qubit, multiplier * rs_func(i) * transformed_feature))
if train_freq:
freq_param = VariationalParameter(freq_prefix + f"_{i}")
fparam = freq_param * scaled_fparam
op_list.append(op(qubit, multiplier * rs_func(i) * transform_func(fparam)))
fm = kron(*op_list)

fm.tag = rs_tag + " " + basis_tag + " FM"
Expand Down
10 changes: 8 additions & 2 deletions qadence/constructors/rydberg_feature_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sympy import Basic, Function

from qadence.blocks import AnalogBlock, KronBlock, kron
from qadence.constructors.feature_maps import fm_parameter
from qadence.constructors.feature_maps import fm_parameter_func, fm_parameter_scaling
from qadence.logger import get_logger
from qadence.operations import AnalogRot, AnalogRX, AnalogRY, AnalogRZ
from qadence.parameters import FeatureParameter, Parameter, VariationalParameter
Expand Down Expand Up @@ -98,9 +98,15 @@ def analog_feature_map(
multiplier: overall multiplier; this is useful for reuploading the feature map serially with
different scalings; can be a number or parameter/expression.
"""
transformed_feature = fm_parameter(

scaled_fparam = fm_parameter_scaling(
fm_type, param, feature_range=feature_range, target_range=target_range
)

transform_func = fm_parameter_func(fm_type)

transformed_feature = transform_func(scaled_fparam)

multiplier = 1.0 if multiplier is None else Parameter(multiplier)

if callable(reupload_scaling):
Expand Down

0 comments on commit 5c4175b

Please sign in to comment.