Skip to content

Commit

Permalink
when
Browse files Browse the repository at this point in the history
  • Loading branch information
jacgoldsm committed Sep 30, 2023
1 parent 5e03ac2 commit 2b8bbcc
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 24 deletions.
5 changes: 3 additions & 2 deletions osos/OsosSession.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ def _parse_schema(schema):
for elem in schema:
cols.append(re.search('(.*):', elem).group(1))
return cols

elif isinstance(schema,list):
return schema
else:
raise TypeError("schema must be str or list")


def range(start: int, end: Optional[int] = None, step: int = 1, numSlices: Optional[int] = None):
Expand All @@ -54,7 +55,7 @@ def range(start: int, end: Optional[int] = None, step: int = 1, numSlices: Optio
end = start
start = 0

return DataFrame(pd.DataFrame({"0":np.arange(start,end,step)}))
return DataFrame(pd.DataFrame({"id":np.arange(start,end,step)}))


class read:
Expand Down
56 changes: 54 additions & 2 deletions osos/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,58 @@ def __ne__(self, other):
def alias(self, newname):
return Func(rename_series, self, NameString(newname, ()))

class When(Node):
"""
The `_args` in a `When` Node will have one of the following two structures:
- [condition,value,condition,value...condition,value]
- [condition,value,condition,value...True,value] (if `otherwise`)
In either case, even numbered elements are conditions, and odd-numbered
elements are values.
In the resolved form, `args` is a list of columns, where even-numbered elements are
boolean conditions, and odd-numbered elements are values.
"""
def __init__(self, condition, value):
self._name = self._when_func
if not isinstance(value,Node):
value = AbstractLit(value)
self._args = [condition,value,]

@staticmethod
def _when_func(*args: pd.Series, **kwargs):
predicted_dtype = args[-1].dtype
if np.issubdtype(predicted_dtype, np.number):
null_type = np.nan
else:
null_type = None
col = np.full(len(args[0].index), null_type)
conditions = [args[i] for i in range(len(args)) if i % 2 == 0] # even numbers
values = [args[i] for i in range(len(args)) if i % 2 == 1] # odd numbers

# `i` will loop over all the conditions in reverse order,
# so starting with `True` if `otherwise` exists
for i in reversed(range(len(conditions))):
col = np.where(conditions[i], values[i], col)

# make some effort to cast back to int if possible
# i.e. if all the replacement values were ints and there are no missings
if all(np.issubdtype(val,np.integer) for val in values) and not np.isnan(col).any():
col = col.astype(int)

return pd.Series(col)

def when(self, condition,value):
if not isinstance(value,Node):
value = AbstractLit(value)
self._args += [condition,value]
return self

def otherwise(self,value):
if not isinstance(value,Node):
value = AbstractLit(value)
self._args += [SimpleContainer(True, []),value]
return self


class ColumnList(Node):
def __init__(self, args: List["AbstractColOrLit"]):
Expand Down Expand Up @@ -186,10 +238,10 @@ def __str__(self):

class SimpleContainer(Node):
def __bool__(self):
return bool(self._args)
return bool(self._name)

def __str__(self):
return str(self._args)
return f"SimpleContainer: {str(self._name)}"

__repr__ = __str__

Expand Down
20 changes: 3 additions & 17 deletions osos/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
AbstractIndex,
Func,
SimpleContainer,
Node
Node,
When,
)
from .exceptions import AnalysisException, OsosValueError, OsosTypeError
from .dataframe import DataFrame
Expand Down Expand Up @@ -3792,10 +3793,6 @@ def when(condition: AbstractCol, value: Any) -> Func:
If :func:`osos.Col.otherwise` is not invoked, None is returned for unmatched
conditions.
Parameters
----------
Expand Down Expand Up @@ -3830,18 +3827,7 @@ def when(condition: AbstractCol, value: Any) -> Func:
| 3|
+----+
"""
# Explicitly not using AbstractColOrName type here to make reading condition less opaque
if not isinstance(condition, AbstractCol):
raise OsosTypeError(
error_class="NOT_AbstractCol",
message_parameters={
"arg_name": "condition",
"arg_type": type(condition).__name__,
},
)
v = value._jc if isinstance(value, AbstractCol) else value

raise NotImplementedError
return When(condition,value)


@overload # type: ignore[no-redef]
Expand Down
8 changes: 5 additions & 3 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from osos.dataframe import DataFrame
from osos.window import Window
from osos.functions import col
from osos import OsosSession


import numpy as np
Expand Down Expand Up @@ -57,6 +58,8 @@
u = one.withColumn("foosqrt", F.sqrt("foo"))
v = one.agg(F.median("baz").alias("baz"))
w = one.withColumn("tup", F.upper("tup"))
df = OsosSession.range(3)
x = df.select(F.when(df['id'] == 2, 3).otherwise(4).alias("age"))



Expand Down Expand Up @@ -108,11 +111,9 @@
up = one._data.assign(**{"foosqrt": np.sqrt(one._data.foo)})
vp = pd.DataFrame(one._data.agg({"baz": np.median})).T
wp = one._data.assign(tup=one._data["tup"].str.upper())
xp = pd.DataFrame({'age':[4,4,3]})



print(v._data)
print(vp)
def compares_equal(osos_dataframe: DataFrame, pandas_dataframe: pd.DataFrame) -> bool:
try:
assert_frame_equal(osos_dataframe.toPandas(), pandas_dataframe)
Expand Down Expand Up @@ -148,6 +149,7 @@ def test_functions():
assert compares_equal(s, sp)
assert compares_equal(t, tp)
assert compares_equal(u, up)
assert compares_equal(x,xp)


iris_pd = pd.read_csv(
Expand Down

0 comments on commit 2b8bbcc

Please sign in to comment.