From 2b8bbccc25c678ae75fe7a75552578bfff70ab9c Mon Sep 17 00:00:00 2001 From: jacgoldsm Date: Sat, 30 Sep 2023 12:12:43 -0400 Subject: [PATCH] when --- osos/OsosSession.py | 5 ++-- osos/column.py | 56 +++++++++++++++++++++++++++++++++++++++++++-- osos/functions.py | 20 +++------------- tests/test_basic.py | 8 ++++--- 4 files changed, 65 insertions(+), 24 deletions(-) diff --git a/osos/OsosSession.py b/osos/OsosSession.py index 011dbb6..0dcc7bb 100644 --- a/osos/OsosSession.py +++ b/osos/OsosSession.py @@ -43,9 +43,10 @@ def _parse_schema(schema): for elem in schema: cols.append(re.search('(.*):', elem).group(1)) return cols - elif isinstance(schema,list): return schema + else: + raise TypeError("schema must be str or list") def range(start: int, end: Optional[int] = None, step: int = 1, numSlices: Optional[int] = None): @@ -54,7 +55,7 @@ def range(start: int, end: Optional[int] = None, step: int = 1, numSlices: Optio end = start start = 0 - return DataFrame(pd.DataFrame({"0":np.arange(start,end,step)})) + return DataFrame(pd.DataFrame({"id":np.arange(start,end,step)})) class read: diff --git a/osos/column.py b/osos/column.py index cba7b17..fd575c6 100644 --- a/osos/column.py +++ b/osos/column.py @@ -93,6 +93,58 @@ def __ne__(self, other): def alias(self, newname): return Func(rename_series, self, NameString(newname, ())) +class When(Node): + """ + The `_args` in a `When` Node will have one of the following two structures: + - [condition,value,condition,value...condition,value] + - [condition,value,condition,value...True,value] (if `otherwise`) + In either case, even numbered elements are conditions, and odd-numbered + elements are values. + + In the resolved form, `args` is a list of columns, where even-numbered elements are + boolean conditions, and odd-numbered elements are values. + """ + def __init__(self, condition, value): + self._name = self._when_func + if not isinstance(value,Node): + value = AbstractLit(value) + self._args = [condition,value,] + + @staticmethod + def _when_func(*args: pd.Series, **kwargs): + predicted_dtype = args[-1].dtype + if np.issubdtype(predicted_dtype, np.number): + null_type = np.nan + else: + null_type = None + col = np.full(len(args[0].index), null_type) + conditions = [args[i] for i in range(len(args)) if i % 2 == 0] # even numbers + values = [args[i] for i in range(len(args)) if i % 2 == 1] # odd numbers + + # `i` will loop over all the conditions in reverse order, + # so starting with `True` if `otherwise` exists + for i in reversed(range(len(conditions))): + col = np.where(conditions[i], values[i], col) + + # make some effort to cast back to int if possible + # i.e. if all the replacement values were ints and there are no missings + if all(np.issubdtype(val,np.integer) for val in values) and not np.isnan(col).any(): + col = col.astype(int) + + return pd.Series(col) + + def when(self, condition,value): + if not isinstance(value,Node): + value = AbstractLit(value) + self._args += [condition,value] + return self + + def otherwise(self,value): + if not isinstance(value,Node): + value = AbstractLit(value) + self._args += [SimpleContainer(True, []),value] + return self + class ColumnList(Node): def __init__(self, args: List["AbstractColOrLit"]): @@ -186,10 +238,10 @@ def __str__(self): class SimpleContainer(Node): def __bool__(self): - return bool(self._args) + return bool(self._name) def __str__(self): - return str(self._args) + return f"SimpleContainer: {str(self._name)}" __repr__ = __str__ diff --git a/osos/functions.py b/osos/functions.py index 248e1c2..727f03f 100644 --- a/osos/functions.py +++ b/osos/functions.py @@ -9,7 +9,8 @@ AbstractIndex, Func, SimpleContainer, - Node + Node, + When, ) from .exceptions import AnalysisException, OsosValueError, OsosTypeError from .dataframe import DataFrame @@ -3792,10 +3793,6 @@ def when(condition: AbstractCol, value: Any) -> Func: If :func:`osos.Col.otherwise` is not invoked, None is returned for unmatched conditions. - - - - Parameters ---------- @@ -3830,18 +3827,7 @@ def when(condition: AbstractCol, value: Any) -> Func: | 3| +----+ """ - # Explicitly not using AbstractColOrName type here to make reading condition less opaque - if not isinstance(condition, AbstractCol): - raise OsosTypeError( - error_class="NOT_AbstractCol", - message_parameters={ - "arg_name": "condition", - "arg_type": type(condition).__name__, - }, - ) - v = value._jc if isinstance(value, AbstractCol) else value - - raise NotImplementedError + return When(condition,value) @overload # type: ignore[no-redef] diff --git a/tests/test_basic.py b/tests/test_basic.py index df09c53..e37a0e1 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -2,6 +2,7 @@ from osos.dataframe import DataFrame from osos.window import Window from osos.functions import col +from osos import OsosSession import numpy as np @@ -57,6 +58,8 @@ u = one.withColumn("foosqrt", F.sqrt("foo")) v = one.agg(F.median("baz").alias("baz")) w = one.withColumn("tup", F.upper("tup")) +df = OsosSession.range(3) +x = df.select(F.when(df['id'] == 2, 3).otherwise(4).alias("age")) @@ -108,11 +111,9 @@ up = one._data.assign(**{"foosqrt": np.sqrt(one._data.foo)}) vp = pd.DataFrame(one._data.agg({"baz": np.median})).T wp = one._data.assign(tup=one._data["tup"].str.upper()) +xp = pd.DataFrame({'age':[4,4,3]}) - -print(v._data) -print(vp) def compares_equal(osos_dataframe: DataFrame, pandas_dataframe: pd.DataFrame) -> bool: try: assert_frame_equal(osos_dataframe.toPandas(), pandas_dataframe) @@ -148,6 +149,7 @@ def test_functions(): assert compares_equal(s, sp) assert compares_equal(t, tp) assert compares_equal(u, up) + assert compares_equal(x,xp) iris_pd = pd.read_csv(