Skip to content

Commit

Permalink
Fixing series fold (#271)
Browse files Browse the repository at this point in the history
Fixing series fold to close #79
  • Loading branch information
Bidek56 authored Sep 30, 2024
1 parent ac4c358 commit 4bcc68c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 26 deletions.
42 changes: 24 additions & 18 deletions __tests__/dataframe.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -338,24 +338,30 @@ describe("dataframe", () => {
const actual = df.fold((a, b) => a.concat(b));
expect(actual).toSeriesEqual(expected);
});
// test("fold", () => {
// const s1 = pl.Series([1, 2, 3]);
// const s2 = pl.Series([4, 5, 6]);
// const s3 = pl.Series([7, 8, 1]);
// const expected = pl.Series("foo", [true, true, false]);
// const df = pl.DataFrame([s1, s2, s3]);
// const actual = df.fold((a, b) => a.lessThan(b)).alias("foo");
// expect(actual).toSeriesEqual(expected);
// });
// test("fold-again", () => {
// const s1 = pl.Series([1, 2, 3]);
// const s2 = pl.Series([4, 5, 6]);
// const s3 = pl.Series([7, 8, 1]);
// const expected = pl.Series("foo", [12, 15, 10]);
// const df = pl.DataFrame([s1, s2, s3]);
// const actual = df.fold((a, b) => a.plus(b)).alias("foo");
// expect(actual).toSeriesEqual(expected);
// });
it.each`
name | actual | expected
${"fold:lessThan"} | ${df.fold((a, b) => a.lessThan(b)).alias("foo")} | ${pl.Series("foo", [true, false, false])}
${"fold:lt"} | ${df.fold((a, b) => a.lt(b)).alias("foo")} | ${pl.Series("foo", [true, false, false])}
${"fold:lessThanEquals"} | ${df.fold((a, b) => a.lessThanEquals(b)).alias("foo")} | ${pl.Series("foo", [true, true, false])}
${"fold:ltEq"} | ${df.fold((a, b) => a.ltEq(b)).alias("foo")} | ${pl.Series("foo", [true, true, false])}
${"fold:neq"} | ${df.fold((a, b) => a.neq(b)).alias("foo")} | ${pl.Series("foo", [true, false, true])}
${"fold:plus"} | ${df.fold((a, b) => a.plus(b)).alias("foo")} | ${pl.Series("foo", [7, 4, 17])}
${"fold:minus"} | ${df.fold((a, b) => a.minus(b)).alias("foo")} | ${pl.Series("foo", [-5, 0, 1])}
${"fold:mul"} | ${df.fold((a, b) => a.mul(b)).alias("foo")} | ${pl.Series("foo", [6, 4, 72])}
`("$# $name expected matches actual", ({ expected, actual }) => {
expect(expected).toSeriesEqual(actual);
});
test("fold:lt", () => {
const s1 = pl.Series([1, 2, 3]);
const s2 = pl.Series([4, 5, 6]);
const s3 = pl.Series([7, 8, 1]);
const df = pl.DataFrame([s1, s2, s3]);
const expected = pl.Series("foo", [true, true, false]);
let actual = df.fold((a, b) => a.lessThan(b)).alias("foo");
expect(actual).toSeriesEqual(expected);
actual = df.fold((a, b) => a.lt(b)).alias("foo");
expect(actual).toSeriesEqual(expected);
});
test("frameEqual:true", () => {
const df = pl.DataFrame({
foo: [1, 2, 3],
Expand Down
48 changes: 40 additions & 8 deletions polars/series/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1528,16 +1528,32 @@ export function _Series(_s: any): Series {
return this.length;
},
lt(field) {
return dtypeWrap("Lt", field);
if (typeof field === "number") return dtypeWrap("Lt", field);
if (Series.isSeries(field)) {
return wrap("lt", (field as any)._s);
}
throw new Error("Not a number nor a series");
},
lessThan(field) {
return dtypeWrap("Lt", field);
if (typeof field === "number") return dtypeWrap("Lt", field);
if (Series.isSeries(field)) {
return wrap("lt", (field as any)._s);
}
throw new Error("Not a number nor a series");
},
ltEq(field) {
return dtypeWrap("LtEq", field);
if (typeof field === "number") return dtypeWrap("LtEq", field);
if (Series.isSeries(field)) {
return wrap("ltEq", (field as any)._s);
}
throw new Error("Not a number nor a series");
},
lessThanEquals(field) {
return dtypeWrap("LtEq", field);
if (typeof field === "number") return dtypeWrap("LtEq", field);
if (Series.isSeries(field)) {
return wrap("ltEq", (field as any)._s);
}
throw new Error("Not a number nor a series");
},
limit(n = 10) {
return wrap("limit", n);
Expand All @@ -1558,16 +1574,28 @@ export function _Series(_s: any): Series {
return wrap("mode");
},
minus(other) {
return dtypeWrap("Sub", other);
if (typeof other === "number") return dtypeWrap("Sub", other);
if (Series.isSeries(other)) {
return wrap("sub", (other as any)._s);
}
throw new Error("Not a number nor a series");
},
mul(other) {
return dtypeWrap("Mul", other);
if (typeof other === "number") return dtypeWrap("Mul", other);
if (Series.isSeries(other)) {
return wrap("mul", (other as any)._s);
}
throw new Error("Not a number nor a series");
},
nChunks() {
return _s.nChunks();
},
neq(other) {
return dtypeWrap("Neq", other);
if (typeof other === "number") return dtypeWrap("Neq", other);
if (Series.isSeries(other)) {
return wrap("neq", (other as any)._s);
}
throw new Error("Not a number nor a series");
},
notEquals(other) {
return this.neq(other);
Expand All @@ -1585,7 +1613,11 @@ export function _Series(_s: any): Series {
return expr_op("peakMin");
},
plus(other) {
return dtypeWrap("Add", other);
if (typeof other === "number") return dtypeWrap("Add", other);
if (Series.isSeries(other)) {
return wrap("add", (other as any)._s);
}
throw new Error("Not a number nor a series");
},
quantile(quantile, interpolation = "nearest") {
return _s.quantile(quantile, interpolation);
Expand Down

0 comments on commit 4bcc68c

Please sign in to comment.