From e73481d288d3c86c230e21a5b0939d95aa640abb Mon Sep 17 00:00:00 2001 From: Alessandro Candido Date: Fri, 31 Mar 2023 12:55:54 +0200 Subject: [PATCH] Split managers creation out of main function --- src/eko/couplings.py | 4 +- src/eko/evolution_operator/grid.py | 3 +- src/eko/io/runcards.py | 2 +- src/eko/runner/commons.py | 49 +++++++++++++++++++++++ src/eko/runner/legacy.py | 46 +++++---------------- src/eko/thresholds.py | 14 ++++--- tests/eko/evolution_operator/test_grid.py | 7 ++-- 7 files changed, 76 insertions(+), 49 deletions(-) diff --git a/src/eko/couplings.py b/src/eko/couplings.py index 9983952bb..dd58d2319 100644 --- a/src/eko/couplings.py +++ b/src/eko/couplings.py @@ -9,7 +9,7 @@ """ import logging import warnings -from typing import List +from typing import Iterable, List import numba as nb import numpy as np @@ -426,7 +426,7 @@ def __init__( method: CouplingEvolutionMethod, masses: List[float], hqm_scheme: QuarkMassScheme, - thresholds_ratios: List[float], + thresholds_ratios: Iterable[float], ): # Sanity checks def assert_positive(name, var): diff --git a/src/eko/evolution_operator/grid.py b/src/eko/evolution_operator/grid.py index 3e292a49b..05bf295cb 100644 --- a/src/eko/evolution_operator/grid.py +++ b/src/eko/evolution_operator/grid.py @@ -7,6 +7,7 @@ import logging import numbers +from typing import List import numpy as np import numpy.typing as npt @@ -41,7 +42,7 @@ def __init__( self, mu2grid: npt.NDArray, order: tuple, - masses: tuple, + masses: List[float], mass_scheme, intrinsic_flavors: list, xif: float, diff --git a/src/eko/io/runcards.py b/src/eko/io/runcards.py index d8afdfc32..40f27678a 100644 --- a/src/eko/io/runcards.py +++ b/src/eko/io/runcards.py @@ -55,7 +55,7 @@ def masses(theory: TheoryCard, evmeth: EvolutionMethod): theory.couplings, theory.order, couplings_mod_ev(evmeth), - np.power(theory.heavy.matching_ratios, 2.0), + np.power(theory.heavy.matching_ratios, 2.0).tolist(), xif2=theory.xif**2, ).tolist() if theory.heavy.masses_scheme is QuarkMassScheme.POLE: diff --git a/src/eko/runner/commons.py b/src/eko/runner/commons.py index 24eb37da2..7b49fab28 100644 --- a/src/eko/runner/commons.py +++ b/src/eko/runner/commons.py @@ -1,4 +1,12 @@ """Runners common utilities.""" +import numpy as np + +from ..couplings import Couplings, couplings_mod_ev +from ..interpolation import InterpolatorDispatcher +from ..io import runcards +from ..io.runcards import OperatorCard, TheoryCard +from ..io.types import ScaleVariationsMethod +from ..thresholds import ThresholdsAtlas BANNER = r""" oooooooooooo oooo oooo \\ .oooooo. @@ -9,3 +17,44 @@ 888 o 888 `88b. `88b // d88' o888ooooood8 o888o o888o `Y8bood8P' """ + + +def interpolator(operator: OperatorCard) -> InterpolatorDispatcher: + """Create interpolator from runcards.""" + return InterpolatorDispatcher( + xgrid=operator.xgrid, + polynomial_degree=operator.configs.interpolation_polynomial_degree, + ) + + +def threshold_atlas(theory: TheoryCard, operator: OperatorCard) -> ThresholdsAtlas: + """Create thresholds atlas from runcards.""" + thresholds_ratios = np.power(theory.heavy.matching_ratios, 2.0) + # TODO: cache result + masses = runcards.masses(theory, operator.configs.evolution_method) + return ThresholdsAtlas( + masses=masses, + q2_ref=operator.mu20, + nf_ref=theory.heavy.num_flavs_init, + thresholds_ratios=thresholds_ratios, + max_nf=theory.heavy.num_flavs_max_pdf, + ) + + +def couplings(theory: TheoryCard, operator: OperatorCard) -> Couplings: + """Create couplings from runcards.""" + thresholds_ratios = np.power(theory.heavy.matching_ratios, 2.0) + masses = runcards.masses(theory, operator.configs.evolution_method) + return Couplings( + couplings=theory.couplings, + order=theory.order, + method=couplings_mod_ev(operator.configs.evolution_method), + masses=masses, + hqm_scheme=theory.heavy.masses_scheme, + thresholds_ratios=thresholds_ratios + * ( + theory.xif**2 + if operator.configs.scvar_method == ScaleVariationsMethod.EXPONENTIATED + else 1.0 + ), + ) diff --git a/src/eko/runner/legacy.py b/src/eko/runner/legacy.py index 94ae2281d..be49af4f6 100644 --- a/src/eko/runner/legacy.py +++ b/src/eko/runner/legacy.py @@ -3,14 +3,9 @@ import os from typing import Union -import numpy as np - -from .. import interpolation -from ..couplings import Couplings, couplings_mod_ev from ..evolution_operator.grid import OperatorGrid from ..io import EKO, Operator, runcards -from ..io.types import RawCard, ScaleVariationsMethod -from ..thresholds import ThresholdsAtlas +from ..io.types import RawCard from . import commons logger = logging.getLogger(__name__) @@ -55,41 +50,18 @@ def __init__( self._theory = new_theory # setup basis grid - bfd = interpolation.InterpolatorDispatcher( - xgrid=new_operator.xgrid, - polynomial_degree=new_operator.configs.interpolation_polynomial_degree, - ) - - # setup the Threshold path, compute masses if necessary - masses = runcards.masses(new_theory, new_operator.configs.evolution_method) + bfd = commons.interpolator(new_operator) # call explicitly iter to explain the static analyzer that is an # iterable - thresholds_ratios = np.power(list(iter(new_theory.heavy.matching_ratios)), 2.0) - tc = ThresholdsAtlas( - masses=masses, - q2_ref=new_operator.mu20, - nf_ref=new_theory.heavy.num_flavs_init, - thresholds_ratios=thresholds_ratios, - max_nf=new_theory.heavy.num_flavs_max_pdf, - ) + tc = commons.threshold_atlas(new_theory, new_operator) # strong coupling - sc = Couplings( - couplings=new_theory.couplings, - order=new_theory.order, - method=couplings_mod_ev(new_operator.configs.evolution_method), - masses=masses, - hqm_scheme=new_theory.heavy.masses_scheme, - thresholds_ratios=thresholds_ratios - * ( - new_theory.xif**2 - if new_operator.configs.scvar_method - == ScaleVariationsMethod.EXPONENTIATED - else 1.0 - ), - ) - # setup operator grid + cs = commons.couplings(new_theory, new_operator) # setup operator grid + + # compute masses if required + masses = runcards.masses(new_theory, new_operator.configs.evolution_method) + self.op_grid = OperatorGrid( mu2grid=new_operator.mu2grid, order=new_theory.order, @@ -100,7 +72,7 @@ def __init__( configs=new_operator.configs, debug=new_operator.debug, thresholds_config=tc, - couplings=sc, + couplings=cs, interpol_dispatcher=bfd, ) diff --git a/src/eko/thresholds.py b/src/eko/thresholds.py index eb0bd7fa2..763509352 100644 --- a/src/eko/thresholds.py +++ b/src/eko/thresholds.py @@ -1,7 +1,7 @@ r"""Holds the classes that define the |FNS|.""" import logging from dataclasses import astuple, dataclass -from typing import List, Optional +from typing import Iterable, List, Optional import numpy as np @@ -47,7 +47,7 @@ def __init__( masses: List[float], q2_ref: Optional[float] = None, nf_ref: Optional[int] = None, - thresholds_ratios: Optional[List[float]] = None, + thresholds_ratios: Optional[Iterable[float]] = None, max_nf: Optional[int] = None, ): """Create basic atlas. @@ -69,6 +69,12 @@ def __init__( sorted_masses = sorted(masses) if not np.allclose(masses, sorted_masses): raise ValueError("masses need to be sorted") + + if thresholds_ratios is None: + thresholds_ratios = [1.0, 1.0, 1.0] + else: + thresholds_ratios = list(thresholds_ratios) + # combine them thresholds = self.build_area_walls(sorted_masses, thresholds_ratios, max_nf) self.area_walls = [0] + thresholds + [np.inf] @@ -126,7 +132,7 @@ def ffns(cls, nf: int, q2_ref: Optional[float] = None): @staticmethod def build_area_walls( masses: List[float], - thresholds_ratios: Optional[List[float]] = None, + thresholds_ratios: List[float], max_nf: Optional[int] = None, ): r"""Create the object from the informations on the run card. @@ -150,8 +156,6 @@ def build_area_walls( """ if len(masses) != 3: raise ValueError("There have to be 3 quark masses") - if thresholds_ratios is None: - thresholds_ratios = [1.0, 1.0, 1.0] if len(thresholds_ratios) != 3: raise ValueError("There have to be 3 quark threshold ratios") if max_nf is None: diff --git a/tests/eko/evolution_operator/test_grid.py b/tests/eko/evolution_operator/test_grid.py index a57dffefe..88a1eea2a 100644 --- a/tests/eko/evolution_operator/test_grid.py +++ b/tests/eko/evolution_operator/test_grid.py @@ -12,8 +12,9 @@ import pytest import eko.io.types -from eko.runner import legacy +from eko import couplings from eko.quantities.couplings import CouplingEvolutionMethod +from eko.runner import legacy def test_init_errors(monkeypatch, theory_ffns, operator_card, tmp_path, caplog): @@ -22,12 +23,12 @@ class FakeEM(enum.Enum): BLUB = "blub" monkeypatch.setattr( - legacy, + couplings, "couplings_mod_ev", lambda *args: CouplingEvolutionMethod.EXACT, ) operator_card.configs.evolution_method = FakeEM.BLUB - with pytest.raises(ValueError, match="blub"): + with pytest.raises(ValueError, match="BLUB"): legacy.Runner(theory_ffns(3), operator_card, path=tmp_path / "eko.tar") # check LO