Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use StrCollection in more places #16301

Merged
merged 7 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16301.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
12 changes: 5 additions & 7 deletions synapse/app/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
NoReturn,
Optional,
Expand Down Expand Up @@ -76,7 +74,7 @@
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
load_legacy_third_party_event_rules,
)
from synapse.types import ISynapseReactor
from synapse.types import ISynapseReactor, StrCollection
from synapse.util import SYNAPSE_VERSION
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries
from synapse.util.daemonize import daemonize_process
Expand Down Expand Up @@ -278,7 +276,7 @@ async def wrapper() -> None:
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))


def listen_metrics(bind_addresses: Iterable[str], port: int) -> None:
def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
"""
Start Prometheus metrics server.
"""
Expand Down Expand Up @@ -315,7 +313,7 @@ def _set_prometheus_client_use_created_metrics(new_value: bool) -> None:


def listen_manhole(
bind_addresses: Collection[str],
bind_addresses: StrCollection,
port: int,
manhole_settings: ManholeConfig,
manhole_globals: dict,
Expand All @@ -339,7 +337,7 @@ def listen_manhole(


def listen_tcp(
bind_addresses: Collection[str],
bind_addresses: StrCollection,
port: int,
factory: ServerFactory,
reactor: IReactorTCP = reactor,
Expand Down Expand Up @@ -448,7 +446,7 @@ def listen_http(


def listen_ssl(
bind_addresses: Collection[str],
bind_addresses: StrCollection,
port: int,
factory: ServerFactory,
context_factory: IOpenSSLContextFactory,
Expand Down
3 changes: 1 addition & 2 deletions synapse/config/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from typing import (
Any,
ClassVar,
Collection,
Dict,
Iterable,
Iterator,
Expand Down Expand Up @@ -384,7 +383,7 @@ class RootConfig:

config_classes: List[Type[Config]] = []

def __init__(self, config_files: Collection[str] = ()):
def __init__(self, config_files: StrSequence = ()):
# Capture absolute paths here, so we can reload config after we daemonize.
self.config_files = [os.path.abspath(path) for path in config_files]

Expand Down
5 changes: 2 additions & 3 deletions synapse/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -408,7 +407,7 @@ def items(self) -> List[Tuple[str, Optional[Any]]]:
def keys(self) -> Iterable[str]:
return self._dict.keys()

def prev_event_ids(self) -> Sequence[str]:
def prev_event_ids(self) -> List[str]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I'm wondering if this should be StrSequence so it isn't mutable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're returning a brand-new list with no references to any internal data I don't think there's a huge reason to return something immutable. My bias would be to leave it as List[str] for now, but that's only a weak preference.

"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.

Expand Down Expand Up @@ -553,7 +552,7 @@ def event_id(self) -> str:
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id

def prev_event_ids(self) -> Sequence[str]:
def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.

Expand Down
8 changes: 4 additions & 4 deletions synapse/events/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import attr
from signedjson.types import SigningKey
Expand All @@ -28,7 +28,7 @@
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.types import EventID, JsonDict
from synapse.types import EventID, JsonDict, StrCollection
from synapse.types.state import StateFilter
from synapse.util import Clock
from synapse.util.stringutils import random_string
Expand Down Expand Up @@ -103,7 +103,7 @@ def is_state(self) -> bool:

async def build(
self,
prev_event_ids: Collection[str],
prev_event_ids: StrCollection,
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
Expand Down Expand Up @@ -136,7 +136,7 @@ async def build(

format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)
Expand Down
6 changes: 3 additions & 3 deletions synapse/events/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
from typing import Iterable, List, Type, Union, cast
from typing import List, Type, Union, cast

import jsonschema
from pydantic import Field, StrictBool, StrictStr
Expand All @@ -36,7 +36,7 @@
from synapse.federation.federation_server import server_matches_acl_event
from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel
from synapse.types import EventID, JsonDict, RoomID, UserID
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID


class EventValidator:
Expand Down Expand Up @@ -225,7 +225,7 @@ def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None:

self._ensure_state_event(event)

def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
def _ensure_strings(self, d: JsonDict, keys: StrCollection) -> None:
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
Expand Down
8 changes: 3 additions & 5 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor
from synapse.types import BytesSequence, ISynapseReactor, StrSequence
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred

Expand Down Expand Up @@ -108,11 +108,9 @@
# the value actually has to be a List, but List is invariant so we can't specify that
# the entries can either be Lists or bytes.
RawHeaderValue = Union[
List[str],
List[bytes],
BytesSequence,
StrSequence,
List[Union[str, bytes]],
Tuple[str, ...],
Tuple[bytes, ...],
Tuple[Union[str, bytes], ...],
]

Expand Down
33 changes: 16 additions & 17 deletions synapse/http/servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Iterable,
List,
Mapping,
Optional,
Expand All @@ -38,7 +37,7 @@
from synapse.api.errors import Codes, SynapseError
from synapse.http import redact_uri
from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID
from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
from synapse.util import json_decoder

if TYPE_CHECKING:
Expand Down Expand Up @@ -340,7 +339,7 @@ def parse_string(
name: str,
default: str,
*,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
Expand All @@ -352,7 +351,7 @@ def parse_string(
name: str,
*,
required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
Expand All @@ -365,7 +364,7 @@ def parse_string(
*,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
Expand All @@ -376,7 +375,7 @@ def parse_string(
name: str,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
"""
Expand Down Expand Up @@ -485,7 +484,7 @@ def parse_enum(

def _parse_string_value(
value: bytes,
allowed_values: Optional[Iterable[str]],
allowed_values: Optional[StrCollection],
name: str,
encoding: str,
) -> str:
Expand All @@ -511,7 +510,7 @@ def parse_strings_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
*,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
Expand All @@ -523,7 +522,7 @@ def parse_strings_from_args(
name: str,
default: List[str],
*,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]:
...
Expand All @@ -535,7 +534,7 @@ def parse_strings_from_args(
name: str,
*,
required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> List[str]:
...
Expand All @@ -548,7 +547,7 @@ def parse_strings_from_args(
default: Optional[List[str]] = None,
*,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
Expand All @@ -559,7 +558,7 @@ def parse_strings_from_args(
name: str,
default: Optional[List[str]] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
"""
Expand Down Expand Up @@ -610,7 +609,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
*,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
Expand All @@ -623,7 +622,7 @@ def parse_string_from_args(
default: Optional[str] = None,
*,
required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> str:
...
Expand All @@ -635,7 +634,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
Expand All @@ -646,7 +645,7 @@ def parse_string_from_args(
name: str,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
allowed_values: Optional[StrCollection] = None,
encoding: str = "ascii",
) -> Optional[str]:
"""
Expand Down Expand Up @@ -821,7 +820,7 @@ def parse_and_validate_json_object_from_request(
return validate_json_object(content, model_type)


def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
def assert_params_in_dict(body: JsonDict, required: StrCollection) -> None:
absent = []
for k in required:
if k not in body:
Expand Down
8 changes: 4 additions & 4 deletions synapse/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Iterable,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
Expand All @@ -49,6 +48,7 @@
from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager
from synapse.metrics._twisted_exposition import MetricsResource, generate_latest
from synapse.metrics._types import Collector
from synapse.types import StrSequence
from synapse.util import SYNAPSE_VERSION

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,7 +81,7 @@ class LaterGauge(Collector):

name: str
desc: str
labels: Optional[Sequence[str]] = attr.ib(hash=False)
labels: Optional[StrSequence] = attr.ib(hash=False)
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
caller: Callable[
Expand Down Expand Up @@ -143,8 +143,8 @@ def __init__(
self,
name: str,
desc: str,
labels: Sequence[str],
sub_metrics: Sequence[str],
labels: StrSequence,
sub_metrics: StrSequence,
):
self.name = name
self.desc = desc
Expand Down
Loading