Skip to content

Commit

Permalink
Fix annotation for api marker (#3099)
Browse files Browse the repository at this point in the history
### Changes

- Rework api decorator to pass annotation of object
- Remove unused 'is_api' function

### Reason for changes

Broken annotation of object
  • Loading branch information
AlexanderDokuchaev authored Nov 21, 2024
1 parent 472331c commit 8d501c7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
9 changes: 5 additions & 4 deletions docs/api/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def collect_api_entities() -> APIInfo:
except Exception as e:
skipped_modules[modname] = str(e)

from nncf.common.utils.api_marker import api
from nncf.common.utils.api_marker import API_MARKER_ATTR
from nncf.common.utils.api_marker import CANONICAL_ALIAS_ATTR

canonical_imports_seen = set()

Expand All @@ -86,7 +87,7 @@ def collect_api_entities() -> APIInfo:
if (
objects_module == modname
and (inspect.isclass(obj) or inspect.isfunction(obj))
and hasattr(obj, api.API_MARKER_ATTR)
and hasattr(obj, API_MARKER_ATTR)
):
marked_object_name = obj._nncf_api_marker
# Check the actual name of the originally marked object
Expand All @@ -95,8 +96,8 @@ def collect_api_entities() -> APIInfo:
if marked_object_name != obj.__name__:
continue
fqn = f"{modname}.{obj_name}"
if hasattr(obj, api.CANONICAL_ALIAS_ATTR):
canonical_import_name = getattr(obj, api.CANONICAL_ALIAS_ATTR)
if hasattr(obj, CANONICAL_ALIAS_ATTR):
canonical_import_name = getattr(obj, CANONICAL_ALIAS_ATTR)
if canonical_import_name in canonical_imports_seen:
assert False, f"Duplicate canonical_alias detected: {canonical_import_name}"
retval.fqn_vs_canonical_name[fqn] = canonical_import_name
Expand Down
41 changes: 26 additions & 15 deletions nncf/common/utils/api_marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from typing import Any, Callable, TypeVar, Union

class api:
API_MARKER_ATTR = "_nncf_api_marker"
CANONICAL_ALIAS_ATTR = "_nncf_canonical_alias"
TObj = TypeVar("TObj", bound=Union[Callable[..., Any], type])

def __init__(self, canonical_alias: str = None):
self._canonical_alias = canonical_alias
API_MARKER_ATTR = "_nncf_api_marker"
CANONICAL_ALIAS_ATTR = "_nncf_canonical_alias"

def __call__(self, obj: Any) -> Any:
# The value of the marker will be useful in determining
# whether we are handling a base class or a derived one.
setattr(obj, api.API_MARKER_ATTR, obj.__name__)
if self._canonical_alias is not None:
setattr(obj, api.CANONICAL_ALIAS_ATTR, self._canonical_alias)
return obj

def api(canonical_alias: str = None) -> Callable[[TObj], TObj]:
"""
Decorator function used to mark a object as an API.
Example:
@api(canonical_alias="alias")
class Class:
pass
@api(canonical_alias="alias")
def function():
pass
:param canonical_alias: The canonical alias for the API class.
"""

def decorator(obj: TObj) -> TObj:
setattr(obj, API_MARKER_ATTR, obj.__name__)
if canonical_alias is not None:
setattr(obj, CANONICAL_ALIAS_ATTR, canonical_alias)
return obj

def is_api(obj: Any) -> bool:
return hasattr(obj, api.API_MARKER_ATTR)
return decorator

0 comments on commit 8d501c7

Please sign in to comment.