diff --git a/per-message-s3-exporter/firehose_reader.py b/per-message-s3-exporter/firehose_reader.py index b398864..990381b 100644 --- a/per-message-s3-exporter/firehose_reader.py +++ b/per-message-s3-exporter/firehose_reader.py @@ -115,19 +115,36 @@ def from_env(cls): keepalive = int(os.environ.get("KEEPALIVE", "60")) keepalive_stale_pitrs = int(os.environ.get("KEEPALIVE_STALE_PITRS", "5")) init_time = os.environ.get("INIT_CMD_TIME", "live") + init_time_split = init_time.split() - if init_time.split()[0] not in ["live", "pitr"]: - raise ValueError('$INIT_CMD_TIME value is invalid, should be "live" or "pitr "') + if init_time_split[0] not in ("live", "pitr", "range"): + raise ValueError( + '$INIT_CMD_TIME value is invalid, should be "live", ' + '"pitr " or "range "' + ) pitr_map = pitr_map_from_file(init_time) if pitr_map: min_pitr = min(pitr_map.values()) logging.info(f"Based on PITR map {pitr_map}") logging.info(f"Using {min_pitr} ({format_epoch(min_pitr)}) as starting PITR value") - init_time = f"pitr {min_pitr}" + + if "pitr" in init_time: + init_time = f"pitr {min_pitr}" + elif "range" in init_time: + init_time_split[1] = f"{min_pitr}" + init_time = " ".join(init_time_split) init_args = os.environ.get("INIT_CMD_ARGS", "") - for command in ["live", "pitr", "compression", "keepalive", "username", "password"]: + for command in [ + "live", + "pitr", + "range", + "compression", + "keepalive", + "username", + "password", + ]: if command in init_args.split(): raise ValueError( f'$INIT_CMD_ARGS should not contain the "{command}" command. ' @@ -287,6 +304,13 @@ def connection_error_limit(self) -> int: """How many Firehose read errors before stopping""" return int(os.environ.get("CONNECTION_ERROR_LIMIT", "3")) + async def _shutdown(self): + """When a range of PITRs is requested, we send a special shutdown + message to every queue to end cleanly""" + logging.info("Initiating shutdown procedure: propagating signal to per-message queues") + for queue in self.message_queues.values(): + await queue.put(None) + async def read_firehose(self): """Read Firehose until a threshold number of errors occurs""" await self._stats.update_stats(None, 0) @@ -296,11 +320,23 @@ async def read_firehose(self): errors = 0 time_mode = self.config.init_time + reached_the_end = False while True: pitr = await self._read_until_error(time_mode) if pitr: - time_mode = f"pitr {pitr}" + time_mode_split = time_mode.split() + if time_mode_split[0] in ("live", "pitr"): + time_mode = f"pitr {pitr}" + else: + if pitr >= int(time_mode_split[-1]): + logging.info("Reached the end of the range") + reached_the_end = True + break + + time_mode_split[1] = f"{pitr}" + time_mode = " ".join(time_mode_split) + logging.info(f'Reconnecting with "{time_mode}"') errors = 0 elif errors < error_limit - 1: @@ -319,6 +355,9 @@ async def read_firehose(self): self._stats.finish() await asyncio.wait_for(stats_task, self.stats_period) + if reached_the_end: + return await self._shutdown() + raise ReadFirehoseErrorThreshold async def _open_connection( @@ -361,6 +400,10 @@ async def _read_until_error(self, time_mode: str) -> Optional[str]: time_mode may be either the string "live" or a pitr string that looks like "pitr " where is a value previously returned by this function + or a pitr string that looks like "range ". In the case of + a range, if we get to the end value we stop cleanly and shutdown, + otherwise we can resume from a start value previously returned by this + function. """ context = ssl.create_default_context() diff --git a/per-message-s3-exporter/main.py b/per-message-s3-exporter/main.py index 8bc1e27..8c2025f 100644 --- a/per-message-s3-exporter/main.py +++ b/per-message-s3-exporter/main.py @@ -14,7 +14,7 @@ from pathlib import Path from signal import Signals, SIGINT, SIGTERM import sys -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import aiofiles import attr @@ -249,25 +249,31 @@ def folder_prefix(self) -> str: return f"{self.args.s3_bucket_folder}/" - def record_pitr(self, record: str) -> int: + def record_pitr(self, record: Optional[str]) -> int: """Return the PITR for a Firehose message""" + if record is None: + record = self._current_batch[-1] return int(json.loads(record)["pitr"]) - async def ingest_record(self, record: str): + async def ingest_record(self, record: Optional[str]): """Ingest a record from Firehose, adding it to the current batch and writing a file to the S3 writer queue if necessary """ - if not self._current_batch: - self._start_pitr = self.record_pitr(record) - self._timer.start() + # An empty line is a signal that we need to shutdown + shutdown = record is None # Even though we might exceed the max bytes, that's okay since it's # only a rough threshold that we strive to maintain here and the number # of records is more strictly adhered to - self._current_batch.append(record) - self._current_batch_bytes += len(record) + if not shutdown: + if not self._current_batch: + self._start_pitr = self.record_pitr(record) + self._timer.start() - if self.should_write_batch_to_file(): + self._current_batch.append(record) + self._current_batch_bytes += len(record) + + if (shutdown and self.batch_length > 0) or self.should_write_batch_to_file(): self._end_pitr = self.record_pitr(record) await self.enqueue_batch_contents() @@ -277,6 +283,11 @@ async def ingest_record(self, record: str): self._current_batch = [] self._current_batch_bytes = 0 + # Propagate shutdown signal + if shutdown: + logging.info(f"Shutting down {self.message_type} queue: sending signal to S3 writer") + await self.s3_writer_queue.put(None) + def should_write_batch_to_file(self) -> bool: """Whether the current batch needs to be written to an S3 file In order to see less common message types, the bytes hit will be @@ -291,6 +302,10 @@ def should_write_batch_to_file(self) -> bool: async def enqueue_batch_contents(self): """Write the current batch of records to the S3 writer's queue""" + if self.batch_length == 0: + logging.warning(f"Current batch for {self.message_type} is empty, skipping") + return + filename = self.batch_filename() file_contents = b"".join(self._current_batch) @@ -304,6 +319,7 @@ async def enqueue_batch_contents(self): end_pitr=self._end_pitr, ) + logging.info(f"Writing a batch to the S3 writer for {self.message_type}") await self.s3_writer_queue.put(s3_object) def _s3_bucket_folder(self) -> str: @@ -344,10 +360,15 @@ async def build_batch_of_records_from_firehose( while True: # Use a "blocking" await on the queue with Firehose messages which will # wait indefinitely until data shows up in the queue - firehose_message = await firehose_queue.get() + firehose_message: Optional[str] = await firehose_queue.get() + await batcher.ingest_record(firehose_message) firehose_queue.task_done() + # We've reached the end of the PITR range and need to shutdown + if firehose_message is None: + break + async def load_pitr_map(pitr_map_path: Path) -> Dict[str, int]: """Load the PITR map from disk if available. Returns an empty dict if @@ -384,8 +405,19 @@ async def write_files_to_s3( pitr_map: Dict[str, int] = await load_pitr_map(args.pitr_map) pitr_map = {message_type: int(pitr) for message_type, pitr in pitr_map.items()} - while True: - s3_write_object: S3WriteObject = await s3_queue.get() + # Keep track of how many shutdown signals we receive for the case where + # we're only ingesting a range of values and not processing files + # indefinitely + total_shutdown_signals = len(FIREHOSE_MESSAGE_TYPES) + shutdown_signals_recvd = 0 + + while shutdown_signals_recvd < total_shutdown_signals: + s3_write_object: Optional[S3WriteObject] = await s3_queue.get() + + # Check for a shutdown signal + if s3_write_object is None: + shutdown_signals_recvd += 1 + continue # Get some timing stats on how long it takes to write to S3 timer = Timer( @@ -451,7 +483,7 @@ async def main(args: ap.Namespace): # Use a single S3 file writer for all message types tasks.append(write_files_to_s3(args, executor, s3_writer_queue)) - # Run all the tasks in the event loop + # Run all the tasks in the event loop to completion await asyncio.gather(*tasks, return_exceptions=False)