Skip to content

Commit

Permalink
fix subpkg support, more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wpbonelli committed Oct 1, 2024
1 parent 4678597 commit a7fb655
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 25 deletions.
98 changes: 91 additions & 7 deletions flopy/mf6/utils/createpackages.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def base(self) -> Optional[type]:
"""A base class from which the input context should inherit."""
l, r = self
if self == ("sim", "nam"):
return MFSimulation
return MFSimulationBase
if r is None:
return MFModel
return MFPackage
Expand Down Expand Up @@ -442,6 +442,7 @@ class Subpkg:
abbr: str
param: str
parents: List[type]
description: Optional[str]

@classmethod
def from_dfn(cls, dfn: Dfn) -> Optional["Subpkg"]:
Expand All @@ -460,11 +461,13 @@ def from_dfn(cls, dfn: Dfn) -> Optional["Subpkg"]:
def _subpkg():
line = lines["subpkg"]
_, key, abbr, param, val = line.split()
descr = dfn.get(val, dict()).get("description", None)
return {
"key": key,
"val": val,
"abbr": abbr,
"param": param,
"description": descr,
}

def _parents():
Expand Down Expand Up @@ -506,13 +509,16 @@ class VarKind(Enum):
@classmethod
def from_type(cls, t: type) -> Optional["VarKind"]:
origin = get_origin(t)
args = get_args(t)
if t is np.ndarray or origin is NDArray or origin is ArrayLike:
return VarKind.Array
elif origin is collections.abc.Iterable or origin is list:
return VarKind.List
elif origin is tuple:
return VarKind.Record
elif origin is Union:
if len(args) == 2 and args[1] is type(None):
return cls.from_type(args[0])
return VarKind.Union
try:
if issubclass(t, (bool, int, float, str)):
Expand All @@ -527,7 +533,7 @@ class Var:
"""A variable in a MODFLOW 6 input context."""

name: str
_type: Optional[type]
_type: type
block: Optional[str]
description: Optional[str]
default: Optional[Any]
Expand All @@ -540,6 +546,7 @@ class Var:
init_assign: bool = False
init_build: bool = False
init_super: bool = False
class_attr: bool = False

def __init__(
self,
Expand All @@ -558,9 +565,10 @@ def __init__(
init_assign: bool = False,
init_build: bool = False,
init_super: bool = False,
class_attr: bool = False,
):
self.name = name
self._type = _type
self._type = _type or Any
self.block = block
self.description = description
self.default = default
Expand Down Expand Up @@ -589,6 +597,9 @@ def __init__(
self.init_build = init_build
# whether to pass arg to super().__init__()
self.init_super = init_super
# whether the variable has a corresponding
# class attribute
self.class_attr = True


Vars = Dict[str, Var]
Expand Down Expand Up @@ -927,6 +938,7 @@ def _is_implicit_record():
# check if the variable references a subpackage
subpkg = subpkgs.get(_name, None)
if subpkg:
var_.init_build = False
var_.subpkg = subpkg

return var_
Expand Down Expand Up @@ -1056,8 +1068,13 @@ def _add_exg_vars(_vars: Vars) -> Vars:
key = vars_.get(k, None)
if not key:
continue
vars_[subpkg.key].init_param = False
vars_[subpkg.key].class_attr = True
vars_[subpkg.key].subpkg = None
vars_[subpkg.val] = Var(
name=subpkg.val,
description=subpkg.description,
subpkg=subpkg,
init_param=True,
init_assign=False,
init_super=False,
Expand Down Expand Up @@ -1146,8 +1163,13 @@ def _add_pkg_vars(_vars: Vars) -> Vars:
key = vars_.get(k, None)
if not key:
continue
vars_[subpkg.key].init_param = False
vars_[subpkg.key].class_attr = True
vars_[subpkg.key].subpkg = None
vars_[subpkg.val] = Var(
name=subpkg.val,
description=subpkg.description,
subpkg=subpkg,
init_param=True,
init_assign=False,
init_super=False,
Expand All @@ -1163,7 +1185,8 @@ def _add_mdl_vars(_vars: Vars) -> Vars:
if packages:
packages.init_param = False
vars_["packages"] = packages
return {

vars_ = {
"simulation": Var(
name="simulation",
_type=MFSimulation,
Expand Down Expand Up @@ -1225,6 +1248,36 @@ def _add_mdl_vars(_vars: Vars) -> Vars:
**vars_,
}

# if a reference map is provided,
# find any variables referring to
# subpackages, and attach another
# "value" variable for them all..
# allows passing data directly to
# `__init__` instead of a path to
# load the subpackage from. maybe
# impossible if the data variable
# doesn't appear in the reference
# definition, though.
if subpkgs:
for k, subpkg in subpkgs.items():
key = vars_.get(k, None)
if not key:
continue
vars_[subpkg.key].init_param = False
vars_[subpkg.key].class_attr = True
vars_[subpkg.key].subpkg = None
vars_[subpkg.val] = Var(
name=subpkg.val,
description=subpkg.description,
subpkg=subpkg,
init_param=True,
init_assign=False,
init_super=False,
init_build=False,
)

return vars_

def _add_sim_params(_vars: Vars) -> Vars:
"""Add variables for a simulation context."""
vars_ = _vars.copy()
Expand All @@ -1239,8 +1292,8 @@ def _add_sim_params(_vars: Vars) -> Vars:
var = vars_.get(k, None)
if var:
var.init_param = False
vars_[k] = var
return {
vars_[k] = var
vars_ = {
"sim_name": Var(
name="sim_name",
_type=str,
Expand Down Expand Up @@ -1309,9 +1362,40 @@ def _add_sim_params(_vars: Vars) -> Vars:
**vars_,
}

# if a reference map is provided,
# find any variables referring to
# subpackages, and attach another
# "value" variable for them all..
# allows passing data directly to
# `__init__` instead of a path to
# load the subpackage from. maybe
# impossible if the data variable
# doesn't appear in the reference
# definition, though.
if subpkgs:
for k, subpkg in subpkgs.items():
key = vars_.get(k, None)
if not key:
continue
vars_[subpkg.key].init_param = False
vars_[subpkg.key].init_build = False
vars_[subpkg.key].class_attr = True
vars_[subpkg.key].subpkg = None
vars_[subpkg.param] = Var(
name=subpkg.param,
description=subpkg.description,
subpkg=subpkg,
init_param=True,
init_assign=False,
init_super=False,
init_build=False,
)

return vars_

# add initializer method parameters
# for this particular context type
if name.base is MFSimulation:
if name.base is MFSimulationBase:
vars_ = _add_sim_params(vars_)
elif name.base is MFModel:
vars_ = _add_mdl_vars(vars_)
Expand Down
10 changes: 6 additions & 4 deletions flopy/mf6/utils/templates/attrs.jinja
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
{% for name, var in variables.items() if var.class_attr %}
{%- if var.kind == "list" or var.kind == "record" %}
{{ var.name }} = ListTemplateGenerator(("{{ component }}6", "{{ subcomponent }}", "{{ var.block }}", "{{ var.name }}"))
{%- if base != "MFSimulationBase" %}
{% for var in variables.values() if var.class_attr %}
{%- if var.kind == "list" or var.kind == "record" or var.kind == "union" %}
{{ var.name }} = ListTemplateGenerator(("{{ name.l }}6", "{{ name.r }}", "{{ var.block }}", "{{ var.name }}"))
{%- elif var.kind == "array" %}
{{ var.name }} = ArrayTemplateGenerator(("{{ component }}6", "{{ subcomponent }}", "{{ var.block }}", "{{ var.name }}"))
{{ var.name }} = ArrayTemplateGenerator(("{{ name.l }}6", "{{ name.r }}", "{{ var.block }}", "{{ var.name }}"))
{%- endif -%}
{%- endfor %}
{% endif -%}
{%- if base == "MFModel" %}
model_type = "{{ name.title }}"
{%- elif base == "MFPackage" %}
Expand Down
2 changes: 1 addition & 1 deletion flopy/mf6/utils/templates/docstring_methods.jinja
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{% if base == "MFSimulation" %}
{% if base == "MFSimulationBase" %}
load : (sim_name : str, version : string,
exe_name : str or PathLike, sim_ws : str or PathLike, strict : bool,
verbosity_level : int, load_only : list, verify_data : bool,
Expand Down
2 changes: 1 addition & 1 deletion flopy/mf6/utils/templates/docstring_params.jinja
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{% for var in variables.values() recursive %}
{% if loop.depth > 1 %}* {% endif %}{{ var.name }}{% if var._type is defined %} : {{ var._type }}{% endif %}
{% if loop.depth > 1 %}* {% endif %}{{ var.name }} : {{ var._type }}
{%- if var.description is defined and not var.is_choice %}
{{ var.description|wordwrap|indent(4 + (loop.depth * 4), first=True) }}
{%- endif %}
Expand Down
28 changes: 17 additions & 11 deletions flopy/mf6/utils/templates/init.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,29 @@ def __init__(
self,
{%- for name, var in variables.items() if var.init_param %}
{%- if var.default is defined %}
{{ name }}{% if var._type is defined%}: {{ var._type }}{% endif %} = {{ var.default }},
{{ name }}: {{ var._type }} = {{ var.default }},
{%- else -%}
{{ name }}{% if var._type is defined%}: {{ var._type }}{% endif %},
{{ name }}: {{ var._type }},
{% endif -%}
{%- endfor %}
**kwargs,
):
{% if base == "MFSimulation" %}
{% if base == "MFSimulationBase" %}
super().__init__(
{%- for name, var in variables.items() if var.init_super %}
{{ name }}={{ name }},
{%- endfor %}
**kwargs
)
{%- for name, var in variables.items() if var.block == "options" %}
{%- for name, var in variables.items() %}
{%- if var.block == "options" and var.init_build %}
self.name_file.{{ name }}.set_data({{ name }})
self.{{ name }} = self.name_file.{{ name }},
self.{{ name }} = self.name_file.{{ name }}
{% endif -%}
{%- if var.subpkg is defined %}
self{{ var.subpkg.data_name }} = self._create_package(
"{{ var.subpkg.var_name }}",
self.{{ var.subpkg.data_name }}
self.{{ var.subpkg.param }} = self._create_package(
"{{ var.subpkg.abbr }}",
{{ var.subpkg.param }}
)
{% endif -%}
{% endfor -%}
Expand All @@ -34,14 +36,18 @@ def __init__(
{%- endfor %}
**kwargs,
)
{%- for name, var in variables.items() if var.block == "options" %}
{%- for name, var in variables.items() %}
{%- if var.block == "options" and var.init_build %}
self.name_file.{{ name }}.set_data({{ name }})
self.{{ name }} = self.name_file.{{ name }}
{% endif -%}
{%- endfor %}
{% elif base == "MFPackage" %}
super().__init__(
{% if name.l == "exg" or name.l == "sln" -%}
{% if parent == "MFSimulation" -%}
simulation,
{% elif parent == "MFModel" -%}
model,
{%- endif %}
package_type="{{ name.r }}",
{%- for name, var in variables.items() if var.init_super %}
Expand All @@ -57,7 +63,7 @@ def __init__(
self.{{ name }} = {{ name }}
{% endif -%}
{%- if var.subpkg is defined -%}
self._{{ name }} = self.build_mfdata("{{ name }}", {{ name }})
self._{{ name }} = self.build_mfdata("{{ name }}", None)
self._{{ var.subpkg.abbr }}_package = self.build_child_package(
"{{ var.subpkg.abbr }}",
{{ var.subpkg.val }},
Expand Down
2 changes: 1 addition & 1 deletion flopy/mf6/utils/templates/load.jinja
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{% if base == "MFSimulation" %}
{% if base == "MFSimulationBase" %}
@classmethod
def load(
cls,
Expand Down

0 comments on commit a7fb655

Please sign in to comment.