From 8a69040c2420c58501b0502ce9e86d70974af89b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul=20Renvois=C3=A9?= Date: Mon, 3 Jun 2024 12:49:31 +0200 Subject: [PATCH] fix(iterating/flatten) --- flashback/iterating/flatten.py | 5 ++--- tests/iterating/test_flatten.py | 5 +++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/flashback/iterating/flatten.py b/flashback/iterating/flatten.py index 6a29b58..adb2a5c 100644 --- a/flashback/iterating/flatten.py +++ b/flashback/iterating/flatten.py @@ -1,12 +1,11 @@ from __future__ import annotations -from collections.abc import Iterable from typing import TypeVar T = TypeVar("T") -def flatten(iterable: Iterable[T]) -> tuple[T, ...]: +def flatten(iterable: list[T] | tuple[T, ...] | set[T] | frozenset[T] | range) -> tuple[T, ...]: """ Unpacks nested iterables into the root `iterable`. @@ -33,7 +32,7 @@ def flatten(iterable: Iterable[T]) -> tuple[T, ...]: """ items = [] for item in iterable: - if isinstance(item, Iterable) and not isinstance(item, str): + if isinstance(item, (list, tuple, set, frozenset, range)): # noqa: UP038 for nested_item in flatten(item): items.append(nested_item) # noqa: PERF402 else: diff --git a/tests/iterating/test_flatten.py b/tests/iterating/test_flatten.py index 6b63904..c4b5ff5 100644 --- a/tests/iterating/test_flatten.py +++ b/tests/iterating/test_flatten.py @@ -21,3 +21,8 @@ def test_mixed_types(self): flattened = flatten([1, (2,), {3, 4}, range(5, 6)]) assert flattened == (1, 2, 3, 4, 5) + + def test_nested_dicts(self): + flattened = flatten([[{"key1": 1}], [{"key2": 2}, {"key3": 3}]]) + + assert flattened == ({"key1": 1}, {"key2": 2}, {"key3": 3})