Skip to content

Commit

Permalink
ENH: Add on_bad_lines for pyarrow (SQUASHED)
Browse files Browse the repository at this point in the history
  • Loading branch information
amithkk committed Aug 20, 2023
1 parent 9d70a49 commit 62c873b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 6 deletions.
26 changes: 26 additions & 0 deletions pandas/io/parsers/arrow_parser_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import warnings

from pandas._config import using_pyarrow_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.errors import ParserWarning
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.inference import is_integer

Expand Down Expand Up @@ -85,6 +88,29 @@ def _get_pyarrow_options(self) -> None:
and option_name
in ("delimiter", "quote_char", "escape_char", "ignore_empty_lines")
}

if "on_bad_lines" in self.kwds:
if callable(self.kwds["on_bad_lines"]):
self.parse_options["invalid_row_handler"] = self.kwds["on_bad_lines"]
elif self.kwds["on_bad_lines"] == ParserBase.BadLineHandleMethod.ERROR:
self.parse_options[
"invalid_row_handler"
] = None # PyArrow raises an exception by default
elif self.kwds["on_bad_lines"] == ParserBase.BadLineHandleMethod.WARN:

def handle_warning(invalid_row):
warnings.warn(
f"Expected {invalid_row.expected_columns} columns, but found "
f"{invalid_row.actual_columns}: {invalid_row.text}",
ParserWarning,
stacklevel=find_stack_level(),
)
return "skip"

self.parse_options["invalid_row_handler"] = handle_warning
elif self.kwds["on_bad_lines"] == ParserBase.BadLineHandleMethod.SKIP:
self.parse_options["invalid_row_handler"] = lambda _: "skip"

self.convert_options = {
option_name: option_value
for option_name, option_value in self.kwds.items()
Expand Down
13 changes: 10 additions & 3 deletions pandas/io/parsers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,13 @@
expected, a ``ParserWarning`` will be emitted while dropping extra elements.
Only supported when ``engine='python'``
.. versionchanged:: 1.4.1
- Callable, function with signature
as described in `pyarrow documentation
<https://arrow.apache.org/docs/python/generated/pyarrow.csv.ParseOptions.html
#pyarrow.csv.ParseOptions.invalid_row_handler>_` when ``engine='pyarrow'``
delim_whitespace : bool, default False
Specifies whether or not whitespace (e.g. ``' '`` or ``'\\t'``) will be
used as the ``sep`` delimiter. Equivalent to setting ``sep='\\s+'``. If this option
Expand Down Expand Up @@ -483,7 +490,6 @@ class _Fwf_Defaults(TypedDict):
"thousands",
"memory_map",
"dialect",
"on_bad_lines",
"delim_whitespace",
"quoting",
"lineterminator",
Expand Down Expand Up @@ -2038,9 +2044,10 @@ def _refine_defaults_read(
elif on_bad_lines == "skip":
kwds["on_bad_lines"] = ParserBase.BadLineHandleMethod.SKIP
elif callable(on_bad_lines):
if engine != "python":
if engine not in ["python", "pyarrow"]:
raise ValueError(
"on_bad_line can only be a callable function if engine='python'"
"on_bad_line can only be a callable function "
"if engine='python' or 'pyarrow'"
)
kwds["on_bad_lines"] = on_bad_lines
else:
Expand Down
10 changes: 7 additions & 3 deletions pandas/tests/io/parser/test_unsupported.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,17 @@ def test_pyarrow_engine(self):
with pytest.raises(ValueError, match=msg):
read_csv(StringIO(data), engine="pyarrow", **kwargs)

def test_on_bad_lines_callable_python_only(self, all_parsers):
def test_on_bad_lines_callable_python_or_pyarrow(self, all_parsers):
# GH 5686
# GH 54643
sio = StringIO("a,b\n1,2")
bad_lines_func = lambda x: x
parser = all_parsers
if all_parsers.engine != "python":
msg = "on_bad_line can only be a callable function if engine='python'"
if all_parsers.engine not in ["python", "pyarrow"]:
msg = (
"on_bad_line can only be a callable "
"function if engine='python' or 'pyarrow'"
)
with pytest.raises(ValueError, match=msg):
parser.read_csv(sio, on_bad_lines=bad_lines_func)
else:
Expand Down

0 comments on commit 62c873b

Please sign in to comment.