Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add on_bad_lines for pyarrow #54643

Merged
merged 19 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ enhancement2

Other enhancements
^^^^^^^^^^^^^^^^^^
-
-
- :func:`read_csv` now supports ``on_bad_lines`` parameter with ``engine="pyarrow"``. (:issue:`54480`)


.. ---------------------------------------------------------------------------
.. _whatsnew_220.notable_bug_fixes:
Expand Down
27 changes: 27 additions & 0 deletions pandas/io/parsers/arrow_parser_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import warnings

from pandas._config import using_pyarrow_string_dtype

from pandas._libs import lib
from pandas.compat._optional import import_optional_dependency
from pandas.errors import ParserWarning
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.inference import is_integer

Expand Down Expand Up @@ -85,6 +88,30 @@ def _get_pyarrow_options(self) -> None:
and option_name
in ("delimiter", "quote_char", "escape_char", "ignore_empty_lines")
}

on_bad_lines = self.kwds.get("on_bad_lines")
if on_bad_lines is not None:
if callable(on_bad_lines):
self.parse_options["invalid_row_handler"] = on_bad_lines
elif on_bad_lines == ParserBase.BadLineHandleMethod.ERROR:
self.parse_options[
"invalid_row_handler"
] = None # PyArrow raises an exception by default
elif on_bad_lines == ParserBase.BadLineHandleMethod.WARN:

def handle_warning(invalid_row):
warnings.warn(
f"Expected {invalid_row.expected_columns} columns, but found "
f"{invalid_row.actual_columns}: {invalid_row.text}",
ParserWarning,
stacklevel=find_stack_level(),
)
return "skip"

self.parse_options["invalid_row_handler"] = handle_warning
elif on_bad_lines == ParserBase.BadLineHandleMethod.SKIP:
self.parse_options["invalid_row_handler"] = lambda _: "skip"

self.convert_options = {
option_name: option_value
for option_name, option_value in self.kwds.items()
Expand Down
13 changes: 10 additions & 3 deletions pandas/io/parsers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,13 @@
expected, a ``ParserWarning`` will be emitted while dropping extra elements.
Only supported when ``engine='python'``

.. versionchanged:: 2.2.0

- Callable, function with signature
as described in `pyarrow documentation
<https://arrow.apache.org/docs/python/generated/pyarrow.csv.ParseOptions.html
#pyarrow.csv.ParseOptions.invalid_row_handler>_` when ``engine='pyarrow'``

delim_whitespace : bool, default False
Specifies whether or not whitespace (e.g. ``' '`` or ``'\\t'``) will be
used as the ``sep`` delimiter. Equivalent to setting ``sep='\\s+'``. If this option
Expand Down Expand Up @@ -493,7 +500,6 @@ class _Fwf_Defaults(TypedDict):
"thousands",
"memory_map",
"dialect",
"on_bad_lines",
"delim_whitespace",
"quoting",
"lineterminator",
Expand Down Expand Up @@ -2062,9 +2068,10 @@ def _refine_defaults_read(
elif on_bad_lines == "skip":
kwds["on_bad_lines"] = ParserBase.BadLineHandleMethod.SKIP
elif callable(on_bad_lines):
if engine != "python":
if engine not in ["python", "pyarrow"]:
raise ValueError(
"on_bad_line can only be a callable function if engine='python'"
"on_bad_line can only be a callable function "
"if engine='python' or 'pyarrow'"
)
kwds["on_bad_lines"] = on_bad_lines
else:
Expand Down
101 changes: 79 additions & 22 deletions pandas/tests/io/parser/common/test_read_errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Tests that work on both the Python and C engines but do not have a
Tests that work on the Python, C and PyArrow engines but do not have a
specific classification into the other test modules.
"""
import codecs
Expand All @@ -15,12 +15,14 @@
from pandas.errors import (
EmptyDataError,
ParserError,
ParserWarning,
)

from pandas import DataFrame
import pandas._testing as tm

pytestmark = pytest.mark.usefixtures("pyarrow_skip")
xfail_pyarrow = pytest.mark.usefixtures("pyarrow_xfail")
skip_pyarrow = pytest.mark.usefixtures("pyarrow_skip")


def test_empty_decimal_marker(all_parsers):
Expand All @@ -32,10 +34,17 @@ def test_empty_decimal_marker(all_parsers):
msg = "Only length-1 decimal markers supported"
parser = all_parsers

if parser.engine == "pyarrow":
msg = (
"only single character unicode strings can be "
"converted to Py_UCS4, got length 0"
)

with pytest.raises(ValueError, match=msg):
parser.read_csv(StringIO(data), decimal="")


@skip_pyarrow
mroeschke marked this conversation as resolved.
Show resolved Hide resolved
def test_bad_stream_exception(all_parsers, csv_dir_path):
# see gh-13652
#
Expand All @@ -56,6 +65,7 @@ def test_bad_stream_exception(all_parsers, csv_dir_path):
parser.read_csv(stream)


@skip_pyarrow
def test_malformed(all_parsers):
# see gh-6607
parser = all_parsers
Expand All @@ -70,6 +80,7 @@ def test_malformed(all_parsers):
parser.read_csv(StringIO(data), header=1, comment="#")


@skip_pyarrow
@pytest.mark.parametrize("nrows", [5, 3, None])
def test_malformed_chunks(all_parsers, nrows):
data = """ignore
Expand All @@ -89,6 +100,7 @@ def test_malformed_chunks(all_parsers, nrows):
reader.read(nrows)


@skip_pyarrow
def test_catch_too_many_names(all_parsers):
# see gh-5156
data = """\
Expand All @@ -108,6 +120,7 @@ def test_catch_too_many_names(all_parsers):
parser.read_csv(StringIO(data), header=0, names=["a", "b", "c", "d"])


@skip_pyarrow
@pytest.mark.parametrize("nrows", [0, 1, 2, 3, 4, 5])
def test_raise_on_no_columns(all_parsers, nrows):
parser = all_parsers
Expand Down Expand Up @@ -135,11 +148,16 @@ def test_suppress_error_output(all_parsers, capsys):
data = "a\n1\n1,2,3\n4\n5,6,7"
expected = DataFrame({"a": [1, 4]})

result = parser.read_csv(StringIO(data), on_bad_lines="skip")
tm.assert_frame_equal(result, expected)
if parser.engine == "pyarrow":
with tm.assert_produces_warning(False):
amithkk marked this conversation as resolved.
Show resolved Hide resolved
result = parser.read_csv(StringIO(data), on_bad_lines="skip")
tm.assert_frame_equal(result, expected)
else:
result = parser.read_csv(StringIO(data), on_bad_lines="skip")
tm.assert_frame_equal(result, expected)

captured = capsys.readouterr()
assert captured.err == ""
captured = capsys.readouterr()
assert captured.err == ""


def test_error_bad_lines(all_parsers):
Expand All @@ -148,7 +166,14 @@ def test_error_bad_lines(all_parsers):
data = "a\n1\n1,2,3\n4\n5,6,7"

msg = "Expected 1 fields in line 3, saw 3"
with pytest.raises(ParserError, match=msg):
ex_type = ParserError
amithkk marked this conversation as resolved.
Show resolved Hide resolved

if parser.engine == "pyarrow":
pa = pytest.importorskip("pyarrow")
ex_type = pa.ArrowInvalid
msg = "CSV parse error: Expected 1 columns, got 3: 1,2,3"

with pytest.raises(ex_type, match=msg):
parser.read_csv(StringIO(data), on_bad_lines="error")


Expand All @@ -158,12 +183,21 @@ def test_warn_bad_lines(all_parsers, capsys):
data = "a\n1\n1,2,3\n4\n5,6,7"
expected = DataFrame({"a": [1, 4]})

result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)
if parser.engine == "pyarrow":
with tm.assert_produces_warning(
ParserWarning,
check_stacklevel=False,
match="Expected 1 columns, but found 3: 1,2,3",
):
result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)
else:
result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)

captured = capsys.readouterr()
assert "Skipping line 3" in captured.err
assert "Skipping line 5" in captured.err
captured = capsys.readouterr()
assert "Skipping line 3" in captured.err
assert "Skipping line 5" in captured.err


def test_read_csv_wrong_num_columns(all_parsers):
Expand All @@ -175,11 +209,18 @@ def test_read_csv_wrong_num_columns(all_parsers):
"""
parser = all_parsers
msg = "Expected 6 fields in line 3, saw 7"
ex_type = ParserError

with pytest.raises(ParserError, match=msg):
if parser.engine == "pyarrow":
pa = pytest.importorskip("pyarrow")
ex_type = pa.ArrowInvalid
msg = "Expected 6 columns, got 7: 6,7,8,9,10,11,12"

with pytest.raises(ex_type, match=msg):
parser.read_csv(StringIO(data))


@skip_pyarrow
def test_null_byte_char(request, all_parsers):
# see gh-2741
data = "\x00,foo"
Expand All @@ -202,6 +243,7 @@ def test_null_byte_char(request, all_parsers):
parser.read_csv(StringIO(data), names=names)


@skip_pyarrow
@pytest.mark.filterwarnings("always::ResourceWarning")
def test_open_file(request, all_parsers):
# GH 39024
Expand Down Expand Up @@ -235,13 +277,18 @@ def test_bad_header_uniform_error(all_parsers):
parser = all_parsers
data = "+++123456789...\ncol1,col2,col3,col4\n1,2,3,4\n"
msg = "Expected 2 fields in line 2, saw 4"
ex_type = ParserError
if parser.engine == "c":
msg = (
"Could not construct index. Requested to use 1 "
"number of columns, but 3 left to parse."
)
elif parser.engine == "pyarrow":
pa = pytest.importorskip("pyarrow")
ex_type = pa.ArrowInvalid
msg = "CSV parse error: Expected 1 columns, got 4: col1,col2,col3,col4"

with pytest.raises(ParserError, match=msg):
with pytest.raises(ex_type, match=msg):
parser.read_csv(StringIO(data), index_col=0, on_bad_lines="error")


Expand All @@ -256,17 +303,27 @@ def test_on_bad_lines_warn_correct_formatting(all_parsers, capsys):
"""
expected = DataFrame({"1": "a", "2": ["b"] * 2})

result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)
# pyarrow engine uses warnings instead of directly printing to stderr
if parser.engine == "pyarrow":
with tm.assert_produces_warning(
ParserWarning,
check_stacklevel=False,
match="Expected 2 columns, but found 3: a,b,c",
):
result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)
else:
result = parser.read_csv(StringIO(data), on_bad_lines="warn")
tm.assert_frame_equal(result, expected)

captured = capsys.readouterr()
if parser.engine == "c":
warn = """Skipping line 3: expected 2 fields, saw 3
captured = capsys.readouterr()
if parser.engine == "c":
warn = """Skipping line 3: expected 2 fields, saw 3
Skipping line 4: expected 2 fields, saw 3

"""
else:
warn = """Skipping line 3: Expected 2 fields in line 3, saw 3
else:
warn = """Skipping line 3: Expected 2 fields in line 3, saw 3
Skipping line 4: Expected 2 fields in line 4, saw 3
"""
assert captured.err == warn
assert captured.err == warn
10 changes: 7 additions & 3 deletions pandas/tests/io/parser/test_unsupported.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,17 @@ def test_pyarrow_engine(self):
with pytest.raises(ValueError, match=msg):
read_csv(StringIO(data), engine="pyarrow", **kwargs)

def test_on_bad_lines_callable_python_only(self, all_parsers):
def test_on_bad_lines_callable_python_or_pyarrow(self, all_parsers):
# GH 5686
# GH 54643
sio = StringIO("a,b\n1,2")
bad_lines_func = lambda x: x
parser = all_parsers
if all_parsers.engine != "python":
msg = "on_bad_line can only be a callable function if engine='python'"
if all_parsers.engine not in ["python", "pyarrow"]:
msg = (
"on_bad_line can only be a callable "
"function if engine='python' or 'pyarrow'"
)
with pytest.raises(ValueError, match=msg):
parser.read_csv(sio, on_bad_lines=bad_lines_func)
else:
Expand Down