Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add on_bad_lines for pyarrow #54643

Merged
merged 19 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ enhancement2

Other enhancements
^^^^^^^^^^^^^^^^^^

- :func:`read_csv` now supports ``on_bad_lines`` parameter with ``engine="pyarrow"``. (:issue:`54480`)
- :meth:`ExtensionArray._explode` interface method added to allow extension type implementations of the ``explode`` method (:issue:`54833`)
- DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`)
-
Expand Down
45 changes: 39 additions & 6 deletions pandas/io/parsers/arrow_parser_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import warnings

from pandas._config import using_pyarrow_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.errors import (
ParserError,
ParserWarning,
)
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.inference import is_integer

Expand Down Expand Up @@ -85,6 +91,30 @@ def _get_pyarrow_options(self) -> None:
and option_name
in ("delimiter", "quote_char", "escape_char", "ignore_empty_lines")
}

on_bad_lines = self.kwds.get("on_bad_lines")
if on_bad_lines is not None:
if callable(on_bad_lines):
self.parse_options["invalid_row_handler"] = on_bad_lines
elif on_bad_lines == ParserBase.BadLineHandleMethod.ERROR:
self.parse_options[
"invalid_row_handler"
] = None # PyArrow raises an exception by default
elif on_bad_lines == ParserBase.BadLineHandleMethod.WARN:

def handle_warning(invalid_row):
warnings.warn(
f"Expected {invalid_row.expected_columns} columns, but found "
f"{invalid_row.actual_columns}: {invalid_row.text}",
ParserWarning,
stacklevel=find_stack_level(),
)
return "skip"

self.parse_options["invalid_row_handler"] = handle_warning
elif on_bad_lines == ParserBase.BadLineHandleMethod.SKIP:
self.parse_options["invalid_row_handler"] = lambda _: "skip"

self.convert_options = {
option_name: option_value
for option_name, option_value in self.kwds.items()
Expand Down Expand Up @@ -190,12 +220,15 @@ def read(self) -> DataFrame:
pyarrow_csv = import_optional_dependency("pyarrow.csv")
self._get_pyarrow_options()

table = pyarrow_csv.read_csv(
self.src,
read_options=pyarrow_csv.ReadOptions(**self.read_options),
parse_options=pyarrow_csv.ParseOptions(**self.parse_options),
convert_options=pyarrow_csv.ConvertOptions(**self.convert_options),
)
try:
table = pyarrow_csv.read_csv(
self.src,
read_options=pyarrow_csv.ReadOptions(**self.read_options),
parse_options=pyarrow_csv.ParseOptions(**self.parse_options),
convert_options=pyarrow_csv.ConvertOptions(**self.convert_options),
)
except pa.ArrowInvalid as e:
raise ParserError(e) from e

dtype_backend = self.kwds["dtype_backend"]

Expand Down
13 changes: 10 additions & 3 deletions pandas/io/parsers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,13 @@
expected, a ``ParserWarning`` will be emitted while dropping extra elements.
Only supported when ``engine='python'``

.. versionchanged:: 2.2.0

- Callable, function with signature
as described in `pyarrow documentation
<https://arrow.apache.org/docs/python/generated/pyarrow.csv.ParseOptions.html
#pyarrow.csv.ParseOptions.invalid_row_handler>_` when ``engine='pyarrow'``

delim_whitespace : bool, default False
Specifies whether or not whitespace (e.g. ``' '`` or ``'\\t'``) will be
used as the ``sep`` delimiter. Equivalent to setting ``sep='\\s+'``. If this option
Expand Down Expand Up @@ -494,7 +501,6 @@ class _Fwf_Defaults(TypedDict):
"thousands",
"memory_map",
"dialect",
"on_bad_lines",
"delim_whitespace",
"quoting",
"lineterminator",
Expand Down Expand Up @@ -2142,9 +2148,10 @@ def _refine_defaults_read(
elif on_bad_lines == "skip":
kwds["on_bad_lines"] = ParserBase.BadLineHandleMethod.SKIP
elif callable(on_bad_lines):
if engine != "python":
if engine not in ["python", "pyarrow"]:
raise ValueError(
"on_bad_line can only be a callable function if engine='python'"
"on_bad_line can only be a callable function "
"if engine='python' or 'pyarrow'"
)
kwds["on_bad_lines"] = on_bad_lines
else:
Expand Down
39 changes: 35 additions & 4 deletions pandas/tests/io/parser/common/test_read_errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Tests that work on both the Python and C engines but do not have a
Tests that work on the Python, C and PyArrow engines but do not have a
specific classification into the other test modules.
"""
import codecs
Expand All @@ -21,7 +21,8 @@
from pandas import DataFrame
import pandas._testing as tm

pytestmark = pytest.mark.usefixtures("pyarrow_skip")
xfail_pyarrow = pytest.mark.usefixtures("pyarrow_xfail")
skip_pyarrow = pytest.mark.usefixtures("pyarrow_skip")


def test_empty_decimal_marker(all_parsers):
Expand All @@ -33,10 +34,17 @@ def test_empty_decimal_marker(all_parsers):
msg = "Only length-1 decimal markers supported"
parser = all_parsers

if parser.engine == "pyarrow":
msg = (
"only single character unicode strings can be "
"converted to Py_UCS4, got length 0"
)

with pytest.raises(ValueError, match=msg):
parser.read_csv(StringIO(data), decimal="")


@skip_pyarrow
mroeschke marked this conversation as resolved.
Show resolved Hide resolved
def test_bad_stream_exception(all_parsers, csv_dir_path):
# see gh-13652
#
Expand All @@ -57,6 +65,7 @@ def test_bad_stream_exception(all_parsers, csv_dir_path):
parser.read_csv(stream)


@skip_pyarrow
def test_malformed(all_parsers):
# see gh-6607
parser = all_parsers
Expand All @@ -71,6 +80,7 @@ def test_malformed(all_parsers):
parser.read_csv(StringIO(data), header=1, comment="#")


@skip_pyarrow
@pytest.mark.parametrize("nrows", [5, 3, None])
def test_malformed_chunks(all_parsers, nrows):
data = """ignore
Expand All @@ -90,6 +100,7 @@ def test_malformed_chunks(all_parsers, nrows):
reader.read(nrows)


@skip_pyarrow
def test_catch_too_many_names(all_parsers):
# see gh-5156
data = """\
Expand All @@ -109,6 +120,7 @@ def test_catch_too_many_names(all_parsers):
parser.read_csv(StringIO(data), header=0, names=["a", "b", "c", "d"])


@skip_pyarrow
@pytest.mark.parametrize("nrows", [0, 1, 2, 3, 4, 5])
def test_raise_on_no_columns(all_parsers, nrows):
parser = all_parsers
Expand Down Expand Up @@ -147,6 +159,10 @@ def test_error_bad_lines(all_parsers):
data = "a\n1\n1,2,3\n4\n5,6,7"

msg = "Expected 1 fields in line 3, saw 3"

if parser.engine == "pyarrow":
msg = "CSV parse error: Expected 1 columns, got 3: 1,2,3"

with pytest.raises(ParserError, match=msg):
parser.read_csv(StringIO(data), on_bad_lines="error")

Expand All @@ -156,9 +172,13 @@ def test_warn_bad_lines(all_parsers):
parser = all_parsers
data = "a\n1\n1,2,3\n4\n5,6,7"
expected = DataFrame({"a": [1, 4]})
match_msg = "Skipping line"

if parser.engine == "pyarrow":
match_msg = "Expected 1 columns, but found 3: 1,2,3"

with tm.assert_produces_warning(
ParserWarning, match="Skipping line", check_stacklevel=False
ParserWarning, match=match_msg, check_stacklevel=False
):
result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)
Expand All @@ -174,10 +194,14 @@ def test_read_csv_wrong_num_columns(all_parsers):
parser = all_parsers
msg = "Expected 6 fields in line 3, saw 7"

if parser.engine == "pyarrow":
msg = "Expected 6 columns, got 7: 6,7,8,9,10,11,12"

with pytest.raises(ParserError, match=msg):
parser.read_csv(StringIO(data))


@skip_pyarrow
def test_null_byte_char(request, all_parsers):
# see gh-2741
data = "\x00,foo"
Expand All @@ -200,6 +224,7 @@ def test_null_byte_char(request, all_parsers):
parser.read_csv(StringIO(data), names=names)


@skip_pyarrow
@pytest.mark.filterwarnings("always::ResourceWarning")
def test_open_file(request, all_parsers):
# GH 39024
Expand Down Expand Up @@ -238,6 +263,8 @@ def test_bad_header_uniform_error(all_parsers):
"Could not construct index. Requested to use 1 "
"number of columns, but 3 left to parse."
)
elif parser.engine == "pyarrow":
msg = "CSV parse error: Expected 1 columns, got 4: col1,col2,col3,col4"

with pytest.raises(ParserError, match=msg):
parser.read_csv(StringIO(data), index_col=0, on_bad_lines="error")
Expand All @@ -253,9 +280,13 @@ def test_on_bad_lines_warn_correct_formatting(all_parsers):
a,b
"""
expected = DataFrame({"1": "a", "2": ["b"] * 2})
match_msg = "Skipping line"

if parser.engine == "pyarrow":
match_msg = "Expected 2 columns, but found 3: a,b,c"

with tm.assert_produces_warning(
ParserWarning, match="Skipping line", check_stacklevel=False
ParserWarning, match=match_msg, check_stacklevel=False
):
result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)
10 changes: 7 additions & 3 deletions pandas/tests/io/parser/test_unsupported.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,17 @@ def test_pyarrow_engine(self):
with pytest.raises(ValueError, match=msg):
read_csv(StringIO(data), engine="pyarrow", **kwargs)

def test_on_bad_lines_callable_python_only(self, all_parsers):
def test_on_bad_lines_callable_python_or_pyarrow(self, all_parsers):
# GH 5686
# GH 54643
sio = StringIO("a,b\n1,2")
bad_lines_func = lambda x: x
parser = all_parsers
if all_parsers.engine != "python":
msg = "on_bad_line can only be a callable function if engine='python'"
if all_parsers.engine not in ["python", "pyarrow"]:
msg = (
"on_bad_line can only be a callable "
"function if engine='python' or 'pyarrow'"
)
with pytest.raises(ValueError, match=msg):
parser.read_csv(sio, on_bad_lines=bad_lines_func)
else:
Expand Down
Loading