Skip to content

Commit

Permalink
ENH: Pull as many messages from subscription as possible at once
Browse files Browse the repository at this point in the history
  • Loading branch information
cortadocodes committed Jan 30, 2024
1 parent 17c5899 commit 24d848e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
34 changes: 22 additions & 12 deletions octue/cloud/pub_sub/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
logger = logging.getLogger(__name__)


MAX_SIMULTANEOUS_MESSAGES_PULL = 50


class OrderedMessageHandler:
"""A handler for Google Pub/Sub messages received via a pull subscription that ensures messages are handled in the
order they were sent.
Expand Down Expand Up @@ -134,7 +137,7 @@ def handle_messages(self, timeout=60, maximum_heartbeat_interval=300, skip_first

while self._alive:
pull_timeout = self._check_timeout_and_get_pull_timeout(timeout)
self._pull_and_enqueue_message(timeout=pull_timeout)
self._pull_and_enqueue_messages(timeout=pull_timeout)
result = self._attempt_to_handle_queued_messages(skip_first_messages_after)

if result is not None:
Expand Down Expand Up @@ -186,9 +189,9 @@ def _check_timeout_and_get_pull_timeout(self, timeout):

return timeout - total_run_time

def _pull_and_enqueue_message(self, timeout):
"""Pull a message from the subscription and enqueue it in `self.waiting_messages`, raising a `TimeoutError` if
the timeout is exceeded before succeeding.
def _pull_and_enqueue_messages(self, timeout):
"""Pull as many messages as are available from the subscription and enqueue them in `self.waiting_messages`,
raising a `TimeoutError` if the timeout is exceeded before succeeding.
:param float|None timeout: how long to wait in seconds for the message before raising a `TimeoutError`
:raise TimeoutError|concurrent.futures.TimeoutError: if the timeout is exceeded
Expand All @@ -201,15 +204,13 @@ def _pull_and_enqueue_message(self, timeout):
logger.debug("Pulling messages from Google Pub/Sub: attempt %d.", attempt)

pull_response = self._subscriber.pull(
request={"subscription": self.subscription.path, "max_messages": 1},
request={"subscription": self.subscription.path, "max_messages": MAX_SIMULTANEOUS_MESSAGES_PULL},
retry=retry.Retry(),
)

try:
answer = pull_response.received_messages[0]
if len(pull_response.received_messages) > 0:
break

except IndexError:
else:
logger.debug("Google Pub/Sub pull response timed out early.")
attempt += 1

Expand All @@ -220,10 +221,19 @@ def _pull_and_enqueue_message(self, timeout):
f"No message received from topic {self.subscription.topic.path!r} after {timeout} seconds.",
)

self._subscriber.acknowledge(request={"subscription": self.subscription.path, "ack_ids": [answer.ack_id]})
logger.debug("%r received a message related to question %r.", self.receiving_service, self.question_uuid)
self._subscriber.acknowledge(
request={
"subscription": self.subscription.path,
"ack_ids": [message.ack_id for message in pull_response.received_messages],
}
)

event, attributes = extract_event_and_attributes_from_pub_sub(answer.message)
for message in pull_response.received_messages:
self._extract_and_enqueue_event(message)

def _extract_and_enqueue_event(self, message):
logger.debug("%r received a message related to question %r.", self.receiving_service, self.question_uuid)
event, attributes = extract_event_and_attributes_from_pub_sub(message.message)

if not is_event_valid(
event=event,
Expand Down
26 changes: 13 additions & 13 deletions tests/cloud/pub_sub/test_message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_timeout(self):
)

with patch(
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_message",
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_messages",
new=MockMessagePuller(
messages=[MockMessage(b"")],
message_handler=message_handler,
Expand All @@ -65,7 +65,7 @@ def test_in_order_messages_are_handled_in_order(self):
]

with patch(
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_message",
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_messages",
new=MockMessagePuller(messages=messages, message_handler=message_handler).pull,
):
result = message_handler.handle_messages()
Expand All @@ -91,7 +91,7 @@ def test_out_of_order_messages_are_handled_in_order(self):
]

with patch(
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_message",
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_messages",
new=MockMessagePuller(messages=messages, message_handler=message_handler).pull,
):
result = message_handler.handle_messages()
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_out_of_order_messages_with_end_message_first_are_handled_in_order(self)
)

with patch(
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_message",
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_messages",
new=MockMessagePuller(
messages=[
MockMessage.from_primitive({"kind": "finish-test", "order": 3}, attributes={"message_number": 3}),
Expand Down Expand Up @@ -162,7 +162,7 @@ def test_no_timeout(self):
]

with patch(
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_message",
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_messages",
new=MockMessagePuller(messages=messages, message_handler=message_handler).pull,
):
result = message_handler.handle_messages(timeout=None)
Expand All @@ -182,7 +182,7 @@ def test_delivery_acknowledgement(self):
)

with patch(
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_message",
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_messages",
new=MockMessagePuller(
[
MockMessage.from_primitive(
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_error_raised_if_heartbeat_not_received_before_checked(self):
receiving_service=receiving_service,
)

with patch("octue.cloud.pub_sub.message_handler.OrderedMessageHandler._pull_and_enqueue_message"):
with patch("octue.cloud.pub_sub.message_handler.OrderedMessageHandler._pull_and_enqueue_messages"):
with self.assertRaises(TimeoutError) as error:
message_handler.handle_messages(maximum_heartbeat_interval=0)

Expand Down Expand Up @@ -245,7 +245,7 @@ def test_error_not_raised_if_heartbeat_has_been_received_in_maximum_allowed_inte
message_handler._last_heartbeat = datetime.datetime.now()

with patch(
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_message",
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_messages",
new=MockMessagePuller(
messages=[
MockMessage.from_primitive(
Expand Down Expand Up @@ -313,7 +313,7 @@ def test_handler_can_skip_first_n_messages_if_missed(self):
]

with patch(
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_message",
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_messages",
new=MockMessagePuller(messages=messages, message_handler=message_handler).pull,
):
result = message_handler.handle_messages(skip_first_messages_after=0)
Expand Down Expand Up @@ -347,7 +347,7 @@ def test_later_missing_messages_cannot_be_skipped(self):
]

with patch(
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_message",
"octue.cloud.pub_sub.service.OrderedMessageHandler._pull_and_enqueue_messages",
new=MockMessagePuller(messages=messages, message_handler=message_handler).pull,
):
with self.assertRaises(TimeoutError):
Expand All @@ -365,7 +365,7 @@ def test_later_missing_messages_cannot_be_skipped(self):


class TestPullAndEnqueueMessage(BaseTestCase):
def test_pull_and_enqueue_message(self):
def test_pull_and_enqueue_messages(self):
"""Test that pulling and enqueuing a message works."""
question_uuid = "4d31bb46-66c4-4e68-831f-e51e17e651ef"

Expand Down Expand Up @@ -401,7 +401,7 @@ def test_pull_and_enqueue_message(self):
)
]

message_handler._pull_and_enqueue_message(timeout=10)
message_handler._pull_and_enqueue_messages(timeout=10)
self.assertEqual(message_handler.waiting_messages, {0: mock_message})
self.assertEqual(message_handler._earliest_message_number_received, 0)

Expand Down Expand Up @@ -430,6 +430,6 @@ def test_timeout_error_raised_if_result_message_not_received_in_time(self):
SUBSCRIPTIONS[mock_subscription.name] = []

with self.assertRaises(TimeoutError):
message_handler._pull_and_enqueue_message(timeout=1e-6)
message_handler._pull_and_enqueue_messages(timeout=1e-6)

self.assertEqual(message_handler._earliest_message_number_received, math.inf)

0 comments on commit 24d848e

Please sign in to comment.