diff --git a/nutils/function.py b/nutils/function.py index c00103214..bbd60d719 100644 --- a/nutils/function.py +++ b/nutils/function.py @@ -27,7 +27,7 @@ from typing import Tuple, Union, Type, Callable, Sequence, Any, Optional, Iterator, Iterable, Dict, Mapping, overload, List, Set, FrozenSet from . import evaluable, numeric, util, expression, types, warnings, debug_flags from .transform import EvaluableTransformChain -from .transformseq import Transforms +from .transformseq import Transforms, IdentifierTransforms import builtins, numpy, re, types as builtin_types, itertools, functools, operator, abc, numbers IntoArray = Union['Array', numpy.ndarray, bool, int, float] @@ -440,6 +440,23 @@ def argshapes(self) -> Mapping[str, Tuple[int, ...]]: shapes[arg._name] = tuple(map(int, arg.shape)) return shapes +class _SetSpace(Array): + + def __init__(self, arg: Array, space: str, value: Array): + if space not in arg.spaces: + raise ValueError('argument does not depend on space {!r}'.format(space)) + if value.ndim != 1: + raise ValueError('expected one-dimensional coordinate for space {!r}'.format(space)) + self._space = space + self._arg = arg + self._value = value + super().__init__(shape=arg.shape, dtype=arg.dtype, spaces=arg.spaces-{space}) + + def lower(self, points_shape: _PointsShape, transform_chains: _TransformChainsMap, coordinates: _CoordinatesMap) -> evaluable.Array: + value = self._value.lower(points_shape, transform_chains, coordinates) + chain = IdentifierTransforms(ndims=1, name='dummy', length=value.shape[-1]).get_evaluable(evaluable.asarray(0)) # <- identity + return self._arg.lower(points_shape, {self._space: (chain, chain), **transform_chains}, {self._space: value, **coordinates}) + class _Unlower(Array): def __init__(self, array: evaluable.Array, spaces: FrozenSet[str], points_shape: Tuple[evaluable.Array, ...], transform_chains: Tuple[EvaluableTransformChain, ...], coordinates: Tuple[evaluable.Array, ...]) -> None: