From 0b131c9ae7cafd7a43f875b0df1fb714683bdcda Mon Sep 17 00:00:00 2001 From: caneff Date: Wed, 20 Sep 2023 15:16:51 -0400 Subject: [PATCH] Change handling of copy=None defaults for Pandas 2 (#28523) --- sdks/python/apache_beam/dataframe/frame_base.py | 8 ++++++++ .../apache_beam/dataframe/frame_base_test.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/sdks/python/apache_beam/dataframe/frame_base.py b/sdks/python/apache_beam/dataframe/frame_base.py index 48a4c29d0589..4e89e473b730 100644 --- a/sdks/python/apache_beam/dataframe/frame_base.py +++ b/sdks/python/apache_beam/dataframe/frame_base.py @@ -674,11 +674,19 @@ def wrap(func): if removed_args: defaults_to_populate -= set(removed_args) + # In pandas 2, many methods rely on the default copy=None + # to mean that copy is the value of copy_on_write. Since + # copy_on_write will always be true for Beam, just fill it + # in here. In pandas 1, the default was True anyway. + if 'copy' in arg_to_default and arg_to_default['copy'] is None: + arg_to_default['copy'] = True + @functools.wraps(func) def wrapper(**kwargs): for name in defaults_to_populate: if name not in kwargs: kwargs[name] = arg_to_default[name] + return func(**kwargs) return wrapper diff --git a/sdks/python/apache_beam/dataframe/frame_base_test.py b/sdks/python/apache_beam/dataframe/frame_base_test.py index b3077320720f..0a73905339fd 100644 --- a/sdks/python/apache_beam/dataframe/frame_base_test.py +++ b/sdks/python/apache_beam/dataframe/frame_base_test.py @@ -174,6 +174,21 @@ def func(self, a, **kwargs): 'a': 2, 'b': 4, 'c': 6, 'kw_only': 8 }) + def test_populate_defaults_overwrites_copy(self): + class Base(object): + def func(self, a=1, b=2, c=3, *, copy=None): + pass + + class Proxy(object): + @frame_base.args_to_kwargs(Base) + @frame_base.populate_defaults(Base) + def func(self, a, copy, **kwargs): + return dict(kwargs, a=a, copy=copy) + + proxy = Proxy() + self.assertEqual(proxy.func(), {'a': 1, 'copy': True}) + self.assertEqual(proxy.func(copy=False), {'a': 1, 'copy': False}) + if __name__ == '__main__': unittest.main()