Skip to content

Commit

Permalink
chore(mypy): add override where needed
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBelthle committed Jan 3, 2025
1 parent 1996f62 commit 9cea805
Show file tree
Hide file tree
Showing 152 changed files with 731 additions and 2 deletions.
7 changes: 7 additions & 0 deletions antarest/core/cache/business/local_chache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import time
from typing import Dict, List, Optional

from typing_extensions import override

from antarest.core.config import CacheConfig
from antarest.core.interfaces.cache import ICache
from antarest.core.model import JSON
Expand All @@ -40,6 +42,7 @@ def __init__(self, config: CacheConfig = CacheConfig()):
daemon=True,
)

@override
def start(self) -> None:
self.checker_thread.start()

Expand All @@ -55,6 +58,7 @@ def checker(self) -> None:
for id in to_delete:
del self.cache[id]

@override
def put(self, id: str, data: JSON, duration: int = 3600) -> None: # Duration in second
with self.lock:
logger.info(f"Adding cache key {id}")
Expand All @@ -64,6 +68,7 @@ def put(self, id: str, data: JSON, duration: int = 3600) -> None: # Duration in
duration=duration,
)

@override
def get(self, id: str, refresh_duration: Optional[int] = None) -> Optional[JSON]:
res = None
with self.lock:
Expand All @@ -76,12 +81,14 @@ def get(self, id: str, refresh_duration: Optional[int] = None) -> Optional[JSON]
res = self.cache[id].data
return res

@override
def invalidate(self, id: str) -> None:
with self.lock:
logger.info(f"Removing cache key {id}")
if id in self.cache:
del self.cache[id]

@override
def invalidate_all(self, ids: List[str]) -> None:
with self.lock:
logger.info(f"Removing cache keys {ids}")
Expand Down
6 changes: 6 additions & 0 deletions antarest/core/cache/business/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import List, Optional

from redis.client import Redis
from typing_extensions import override

from antarest.core.interfaces.cache import ICache
from antarest.core.model import JSON
Expand All @@ -31,17 +32,20 @@ class RedisCache(ICache):
def __init__(self, redis_client: Redis): # type: ignore
self.redis = redis_client

@override
def start(self) -> None:
# Assuming the Redis service is already running; no need to start it here.
pass

@override
def put(self, id: str, data: JSON, duration: int = 3600) -> None:
redis_element = RedisCacheElement(duration=duration, data=data)
redis_key = f"cache:{id}"
logger.info(f"Adding cache key {id}")
self.redis.set(redis_key, redis_element.model_dump_json())
self.redis.expire(redis_key, duration)

@override
def get(self, id: str, refresh_timeout: Optional[int] = None) -> Optional[JSON]:
redis_key = f"cache:{id}"
result = self.redis.get(redis_key)
Expand All @@ -58,10 +62,12 @@ def get(self, id: str, refresh_timeout: Optional[int] = None) -> Optional[JSON]:
logger.info(f"Cache key {id} not found")
return None

@override
def invalidate(self, id: str) -> None:
logger.info(f"Removing cache key {id}")
self.redis.delete(f"cache:{id}")

@override
def invalidate_all(self, ids: List[str]) -> None:
logger.info(f"Removing cache keys {ids}")
self.redis.delete(*[f"cache:{id}" for id in ids])
3 changes: 3 additions & 0 deletions antarest/core/configdata/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Optional

from sqlalchemy import Column, Integer, String # type: ignore
from typing_extensions import override

from antarest.core.persistence import Base
from antarest.core.serialization import AntaresBaseModel
Expand All @@ -30,11 +31,13 @@ class ConfigData(Base): # type: ignore
key = Column(String(), primary_key=True)
value = Column(String(), nullable=True)

@override
def __eq__(self, other: Any) -> bool:
if not isinstance(other, ConfigData):
return False
return bool(other.key == self.key and other.value == self.value and other.owner == self.owner)

@override
def __repr__(self) -> str:
return f"key={self.key}, value={self.value}, owner={self.owner}"

Expand Down
11 changes: 11 additions & 0 deletions antarest/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from http import HTTPStatus

from fastapi.exceptions import HTTPException
from typing_extensions import override


class ShouldNotHappenException(Exception):
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(self, path: str, *area_ids: str):
detail = f"{self.object_name.title()} {detail}"
super().__init__(HTTPStatus.NOT_FOUND, detail)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down Expand Up @@ -127,6 +129,7 @@ def __init__(self, path: str, section_id: str):
detail = f"{object_name.title()} '{section_id}' not found in '{path}'"
super().__init__(HTTPStatus.NOT_FOUND, detail)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down Expand Up @@ -172,6 +175,7 @@ def __init__(self, path: str):
detail = f"{self.object_name.title()} {detail}"
super().__init__(HTTPStatus.NOT_FOUND, detail)

@override
def __str__(self) -> str:
return self.detail

Expand Down Expand Up @@ -227,6 +231,7 @@ def __init__(self, area_id: str, *duplicates: str):
detail = f"{self.object_name.title()} {detail}"
super().__init__(HTTPStatus.CONFLICT, detail)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down Expand Up @@ -397,6 +402,7 @@ def __init__(self, object_id: str, binding_ids: t.Sequence[str], *, object_type:
)
super().__init__(HTTPStatus.FORBIDDEN, message)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down Expand Up @@ -429,6 +435,7 @@ def __init__(self, output_id: str) -> None:
message = f"Output '{output_id}' not found"
super().__init__(HTTPStatus.NOT_FOUND, message)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down Expand Up @@ -463,6 +470,7 @@ def __init__(self, output_id: str, mc_root: str) -> None:
message = f"The output '{output_id}' sub-folder '{mc_root}' does not exist"
super().__init__(HTTPStatus.NOT_FOUND, message)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down Expand Up @@ -552,6 +560,7 @@ def __init__(self, binding_constraint_id: str, *ids: str) -> None:
}[min(count, 2)]
super().__init__(HTTPStatus.NOT_FOUND, message)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand All @@ -572,6 +581,7 @@ def __init__(self, binding_constraint_id: str, *ids: str) -> None:
}[min(count, 2)]
super().__init__(HTTPStatus.CONFLICT, message)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand All @@ -589,6 +599,7 @@ def __init__(self, binding_constraint_id: str, term_json: str) -> None:
)
super().__init__(HTTPStatus.UNPROCESSABLE_ENTITY, message)

@override
def __str__(self) -> str:
"""Return a string representation of the exception."""
return self.detail
Expand Down
2 changes: 2 additions & 0 deletions antarest/core/filetransfer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Optional

from sqlalchemy import Boolean, Column, DateTime, Integer, String # type: ignore
from typing_extensions import override

from antarest.core.persistence import Base
from antarest.core.serialization import AntaresBaseModel
Expand Down Expand Up @@ -81,6 +82,7 @@ def to_dto(self) -> FileDownloadDTO:
error_message=self.error_message or "",
)

@override
def __repr__(self) -> str:
return (
f"(id={self.id},"
Expand Down
9 changes: 9 additions & 0 deletions antarest/core/interfaces/eventbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from enum import StrEnum
from typing import Any, Awaitable, Callable, List, Optional

from typing_extensions import override

from antarest.core.model import PermissionInfo
from antarest.core.serialization import AntaresBaseModel

Expand Down Expand Up @@ -140,32 +142,39 @@ class DummyEventBusService(IEventBus):
def __init__(self) -> None:
self.events: List[Event] = []

@override
def queue(self, event: Event, queue: str) -> None:
# Noop
pass

@override
def add_queue_consumer(self, listener: Callable[[Event], Awaitable[None]], queue: str) -> str:
return ""

@override
def remove_queue_consumer(self, listener_id: str) -> None:
# Noop
pass

@override
def push(self, event: Event) -> None:
# Noop
self.events.append(event)

@override
def add_listener(
self,
listener: Callable[[Event], Awaitable[None]],
type_filter: Optional[List[EventType]] = None,
) -> str:
return ""

@override
def remove_listener(self, listener_id: str) -> None:
# Noop
pass

@override
def start(self, threaded: bool = True) -> None:
# Noop
pass
4 changes: 4 additions & 0 deletions antarest/core/logging/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from typing_extensions import override

from antarest.core.config import Config

Expand All @@ -39,6 +40,7 @@ class CustomDefaultFormatter(logging.Formatter):
fields to the log record with a value of `None`.
"""

@override
def format(self, record: logging.LogRecord) -> str:
"""
Formats the specified log record using the custom formatter,
Expand Down Expand Up @@ -169,13 +171,15 @@ def configure_logger(config: Config, handler_cls: str = "logging.FileHandler") -


class LoggingMiddleware(BaseHTTPMiddleware):
@override
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
with RequestContext(request):
response = await call_next(request)
return response


class ContextFilter(logging.Filter):
@override
def filter(self, record: logging.LogRecord) -> bool:
request: Optional[Request] = _request.get()
request_id: Optional[str] = _request_id.get()
Expand Down
8 changes: 8 additions & 0 deletions antarest/core/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from fastapi import HTTPException
from markupsafe import escape
from ratelimit import Rule # type: ignore
from typing_extensions import override

from antarest.core.jwt import JWTUser

Expand All @@ -38,24 +39,30 @@ def __init__(self, data=None, **kwargs) -> None: # type: ignore
data = {}
self.update(data, **kwargs)

@override
def __setitem__(self, key: str, value: t.Any) -> None:
self._store[key.lower()] = (key, value)

@override
def __getitem__(self, key: str) -> t.Any:
return self._store[key.lower()][1]

@override
def __delitem__(self, key: str) -> None:
del self._store[key.lower()]

@override
def __iter__(self) -> t.Any:
return (casedkey for casedkey, mappedvalue in self._store.values())

@override
def __len__(self) -> int:
return len(self._store)

def lower_items(self) -> Generator[Tuple[Any, Any], Any, None]:
return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items())

@override
def __eq__(self, other: t.Any) -> bool:
if isinstance(other, t.Mapping):
other = CaseInsensitiveDict(other)
Expand All @@ -66,6 +73,7 @@ def __eq__(self, other: t.Any) -> bool:
def copy(self) -> "CaseInsensitiveDict":
return CaseInsensitiveDict(self._store.values())

@override
def __repr__(self) -> str:
return str(dict(self.items()))

Expand Down
5 changes: 5 additions & 0 deletions antarest/core/tasks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore
from sqlalchemy.engine.base import Engine # type: ignore
from sqlalchemy.orm import relationship, sessionmaker # type: ignore
from typing_extensions import override

from antarest.core.persistence import Base
from antarest.core.serialization import AntaresBaseModel
Expand Down Expand Up @@ -122,11 +123,13 @@ class TaskJobLog(Base): # type: ignore
# If the TaskJob is deleted, all attached logs must also be deleted in cascade.
job: "TaskJob" = relationship("TaskJob", back_populates="logs", uselist=False)

@override
def __eq__(self, other: t.Any) -> bool:
if not isinstance(other, TaskJobLog):
return False
return bool(other.id == self.id and other.message == self.message and other.task_id == self.task_id)

@override
def __repr__(self) -> str:
return f"id={self.id}, message={self.message}, task_id={self.task_id}"

Expand Down Expand Up @@ -198,6 +201,7 @@ def to_dto(self, with_logs: bool = False) -> TaskDTO:
progress=self.progress,
)

@override
def __eq__(self, other: t.Any) -> bool:
if not isinstance(other, TaskJob):
return False
Expand All @@ -213,6 +217,7 @@ def __eq__(self, other: t.Any) -> bool:
and other.logs == self.logs
)

@override
def __repr__(self) -> str:
return (
f"id={self.id},"
Expand Down
Loading

0 comments on commit 9cea805

Please sign in to comment.