diff --git a/pyproject.toml b/pyproject.toml index 4a49ec99f4ec6..ef651db3d1849 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,6 @@ extend_skip_glob = [ # These files do not need to be formatted, # see .flake8 for more details "python/paddle/utils/gast/**", - "python/paddle/base/framework.py", ] [tool.ruff] diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index e73b9ae0cc309..f9d0d9a536e70 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -23,6 +23,9 @@ import paddle from the source directory; please install paddlepaddle*.whl firstly.''' ) +# NOTE(SigureMo): We should place the import of base.core before other modules, +# because there are some initialization codes in base/core/__init__.py. +from .base import core # noqa: F401 from .batch import batch # Do the *DUPLICATED* monkey-patch for the tensor object. @@ -532,8 +535,8 @@ from .pir_utils import IrGuard -ir_change = IrGuard() -ir_change._switch_to_pir() +ir_guard = IrGuard() +ir_guard._switch_to_pir() __all__ = [ 'iinfo', diff --git a/python/paddle/base/__init__.py b/python/paddle/base/__init__.py index 4acf21c465776..5bab0d5cf84f0 100644 --- a/python/paddle/base/__init__.py +++ b/python/paddle/base/__init__.py @@ -15,6 +15,7 @@ import os import sys import atexit +import platform # The legacy core need to be removed before "import core", # in case of users installing paddlepaddle without -U option @@ -32,6 +33,8 @@ except Exception as e: raise e +from . import core + # import all class inside framework into base module from . import framework from .framework import ( @@ -138,11 +141,6 @@ def __bootstrap__(): Returns: None """ - import sys - import os - import platform - from . import core - # NOTE(zhiqiu): When (1)numpy < 1.19; (2) python < 3.7, # unittest is always imported in numpy (maybe some versions not). # so is_test is True and p2p is not inited. diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index ec580ba50d246..5c86638a76627 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -12,33 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -import textwrap import collections -from collections.abc import Iterable -from .wrapped_decorator import signature_safe_contextmanager, wrap_decorator +import copy +import functools +import multiprocessing import os import re +import subprocess +import sys +import textwrap +import threading import traceback -import copy -from types import MethodType, FunctionType +import warnings +from collections.abc import Iterable +from types import FunctionType, MethodType import numpy as np -import subprocess -import multiprocessing -import sys -from .proto import framework_pb2 -from .proto import data_feed_pb2 # noqa: F401 +import paddle.version as paddle_version -from . import core -from . import unique_name from .. import pir -from paddle.base.libpaddle import DataType -import paddle.version as fluid_version -import warnings -import functools -from .variable_index import _getitem_static, _setitem_static, _setitem_impl_ -import threading +from . import core, unique_name +from .libpaddle import DataType +from .proto import data_feed_pb2 # noqa: F401 +from .proto import framework_pb2 +from .variable_index import _getitem_static, _setitem_impl_, _setitem_static +from .wrapped_decorator import signature_safe_contextmanager, wrap_decorator __all__ = [] @@ -503,10 +502,10 @@ def require_version(min_version, max_version=None): ) version_installed = [ - fluid_version.major, - fluid_version.minor, - fluid_version.patch, - fluid_version.rc, + paddle_version.major, + paddle_version.minor, + paddle_version.patch, + paddle_version.rc, ] zero_version = ['0', '0', '0', '0'] @@ -524,7 +523,7 @@ def version_cmp(ver_a, ver_b): "PaddlePaddle version in [{}, {}] required, but {} installed. " "Maybe you are using a develop version, " "please make sure the version is good with your code.".format( - min_version, max_version, fluid_version.full_version + min_version, max_version, paddle_version.full_version ) ) else: @@ -532,7 +531,7 @@ def version_cmp(ver_a, ver_b): "PaddlePaddle version {} or higher is required, but {} installed, " "Maybe you are using a develop version, " "please make sure the version is good with your code.".format( - min_version, fluid_version.full_version + min_version, paddle_version.full_version ) ) return @@ -554,7 +553,7 @@ def version_cmp(ver_a, ver_b): ): raise Exception( "VersionError: PaddlePaddle version in [{}, {}] required, but {} installed.".format( - min_version, max_version, fluid_version.full_version + min_version, max_version, paddle_version.full_version ) ) else: @@ -562,7 +561,7 @@ def version_cmp(ver_a, ver_b): raise Exception( "VersionError: PaddlePaddle version {} or higher is required, but {} installed, " "please upgrade your PaddlePaddle to {} or other higher version.".format( - min_version, fluid_version.full_version, min_version + min_version, paddle_version.full_version, min_version ) ) @@ -7703,8 +7702,8 @@ def _get_var(name, program=None): @signature_safe_contextmanager def dygraph_guard_if_declarative(): - from .dygraph.base import in_to_static_mode from .dygraph import Tracer + from .dygraph.base import in_to_static_mode if in_to_static_mode(): # Under @paddle.jit.to_static decorator, we switch back dygraph mode temporarily. diff --git a/python/paddle/pir/__init__.py b/python/paddle/pir/__init__.py index 39b8c71ca5a2f..8a454c09e058d 100644 --- a/python/paddle/pir/__init__.py +++ b/python/paddle/pir/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle.base.libpaddle.pir import ( +from paddle.base.libpaddle.pir import ( # noqa: F401 Program, Block, Operation, @@ -22,8 +22,8 @@ fake_op_result, is_fake_op_result, Type, -) # noqa: F401 -from paddle.base.libpaddle.pir import ( +) +from paddle.base.libpaddle.pir import ( # noqa: F401 translate_to_new_ir, set_global_program, set_insertion_point, @@ -32,7 +32,7 @@ check_unregistered_ops, register_paddle_dialect, PassManager, -) # noqa: F401 +) from . import core diff --git a/test/dygraph_to_static/test_origin_info.py b/test/dygraph_to_static/test_origin_info.py index c6415dff1ba1c..e2925d4fa1a4b 100644 --- a/test/dygraph_to_static/test_origin_info.py +++ b/test/dygraph_to_static/test_origin_info.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import sys import unittest @@ -23,11 +24,10 @@ OriginInfo, attach_origin_info, create_and_update_origin_info_map, - gast, - inspect, unwrap, ) from paddle.jit.dy2static.utils import ast_to_func +from paddle.utils import gast def simple_func(x):