Skip to content

Commit

Permalink
completely remove aiostream
Browse files Browse the repository at this point in the history
  • Loading branch information
aquamatthias committed Nov 7, 2024
1 parent 3953c3b commit 881de7f
Show file tree
Hide file tree
Showing 20 changed files with 200 additions and 238 deletions.
2 changes: 1 addition & 1 deletion fixcore/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ ignored-modules=

# List of classes names for which member attributes should not be checked
# (useful for classes with attributes dynamically set).
ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local, aiostream.pipe
ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local

# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
Expand Down
34 changes: 17 additions & 17 deletions fixcore/fixcore/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
from typing import Dict, List, Tuple, Union, Sequence
from typing import Optional, Any, TYPE_CHECKING

from aiostream import stream
from attrs import evolve
from parsy import Parser
from rich.padding import Padding

from fixcore import version
from fixcore.analytics import CoreEvent
from fixcore.cli import cmd_with_args_parser, key_values_parser, T, Sink, args_values_parser, JsGen
from fixcore.cli import cmd_with_args_parser, key_values_parser, T, Sink, args_values_parser, JsStream
from fixcore.cli.command import (
SearchPart,
PredecessorsPart,
Expand Down Expand Up @@ -78,6 +77,7 @@
from fixcore.types import JsonElement
from fixcore.user.model import Permission
from fixcore.util import group_by
from fixlib.asynchronous.stream import Stream
from fixlib.parse_util import make_parser, pipe_p, semicolon_p

if TYPE_CHECKING:
Expand All @@ -104,7 +104,7 @@ def command_line_parser() -> Parser:
return ParsedCommands(commands, maybe_env if maybe_env else {})


# multiple piped commands are separated by semicolon
# semicolon separates multiple piped commands
multi_command_parser = command_line_parser.sep_by(semicolon_p)


Expand Down Expand Up @@ -187,7 +187,7 @@ def overview() -> str:
logo = ctx.render_console(Padding(WelcomeCommand.ck, pad=(0, 0, 0, middle))) if ctx.supports_color() else ""
return headline + logo + ctx.render_console(result)

def help_command() -> JsGen:
def help_command() -> JsStream:
if not arg:
result = overview()
elif arg == "placeholders":
Expand All @@ -209,7 +209,7 @@ def help_command() -> JsGen:
else:
result = f"No command found with this name: {arg}"

return stream.just(result)
return Stream.just(result)

return CLISource.single(help_command, required_permissions={Permission.read})

Expand Down Expand Up @@ -352,11 +352,11 @@ def command(
self, name: str, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any
) -> ExecutableCommand:
"""
Create an executable command for given command name, args and context.
:param name: the name of the command to execute (must be a known command)
:param arg: the arg of the command (must be parsable by the command)
:param ctx: the context of this command.
:return: the ready to run executable command.
Create an executable command for given command name, args, and context.
:param name: The name of the command to execute (must be a known command).
:param arg: The arg of the command (must be parsable by the command).
:param ctx: The context of this command.
:return: The ready to run executable command.
:raises:
CLIParseError: if the name of the command is not known, or the argument fails to parse.
"""
Expand All @@ -377,9 +377,9 @@ async def create_query(
Takes a list of query part commands and combine them to a single executable query command.
This process can also introduce new commands that should run after the query is finished.
Therefore, a list of executable commands is returned.
:param commands: the incoming executable commands, which actions are all instances of SearchCLIPart.
:param ctx: the context to execute within.
:return: the resulting list of commands to execute.
:param commands: The incoming executable commands, which actions are all instances of SearchCLIPart.
:param ctx: The context to execute within.
:return: The resulting list of commands to execute.
"""

# Pass parsed options to execute query
Expand Down Expand Up @@ -484,8 +484,8 @@ async def parse_query(query_arg: str) -> Query:
first_head_tail_in_a_row = None
head_tail_keep_order = True

# Define default sort order, if not already defined
# A sort order is required to always return the result in a deterministic way to the user.
# Define default sort order, if not already defined.
# A sort order is required to always return the result deterministically to the user.
# Deterministic order is required for head/tail to work
if query.is_simple_fulltext_search():
# Do not define any additional sort order for fulltext searches
Expand All @@ -494,7 +494,7 @@ async def parse_query(query_arg: str) -> Query:
parts = [pt if pt.sort else evolve(pt, sort=default_sort) for pt in query.parts]
query = evolve(query, parts=parts)

# If the last part is a navigation, we need to add sort which will ingest a new part.
# If the last part is a navigation, we need to add a sort which will ingest a new part.
with_sort = query.set_sort(*default_sort) if query.current_part.navigation else query
section = ctx.env.get("section", PathRoot)
# If this is an aggregate query, the default sort needs to be changed
Expand Down Expand Up @@ -534,7 +534,7 @@ def rewrite_command_line(cmds: List[ExecutableCommand], ctx: CLIContext) -> List
Rules:
- add the list command if no output format is defined
- add a format to write commands if no output format is defined
- report benchmark run will be formatted as benchmark result automatically
- report benchmark run will be formatted as a benchmark result automatically
"""
if ctx.env.get("no_rewrite") or len(cmds) == 0:
return cmds
Expand Down
22 changes: 11 additions & 11 deletions fixcore/fixcore/cli/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,7 +1593,7 @@ def args_info(self) -> ArgsInfo:
def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwargs: Any) -> CLIFlow:
size = int(arg) if arg else 100
return CLIFlow(
lambda in_stream: Stream(in_stream).chunks(size).map(Stream.as_list),
lambda in_stream: Stream(in_stream).chunks(size),
required_permissions={Permission.read},
)

Expand Down Expand Up @@ -1978,12 +1978,12 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa
return CLIFlow(lambda i: Stream(i).chunks(buffer_size).flatmap(func), required_permissions={Permission.write})

async def set_desired(
self, arg: Optional[str], graph_name: GraphName, patch: Json, items: Stream[Json]
self, arg: Optional[str], graph_name: GraphName, patch: Json, items: List[Json]
) -> AsyncIterator[JsonElement]:
model = await self.dependencies.model_handler.load_model(graph_name)
db = self.dependencies.db_access.get_graph_db(graph_name)
node_ids = []
async for item in items:
for item in items:
if "id" in item:
node_ids.append(item["id"])
elif isinstance(item, str):
Expand Down Expand Up @@ -2090,7 +2090,7 @@ def patch(self, arg: Optional[str], ctx: CLIContext) -> Json:
return {"clean": True}

async def set_desired(
self, arg: Optional[str], graph_name: GraphName, patch: Json, items: Stream[Json]
self, arg: Optional[str], graph_name: GraphName, patch: Json, items: List[Json]
) -> AsyncIterator[JsonElement]:
reason = f"Reason: {strip_quotes(arg)}" if arg else "No reason provided."
async for elem in super().set_desired(arg, graph_name, patch, items):
Expand All @@ -2113,11 +2113,11 @@ def parse(self, arg: Optional[str] = None, ctx: CLIContext = EmptyContext, **kwa
func = partial(self.set_metadata, ctx.graph_name, self.patch(arg, ctx))
return CLIFlow(lambda i: Stream(i).chunks(buffer_size).flatmap(func), required_permissions={Permission.write})

async def set_metadata(self, graph_name: GraphName, patch: Json, items: Stream[Json]) -> AsyncIterator[JsonElement]:
async def set_metadata(self, graph_name: GraphName, patch: Json, items: List[Json]) -> AsyncIterator[JsonElement]:
model = await self.dependencies.model_handler.load_model(graph_name)
db = self.dependencies.db_access.get_graph_db(graph_name)
node_ids = []
async for item in items:
for item in items:
if "id" in item:
node_ids.append(item["id"])
elif isinstance(item, str):
Expand Down Expand Up @@ -2864,7 +2864,7 @@ def extract_values(elem: JsonElement) -> List[Any | None]:
result.append(value)
return result

async def generate_markdown(chunk: Tuple[int, Stream[List[Any]]]) -> JsGen:
async def generate_markdown(chunk: Tuple[int, List[List[Any]]]) -> JsGen:
idx, rows = chunk

def to_str(elem: Any) -> str:
Expand Down Expand Up @@ -2896,7 +2896,7 @@ def to_str(elem: Any) -> str:
line += "|"
yield line

async for row in rows:
for row in rows:
line = ""
for value, padding in zip(row, columns_padding):
line += f"|{to_str(value).ljust(padding)}"
Expand Down Expand Up @@ -3260,12 +3260,12 @@ def load_by_id_merged(
expected_kind: Optional[str] = None,
**env: str,
) -> JsStream:
async def load_element(items: JsStream) -> AsyncIterator[JsonElement]:
async def load_element(items: List[JsonElement]) -> AsyncIterator[JsonElement]:
# collect ids either from json dict or string
ids: List[str] = [i["id"] if is_node(i) else i async for i in items] # type: ignore
ids: List[str] = [i["id"] if is_node(i) else i for i in items] # type: ignore
# if there is an entry which is not a string, use the list as is (e.g. chunked)
if any(a for a in ids if not isinstance(a, str)):
async for a in items:
for a in items:
yield a
else:
# one query to load all items that match given ids (max 1000 as defined in chunk size)
Expand Down
7 changes: 3 additions & 4 deletions fixcore/fixcore/db/graphdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Union,
)

from aiostream import stream, pipe
from arango import AnalyzerGetError
from arango.collection import VertexCollection, StandardCollection, EdgeCollection
from arango.graph import Graph
Expand Down Expand Up @@ -67,6 +66,7 @@
set_value_in_path,
if_set,
)
from fixlib.asynchronous.stream import Stream

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -675,9 +675,8 @@ async def move_security_temp_to_proper() -> None:

try:
# stream updates to the temp collection
async with (stream.iterate(iterator) | pipe.chunks(1000)).stream() as streamer:
async for part in streamer:
await update_chunk(dict(part))
async for part in Stream.iterate(iterator).chunks(1000):
await update_chunk(dict(part))
# move temp collection to proper and history collection
await move_security_temp_to_proper()
finally:
Expand Down
2 changes: 1 addition & 1 deletion fixcore/fixcore/infra_apps/local_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,4 @@ async def _interpret_line(self, line: str, ctx: CLIContext) -> JsStream:
total_nr_outputs = total_nr_outputs + (src_ctx.count or 0)
command_streams.append(command_output_stream)

return Stream.iterate(command_streams).concat(task_limit=1) # type: ignore
return Stream.iterate(command_streams).concat() # type: ignore
54 changes: 23 additions & 31 deletions fixcore/fixcore/model/db_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
from multiprocessing import Process, Queue
from pathlib import Path
from queue import Empty
from typing import Optional, Union, Any, Generator, List, AsyncIterator, Dict
from typing import Optional, Union, Any, List, AsyncIterator, Dict

import aiofiles
from aiostream import stream, pipe
from aiostream.core import Stream
from attrs import define

from fixcore.analytics import AnalyticsEventSender, InMemoryEventSender, AnalyticsEvent
Expand All @@ -36,6 +34,7 @@
from fixcore.system_start import db_access, setup_process, reset_process_start_method
from fixcore.types import Json
from fixcore.util import utc, uuid_str, shutdown_process
from fixlib.asynchronous.stream import Stream

log = logging.getLogger(__name__)

Expand All @@ -56,9 +55,9 @@ class ReadFile(ProcessAction):
path: Path
task_id: Optional[str]

def jsons(self) -> Generator[Json, Any, None]:
with open(self.path, "r", encoding="utf-8") as f:
for line in f:
async def jsons(self) -> AsyncIterator[Json]:
async with aiofiles.open(self.path, "r", encoding="utf-8") as f:
async for line in f:
if line.strip():
yield json.loads(line)

Expand All @@ -75,8 +74,8 @@ class ReadElement(ProcessAction):
elements: List[Union[bytes, Json]]
task_id: Optional[str]

def jsons(self) -> Generator[Json, Any, None]:
return (e if isinstance(e, dict) else json.loads(e) for e in self.elements)
def jsons(self) -> AsyncIterator[Json]:
return Stream.iterate(self.elements).map(lambda e: e if isinstance(e, dict) else json.loads(e))


@define
Expand Down Expand Up @@ -125,15 +124,15 @@ def get_value(self) -> GraphUpdate:

class DbUpdaterProcess(Process):
"""
This update class implements Process and is supposed to run as separate process.
This update class implements Process and is supposed to run as a separate process.
Note: default starting method is supposed to be "spawn".
This process has 2 queues to read input from and write output to.
All elements in either queues are of type ProcessAction.
All elements in all queues are of type ProcessAction.
The parent process should stream the raw commands of graph to this process via ReadElement objects.
Once the MergeGraph action is received, the graph gets imported.
From here the parent expects result messages from the child.
From here, the parent expects result messages from the child.
All events happen in the child are forwarded to the parent via EmitEvent.
Once the graph update is done, a result is send.
The result is either an exception in case of failure or a graph update in success case.
Expand All @@ -156,8 +155,8 @@ def __init__(

def next_action(self) -> ProcessAction:
try:
# graph is read into memory. If the sender does not send data in a given amount of time,
# we raise an exception and abort the update.
# The graph is read into memory.
# If the sender does not send data in a given amount of time, we raise an exception and abort the update.
return self.read_queue.get(True, 90)
except Empty as ex:
raise ImportAborted("Merge process did not receive any data for more than 90 seconds. Abort.") from ex
Expand All @@ -168,12 +167,12 @@ async def merge_graph(self, db: DbAccess) -> GraphUpdate: # type: ignore
builder = GraphBuilder(model, self.change_id)
nxt = self.next_action()
if isinstance(nxt, ReadFile):
for element in nxt.jsons():
async for element in nxt.jsons():
builder.add_from_json(element)
nxt = self.next_action()
elif isinstance(nxt, ReadElement):
while isinstance(nxt, ReadElement):
for element in nxt.jsons():
async for element in nxt.jsons():
builder.add_from_json(element)
log.debug(f"Read {int(BatchSize / 1000)}K elements in process")
nxt = self.next_action()
Expand Down Expand Up @@ -276,16 +275,11 @@ async def __process_item(self, item: GraphUpdateTask) -> Union[GraphUpdate, Exce
async def start(self) -> None:
async def wait_for_update() -> None:
log.info("Start waiting for graph updates")
fl = (
stream.call(self.update_queue.get) # type: ignore
| pipe.cycle()
| pipe.map(self.__process_item, task_limit=self.config.graph.parallel_imports) # type: ignore
)
fl = Stream.for_ever(self.update_queue.get).map(self.__process_item, task_limit=self.config.graph.parallel_imports) # type: ignore # noqa
with suppress(CancelledError):
async with fl.stream() as streamer:
async for update in streamer:
if isinstance(update, GraphUpdate):
log.info(f"Finished spawned graph merge: {update}")
async for update in fl:
if isinstance(update, GraphUpdate):
log.info(f"Finished spawned graph merge: {update}")

self.handler_task = asyncio.create_task(wait_for_update())

Expand Down Expand Up @@ -373,19 +367,17 @@ async def read_forever() -> GraphUpdate:
task: Optional[Task[GraphUpdate]] = None
result: Optional[GraphUpdate] = None
try:
reset_process_start_method() # other libraries might have tampered the value in the mean time
reset_process_start_method() # other libraries might have tampered the value in the meantime
updater.start()
task = read_results() # concurrently read result queue
# Either send a file or stream the content directly
if isinstance(content, Path):
await send_to_child(ReadFile(content, task_id))
else:
chunked: Stream[List[Union[bytes, Json]]] = stream.chunks(content, BatchSize) # type: ignore
async with chunked.stream() as streamer:
async for lines in streamer:
if not await send_to_child(ReadElement(lines, task_id)):
# in case the child is dead, we should stop
break
async for lines in Stream.iterate(content).chunks(BatchSize):
if not await send_to_child(ReadElement(lines, task_id)):
# in case the child is dead, we should stop
break
await send_to_child(MergeGraph(db.name, change_id, maybe_batch is not None, task_id))
result = await task # wait for final result
await self.model_handler.load_model(db.name, force=True) # reload model to get the latest changes
Expand Down
Loading

0 comments on commit 881de7f

Please sign in to comment.