diff --git a/fixcore/.pylintrc b/fixcore/.pylintrc index 91536aa44b..94fd5c5778 100644 --- a/fixcore/.pylintrc +++ b/fixcore/.pylintrc @@ -246,7 +246,7 @@ ignored-modules= # List of classes names for which member attributes should not be checked # (useful for classes with attributes dynamically set). -ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local, aiostream.pipe +ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular diff --git a/fixcore/fixcore/cli/__init__.py b/fixcore/fixcore/cli/__init__.py index cb93c9334d..3af2a601e3 100644 --- a/fixcore/fixcore/cli/__init__.py +++ b/fixcore/fixcore/cli/__init__.py @@ -15,13 +15,12 @@ AsyncIterable, ) -from aiostream import stream -from aiostream.core import Stream from parsy import Parser, regex, string from fixcore.model.graph_access import Section from fixcore.types import JsonElement, Json from fixcore.util import utc, parse_utc, AnyT +from fixlib.asynchronous.stream import Stream from fixlib.durations import parse_duration, DurationRe from fixlib.parse_util import ( make_parser, @@ -47,7 +46,7 @@ # A sink function takes a stream and creates a result Sink = Callable[[JsStream], Awaitable[T]] -list_sink: Callable[[JsGen], Awaitable[Any]] = stream.list # type: ignore +list_sink: Callable[[JsGen], Awaitable[List[Any]]] = Stream.as_list @make_parser diff --git a/fixcore/fixcore/cli/cli.py b/fixcore/fixcore/cli/cli.py index 8c535bef82..178bb7233c 100644 --- a/fixcore/fixcore/cli/cli.py +++ b/fixcore/fixcore/cli/cli.py @@ -10,14 +10,13 @@ from typing import Dict, List, Tuple, Union, Sequence from typing import Optional, Any, TYPE_CHECKING -from aiostream import stream from attrs import evolve from parsy import Parser from rich.padding import Padding from fixcore import version from fixcore.analytics import CoreEvent -from fixcore.cli import cmd_with_args_parser, key_values_parser, T, Sink, args_values_parser, JsGen +from fixcore.cli import cmd_with_args_parser, key_values_parser, T, Sink, args_values_parser, JsStream from fixcore.cli.command import ( SearchPart, PredecessorsPart, @@ -78,6 +77,7 @@ from fixcore.types import JsonElement from fixcore.user.model import Permission from fixcore.util import group_by +from fixlib.asynchronous.stream import Stream from fixlib.parse_util import make_parser, pipe_p, semicolon_p if TYPE_CHECKING: @@ -104,7 +104,7 @@ def command_line_parser() -> Parser: return ParsedCommands(commands, maybe_env if maybe_env else {}) -# multiple piped commands are separated by semicolon +# semicolon separates multiple piped commands multi_command_parser = command_line_parser.sep_by(semicolon_p) @@ -187,7 +187,7 @@ def overview() -> str: logo = ctx.render_console(Padding(WelcomeCommand.ck, pad=(0, 0, 0, middle))) if ctx.supports_color() else "" return headline + logo + ctx.render_console(result) - def help_command() -> JsGen: + def help_command() -> JsStream: if not arg: result = overview() elif arg == "placeholders": @@ -209,7 +209,7 @@ def help_command() -> JsGen: else: result = f"No command found with this name: {arg}" - return stream.just(result) + return Stream.just(result) return CLISource.single(help_command, required_permissions={Permission.read}) @@ -352,11 +352,11 @@ def command( self, name: str, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any ) -> ExecutableCommand: """ - Create an executable command for given command name, args and context. - :param name: the name of the command to execute (must be a known command) - :param arg: the arg of the command (must be parsable by the command) - :param ctx: the context of this command. - :return: the ready to run executable command. + Create an executable command for given command name, args, and context. + :param name: The name of the command to execute (must be a known command). + :param arg: The arg of the command (must be parsable by the command). + :param ctx: The context of this command. + :return: The ready to run executable command. :raises: CLIParseError: if the name of the command is not known, or the argument fails to parse. """ @@ -377,9 +377,9 @@ async def create_query( Takes a list of query part commands and combine them to a single executable query command. This process can also introduce new commands that should run after the query is finished. Therefore, a list of executable commands is returned. - :param commands: the incoming executable commands, which actions are all instances of SearchCLIPart. - :param ctx: the context to execute within. - :return: the resulting list of commands to execute. + :param commands: The incoming executable commands, which actions are all instances of SearchCLIPart. + :param ctx: The context to execute within. + :return: The resulting list of commands to execute. """ # Pass parsed options to execute query @@ -484,8 +484,8 @@ async def parse_query(query_arg: str) -> Query: first_head_tail_in_a_row = None head_tail_keep_order = True - # Define default sort order, if not already defined - # A sort order is required to always return the result in a deterministic way to the user. + # Define default sort order, if not already defined. + # A sort order is required to always return the result deterministically to the user. # Deterministic order is required for head/tail to work if query.is_simple_fulltext_search(): # Do not define any additional sort order for fulltext searches @@ -494,7 +494,7 @@ async def parse_query(query_arg: str) -> Query: parts = [pt if pt.sort else evolve(pt, sort=default_sort) for pt in query.parts] query = evolve(query, parts=parts) - # If the last part is a navigation, we need to add sort which will ingest a new part. + # If the last part is a navigation, we need to add a sort which will ingest a new part. with_sort = query.set_sort(*default_sort) if query.current_part.navigation else query section = ctx.env.get("section", PathRoot) # If this is an aggregate query, the default sort needs to be changed @@ -534,7 +534,7 @@ def rewrite_command_line(cmds: List[ExecutableCommand], ctx: CLIContext) -> List Rules: - add the list command if no output format is defined - add a format to write commands if no output format is defined - - report benchmark run will be formatted as benchmark result automatically + - report benchmark run will be formatted as a benchmark result automatically """ if ctx.env.get("no_rewrite") or len(cmds) == 0: return cmds diff --git a/fixcore/fixcore/cli/command.py b/fixcore/fixcore/cli/command.py index b582cba90d..2645658266 100644 --- a/fixcore/fixcore/cli/command.py +++ b/fixcore/fixcore/cli/command.py @@ -29,7 +29,6 @@ Optional, Any, AsyncIterator, - Iterable, Callable, Awaitable, cast, @@ -46,9 +45,6 @@ import yaml from aiofiles.tempfile import TemporaryDirectory from aiohttp import ClientTimeout, JsonPayload, BasicAuth -from aiostream import stream, pipe -from aiostream.aiter_utils import is_async_iterable -from aiostream.core import Stream from attr import evolve, frozen from attrs import define, field from dateutil import parser as date_parser @@ -178,6 +174,7 @@ respond_cytoscape, ) from fixcore.worker_task_queue import WorkerTask, WorkerTaskName +from fixlib.asynchronous.stream import Stream from fixlib.core import CLIEnvelope from fixlib.durations import parse_duration from fixlib.parse_util import ( @@ -946,14 +943,13 @@ def group(keys: tuple[Any, ...]) -> Json: return result async def aggregate_data(content: JsStream) -> AsyncIterator[JsonElement]: - async with content.stream() as in_stream: - for key, value in (await self.aggregate_in(in_stream, var_names, aggregate.group_func)).items(): - entry: Json = {"group": group(key)} - for fn_name, (fn_val, fn_count) in value.fn_values.items(): - if fn_by_name.get(fn_name) == "avg" and fn_val is not None and fn_count > 0: - fn_val = fn_val / fn_count # type: ignore - entry[fn_name] = fn_val - yield entry + for key, value in (await self.aggregate_in(content, var_names, aggregate.group_func)).items(): + entry: Json = {"group": group(key)} + for fn_name, (fn_val, fn_count) in value.fn_values.items(): + if fn_by_name.get(fn_name) == "avg" and fn_val is not None and fn_count > 0: + fn_val = fn_val / fn_count # type: ignore + entry[fn_name] = fn_val + yield entry # noinspection PyTypeChecker return CLIFlow(aggregate_data) @@ -1000,7 +996,7 @@ def info(self) -> str: def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIAction: size = self.parse_size(arg) - return CLIFlow(lambda in_stream: in_stream | pipe.take(size)) + return CLIFlow(lambda in_stream: Stream(in_stream).take(size)) def args_info(self) -> ArgsInfo: return [ArgInfo(expects_value=True, help_text="number of elements to take")] @@ -1054,7 +1050,7 @@ def args_info(self) -> ArgsInfo: def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIAction: size = HeadCommand.parse_size(arg) - return CLIFlow(lambda in_stream: in_stream | pipe.takelast(size)) + return CLIFlow(lambda in_stream: Stream(in_stream).take_last(size)) class CountCommand(SearchCLIPart): @@ -1145,9 +1141,8 @@ def inc_identity(_: Any) -> None: fn = inc_prop if arg else inc_identity async def count_in_stream(content: JsStream) -> AsyncIterator[JsonElement]: - async with content.stream() as in_stream: - async for element in in_stream: - fn(element) + async for element in content: + fn(element) for key, value in sorted(counter.items(), key=lambda x: x[1]): yield f"{key}: {value}" @@ -1194,7 +1189,7 @@ def args_info(self) -> ArgsInfo: def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLISource: return CLISource.single( - lambda: stream.just(strip_quotes(arg if arg else "")), required_permissions={Permission.read} + lambda: Stream.just(strip_quotes(arg if arg else "")), required_permissions={Permission.read} ) @@ -1256,7 +1251,7 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa else: raise AttributeError(f"json does not understand {arg}.") return CLISource.with_count( - lambda: stream.iterate(elements), len(elements), required_permissions={Permission.read} + lambda: Stream.iterate(elements), len(elements), required_permissions={Permission.read} ) @@ -1339,19 +1334,17 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa async def to_count(in_stream: JsStream) -> AsyncIterator[JsonElement]: null_value = 0 total = 0 - in_streamer = in_stream if isinstance(in_stream, Stream) else stream.iterate(in_stream) - async with in_streamer.stream() as streamer: - async for elem in streamer: - name = js_value_at(elem, name_path) - count = js_value_get(elem, count_path, 0) - if name is None: - null_value = count - else: - total += count - yield f"{name}: {count}" - tm, tu = (total, null_value) if arg else (null_value + total, 0) - yield f"total matched: {tm}" - yield f"total unmatched: {tu}" + async for elem in in_stream: + name = js_value_at(elem, name_path) + count = js_value_get(elem, count_path, 0) + if name is None: + null_value = count + else: + total += count + yield f"{name}: {count}" + tm, tu = (total, null_value) if arg else (null_value + total, 0) + yield f"total matched: {tm}" + yield f"total unmatched: {tu}" return CLIFlow(to_count) @@ -1550,7 +1543,7 @@ def args_info(self) -> ArgsInfo: return [] def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLISource: - return CLISource.with_count(lambda: stream.just(ctx.env), len(ctx.env), required_permissions={Permission.read}) + return CLISource.with_count(lambda: Stream.just(ctx.env), len(ctx.env), required_permissions={Permission.read}) class ChunkCommand(CLICommand): @@ -1599,7 +1592,10 @@ def args_info(self) -> ArgsInfo: def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIFlow: size = int(arg) if arg else 100 - return CLIFlow(lambda in_stream: in_stream | pipe.chunks(size), required_permissions={Permission.read}) + return CLIFlow( + lambda in_stream: Stream(in_stream).chunks(size), + required_permissions={Permission.read}, + ) class FlattenCommand(CLICommand): @@ -1646,13 +1642,7 @@ def args_info(self) -> ArgsInfo: return [] def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIFlow: - def iterable(it: Any) -> bool: - return False if isinstance(it, str) else isinstance(it, Iterable) - - def iterate(it: Any) -> JsGen: - return stream.iterate(it) if is_async_iterable(it) or iterable(it) else stream.just(it) - - return CLIFlow(lambda i: i | pipe.flatmap(iterate), required_permissions={Permission.read}) # type: ignore + return CLIFlow(lambda i: Stream(i).flatten(), required_permissions={Permission.read}) class UniqCommand(CLICommand): @@ -1709,7 +1699,7 @@ def has_not_seen(item: Any) -> bool: visited.add(item) return True - return CLIFlow(lambda in_stream: stream.filter(in_stream, has_not_seen), required_permissions={Permission.read}) + return CLIFlow(lambda in_stream: Stream(in_stream).filter(has_not_seen), required_permissions={Permission.read}) class JqCommand(CLICommand, OutputTransformer): @@ -1809,7 +1799,7 @@ def process(in_json: JsonElement) -> JsonElement: result = out[0] if len(out) == 1 else out return cast(Json, result) - return CLIFlow(lambda i: i | pipe.map(process), required_permissions={Permission.read}) # type: ignore + return CLIFlow(lambda i: Stream(i).map(process), required_permissions={Permission.read}) class KindsCommand(CLICommand, PreserveOutputFormat): @@ -1962,16 +1952,16 @@ def show(k: ComplexKind) -> bool: result: JsonElement = ( kind_to_js(model, model[kind]) if kind in model else f"No kind with this name: {kind}" ) - return 1, stream.just(result) + return 1, Stream.just(result) elif args.property_path: no_section = Section.without_section(args.property_path) result = kind_to_js(model, model.kind_by_path(no_section)) if appears_in := property_defined_in(model, no_section): result["appears_in"] = appears_in - return 1, stream.just(result) + return 1, Stream.just(result) else: result = sorted([k.fqn for k in model.kinds.values() if isinstance(k, ComplexKind) and show(k)]) - return len(model.kinds), stream.iterate(result) + return len(model.kinds), Stream.iterate(result) return CLISource.only_count(source, required_permissions={Permission.read}) @@ -1986,7 +1976,8 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa buffer_size = 1000 func = partial(self.set_desired, arg, ctx.graph_name, self.patch(arg, ctx)) return CLIFlow( - lambda i: i | pipe.chunks(buffer_size) | pipe.flatmap(func), required_permissions={Permission.write} + lambda i: Stream(i).chunks(buffer_size).flatmap(func, task_limit=10, ordered=False), + required_permissions={Permission.write}, ) async def set_desired( @@ -2124,7 +2115,8 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa buffer_size = 1000 func = partial(self.set_metadata, ctx.graph_name, self.patch(arg, ctx)) return CLIFlow( - lambda i: i | pipe.chunks(buffer_size) | pipe.flatmap(func), required_permissions={Permission.write} + lambda i: Stream(i).chunks(buffer_size).flatmap(func, task_limit=10, ordered=False), + required_permissions={Permission.write}, ) async def set_metadata(self, graph_name: GraphName, patch: Json, items: List[Json]) -> AsyncIterator[JsonElement]: @@ -2331,9 +2323,8 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa use = next(iter(format_to_use)) async def render_single(converter: ConvertFn, iss: JsStream) -> JsGen: - async with iss.stream() as streamer: - async for elem in converter(streamer): - yield elem + async for elem in converter(iss): + yield elem async def format_stream(in_stream: JsStream) -> JsGen: if use: @@ -2344,7 +2335,7 @@ async def format_stream(in_stream: JsStream) -> JsGen: else: raise ValueError(f"Unknown format: {use}") elif formatting_string: - return in_stream | pipe.map(ctx.formatter(arg)) if arg else in_stream # type: ignore + return in_stream.map(ctx.formatter(arg)) if arg else in_stream # type: ignore else: return in_stream @@ -2817,14 +2808,13 @@ def to_csv_string(lst: List[Any]) -> str: header_values = [prop.name for prop in props] yield to_csv_string(header_values) - async with in_stream.stream() as s: - async for elem in s: - if is_node(elem) or is_aggregate: - result = [] - for prop in props: - value = prop.value(elem) - result.append(value) - yield to_csv_string(result) + async for elem in in_stream: + if is_node(elem) or is_aggregate: + result = [] + for prop in props: + value = prop.value(elem) + result.append(value) + yield to_csv_string(result) async def json_table_stream(in_stream: JsStream, model: QueryModel) -> JsGen: def kind_of(path: str) -> Kind: @@ -2857,13 +2847,12 @@ def render_prop(elem: JsonElement) -> JsonElement: ], } # data columns - async with in_stream.stream() as s: - async for elem in s: - if isinstance(elem, dict) and (is_node(elem) or is_aggregate): - yield { - "id": None if is_aggregate else elem["id"], # aggregates have no id - "row": {prop.name: render_prop(prop.value(elem)) for prop in props}, - } + async for elem in in_stream: + if isinstance(elem, dict) and (is_node(elem) or is_aggregate): + yield { + "id": None if is_aggregate else elem["id"], # aggregates have no id + "row": {prop.name: render_prop(prop.value(elem)) for prop in props}, + } def markdown_stream(in_stream: JsStream) -> JsGen: chunk_size = 500 @@ -2922,12 +2911,11 @@ def to_str(elem: Any) -> str: # noinspection PyUnresolvedReferences markdown_chunks = ( - in_stream - | pipe.filter(lambda x: is_node(x) or is_aggregate) - | pipe.map(extract_values) # type: ignore - | pipe.chunks(chunk_size) - | pipe.enumerate() - | pipe.flatmap(generate_markdown) # type: ignore + in_stream.filter(lambda x: is_node(x) or is_aggregate) + .map(extract_values) + .chunks(chunk_size) + .enumerate() + .flatmap(generate_markdown) ) return markdown_chunks @@ -2943,12 +2931,9 @@ async def load_model() -> QueryModel: model = await self.dependencies.model_handler.load_model(ctx.graph_name) return QueryModel(ctx.query or Query.empty(), model, ctx.env) - return stream.call(load_model) | pipe.flatmap(partial(json_table_stream, in_stream)) # type: ignore + return Stream.call(load_model).flatmap(partial(json_table_stream, in_stream)) # type: ignore else: - return stream.map( - in_stream, - lambda elem: fmt_json(elem) if isinstance(elem, dict) else str(elem), # type: ignore - ) + return Stream(in_stream).map(lambda elem: fmt_json(elem) if isinstance(elem, dict) else str(elem)) return CLIFlow(fmt, produces=MediaType.String, required_permissions={Permission.read}) @@ -3208,7 +3193,7 @@ async def activate_deactivate_job(job_id: str, active: bool) -> AsyncIterator[Js async def running_jobs() -> Tuple[int, JsStream]: tasks = await self.dependencies.task_handler.running_tasks() - return len(tasks), stream.iterate( + return len(tasks), Stream.iterate( {"job": t.descriptor.id, "started_at": to_json(t.task_started_at), "task-id": t.id} for t in tasks if isinstance(t.descriptor, Job) @@ -3271,7 +3256,7 @@ async def send_to_queue(task_name: str, task_args: Dict[str, str], data: Json) - await self.dependencies.forked_tasks.put((result_task, f"WorkerTask {task_name}:{task.id}")) return f"Spawned WorkerTask {task_name}:{task.id}" - return in_stream | pipe.starmap(send_to_queue, ordered=False, task_limit=self.task_limit()) # type: ignore + return in_stream.starmap(send_to_queue, ordered=False, task_limit=self.task_limit()) def load_by_id_merged( self, @@ -3307,7 +3292,7 @@ async def load_element(items: List[JsonElement]) -> AsyncIterator[JsonElement]: async for a in crs: yield a - return stream.chunks(in_stream, 1000) | pipe.flatmap(load_element) # type: ignore + return in_stream.chunks(1000).flatmap(load_element, task_limit=10) async def no_update(self, _: WorkerTask, future_result: Future[Json]) -> Json: return await future_result @@ -3448,17 +3433,17 @@ def setup_stream(in_stream: JsStream) -> JsStream: def with_dependencies(model: Model) -> JsStream: load = self.load_by_id_merged(model, in_stream, variables, allowed_on_kind, **ctx.env) handler = self.update_node_in_graphdb(model, **ctx.env) if expect_node_result else self.no_update - return self.send_to_queue_stream(load | pipe.map(fn), handler, True) # type: ignore + return self.send_to_queue_stream(load.map(fn), handler, True) # type: ignore # dependencies are not resolved directly (no async function is allowed here) async def load_model() -> Model: return await self.dependencies.model_handler.load_model(ctx.graph_name) - return stream.call(load_model) | pipe.flatmap(with_dependencies) # type: ignore + return Stream.call(load_model).flatmap(with_dependencies) # type: ignore def setup_source() -> JsStream: arg = {"args": args_parts_unquoted_parser.parse(formatter({}))} - return self.send_to_queue_stream(stream.just((command_name, {}, arg)), self.no_update, True) + return self.send_to_queue_stream(Stream.just((command_name, {}, arg)), self.no_update, True) return ( CLISource.single(setup_source, required_permissions={Permission.write}) @@ -3575,13 +3560,13 @@ def setup_stream(in_stream: JsStream) -> JsStream: def with_dependencies(model: Model) -> JsStream: load = self.load_by_id_merged(model, in_stream, variables, **ctx.env) result_handler = self.update_node_in_graphdb(model, **ctx.env) - return self.send_to_queue_stream(load | pipe.map(fn), result_handler, not ns.nowait) # type: ignore + return self.send_to_queue_stream(load.map(fn), result_handler, not ns.nowait) # type: ignore async def load_model() -> Model: return await self.dependencies.model_handler.load_model(ctx.graph_name) # dependencies are not resolved directly (no async function is allowed here) - return stream.call(load_model) | pipe.flatmap(with_dependencies) # type: ignore + return Stream.call(load_model).flatmap(with_dependencies) # type: ignore return CLIFlow(setup_stream, required_permissions={Permission.write}) @@ -3594,7 +3579,7 @@ def file_command() -> JsStream: elif not os.path.exists(arg): raise AttributeError(f"file does not exist: {arg}!") else: - return stream.just(arg if arg else "") + return Stream.just(arg if arg else "") return CLISource.single(file_command, MediaType.FilePath, required_permissions={Permission.admin}) @@ -3618,7 +3603,7 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa def upload_command() -> JsStream: if file_id in ctx.uploaded_files: file = ctx.uploaded_files[file_id] - return stream.just(f"Received file {file} of size {os.path.getsize(file)}") + return Stream.just(f"Received file {file} of size {os.path.getsize(file)}") else: raise AttributeError(f"file was not uploaded: {arg}!") @@ -3932,19 +3917,17 @@ async def write_result_to_file(ctx: CLIContext, in_stream: JsStream, file_name: async with TemporaryDirectory() as temp_dir: path = file_name if ctx.intern else os.path.join(temp_dir, uuid_str()) async with aiofiles.open(path, "w") as f: - async with in_stream.stream() as streamer: - async for out in streamer: - if isinstance(out, str): - await f.write(out + "\n") - else: - raise AttributeError("No output format is defined! Consider to use the format command.") + async for out in in_stream: + if isinstance(out, str): + await f.write(out + "\n") + else: + raise AttributeError("No output format is defined! Consider to use the format command.") yield FilePath.user_local(user=file_name, local=path).json() @staticmethod async def already_file_stream(in_stream: JsStream, file_name: str) -> AsyncIterator[JsonElement]: - async with in_stream.stream() as streamer: - async for out in streamer: - yield evolve(FilePath.from_path(out), user=Path(file_name)).json() + async for out in in_stream: + yield evolve(FilePath.from_path(out), user=Path(file_name)).json() def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIAction: if arg is None: @@ -4040,7 +4023,7 @@ async def get_template(name: str) -> AsyncIterator[JsonElement]: async def list_templates() -> Tuple[int, Stream[str]]: templates = await self.dependencies.template_expander.list_templates() - return len(templates), stream.iterate(template_str(t) for t in templates) + return len(templates), Stream.iterate(template_str(t) for t in templates) async def put_template(name: str, template_query: str) -> AsyncIterator[str]: # try to render_console the template with dummy values and see if the search can be parsed @@ -4283,10 +4266,9 @@ async def perform_request(e: JsonElement) -> int: async def iterate_stream(in_stream: JsStream) -> AsyncIterator[JsonElement]: results: Dict[int, int] = defaultdict(lambda: 0) - async with in_stream.stream() as streamer: - async for elem in streamer: - status_code = await perform_request(elem) - results[status_code] += 1 + async for elem in in_stream: + status_code = await perform_request(elem) + results[status_code] += 1 summary = ", ".join(f"{count} requests with status {status}" for status, count in results.items()) if results: yield f"{summary} sent." @@ -4514,18 +4496,18 @@ def info(rt: RunningTask) -> JsonElement: **progress, } - return len(tasks), stream.iterate(info(t) for t in tasks if isinstance(t.descriptor, Workflow)) + return len(tasks), Stream.iterate(info(t) for t in tasks if isinstance(t.descriptor, Workflow)) async def show_log(wf_id: str) -> Tuple[int, JsStream]: rtd = await self.dependencies.db_access.running_task_db.get(wf_id) if rtd: messages = [msg.info() for msg in rtd.info_messages()] if messages: - return len(messages), stream.iterate(messages) + return len(messages), Stream.iterate(messages) else: - return 0, stream.just("No error messages for this run.") + return 0, Stream.just("No error messages for this run.") else: - return 0, stream.just(f"No workflow task with this id: {wf_id}") + return 0, Stream.just(f"No workflow task with this id: {wf_id}") def running_task_data(rtd: RunningTaskData) -> Json: result = { @@ -4539,7 +4521,7 @@ def running_task_data(rtd: RunningTaskData) -> Json: async def history_aggregation() -> JsStream: info = await self.dependencies.db_access.running_task_db.aggregated_history() - return stream.just(info) + return Stream.just(info) async def history_of(history_args: List[str]) -> Tuple[int, JsStream]: parser = NoExitArgumentParser() @@ -4558,7 +4540,7 @@ async def history_of(history_args: List[str]) -> Tuple[int, JsStream]: ) cursor: AsyncCursor = context.cursor try: - return cursor.count() or 0, stream.map(cursor, running_task_data) # type: ignore + return cursor.count() or 0, Stream(cursor).map(running_task_data) finally: cursor.close() @@ -4591,7 +4573,7 @@ async def stop_workflow(task_id: TaskId) -> AsyncIterator[str]: return CLISource.only_count(list_workflows, required_permissions={Permission.read}) else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -4763,7 +4745,7 @@ async def update_config(cfg_id: ConfigId) -> AsyncIterator[str]: async def list_configs() -> Tuple[int, JsStream]: ids = [i async for i in self.dependencies.config_handler.list_config_ids()] - return len(ids), stream.iterate(ids) + return len(ids), Stream.iterate(ids) args = re.split("\\s+", arg, maxsplit=2) if arg else [] if arg and len(args) == 2 and (args[0] == "show" or args[0] == "get"): @@ -4800,7 +4782,7 @@ async def list_configs() -> Tuple[int, JsStream]: return CLISource.only_count(list_configs, required_permissions={Permission.read}) else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -4889,7 +4871,7 @@ async def welcome() -> str: res = ctx.render_console(grid) return res - return CLISource.single(lambda: stream.just(welcome()), required_permissions={Permission.read}) # type: ignore + return CLISource.single(lambda: Stream.just(welcome()), required_permissions={Permission.read}) class TipOfTheDayCommand(CLICommand): @@ -4926,7 +4908,7 @@ async def totd() -> str: res = ctx.render_console(info) return res - return CLISource.single(lambda: stream.just(totd()), required_permissions={Permission.read}) # type: ignore + return CLISource.single(lambda: Stream.just(totd()), required_permissions={Permission.read}) class CertificateCommand(CLICommand): @@ -5004,7 +4986,7 @@ async def create_certificate( ) else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -5435,9 +5417,8 @@ async def app_run( raise ValueError(f"Config {config} not found.") async def stream_to_iterator() -> AsyncIterator[JsonElement]: - async with in_stream.stream() as streamer: - async for item in streamer: - yield item + async for item in in_stream: + yield item stdin = stream_to_iterator() if dry_run: @@ -5531,7 +5512,7 @@ async def stream_to_iterator() -> AsyncIterator[JsonElement]: return CLISource.no_count( partial( app_run, - in_stream=stream.empty(), + in_stream=Stream.empty(), app_name=InfraAppName(parsed.app_name), dry_run=parsed.dry_run, config=parsed.config, @@ -5552,7 +5533,7 @@ async def stream_to_iterator() -> AsyncIterator[JsonElement]: ) else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -5757,7 +5738,7 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa return CLISource.no_count(partial(self.show_user, args[1]), required_permissions={Permission.read}) else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -5950,7 +5931,7 @@ async def lines_iterator() -> AsyncIterator[str]: else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -6102,7 +6083,7 @@ async def sync_database_result(p: Namespace, maybe_stream: Optional[JsStream]) - async with await graph_db.search_graph_gen( QueryModel(query, fix_model, ctx.env), timeout=timedelta(weeks=200000) ) as cursor: - await sync_fn(query=query, in_stream=stream.iterate(cursor)) + await sync_fn(query=query, in_stream=Stream.iterate(cursor)) if file_output is not None: assert p.database, "No database name provided. Use the --database argument." @@ -6160,11 +6141,10 @@ def key_fn(node: Json) -> Union[str, Tuple[str, str]]: kind_by_id[node["id"]] = node["reported"]["kind"] return cast(str, node["reported"]["kind"]) - async with in_stream.stream() as streamer: - batched = BatchStream(streamer, key_fn, engine_config.batch_size, engine_config.batch_size * 10) - await update_sql( - engine_config, rcm, batched, edges, swap_temp_tables=True, drop_existing_tables=drop_existing_tables - ) + batched = BatchStream(in_stream, key_fn, engine_config.batch_size, engine_config.batch_size * 10) + await update_sql( + engine_config, rcm, batched, edges, swap_temp_tables=True, drop_existing_tables=drop_existing_tables + ) args = arg.split(maxsplit=1) if arg else [] if len(args) == 2 and args[0] == "sync": @@ -6339,16 +6319,16 @@ def parse_duration_or_int(s: str) -> Union[int, timedelta]: async def list_ts() -> Tuple[int, JsGen]: ts = await self.dependencies.db_access.time_series_db.list_time_series() - return len(ts), stream.iterate([to_js(a) for a in ts]) + return len(ts), Stream.iterate([to_js(a) for a in ts]) async def downsample() -> Tuple[int, JsGen]: ts = await self.dependencies.db_access.time_series_db.downsample() if isinstance(ts, str): - return 1, stream.just(ts) + return 1, Stream.just(ts) elif ts: - return len(ts), stream.iterate([{k: v} for k, v in ts.items()]) + return len(ts), Stream.iterate([{k: v} for k, v in ts.items()]) else: - return 1, stream.just("No time series to downsample.") + return 1, Stream.just("No time series to downsample.") args = re.split("\\s+", arg, maxsplit=1) if arg else [] if arg and len(args) == 2 and args[0] == "snapshot": @@ -6363,7 +6343,7 @@ async def downsample() -> Tuple[int, JsGen]: return CLISource.only_count(downsample, required_permissions={Permission.read}) else: return CLISource.single( - lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} + lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read} ) @@ -6459,29 +6439,28 @@ def walk_element(el: JsonElement) -> Iterator[Tuple[str, PotentialSecret]]: if r.startswith("True"): yield el, secret - async def detect_secrets_in(content: JsStream) -> JsGen: + async def detect_secrets_in(in_stream: JsStream) -> JsGen: self.configure_detect() # make sure all plugins are loaded - async with content.stream() as in_stream: - async for element in in_stream: - paths = [p for pl in parsed.path for p in pl] - paths = paths or [PropertyPath.from_list([ctx.section]) if is_node(element) else EmptyPath] - found_secrets = False - for path in paths: - if to_check_js := path.value_in(element): - for secret_string, possible_secret in walk_element(to_check_js): - found_secrets = True - if isinstance(element, dict): - element["info"] = { - "secret_detected": True, - "potential_secret": secret_string, - "secret_type": possible_secret.type, - } - yield element - break - if found_secrets: - break # no need to check other paths - if not found_secrets and not parsed.with_secrets: - yield element + async for element in in_stream: + paths = [p for pl in parsed.path for p in pl] + paths = paths or [PropertyPath.from_list([ctx.section]) if is_node(element) else EmptyPath] + found_secrets = False + for path in paths: + if to_check_js := path.value_in(element): + for secret_string, possible_secret in walk_element(to_check_js): + found_secrets = True + if isinstance(element, dict): + element["info"] = { + "secret_detected": True, + "potential_secret": secret_string, + "secret_type": possible_secret.type, + } + yield element + break + if found_secrets: + break # no need to check other paths + if not found_secrets and not parsed.with_secrets: + yield element return CLIFlow(detect_secrets_in) @@ -6535,7 +6514,7 @@ async def load_model() -> Model: def setup_stream(in_stream: JsStream) -> JsStream: def with_dependencies(model: Model) -> JsStream: - async def process_element(el: JsonElement) -> JsonElement: + def process_element(el: JsonElement) -> JsonElement: if ( is_node(el) and (fqn := value_in_path(el, NodePath.reported_kind)) @@ -6546,9 +6525,9 @@ async def process_element(el: JsonElement) -> JsonElement: set_value_in_path(refinement.value, refinement.path, el) # type: ignore return el - return in_stream | pipe.map(process_element) # type: ignore + return in_stream.map(process_element) - return stream.call(load_model) | pipe.flatmap(with_dependencies) # type: ignore + return Stream.call(load_model).flatmap(with_dependencies) # type: ignore return CLIFlow(setup_stream, required_permissions={Permission.read}) @@ -6590,7 +6569,7 @@ async def delete_node(node_id: NodeId, keep_history: bool) -> AsyncIterator[str] fn=partial(delete_node, node_id=parsed.node_id, keep_history=parsed.keep_history), required_permissions={Permission.write}, ) - return CLISource.single(lambda: stream.just(self.rendered_help(ctx)), required_permissions={Permission.read}) + return CLISource.single(lambda: Stream.just(self.rendered_help(ctx)), required_permissions={Permission.read}) def all_commands(d: TenantDependencies) -> List[CLICommand]: diff --git a/fixcore/fixcore/cli/model.py b/fixcore/fixcore/cli/model.py index f4afcde957..ec2cac1e45 100644 --- a/fixcore/fixcore/cli/model.py +++ b/fixcore/fixcore/cli/model.py @@ -26,8 +26,6 @@ TYPE_CHECKING, ) -from aiostream import stream -from aiostream.core import Stream from attrs import define, field from parsy import test_char, string from rich.jupyter import JupyterMixin @@ -42,6 +40,7 @@ from fixcore.query.template_expander import render_template from fixcore.types import Json, JsonElement from fixcore.util import AccessJson, uuid_str, from_utc, utc, utc_str +from fixlib.asynchronous.stream import Stream from fixlib.parse_util import l_curly_dp, r_curly_dp from fixlib.utils import get_local_tzinfo @@ -236,7 +235,7 @@ def __init__( @staticmethod def make_stream(in_stream: JsGen) -> JsStream: - return in_stream if isinstance(in_stream, Stream) else stream.iterate(in_stream) + return in_stream if isinstance(in_stream, Stream) else Stream.iterate(in_stream) @define @@ -316,7 +315,7 @@ def single( @staticmethod def empty() -> CLISource: - return CLISource.with_count(stream.empty, 0) + return CLISource.with_count(Stream.empty, 0) class CLIFlow(CLIAction): @@ -739,7 +738,7 @@ async def execute(self) -> Tuple[CLISourceContext, JsStream]: flow = await flow_action.flow(flow) return context, flow else: - return CLISourceContext(count=0), stream.empty() + return CLISourceContext(count=0), Stream.empty() class CLI(ABC): diff --git a/fixcore/fixcore/db/graphdb.py b/fixcore/fixcore/db/graphdb.py index 14b997c8a5..0424414776 100644 --- a/fixcore/fixcore/db/graphdb.py +++ b/fixcore/fixcore/db/graphdb.py @@ -23,7 +23,6 @@ Union, ) -from aiostream import stream, pipe from arango import AnalyzerGetError from arango.collection import VertexCollection, StandardCollection, EdgeCollection from arango.graph import Graph @@ -67,6 +66,7 @@ set_value_in_path, if_set, ) +from fixlib.asynchronous.stream import Stream log = logging.getLogger(__name__) @@ -675,9 +675,8 @@ async def move_security_temp_to_proper() -> None: try: # stream updates to the temp collection - async with (stream.iterate(iterator) | pipe.chunks(1000)).stream() as streamer: - async for part in streamer: - await update_chunk(dict(part)) + async for part in Stream.iterate(iterator).chunks(1000): + await update_chunk(dict(part)) # move temp collection to proper and history collection await move_security_temp_to_proper() finally: diff --git a/fixcore/fixcore/infra_apps/local_runtime.py b/fixcore/fixcore/infra_apps/local_runtime.py index c2d2376291..4ca160e8a1 100644 --- a/fixcore/fixcore/infra_apps/local_runtime.py +++ b/fixcore/fixcore/infra_apps/local_runtime.py @@ -3,7 +3,6 @@ from pydoc import locate from typing import List, AsyncIterator, Type, Optional, Any -from aiostream import stream, pipe from jinja2 import Environment from fixcore.cli import NoExitArgumentParser, JsStream, JsGen @@ -14,6 +13,7 @@ from fixcore.infra_apps.runtime import Runtime from fixcore.service import Service from fixcore.types import Json, JsonElement +from fixlib.asynchronous.stream import Stream from fixlib.asynchronous.utils import async_lines from fixlib.durations import parse_optional_duration @@ -46,9 +46,8 @@ async def execute( Runtime implementation that runs the app locally. """ async for line in self.generate_template(graph, manifest, config, stdin, argv): - async with (await self._interpret_line(line, ctx)).stream() as streamer: - async for item in streamer: - yield item + async for item in await self._interpret_line(line, ctx): + yield item async def generate_template( self, @@ -117,4 +116,4 @@ async def _interpret_line(self, line: str, ctx: CLIContext) -> JsStream: total_nr_outputs = total_nr_outputs + (src_ctx.count or 0) command_streams.append(command_output_stream) - return stream.iterate(command_streams) | pipe.concat(task_limit=1) + return Stream.iterate(command_streams).concat() # type: ignore diff --git a/fixcore/fixcore/model/db_updater.py b/fixcore/fixcore/model/db_updater.py index cab00efb37..9a29c95573 100644 --- a/fixcore/fixcore/model/db_updater.py +++ b/fixcore/fixcore/model/db_updater.py @@ -13,11 +13,9 @@ from multiprocessing import Process, Queue from pathlib import Path from queue import Empty -from typing import Optional, Union, Any, Generator, List, AsyncIterator, Dict +from typing import Optional, Union, Any, List, AsyncIterator, Dict import aiofiles -from aiostream import stream, pipe -from aiostream.core import Stream from attrs import define from fixcore.analytics import AnalyticsEventSender, InMemoryEventSender, AnalyticsEvent @@ -36,6 +34,7 @@ from fixcore.system_start import db_access, setup_process, reset_process_start_method from fixcore.types import Json from fixcore.util import utc, uuid_str, shutdown_process +from fixlib.asynchronous.stream import Stream log = logging.getLogger(__name__) @@ -56,9 +55,9 @@ class ReadFile(ProcessAction): path: Path task_id: Optional[str] - def jsons(self) -> Generator[Json, Any, None]: - with open(self.path, "r", encoding="utf-8") as f: - for line in f: + async def jsons(self) -> AsyncIterator[Json]: + async with aiofiles.open(self.path, "r", encoding="utf-8") as f: + async for line in f: if line.strip(): yield json.loads(line) @@ -75,8 +74,8 @@ class ReadElement(ProcessAction): elements: List[Union[bytes, Json]] task_id: Optional[str] - def jsons(self) -> Generator[Json, Any, None]: - return (e if isinstance(e, dict) else json.loads(e) for e in self.elements) + def jsons(self) -> AsyncIterator[Json]: + return Stream.iterate(self.elements).map(lambda e: e if isinstance(e, dict) else json.loads(e)) @define @@ -125,15 +124,15 @@ def get_value(self) -> GraphUpdate: class DbUpdaterProcess(Process): """ - This update class implements Process and is supposed to run as separate process. + This update class implements Process and is supposed to run as a separate process. Note: default starting method is supposed to be "spawn". This process has 2 queues to read input from and write output to. - All elements in either queues are of type ProcessAction. + All elements in all queues are of type ProcessAction. The parent process should stream the raw commands of graph to this process via ReadElement objects. Once the MergeGraph action is received, the graph gets imported. - From here the parent expects result messages from the child. + From here, the parent expects result messages from the child. All events happen in the child are forwarded to the parent via EmitEvent. Once the graph update is done, a result is send. The result is either an exception in case of failure or a graph update in success case. @@ -156,8 +155,8 @@ def __init__( def next_action(self) -> ProcessAction: try: - # graph is read into memory. If the sender does not send data in a given amount of time, - # we raise an exception and abort the update. + # The graph is read into memory. + # If the sender does not send data in a given amount of time, we raise an exception and abort the update. return self.read_queue.get(True, 90) except Empty as ex: raise ImportAborted("Merge process did not receive any data for more than 90 seconds. Abort.") from ex @@ -168,12 +167,12 @@ async def merge_graph(self, db: DbAccess) -> GraphUpdate: # type: ignore builder = GraphBuilder(model, self.change_id) nxt = self.next_action() if isinstance(nxt, ReadFile): - for element in nxt.jsons(): + async for element in nxt.jsons(): builder.add_from_json(element) nxt = self.next_action() elif isinstance(nxt, ReadElement): while isinstance(nxt, ReadElement): - for element in nxt.jsons(): + async for element in nxt.jsons(): builder.add_from_json(element) log.debug(f"Read {int(BatchSize / 1000)}K elements in process") nxt = self.next_action() @@ -276,16 +275,11 @@ async def __process_item(self, item: GraphUpdateTask) -> Union[GraphUpdate, Exce async def start(self) -> None: async def wait_for_update() -> None: log.info("Start waiting for graph updates") - fl = ( - stream.call(self.update_queue.get) # type: ignore - | pipe.cycle() - | pipe.map(self.__process_item, task_limit=self.config.graph.parallel_imports) # type: ignore - ) + fl = Stream.for_ever(self.update_queue.get).map(self.__process_item, task_limit=self.config.graph.parallel_imports) # type: ignore # noqa with suppress(CancelledError): - async with fl.stream() as streamer: - async for update in streamer: - if isinstance(update, GraphUpdate): - log.info(f"Finished spawned graph merge: {update}") + async for update in fl: + if isinstance(update, GraphUpdate): + log.info(f"Finished spawned graph merge: {update}") self.handler_task = asyncio.create_task(wait_for_update()) @@ -373,19 +367,17 @@ async def read_forever() -> GraphUpdate: task: Optional[Task[GraphUpdate]] = None result: Optional[GraphUpdate] = None try: - reset_process_start_method() # other libraries might have tampered the value in the mean time + reset_process_start_method() # other libraries might have tampered the value in the meantime updater.start() task = read_results() # concurrently read result queue # Either send a file or stream the content directly if isinstance(content, Path): await send_to_child(ReadFile(content, task_id)) else: - chunked: Stream[List[Union[bytes, Json]]] = stream.chunks(content, BatchSize) # type: ignore - async with chunked.stream() as streamer: - async for lines in streamer: - if not await send_to_child(ReadElement(lines, task_id)): - # in case the child is dead, we should stop - break + async for lines in Stream.iterate(content).chunks(BatchSize): + if not await send_to_child(ReadElement(lines, task_id)): + # in case the child is dead, we should stop + break await send_to_child(MergeGraph(db.name, change_id, maybe_batch is not None, task_id)) result = await task # wait for final result await self.model_handler.load_model(db.name, force=True) # reload model to get the latest changes diff --git a/fixcore/fixcore/report/benchmark_renderer.py b/fixcore/fixcore/report/benchmark_renderer.py index af40975da7..babd149d5d 100644 --- a/fixcore/fixcore/report/benchmark_renderer.py +++ b/fixcore/fixcore/report/benchmark_renderer.py @@ -1,6 +1,5 @@ from typing import AsyncGenerator, List, AsyncIterable -from aiostream import stream from networkx import DiGraph from rich._emoji_codes import EMOJI @@ -91,27 +90,26 @@ def render_check_result(check_result: CheckResult, account: str) -> str: async def respond_benchmark_result(gen: AsyncIterable[JsonElement]) -> AsyncGenerator[str, None]: - # step 1: read graph + # step 1: read graph graph = DiGraph() - async with stream.iterate(gen).stream() as streamer: - async for item in streamer: - if isinstance(item, dict): - type_name = item.get("type") - if type_name == "node": - uid = value_in_path(item, NodePath.node_id) - reported = value_in_path(item, NodePath.reported) - kind = value_in_path(item, NodePath.reported_kind) - if uid and reported and kind and (reader := kind_reader.get(kind)): - graph.add_node(uid, data=reader(item)) - elif type_name == "edge": - from_node = value_in_path(item, NodePath.from_node) - to_node = value_in_path(item, NodePath.to_node) - if from_node and to_node: - graph.add_edge(from_node, to_node) - else: - raise AttributeError(f"Expect json object but got: {type(item)}: {item}") + async for item in gen: + if isinstance(item, dict): + type_name = item.get("type") + if type_name == "node": + uid = value_in_path(item, NodePath.node_id) + reported = value_in_path(item, NodePath.reported) + kind = value_in_path(item, NodePath.reported_kind) + if uid and reported and kind and (reader := kind_reader.get(kind)): + graph.add_node(uid, data=reader(item)) + elif type_name == "edge": + from_node = value_in_path(item, NodePath.from_node) + to_node = value_in_path(item, NodePath.to_node) + if from_node and to_node: + graph.add_edge(from_node, to_node) else: raise AttributeError(f"Expect json object but got: {type(item)}: {item}") + else: + raise AttributeError(f"Expect json object but got: {type(item)}: {item}") # step 2: read benchmark result from graph def traverse(node_id: str, collection: CheckCollectionResult) -> None: diff --git a/fixcore/fixcore/report/inspector_service.py b/fixcore/fixcore/report/inspector_service.py index ceb9c8f5b4..f3b70cc77e 100644 --- a/fixcore/fixcore/report/inspector_service.py +++ b/fixcore/fixcore/report/inspector_service.py @@ -3,8 +3,6 @@ from functools import lru_cache from typing import Optional, List, Dict, Tuple, Callable, AsyncIterator, cast, Set -from aiostream import stream, pipe -from aiostream.core import Stream from attr import define from fixcore.analytics import CoreEvent @@ -40,6 +38,7 @@ from fixcore.service import Service from fixcore.types import Json from fixcore.util import value_in_path, uuid_str, value_in_path_get +from fixlib.asynchronous.stream import Stream from fixlib.json_bender import Bender, S, bend log = logging.getLogger(__name__) @@ -380,7 +379,7 @@ async def list_failing_resources( async def __list_failing_resources( self, graph: GraphName, model: Model, inspection: ReportCheck, context: CheckContext ) -> AsyncIterator[Json]: - # final environment: defaults are coming from the check and are eventually overriden in the config + # final environment: defaults are coming from the check and are eventually overridden in the config env = inspection.environment(context.override_values()) account_id_prop = "ancestors.account.reported.id" ignore_prop = "metadata.security_ignore" @@ -484,7 +483,7 @@ def to_result(cc: CheckCollection) -> CheckCollectionResult: node_id=next_node_id(), ) - async def __perform_checks( # type: ignore + async def __perform_checks( self, graph: GraphName, checks: List[ReportCheck], context: CheckContext ) -> Dict[str, SingleCheckResult]: # load model @@ -493,11 +492,10 @@ async def __perform_checks( # type: ignore async def perform_single(check: ReportCheck) -> Tuple[str, SingleCheckResult]: return check.id, await self.__perform_check(graph, model, check, context) - check_results: Stream[Tuple[str, SingleCheckResult]] = stream.iterate(checks) | pipe.map( - perform_single, ordered=False, task_limit=context.parallel_checks # type: ignore + check_results: Stream[Tuple[str, SingleCheckResult]] = Stream.iterate(checks).map( + perform_single, ordered=False, task_limit=context.parallel_checks ) - async with check_results.stream() as streamer: - return {key: value async for key, value in streamer} + return {key: value async for key, value in check_results} async def __perform_check( self, graph: GraphName, model: Model, inspection: ReportCheck, context: CheckContext diff --git a/fixcore/fixcore/task/task_handler.py b/fixcore/fixcore/task/task_handler.py index d8677fee54..4d7073afe1 100644 --- a/fixcore/fixcore/task/task_handler.py +++ b/fixcore/fixcore/task/task_handler.py @@ -8,7 +8,6 @@ from copy import copy from datetime import timedelta from typing import Optional, Any, Callable, Union, Sequence, Dict, List, Tuple -from aiostream import stream from attrs import evolve from fixcore.analytics import AnalyticsEventSender, CoreEvent @@ -57,6 +56,7 @@ ) from fixcore.util import first, Periodic, group_by, utc_str, utc, partition_by from fixcore.types import Json +from fixlib.asynchronous.stream import Stream log = logging.getLogger(__name__) @@ -89,7 +89,7 @@ def __init__( # note: the waiting queue is kept in memory and lost when the service is restarted. self.start_when_done: Dict[str, TaskDescription] = {} - # Step1: define all workflows and jobs in code: later it will be persisted and read from database + # Step1: define all workflows and jobs in code: later it will be persisted and read from the database self.task_descriptions: Sequence[TaskDescription] = [*self.known_workflows(config), *self.known_jobs()] self.tasks: Dict[TaskId, RunningTask] = {} self.message_bus_watcher: Optional[Task[None]] = None @@ -496,7 +496,7 @@ async def execute_commands() -> None: results[command] = None elif isinstance(command, ExecuteOnCLI): ctx = evolve(self.cli_context, env={**command.env, **wi.descriptor.environment}) - result = await self.cli.execute_cli_command(command.command, stream.list, ctx) # type: ignore + result = await self.cli.execute_cli_command(command.command, Stream.as_list, ctx) results[command] = result else: raise AttributeError(f"Does not understand this command: {wi.descriptor.name}: {command}") diff --git a/fixcore/fixcore/web/api.py b/fixcore/fixcore/web/api.py index e7c830160b..06b95ccd77 100644 --- a/fixcore/fixcore/web/api.py +++ b/fixcore/fixcore/web/api.py @@ -54,7 +54,6 @@ from aiohttp.web_fileresponse import FileResponse from aiohttp.web_response import json_response from aiohttp_swagger3 import SwaggerFile, SwaggerUiSettings -from aiostream import stream from attrs import evolve from dateutil import parser as date_parser from multidict import MultiDict @@ -134,6 +133,7 @@ WorkerTaskResult, WorkerTaskInProgress, ) +from fixlib.asynchronous.stream import Stream from fixlib.asynchronous.web.ws_handler import accept_websocket, clean_ws_handler from fixlib.durations import parse_duration from fixlib.jwt import encode_jwt @@ -664,7 +664,7 @@ async def perform_benchmark_on_checks(self, request: Request, deps: TenantDepend ) return await single_result(request, to_js(result)) - async def perform_benchmark(self, request: Request, deps: TenantDependencies) -> StreamResponse: # type: ignore + async def perform_benchmark(self, request: Request, deps: TenantDependencies) -> StreamResponse: benchmark = request.match_info["benchmark"] graph = GraphName(request.match_info["graph_id"]) acc = request.query.get("accounts") @@ -677,8 +677,8 @@ async def perform_benchmark(self, request: Request, deps: TenantDependencies) -> else: raise ValueError(f"Unknown action {action}. One of run or load is expected.") result_graph = results[benchmark].to_graph() - async with stream.iterate(result_graph).stream() as streamer: - return await self.stream_response_from_gen(request, streamer, count=len(result_graph)) + stream = Stream.iterate(result_graph) + return await self.stream_response_from_gen(request, stream, count=len(result_graph)) async def inspection_checks(self, request: Request, deps: TenantDependencies) -> StreamResponse: provider = request.query.get("provider") @@ -1433,7 +1433,7 @@ async def write_files(mpr: MultipartReader, tmp_dir: str) -> Dict[str, str]: if temp_dir: shutil.rmtree(temp_dir) - async def execute_parsed( # type: ignore + async def execute_parsed( self, request: Request, command: str, parsed: List[ParsedCommandLine], ctx: CLIContext ) -> StreamResponse: # what is the accepted content type @@ -1455,43 +1455,41 @@ async def execute_parsed( # type: ignore first_result = parsed[0] src_ctx, generator = await first_result.execute() # flat the results from 0 or 1 - async with generator.stream() as streamer: - gen = await force_gen(streamer) - if first_result.produces.text: - text_gen = ctx.text_generator(first_result, gen) - return await self.stream_response_from_gen( - request, - text_gen, - count=src_ctx.count, - total_count=src_ctx.total_count, - query_stats=src_ctx.stats, - additional_header=first_result.envelope, - ) - elif first_result.produces.file_path: - await mp_response.prepare(request) - await Api.multi_file_response(first_result, gen, boundary, mp_response) - await Api.close_multi_part_response(mp_response, boundary) - return mp_response - else: - raise AttributeError(f"Can not handle type: {first_result.produces}") + gen = await force_gen(generator) + if first_result.produces.text: + text_gen = ctx.text_generator(first_result, gen) + return await self.stream_response_from_gen( + request, + text_gen, + count=src_ctx.count, + total_count=src_ctx.total_count, + query_stats=src_ctx.stats, + additional_header=first_result.envelope, + ) + elif first_result.produces.file_path: + await mp_response.prepare(request) + await Api.multi_file_response(first_result, gen, boundary, mp_response) + await Api.close_multi_part_response(mp_response, boundary) + return mp_response + else: + raise AttributeError(f"Can not handle type: {first_result.produces}") elif len(parsed) > 1: await mp_response.prepare(request) for single in parsed: _, generator = await single.execute() - async with generator.stream() as streamer: - gen = await force_gen(streamer) - if single.produces.text: - with MultipartWriter(repr(single.produces), boundary) as mp: - text_gen = ctx.text_generator(single, gen) - content_type, result_stream = await result_binary_gen(request, text_gen) - mp.append_payload( - AsyncIterablePayload(result_stream, content_type=content_type, headers=single.envelope) - ) - await mp.write(mp_response, close_boundary=False) - elif single.produces.file_path: - await Api.multi_file_response(single, gen, boundary, mp_response) - else: - raise AttributeError(f"Can not handle type: {single.produces}") + gen = await force_gen(generator) + if single.produces.text: + with MultipartWriter(repr(single.produces), boundary) as mp: + text_gen = ctx.text_generator(single, gen) + content_type, result_stream = await result_binary_gen(request, text_gen) + mp.append_payload( + AsyncIterablePayload(result_stream, content_type=content_type, headers=single.envelope) + ) + await mp.write(mp_response, close_boundary=False) + elif single.produces.file_path: + await Api.multi_file_response(single, gen, boundary, mp_response) + else: + raise AttributeError(f"Can not handle type: {single.produces}") await Api.close_multi_part_response(mp_response, boundary) return mp_response else: diff --git a/fixcore/pyproject.toml b/fixcore/pyproject.toml index 5c85874c80..2ef84831f1 100644 --- a/fixcore/pyproject.toml +++ b/fixcore/pyproject.toml @@ -35,7 +35,6 @@ dependencies = [ "aiohttp-jinja2", "aiohttp-swagger3", "aiohttp[speedups]", - "aiostream", "cryptography", "deepdiff", "detect_secrets", diff --git a/fixcore/tests/fixcore/cli/command_test.py b/fixcore/tests/fixcore/cli/command_test.py index cf85a7ec38..514d0fc2c3 100644 --- a/fixcore/tests/fixcore/cli/command_test.py +++ b/fixcore/tests/fixcore/cli/command_test.py @@ -13,9 +13,9 @@ from _pytest.logging import LogCaptureFixture from aiohttp import ClientTimeout from aiohttp.web import Request -from aiostream import stream, pipe from attrs import evolve from pytest import fixture + from fixcore import version from fixcore.cli import is_node, JsStream, list_sink from fixcore.cli.cli import CLIService @@ -48,6 +48,7 @@ from fixcore.user import UsersConfigId from fixcore.util import AccessJson, utc_str, utc from fixcore.worker_task_queue import WorkerTask +from fixlib.asynchronous.stream import Stream from tests.fixcore.util_test import not_in_path @@ -279,7 +280,7 @@ async def test_list_sink(cli: CLI, dependencies: TenantDependencies) -> None: async def test_flat_sink(cli: CLI) -> None: parsed = await cli.evaluate_cli_command("json [1,2,3] | dump; json [4,5,6] | dump; json [7,8,9] | dump") expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] - assert await stream.list(stream.iterate((await p.execute())[1] for p in parsed) | pipe.concat()) == expected + assert expected == await Stream.iterate(await (await p.execute())[1].collect() for p in parsed).flatten().collect() # type: ignore @pytest.mark.asyncio @@ -315,7 +316,7 @@ async def test_format(cli: CLI) -> None: async def test_workflows_command(cli: CLIService, task_handler: TaskHandlerService, test_workflow: Workflow) -> None: async def execute(cmd: str) -> List[JsonElement]: ctx = CLIContext(cli.cli_env) - return (await cli.execute_cli_command(cmd, list_sink, ctx))[0] # type: ignore + return (await cli.execute_cli_command(cmd, list_sink, ctx))[0] assert await execute("workflows list") == ["sleep_workflow", "wait_for_collect_done", "test_workflow"] assert await execute("workflows show test_workflow") == [to_js(test_workflow)] @@ -754,15 +755,14 @@ async def test_aggregation_to_count_command(cli: CLI) -> None: @pytest.mark.asyncio async def test_system_backup_command(cli: CLI) -> None: async def check_backup(res: JsStream) -> None: - async with res.stream() as streamer: - only_one = True - async for s in streamer: - path = FilePath.from_path(s) - assert path.local.exists() - # backup should have size between 30k and 1500k (adjust size if necessary) - assert 30000 < path.local.stat().st_size < 1500000 - assert only_one - only_one = False + only_one = True + async for s in res: + path = FilePath.from_path(s) + assert path.local.exists() + # backup should have size between 30k and 1500k (adjust size if necessary) + assert 30000 < path.local.stat().st_size < 1500000 + assert only_one + only_one = False await cli.execute_cli_command("system backup create", check_backup) @@ -781,10 +781,9 @@ async def test_system_restore_command(cli: CLI, tmp_directory: str) -> None: backup = os.path.join(tmp_directory, "backup") async def move_backup(res: JsStream) -> None: - async with res.stream() as streamer: - async for s in streamer: - path = FilePath.from_path(s) - os.rename(path.local, backup) + async for s in res: + path = FilePath.from_path(s) + os.rename(path.local, backup) await cli.execute_cli_command("system backup create", move_backup) ctx = CLIContext(uploaded_files={"backup": backup}) @@ -802,11 +801,10 @@ async def test_configs_command(cli: CLI, tmp_directory: str) -> None: config_file = os.path.join(tmp_directory, "config.yml") async def check_file_is_yaml(res: JsStream) -> None: - async with res.stream() as streamer: - async for s in streamer: - assert isinstance(s, str) - with open(s, "r") as file: - yaml.safe_load(file.read()) + async for s in res: + assert isinstance(s, str) + with open(s, "r") as file: + yaml.safe_load(file.read()) # create a new config entry create_result = await cli.execute_cli_command("configs set test_config t1=1, t2=2, t3=3 ", list_sink) @@ -865,19 +863,18 @@ async def test_templates_command(cli: CLI) -> None: @pytest.mark.asyncio async def test_write_command(cli: CLI) -> None: async def check_file(res: JsStream, check_content: Optional[str] = None) -> None: - async with res.stream() as streamer: - only_one = True - async for s in streamer: - fp = FilePath.from_path(s) - assert fp.local.exists() and fp.local.is_file() - assert 1 < fp.local.stat().st_size < 100000 - assert fp.user.name.startswith("write_test") - assert only_one - only_one = False - if check_content: - with open(fp.local, "r") as file: - data = file.read() - assert data == check_content + only_one = True + async for s in res: + fp = FilePath.from_path(s) + assert fp.local.exists() and fp.local.is_file() + assert 1 < fp.local.stat().st_size < 100000 + assert fp.user.name.startswith("write_test") + assert only_one + only_one = False + if check_content: + with open(fp.local, "r") as file: + data = file.read() + assert data == check_content # result can be read as json await cli.execute_cli_command("search all limit 3 | format --json | write write_test.json ", check_file) @@ -1095,14 +1092,12 @@ async def history_count(cmd: str) -> int: @pytest.mark.asyncio async def test_aggregate(dependencies: TenantDependencies) -> None: - in_stream = stream.iterate( - [{"a": 1, "b": 1, "c": 1}, {"a": 2, "b": 1, "c": 1}, {"a": 3, "b": 2, "c": 1}, {"a": 4, "b": 2, "c": 1}] - ) - - async def aggregate(agg_str: str) -> List[JsonElement]: # type: ignore + async def aggregate(agg_str: str) -> List[JsonElement]: + in_stream = Stream.iterate( + [{"a": 1, "b": 1, "c": 1}, {"a": 2, "b": 1, "c": 1}, {"a": 3, "b": 2, "c": 1}, {"a": 4, "b": 2, "c": 1}] + ) res = AggregateCommand(dependencies).parse(agg_str) - async with (await res.flow(in_stream)).stream() as flow: - return [s async for s in flow] + return [s async for s in (await res.flow(in_stream))] assert await aggregate("b as bla, c, r.d.f.name: sum(1) as count, min(a) as min, max(a) as max") == [ {"group": {"bla": 1, "c": 1, "r.d.f.name": None}, "count": 2, "min": 1, "max": 2}, @@ -1161,11 +1156,10 @@ async def execute(cmd: str, _: Type[T]) -> List[T]: return cast(List[T], result[0]) async def check_file_is_yaml(res: JsStream) -> None: - async with res.stream() as streamer: - async for s in streamer: - assert isinstance(s, str) - with open(s, "r") as file: - yaml.safe_load(file.read()) + async for s in res: + assert isinstance(s, str) + with open(s, "r") as file: + yaml.safe_load(file.read()) # install a package assert "installed successfully" in (await execute("apps install cleanup-untagged", str))[0] @@ -1235,7 +1229,7 @@ async def check_file_is_yaml(res: JsStream) -> None: async def test_user(cli: CLI) -> None: async def execute(cmd: str) -> List[JsonElement]: all_results = await cli.execute_cli_command(cmd, list_sink) - return all_results[0] # type: ignore + return all_results[0] # remove all existing users await cli.dependencies.config_handler.delete_config(UsersConfigId) @@ -1355,10 +1349,9 @@ async def execute(cmd: str, _: Type[T]) -> List[T]: dump = os.path.join(tmp_directory, "dump") async def move_dump(res: JsStream) -> None: - async with res.stream() as streamer: - async for s in streamer: - fp = FilePath.from_path(s) - os.rename(fp.local, dump) + async for s in res: + fp = FilePath.from_path(s) + os.rename(fp.local, dump) # graph export works await cli.execute_cli_command("graph export graphtest dump", move_dump) @@ -1387,28 +1380,27 @@ async def sync_and_check( ) -> Json: result: List[Json] = [] - async def check(in_: JsStream) -> None: - async with in_.stream() as streamer: - async for s in streamer: - assert isinstance(s, dict) - path = FilePath.from_path(s) - # open sqlite database - conn = sqlite3.connect(path.local) - c = conn.cursor() - tables = { - row[0] for row in c.execute("SELECT tbl_name FROM sqlite_master WHERE type='table'").fetchall() - } - if expected_tables is not None: - assert tables == expected_tables - if expected_table_count is not None: - assert len(tables) == expected_table_count - if expected_table is not None: - for table in tables: - count = c.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] - assert expected_table(table, count), f"Table {table} has {count} rows" - c.close() - conn.close() - result.append(s) + async def check(streamer: JsStream) -> None: + async for s in streamer: + assert isinstance(s, dict) + path = FilePath.from_path(s) + # open sqlite database + conn = sqlite3.connect(path.local) + c = conn.cursor() + tables = { + row[0] for row in c.execute("SELECT tbl_name FROM sqlite_master WHERE type='table'").fetchall() + } + if expected_tables is not None: + assert tables == expected_tables + if expected_table_count is not None: + assert len(tables) == expected_table_count + if expected_table is not None: + for table in tables: + count = c.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] + assert expected_table(table, count), f"Table {table} has {count} rows" + c.close() + conn.close() + result.append(s) await cli.execute_cli_command(cmd, check) assert len(result) == 1 diff --git a/fixcore/tests/fixcore/db/graphdb_test.py b/fixcore/tests/fixcore/db/graphdb_test.py index f5bc1b393b..a38a9d768b 100644 --- a/fixcore/tests/fixcore/db/graphdb_test.py +++ b/fixcore/tests/fixcore/db/graphdb_test.py @@ -291,7 +291,7 @@ async def check_usage() -> bool: expected = {"min": 42, "avg": 42, "max": 42} return node_usage == expected - await eventually(check_usage) + await eventually(check_usage, timeout=timedelta(seconds=30)) # exactly the same graph is updated: expect no changes assert await graph_db.merge_graph(create("yes or no"), foo_model) == (p, GraphUpdate(0, 0, 0, 0, 0, 0)) diff --git a/fixcore/tests/fixcore/hypothesis_extension.py b/fixcore/tests/fixcore/hypothesis_extension.py index 297376f6ea..db2cecf6b4 100644 --- a/fixcore/tests/fixcore/hypothesis_extension.py +++ b/fixcore/tests/fixcore/hypothesis_extension.py @@ -1,9 +1,7 @@ import string from datetime import datetime -from typing import TypeVar, Callable, Any, cast, Optional, List, Generator +from typing import TypeVar, Any, cast, Optional, List, Generator -from aiostream import stream -from aiostream.core import Stream from hypothesis.strategies import ( SearchStrategy, just, @@ -20,6 +18,7 @@ from fixcore.model.resolve_in_graph import NodePath from fixcore.types import JsonElement, Json from fixcore.util import value_in_path, interleave +from fixlib.asynchronous.stream import Stream T = TypeVar("T") @@ -71,4 +70,4 @@ def from_node() -> Generator[Json, Any, None]: for from_n, to_n in interleave(node_ids): yield {"type": "edge", "from": from_n, "to": to_n} - return stream.iterate(from_node()) + return Stream.iterate(from_node()) diff --git a/fixcore/tests/fixcore/report/benchmark_renderer_test.py b/fixcore/tests/fixcore/report/benchmark_renderer_test.py index 080740695c..f34c711124 100644 --- a/fixcore/tests/fixcore/report/benchmark_renderer_test.py +++ b/fixcore/tests/fixcore/report/benchmark_renderer_test.py @@ -1,16 +1,16 @@ import pytest -from aiostream import stream from fixcore.report.benchmark_renderer import respond_benchmark_result from fixcore.report.inspector_service import InspectorService from fixcore.ids import GraphName +from fixlib.asynchronous.stream import Stream @pytest.mark.asyncio async def test_benchmark_renderer(inspector_service: InspectorService) -> None: bench_results = await inspector_service.perform_benchmarks(GraphName("ns"), ["test"]) bench_result = bench_results["test"] - render_result = [elem async for elem in respond_benchmark_result(stream.iterate(bench_result.to_graph()))] + render_result = [elem async for elem in respond_benchmark_result(Stream.iterate(bench_result.to_graph()))] assert len(render_result) == 1 assert ( render_result[0] diff --git a/fixcore/tests/fixcore/util_test.py b/fixcore/tests/fixcore/util_test.py index bf64f33d8e..5a688f3366 100644 --- a/fixcore/tests/fixcore/util_test.py +++ b/fixcore/tests/fixcore/util_test.py @@ -5,7 +5,6 @@ import pytest import pytz -from aiostream import stream from fixcore.util import ( AccessJson, @@ -21,6 +20,7 @@ utc_str, parse_utc, ) +from fixlib.asynchronous.stream import Stream def not_in_path(name: str, *other: str) -> bool: @@ -107,17 +107,9 @@ def test_del_value_in_path() -> None: @pytest.mark.asyncio async def test_async_gen() -> None: - async with stream.empty().stream() as empty: - async for _ in await force_gen(empty): - pass - - with pytest.raises(Exception): - async with stream.throw(Exception(";)")).stream() as err: - async for _ in await force_gen(err): - pass - - async with stream.iterate(range(0, 100)).stream() as elems: - assert [x async for x in await force_gen(elems)] == list(range(0, 100)) + async for _ in await force_gen(Stream.empty()): + pass + assert [x async for x in await force_gen(Stream.iterate(range(0, 100)))] == list(range(0, 100)) def test_deep_merge() -> None: diff --git a/fixcore/tests/fixcore/web/content_renderer_test.py b/fixcore/tests/fixcore/web/content_renderer_test.py index 4d3c5c6724..f37cf276e5 100644 --- a/fixcore/tests/fixcore/web/content_renderer_test.py +++ b/fixcore/tests/fixcore/web/content_renderer_test.py @@ -4,7 +4,6 @@ import pytest import yaml -from aiostream import stream from hypothesis import given, settings, HealthCheck from hypothesis.strategies import lists @@ -18,6 +17,7 @@ respond_cytoscape, respond_graphml, ) +from fixlib.asynchronous.stream import Stream from tests.fixcore.hypothesis_extension import ( json_array_gen, json_simple_element_gen, @@ -30,79 +30,72 @@ @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_json(elements: List[JsonElement]) -> None: - async with stream.iterate(elements).stream() as streamer: - result = "" - async for elem in respond_json(streamer): - result += elem - assert json.loads(result) == elements + result = "" + async for elem in respond_json(Stream.iterate(elements)): + result += elem + assert json.loads(result) == elements @given(json_array_gen) @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_ndjson(elements: List[JsonElement]) -> None: - async with stream.iterate(elements).stream() as streamer: - result = [] - async for elem in respond_ndjson(streamer): - result.append(json.loads(elem.strip())) - assert result == elements + result = [] + async for elem in respond_ndjson(Stream.iterate(elements)): + result.append(json.loads(elem.strip())) + assert result == elements @given(json_array_gen) @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_yaml(elements: List[JsonElement]) -> None: - async with stream.iterate(elements).stream() as streamer: - result = "" - async for elem in respond_yaml(streamer): - result += elem + "\n" - assert [a for a in yaml.full_load_all(result)] == elements + result = "" + async for elem in respond_yaml(Stream.iterate(elements)): + result += elem + "\n" + assert [a for a in yaml.full_load_all(result)] == elements @given(lists(json_simple_element_gen, min_size=1, max_size=10)) @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_text_simple_elements(elements: List[JsonElement]) -> None: - async with stream.iterate(elements).stream() as streamer: - result = "" - async for elem in respond_text(streamer): - result += elem + "\n" - # every element is rendered as one or more line (string with \n is rendered as multiple lines) - assert len(elements) + 1 <= len(result.split("\n")) + result = "" + async for elem in respond_text(Stream.iterate(elements)): + result += elem + "\n" + # every element is rendered as one or more line (string with \n is rendered as multiple lines) + assert len(elements) + 1 <= len(result.split("\n")) @given(lists(node_gen(), min_size=1, max_size=10)) @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_text_complex_elements(elements: List[JsonElement]) -> None: - async with stream.iterate(elements).stream() as streamer: - result = "" - async for elem in respond_text(streamer): - result += elem - # every element is rendered as yaml with --- as object deliminator - assert len(elements) == len(result.split("---")) + result = "" + async for elem in respond_text(Stream.iterate(elements)): + result += elem + # every element is rendered as yaml with --- as object deliminator + assert len(elements) == len(result.split("---")) @given(lists(node_gen(), min_size=1, max_size=10)) @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_cytoscape(elements: List[Json]) -> None: - async with graph_stream(elements).stream() as streamer: - result = "" - async for elem in respond_cytoscape(streamer): - result += elem - # The resulting string can be parsed as json - assert json.loads(result) + result = "" + async for elem in respond_cytoscape(Stream.iterate(elements)): + result += elem + # The resulting string can be parsed as json + assert json.loads(result) @given(lists(node_gen(), min_size=1, max_size=10)) @settings(max_examples=20, suppress_health_check=list(HealthCheck), deadline=1000) @pytest.mark.asyncio async def test_graphml(elements: List[Json]) -> None: - async with graph_stream(elements).stream() as streamer: - result = "" - async for elem in respond_graphml(streamer): - result += elem + result = "" + async for elem in respond_graphml(Stream.iterate(elements)): + result += elem # The resulting string can be parsed as xml assert ElementTree.fromstring(result) is not None @@ -119,30 +112,29 @@ def edge(from_node: str, to_node: str) -> Json: nodes = [node("a", "acc1"), node("b", "acc1"), node("c", "acc2")] edges = [edge("a", "b"), edge("a", "c"), edge("b", "c")] - async with stream.iterate(nodes + edges).stream() as streamer: - result = "" - async for elem in respond_dot(streamer): - result += elem + "\n" - expected = ( - "digraph {\n" - "rankdir=LR\n" - "overlap=false\n" - "splines=true\n" - "node [shape=Mrecord colorscheme=paired12]\n" - "edge [arrowsize=0.5]\n" - ' "a" [label="a|a", style=filled fillcolor=1];\n' - ' "b" [label="b|b", style=filled fillcolor=2];\n' - ' "c" [label="c|c", style=filled fillcolor=3];\n' - ' "a" -> "b" [label="delete"]\n' - ' "a" -> "c" [label="delete"]\n' - ' "b" -> "c" [label="delete"]\n' - ' subgraph "acc1" {\n' - ' "a"\n' - ' "b"\n' - " }\n" - ' subgraph "acc2" {\n' - ' "c"\n' - " }\n" - "}\n" - ) - assert result == expected + result = "" + async for elem in respond_dot(Stream.iterate(nodes + edges)): + result += elem + "\n" + expected = ( + "digraph {\n" + "rankdir=LR\n" + "overlap=false\n" + "splines=true\n" + "node [shape=Mrecord colorscheme=paired12]\n" + "edge [arrowsize=0.5]\n" + ' "a" [label="a|a", style=filled fillcolor=1];\n' + ' "b" [label="b|b", style=filled fillcolor=2];\n' + ' "c" [label="c|c", style=filled fillcolor=3];\n' + ' "a" -> "b" [label="delete"]\n' + ' "a" -> "c" [label="delete"]\n' + ' "b" -> "c" [label="delete"]\n' + ' subgraph "acc1" {\n' + ' "a"\n' + ' "b"\n' + " }\n" + ' subgraph "acc2" {\n' + ' "c"\n' + " }\n" + "}\n" + ) + assert result == expected diff --git a/fixlib/fixlib/asynchronous/stream.py b/fixlib/fixlib/asynchronous/stream.py new file mode 100644 index 0000000000..a191b90305 --- /dev/null +++ b/fixlib/fixlib/asynchronous/stream.py @@ -0,0 +1,373 @@ +from __future__ import annotations + +import asyncio +from asyncio import TaskGroup, Task +from collections import deque +from typing import AsyncIterable, AsyncIterator, TypeVar, Optional, List, Dict, Callable, Generic, ParamSpec, TypeAlias +from typing import Iterable, Awaitable, Never, Tuple, Union + +T = TypeVar("T") +R = TypeVar("R", covariant=True) +P = ParamSpec("P") + +DirectOrAwaitable: TypeAlias = Union[T, Awaitable[T]] +IterOrAsyncIter: TypeAlias = Union[Iterable[T], AsyncIterable[T]] + + +def _async_iter(x: Iterable[T]) -> AsyncIterator[T]: + async def gen() -> AsyncIterator[T]: + for item in x: + yield item + + return gen() + + +def _to_async_iter(x: IterOrAsyncIter[T]) -> AsyncIterable[T]: + if isinstance(x, AsyncIterable): + return x + else: + return _async_iter(x) + + +def _flatmap( + source: AsyncIterable[IterOrAsyncIter[DirectOrAwaitable[T]]], + task_limit: Optional[int], + ordered: bool, +) -> AsyncIterator[T]: + if task_limit is None or task_limit == 1: + return _flatmap_direct(source) + elif ordered: + return _flatmap_ordered(source, task_limit) + else: + return _flatmap_unordered(source, task_limit) + + +async def _flatmap_direct( + source: AsyncIterable[IterOrAsyncIter[DirectOrAwaitable[T]]], +) -> AsyncIterator[T]: + async for sub_iter in source: + if isinstance(sub_iter, AsyncIterable): + async for item in sub_iter: + if isinstance(item, Awaitable): + item = await item + yield item + else: + for item in sub_iter: + if isinstance(item, Awaitable): + item = await item + yield item + + +async def _flatmap_unordered( + source: AsyncIterable[IterOrAsyncIter[DirectOrAwaitable[T]]], + task_limit: int, +) -> AsyncIterator[T]: + semaphore = asyncio.Semaphore(task_limit) + queue: asyncio.Queue[T | Exception] = asyncio.Queue() + tasks_in_flight = 0 + ingest_done = False + + async def worker(sub_iter: IterOrAsyncIter[DirectOrAwaitable[T]]) -> None: + nonlocal tasks_in_flight + try: + if isinstance(sub_iter, AsyncIterable): + async for si in sub_iter: + if isinstance(si, Awaitable): + si = await si + await queue.put(si) + else: + for si in sub_iter: + if isinstance(si, Awaitable): + si = await si + await queue.put(si) + except Exception as e: + await queue.put(e) # exception: put it in the queue to be handled + finally: + semaphore.release() + tasks_in_flight -= 1 + + async with TaskGroup() as tg: + + async def ingest_tasks() -> None: + nonlocal tasks_in_flight, ingest_done + # Start worker tasks + async for src in source: + await semaphore.acquire() + tg.create_task(worker(src)) + tasks_in_flight += 1 + ingest_done = True + + # Consume items from the queue and yield them + tg.create_task(ingest_tasks()) + while True: + if ingest_done and tasks_in_flight == 0 and queue.empty(): + break + try: + item = await queue.get() + if isinstance(item, Exception): + raise item + yield item + except asyncio.CancelledError: + break + + +async def _flatmap_ordered( + source: AsyncIterable[IterOrAsyncIter[DirectOrAwaitable[T]]], + task_limit: int, +) -> AsyncIterator[T]: + semaphore = asyncio.Semaphore(task_limit) + tasks: Dict[int, Task[None]] = {} + results: Dict[int, List[T] | Exception] = {} + next_index_to_yield = 0 + source_iter = aiter(source) + max_index_started = -1 # Highest index of tasks started + source_exhausted = False + + async def worker(sub_iter: IterOrAsyncIter[T | Awaitable[T]], index: int) -> None: + items = [] + try: + if isinstance(sub_iter, AsyncIterable): + async for item in sub_iter: + if isinstance(item, Awaitable): + item = await item + items.append(item) + else: + for item in sub_iter: + if isinstance(item, Awaitable): + item = await item + items.append(item) + results[index] = items + except Exception as e: + results[index] = e # Store exception to be raised later + finally: + semaphore.release() + + async with TaskGroup() as tg: + while True: + # Start new tasks up to task_limit ahead of next_index_to_yield + while (not source_exhausted) and (max_index_started - next_index_to_yield + 1) < task_limit: + try: + await semaphore.acquire() + si = await anext(source_iter) + max_index_started += 1 + tasks[max_index_started] = tg.create_task(worker(_to_async_iter(si), max_index_started)) + except StopAsyncIteration: + source_exhausted = True + break + + if next_index_to_yield in results: + result = results.pop(next_index_to_yield) + if isinstance(result, Exception): + raise result + else: + for res in result: + yield res + # Remove completed task + tasks.pop(next_index_to_yield, None) # noqa + next_index_to_yield += 1 + else: + # Wait for the next task to complete + if next_index_to_yield in tasks: + task = tasks[next_index_to_yield] + await asyncio.wait({task}) + elif not tasks and source_exhausted: + # No more tasks to process + break + else: + # Yield control to the event loop + await asyncio.sleep(0.01) + + +class Stream(Generic[T], AsyncIterator[T]): + def __init__(self, iterator: AsyncIterator[T]): + self.iterator = iterator + + def __aiter__(self) -> AsyncIterator[T]: + return self + + async def __anext__(self) -> T: + return await anext(self.iterator) + + def filter(self, fn: Callable[[T], DirectOrAwaitable[bool]]) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + async for item in self: + af = fn(item) + flag = await af if isinstance(af, Awaitable) else af + if flag: + yield item + + return Stream(gen()) + + def starmap( + self, + fn: Callable[..., DirectOrAwaitable[R]], + task_limit: Optional[int] = None, + ordered: bool = True, + ) -> Stream[R]: + return self.map(lambda args: fn(*args), task_limit, ordered) # type: ignore + + def map( + self, + fn: Callable[[T], DirectOrAwaitable[R]], + task_limit: Optional[int] = None, + ordered: bool = True, + ) -> Stream[R]: + async def gen() -> AsyncIterator[IterOrAsyncIter[DirectOrAwaitable[R]]]: + async for item in self: + res = fn(item) + yield [res] + + # in the case of a synchronous function, task_limit is ignored + task_limit = task_limit if asyncio.iscoroutinefunction(fn) else 1 + return Stream(_flatmap(gen(), task_limit, ordered)) + + def flatmap( + self, + fn: Callable[[T], DirectOrAwaitable[IterOrAsyncIter[DirectOrAwaitable[R]]]], + task_limit: Optional[int] = None, + ordered: bool = True, + ) -> Stream[R]: + async def gen() -> AsyncIterator[IterOrAsyncIter[DirectOrAwaitable[R]]]: + async for item in self: + res = fn(item) + if isinstance(res, Awaitable): + res = await res + yield res + + # in the case of a synchronous function, task_limit is ignored + task_limit = task_limit if asyncio.iscoroutinefunction(fn) else 1 + return Stream(_flatmap(gen(), task_limit, ordered)) + + def concat(self: Stream[Stream[T]], task_limit: Optional[int] = None, ordered: bool = True) -> Stream[T]: + return self.flatmap(lambda x: x, task_limit, ordered) + + def skip(self, num: int) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + count = 0 + async for item in self: + if count < num: + count += 1 + continue + yield item + + return Stream(gen()) + + def take(self, num: int) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + count = 0 + async for item in self: + if count >= num: + break + yield item + count += 1 + + return Stream(gen()) + + def take_last(self, num: int) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + queue: deque[T] = deque(maxlen=num) + async for item in self: + queue.append(item) + for item in queue: + yield item + + return Stream(gen()) + + def enumerate(self) -> Stream[Tuple[int, T]]: + async def gen() -> AsyncIterator[Tuple[int, T]]: + i = 0 + async for item in self: + yield i, item + i += 1 + + return Stream(gen()) + + def chunks(self, num: int) -> Stream[List[T]]: + async def gen() -> AsyncIterator[List[T]]: + while True: + chunk_items: List[T] = [] + try: + for _ in range(num): + item = await anext(self.iterator) + chunk_items.append(item) + yield chunk_items + except StopAsyncIteration: + if chunk_items: + yield chunk_items + break + + return Stream(gen()) + + def flatten(self) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + async for item in self: + if isinstance(item, AsyncIterator) or hasattr(item, "__aiter__"): + async for subitem in item: + yield subitem + elif isinstance(item, Iterable): + for subitem in item: + yield subitem + else: + yield item + + return Stream(gen()) + + async def collect(self) -> List[T]: + return [item async for item in self] + + @staticmethod + def just(x: T | Awaitable[T]) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + if isinstance(x, Awaitable): + yield await x + else: + yield x + + return Stream(gen()) + + @staticmethod + def iterate(x: Iterable[T] | AsyncIterable[T] | AsyncIterator[T]) -> Stream[T]: + if isinstance(x, AsyncIterator): + return Stream(x) + elif isinstance(x, AsyncIterable): + return Stream(aiter(x)) + else: + return Stream(_async_iter(x)) + + @staticmethod + def empty() -> Stream[T]: + async def empty() -> AsyncIterator[Never]: + if False: + yield # noqa + + return Stream(empty()) + + @staticmethod + def for_ever(fn: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Stream[T]: + async def gen() -> AsyncIterator[T]: + while True: + if asyncio.iscoroutinefunction(fn): + yield await fn(*args, **kwargs) + else: + yield fn(*args, **kwargs) # type: ignore + + return Stream(gen()) + + @staticmethod + def call(fn: Callable[P, Awaitable[R]] | Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Stream[R]: + async def gen() -> AsyncIterator[R]: + if asyncio.iscoroutinefunction(fn): + yield await fn(*args, **kwargs) + else: + yield fn(*args, **kwargs) # type: ignore + + return Stream(gen()) + + @staticmethod + async def as_list(x: Iterable[T] | AsyncIterable[T] | AsyncIterator[T]) -> List[T]: + if isinstance(x, AsyncIterator): + return [item async for item in x] + elif isinstance(x, AsyncIterable): + return [item async for item in aiter(x)] + else: + return [item for item in x] diff --git a/fixlib/test/asynchronous/stream_test.py b/fixlib/test/asynchronous/stream_test.py new file mode 100644 index 0000000000..342f7ad834 --- /dev/null +++ b/fixlib/test/asynchronous/stream_test.py @@ -0,0 +1,123 @@ +import asyncio +from typing import AsyncIterator, Iterator + +from fixlib.asynchronous.stream import Stream + + +async def example_gen() -> AsyncIterator[int]: + for i in range(5, 0, -1): + yield i + + +def example_stream() -> Stream: + return Stream(example_gen()) + + +async def test_just() -> None: + assert await Stream.just(1).collect() == [1] + + +async def test_iterate() -> None: + assert await Stream.iterate([1, 2, 3]).collect() == [1, 2, 3] + assert await Stream.iterate(example_gen()).collect() == [5, 4, 3, 2, 1] + assert await Stream.iterate(example_stream()).collect() == [5, 4, 3, 2, 1] + + +async def test_filter() -> None: + assert await example_stream().filter(lambda x: x % 2).collect() == [5, 3, 1] + assert await example_stream().filter(lambda x: x is None).collect() == [] + assert await example_stream().filter(lambda x: True).collect() == [5, 4, 3, 2, 1] + + +async def test_map() -> None: + invoked = 0 + max_invoked = 0 + + def sync_fn(x: int) -> int: + return x * 2 + + async def async_fn(x: int) -> int: + await asyncio.sleep(x / 100) + return x * 2 + + async def count_invoked_fn(x: int) -> int: + nonlocal invoked, max_invoked + invoked += 1 + await asyncio.sleep(0.003) + max_invoked = max(max_invoked, invoked) + await asyncio.sleep(0.003) + invoked -= 1 + return x + + assert await example_stream().map(lambda x: x * 2).collect() == [10, 8, 6, 4, 2] + assert await example_stream().map(sync_fn).collect() == [10, 8, 6, 4, 2] + assert await example_stream().map(async_fn).collect() == [10, 8, 6, 4, 2] + # The function will wait depending on the streamed value. + # Since we start from biggest to smallest, the result should be reversed + # High chance of being flaky, since it relies on timing. + assert await example_stream().map(async_fn, task_limit=100, ordered=False).collect() == [2, 4, 6, 8, 10] + # All items are processed in parallel, while the order is preserved. + assert await example_stream().map(async_fn, task_limit=100, ordered=True).collect() == [10, 8, 6, 4, 2] + # Make sure all items are processed in parallel. + max_invoked = invoked = 0 + assert await example_stream().map(count_invoked_fn, task_limit=100, ordered=False).collect() + assert max_invoked == 5 + # Limit the number of parallel tasks to 2. + max_invoked = invoked = 0 + assert await example_stream().map(count_invoked_fn, task_limit=2, ordered=False).collect() + assert max_invoked == 2 + # Make sure all items are processed in parallel. + max_invoked = invoked = 0 + assert await example_stream().map(count_invoked_fn, task_limit=100, ordered=True).collect() + assert max_invoked == 5 + # Limit the number of parallel tasks to 2. + max_invoked = invoked = 0 + assert await example_stream().map(count_invoked_fn, task_limit=2, ordered=True).collect() + assert max_invoked == 2 + + +async def test_flatmap() -> None: + def sync_gen(x: int) -> Iterator[int]: + for i in range(2): + yield x * 2 + + async def async_gen(x: int) -> AsyncIterator[int]: + await asyncio.sleep(0) + for i in range(2): + yield x * 2 + + assert await example_stream().flatmap(sync_gen).collect() == [10, 10, 8, 8, 6, 6, 4, 4, 2, 2] + assert await example_stream().flatmap(async_gen).collect() == [10, 10, 8, 8, 6, 6, 4, 4, 2, 2] + assert await Stream.empty().flatmap(sync_gen).collect() == [] + assert await Stream.empty().flatmap(async_gen).collect() == [] + assert await Stream.iterate([]).flatmap(sync_gen).collect() == [] + assert await Stream.iterate([]).flatmap(async_gen).collect() == [] + + +async def test_take() -> None: + assert await example_stream().take(3).collect() == [5, 4, 3] + + +async def test_take_last() -> None: + assert await example_stream().take_last(3).collect() == [3, 2, 1] + + +async def test_skip() -> None: + assert await example_stream().skip(2).collect() == [3, 2, 1] + assert await example_stream().skip(10).collect() == [] + + +async def test_call() -> None: + def fn(foo: int, bla: str) -> int: + return 123 + + def with_int(foo: int) -> int: + return foo + 1 + + assert await Stream.call(fn, 1, "bla").map(with_int).collect() == [124] + + +async def test_chunks() -> None: + assert len([chunk async for chunk in example_stream().chunks(2)]) == 3 + assert [chunk async for chunk in example_stream().chunks(2)] == await example_stream().chunks(2).collect() + assert await example_stream().chunks(2).collect() == [[5, 4], [3, 2], [1]] diff --git a/fixshell/.pylintrc b/fixshell/.pylintrc index fd1655a604..b2bce42c1c 100644 --- a/fixshell/.pylintrc +++ b/fixshell/.pylintrc @@ -245,7 +245,7 @@ ignored-modules= # List of classes names for which member attributes should not be checked # (useful for classes with attributes dynamically set). -ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local, aiostream.pipe +ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular diff --git a/requirements-all.txt b/requirements-all.txt index 50791c998d..df9cb7a7f1 100644 --- a/requirements-all.txt +++ b/requirements-all.txt @@ -5,7 +5,6 @@ aiohttp[speedups]==3.10.10 aiohttp-jinja2==1.6 aiohttp-swagger3==0.9.0 aiosignal==1.3.1 -aiostream==0.6.3 appdirs==1.4.4 apscheduler==3.10.4 asn1crypto==1.5.1 diff --git a/requirements-extra.txt b/requirements-extra.txt index 98cebdde10..551a443508 100644 --- a/requirements-extra.txt +++ b/requirements-extra.txt @@ -5,7 +5,6 @@ aiohttp[speedups]==3.10.10 aiohttp-jinja2==1.6 aiohttp-swagger3==0.9.0 aiosignal==1.3.1 -aiostream==0.6.3 appdirs==1.4.4 apscheduler==3.10.4 asn1crypto==1.5.1 diff --git a/requirements.txt b/requirements.txt index 4b37f2983f..1af1fd3cf8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ aiohttp[speedups]==3.10.10 aiohttp-jinja2==1.6 aiohttp-swagger3==0.9.0 aiosignal==1.3.1 -aiostream==0.6.3 appdirs==1.4.4 apscheduler==3.10.4 attrs==24.2.0