Skip to content

Commit

Permalink
Cleanup composite analysis (qiskit-community#1397)
Browse files Browse the repository at this point in the history
### Summary

Thanks to qiskit-community#1342 we can cleanup internals of `CompositeCurveAnalysis`.
Not API break and no feature upgrade with this PR.

### Details and comments

Previously the curve data and fit summary data are internally created in
`CurveAnalysis` but immediately discarded. The implementation in
`CurveAnalysis._run_analysis` is manually copied to
`CompositeCurveAnalysis._run_analysis` to access these artifact data to
create composite artifact data from them. This makes code fragile since
developers needed to manually update both base classes. With this PR,
implementation of component analysis is encapsulated.
  • Loading branch information
nkanazawa1989 authored Apr 22, 2024
1 parent e9acd22 commit cb37d42
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 106 deletions.
2 changes: 1 addition & 1 deletion qiskit_experiments/curve_analysis/base_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def _create_curve_data(
"""
samples = []

for model_name, sub_data in list(curve_data.groupby("model_name")):
for model_name, sub_data in list(curve_data.dataframe.groupby("model_name")):
raw_datum = AnalysisResultData(
name=DATA_ENTRY_PREFIX + self.__class__.__name__,
value={
Expand Down
150 changes: 45 additions & 105 deletions qiskit_experiments/curve_analysis/composite_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# pylint: disable=invalid-name
import warnings
from typing import Dict, List, Optional, Tuple, Union
from collections import defaultdict

import lmfit
import numpy as np
import pandas as pd
from uncertainties import unumpy as unp

from qiskit.utils.deprecation import deprecate_func

Expand All @@ -39,10 +39,9 @@
)

from qiskit_experiments.framework.containers import FigureType, ArtifactData
from .base_curve_analysis import DATA_ENTRY_PREFIX, BaseCurveAnalysis, PARAMS_ENTRY_PREFIX
from .base_curve_analysis import BaseCurveAnalysis
from .curve_data import CurveFitResult
from .scatter_table import ScatterTable
from .utils import eval_with_uncertainties


class CompositeCurveAnalysis(BaseAnalysis):
Expand Down Expand Up @@ -344,123 +343,64 @@ def _run_analysis(
else:
plot = getattr(self, "_generate_figures", "always")

fit_dataset = {}
curve_data_set = []
for analysis in self._analyses:
analysis._initialize(experiment_data)
analysis.set_options(plot=False)

metadata = analysis.options.extra.copy()
sub_artifacts = defaultdict(list)
for source_analysis in self._analyses:
analysis = source_analysis.copy()
metadata = analysis.options.extra
metadata["group"] = analysis.name
analysis.set_options(
plot=False,
extra=metadata,
return_fit_parameters=self.options.return_fit_parameters,
return_data_points=self.options.return_data_points,
)
results, _ = analysis._run_analysis(experiment_data)
for res in results:
if isinstance(res, ArtifactData):
sub_artifacts[res.name].append((analysis.name, res.data))
else:
result_data.append(res)

if "curve_data" in sub_artifacts:
combined_curve_data = ScatterTable.from_dataframe(
data=pd.concat([d.dataframe for _, d in sub_artifacts["curve_data"]])
)
artifacts.append(ArtifactData(name="curve_data", data=combined_curve_data))
else:
combined_curve_data = None

table = analysis._format_data(analysis._run_data_processing(experiment_data.data()))
formatted_subset = table.filter(category=analysis.options.fit_category)
fit_data = analysis._run_curve_fit(formatted_subset)
fit_dataset[analysis.name] = fit_data

if fit_data.success:
quality = analysis._evaluate_quality(fit_data)
else:
quality = "bad"

if self.options.return_fit_parameters:
# Store fit status overview entry regardless of success.
# This is sometime useful when debugging the fitting code.
overview = AnalysisResultData(
name=PARAMS_ENTRY_PREFIX + analysis.name,
value=fit_data,
quality=quality,
extra=metadata,
)
result_data.append(overview)

if fit_data.success:
# Add fit data to curve data table
model_names = analysis.model_names()
for series_id, sub_data in formatted_subset.iter_by_series_id():
xval = sub_data.x
if len(xval) == 0:
# If data is empty, skip drawing this model.
# This is the case when fit model exist but no data to fit is provided.
continue
# Compute X, Y values with fit parameters.
xval_arr_fit = np.linspace(np.min(xval), np.max(xval), num=100, dtype=float)
uval_arr_fit = eval_with_uncertainties(
x=xval_arr_fit,
model=analysis.models[series_id],
params=fit_data.ufloat_params,
)
yval_arr_fit = unp.nominal_values(uval_arr_fit)
if fit_data.covar is not None:
yerr_arr_fit = unp.std_devs(uval_arr_fit)
else:
yerr_arr_fit = np.zeros_like(xval_arr_fit)
for xval, yval, yerr in zip(xval_arr_fit, yval_arr_fit, yerr_arr_fit):
table.add_row(
xval=xval,
yval=yval,
yerr=yerr,
series_name=model_names[series_id],
series_id=series_id,
category="fitted",
analysis=analysis.name,
)
result_data.extend(
analysis._create_analysis_results(
fit_data=fit_data,
quality=quality,
**metadata.copy(),
)
)

if self.options.return_data_points:
# Add raw data points
warnings.warn(
f"{DATA_ENTRY_PREFIX + self.name} has been moved to experiment data artifacts. "
"Saving this result with 'return_data_points'=True will be disabled in "
"Qiskit Experiments 0.7.",
DeprecationWarning,
)
result_data.extend(
analysis._create_curve_data(curve_data=formatted_subset, **metadata)
)

curve_data_set.append(table)

combined_curve_data = ScatterTable.from_dataframe(
pd.concat([d.dataframe for d in curve_data_set])
)
total_quality = self._evaluate_quality(fit_dataset)
if "fit_summary" in sub_artifacts:
combined_summary = dict(sub_artifacts["fit_summary"])
artifacts.append(ArtifactData(name="fit_summary", data=combined_summary))
total_quality = self._evaluate_quality(combined_summary)
else:
combined_summary = None
total_quality = "No Information"

# After the quality is determined, plot can become a boolean flag for whether
# to generate the figure
plot_bool = plot == "always" or (plot == "selective" and total_quality == "bad")

# Create analysis results by combining all fit data
if all(fit_data.success for fit_data in fit_dataset.values()):
if combined_summary and all(fit_data.success for fit_data in combined_summary.values()):
composite_results = self._create_analysis_results(
fit_data=fit_dataset, quality=total_quality, **self.options.extra.copy()
fit_data=combined_summary,
quality=total_quality,
**self.options.extra.copy(),
)
result_data.extend(composite_results)
else:
composite_results = []

artifacts.append(
ArtifactData(
name="curve_data",
data=combined_curve_data,
)
)
artifacts.append(
ArtifactData(
name="fit_summary",
data=fit_dataset,
)
)

if plot_bool:
if plot_bool and combined_curve_data:
if combined_summary:
red_chi_dict = {
k: v.reduced_chisq for k, v in combined_summary.items() if v.success
}
else:
red_chi_dict = {}
self.plotter.set_supplementary_data(
fit_red_chi={k: v.reduced_chisq for k, v in fit_dataset.items() if v.success},
fit_red_chi=red_chi_dict,
primary_results=composite_results,
)
figures.extend(self._create_figures(curve_data=combined_curve_data))
Expand Down

0 comments on commit cb37d42

Please sign in to comment.