From b67feb12749ef8c01ef77142c3cd534bb3d87eda Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Mon, 4 Nov 2024 01:19:51 -0500 Subject: [PATCH] [Bugfix]Using the correct type hints (#9885) Signed-off-by: Gregory Shtrasberg --- vllm/sequence.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index ee547dde45394..44a9257c9a4c1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -6,7 +6,8 @@ from collections import defaultdict from dataclasses import dataclass, field from functools import cached_property, reduce -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional +from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, + Mapping, Optional) from typing import Sequence as GenericSequence from typing import Set, Tuple, Union, cast @@ -256,7 +257,8 @@ def output_token_ids(self) -> Tuple[int, ...]: return tuple(self._output_token_ids) @output_token_ids.setter - def output_token_ids(self, new_output_token_ids: List[int]) -> None: + def output_token_ids(self, + new_output_token_ids: GenericSequence[int]) -> None: self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, new_output_token_ids) self._update_cached_all_tokens() @@ -1173,7 +1175,7 @@ def get_all_seq_ids_and_request_ids( sequence ids. """ seq_ids: List[int] = [] - request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) + request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set) for sg in seq_group_metadata_list: for seq_id in sg.seq_data: seq_ids.append(seq_id)