diff --git a/builders/builder.py b/builders/builder.py index 4144b71..1d9b8e0 100644 --- a/builders/builder.py +++ b/builders/builder.py @@ -13,15 +13,15 @@ def flatten(l): Generator that flattens iterable infinitely. If an item is iterable, ``flatten`` descends on it. If it is callable, it descends on the call result (with no arguments), and it yields the item itself otherwise. """ - for el in l: - if isinstance(el, collections.Iterable) and not isinstance(el, basestring): + if isinstance(l, collections.Iterable) and not isinstance(l, basestring): + for el in l: for sub in flatten(el): yield sub - elif callable(el): - for sub in flatten(el()): - yield sub - else: - yield el + elif callable(l): + for sub in flatten(l()): + yield sub + else: + yield l class Builder: diff --git a/builders/tests/test_builder.py b/builders/tests/test_builder.py index a9430f8..264141f 100644 --- a/builders/tests/test_builder.py +++ b/builders/tests/test_builder.py @@ -175,6 +175,18 @@ def test_flatten(): assert list(flatten(l)) == [1, 2, 3, 4, 5, 'ololo'] +def test_flatten_noniterable(): + y = 100 + assert list(flatten(y)) == [y] + + +def test_flatten_function_returning_noniterable(): + def y(): + return 100 + + assert list(flatten(y)) == [100] + + def test_flatten_callable(): def x(): return [1, 2, 3]