Skip to content

Commit

Permalink
Fix open_local returning list for pathlib.Path (#1418)
Browse files Browse the repository at this point in the history
---------
Co-authored-by: Martin Durant <[email protected]>
  • Loading branch information
BENR0 authored Nov 9, 2023
1 parent 5f268e4 commit 4f70f1b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
11 changes: 9 additions & 2 deletions fsspec/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import io
import logging
import os
import re
from glob import has_magic
from pathlib import Path

# for backwards compat, we export cache things from here too
from .caching import ( # noqa: F401
Expand Down Expand Up @@ -469,7 +472,11 @@ def open(
return out[0]


def open_local(url, mode="rb", **storage_options):
def open_local(
url: str | list[str] | Path | list[Path],
mode: str = "rb",
**storage_options: dict,
) -> str | list[str]:
"""Open file(s) which can be resolved to local
For files which either are local, or get downloaded upon open
Expand All @@ -493,7 +500,7 @@ def open_local(url, mode="rb", **storage_options):
)
with of as files:
paths = [f.name for f in files]
if isinstance(url, str) and not has_magic(url):
if (isinstance(url, str) and not has_magic(url)) or isinstance(url, Path):
return paths[0]
return paths

Expand Down
42 changes: 41 additions & 1 deletion fsspec/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
import zipfile
from contextlib import contextmanager
from pathlib import Path

import pytest

Expand Down Expand Up @@ -101,7 +102,7 @@ def test_openfile_open(m):
assert m.size("somepath") == 5


def test_open_local():
def test_open_local_w_cache():
d1 = str(tempfile.mkdtemp())
f1 = os.path.join(d1, "f1")
open(f1, "w").write("test1")
Expand All @@ -112,6 +113,45 @@ def test_open_local():
assert d2 in fn


def test_open_local_w_magic():
d1 = str(tempfile.mkdtemp())
f1 = os.path.join(d1, "f1")
open(f1, "w").write("test1")
fn = open_local(os.path.join(d1, "f*"))
assert len(fn) == 1
assert isinstance(fn, list)


def test_open_local_w_list_of_str():
d1 = str(tempfile.mkdtemp())
f1 = os.path.join(d1, "f1")
open(f1, "w").write("test1")
fn = open_local([f1, f1])
assert len(fn) == 2
assert isinstance(fn, list)
assert all(isinstance(elem, str) for elem in fn)


def test_open_local_w_path():
d1 = str(tempfile.mkdtemp())
f1 = os.path.join(d1, "f1")
open(f1, "w").write("test1")
p = Path(f1)
fn = open_local(p)
assert isinstance(fn, str)


def test_open_local_w_list_of_path():
d1 = str(tempfile.mkdtemp())
f1 = os.path.join(d1, "f1")
open(f1, "w").write("test1")
p = Path(f1)
fn = open_local([p, p])
assert len(fn) == 2
assert isinstance(fn, list)
assert all(isinstance(elem, str) for elem in fn)


def test_xz_lzma_compressions():
pytest.importorskip("lzma")
# Ensure that both 'xz' and 'lzma' compression names can be parsed
Expand Down

0 comments on commit 4f70f1b

Please sign in to comment.