Skip to content

Commit

Permalink
ENH: Use new service communication schema
Browse files Browse the repository at this point in the history
  • Loading branch information
cortadocodes committed Nov 21, 2023
1 parent 6b86478 commit 67b0849
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 55 deletions.
33 changes: 17 additions & 16 deletions octue/cloud/emulators/_pub_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def __init__(self, data, attributes=None):
self.attributes = attributes or {}

# Encode the attributes as they would be in a real Pub/Sub message.
for key, value in attributes.items():
for key, value in self.attributes.items():
if isinstance(value, bool):
value = str(int(value))
elif isinstance(value, (int, float)):
Expand Down Expand Up @@ -335,34 +335,35 @@ def ask(
timeout=timeout,
)

# Ignore any errors from the answering service as they will be raised on the remote service in practice, not
# locally as is done in this mock.
if input_manifest is not None:
input_manifest = input_manifest.serialise()

# Delete question from messages sent to topic so the parent doesn't pick it up as a response message. We do this
# as subscription filtering isn't implemented in this set of mocks.
subscription_name = ".".join((convert_service_id_to_pub_sub_form(service_id), ANSWERS_NAMESPACE, question_uuid))
SUBSCRIPTIONS["octue.services." + subscription_name].pop(0)

question = {"type": "question"}

if input_values is not None:
question["input_values"] = input_values

# Ignore any errors from the answering service as they will be raised on the remote service in practice, not
# locally as is done in this mock.
if input_manifest is not None:
question["input_manifest"] = input_manifest.serialise()

if children is not None:
question["children"] = children

try:
self.children[service_id].answer(
MockMessage(
data=json.dumps(
{
"type": "question",
"input_values": input_values,
"input_manifest": input_manifest,
"children": children,
},
cls=OctueJSONEncoder,
).encode(),
data=json.dumps(question, cls=OctueJSONEncoder).encode(),
attributes={
"is_question": True,
"question_uuid": question_uuid,
"forward_logs": subscribe_to_logs,
"octue_sdk_version": parent_sdk_version,
"allow_save_diagnostics_data_on_crash": allow_save_diagnostics_data_on_crash,
"is_question": True,
"message_number": 0,
},
)
)
Expand Down
15 changes: 8 additions & 7 deletions octue/cloud/pub_sub/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from google.cloud.pubsub_v1 import SubscriberClient

from octue.cloud import EXCEPTIONS_MAPPING
from octue.cloud.pub_sub.messages import extract_event_and_attributes_from_pub_sub
from octue.cloud.validation import SERVICE_COMMUNICATION_SCHEMA, is_message_valid
from octue.definitions import GOOGLE_COMPUTE_PROVIDERS
from octue.log_handlers import COLOUR_PALETTE
Expand Down Expand Up @@ -224,24 +225,24 @@ def _pull_and_enqueue_message(self, timeout):
self._subscriber.acknowledge(request={"subscription": self.subscription.path, "ack_ids": [answer.ack_id]})
logger.debug("%r received a message related to question %r.", self.receiving_service, self.question_uuid)

event, attributes = extract_event_and_attributes_from_pub_sub(answer.message)
message_number = attributes["message_number"]

# Get the child's Octue SDK version from the first message.
if not self._child_sdk_version:
self._child_sdk_version = answer.message.attributes.get("octue_sdk_version")

message_number = int(answer.message.attributes["message_number"])
message = json.loads(answer.message.data.decode(), cls=OctueJSONDecoder)
self._child_sdk_version = attributes["octue_sdk_version"]

if not is_message_valid(
message=message,
attributes=dict(answer.message.attributes),
message=event,
attributes=attributes,
receiving_service=self.receiving_service,
parent_sdk_version=importlib.metadata.version("octue"),
child_sdk_version=self._child_sdk_version,
schema=self.message_schema,
):
return

self.waiting_messages[message_number] = message
self.waiting_messages[message_number] = event
self._earliest_message_number_received = min(self._earliest_message_number_received, message_number)

def _attempt_to_handle_queued_messages(self, skip_first_messages_after=60):
Expand Down
45 changes: 45 additions & 0 deletions octue/cloud/pub_sub/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import base64
import json

from octue.utils.decoders import OctueJSONDecoder
from octue.utils.objects import getattr_or_subscribe


def extract_event_and_attributes_from_pub_sub(message):
# Cast attributes to dict to avoid defaultdict behaviour.
attributes = dict(getattr_or_subscribe(message, "attributes"))
is_question = bool(int(attributes["is_question"]))
question_uuid = attributes["question_uuid"]
message_number = int(attributes["message_number"])
octue_sdk_version = attributes["octue_sdk_version"]

try:
forward_logs = {"forward_logs": bool(int(attributes["forward_logs"]))}
except KeyError:
forward_logs = {}

try:
allow_save_diagnostics_data_on_crash = {
"allow_save_diagnostics_data_on_crash": bool(int(attributes["allow_save_diagnostics_data_on_crash"]))
}
except KeyError:
allow_save_diagnostics_data_on_crash = {}

try:
# Parse event directly from Pub/Sub or Dataflow.
event = json.loads(message.data.decode(), cls=OctueJSONDecoder)
except Exception:
# Parse event from Google Cloud Run.
event = json.loads(base64.b64decode(message["data"]).decode("utf-8").strip(), cls=OctueJSONDecoder)

return (
event,
{
"is_question": is_question,
"question_uuid": question_uuid,
"octue_sdk_version": octue_sdk_version,
"message_number": message_number,
**forward_logs,
**allow_save_diagnostics_data_on_crash,
},
)
53 changes: 23 additions & 30 deletions octue/cloud/pub_sub/service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import base64
import concurrent.futures
import copy
import datetime
import functools
import importlib.metadata
Expand All @@ -17,6 +17,7 @@
from octue.cloud.pub_sub import Subscription, Topic
from octue.cloud.pub_sub.logging import GooglePubSubHandler
from octue.cloud.pub_sub.message_handler import OrderedMessageHandler
from octue.cloud.pub_sub.messages import extract_event_and_attributes_from_pub_sub
from octue.cloud.service_id import (
convert_service_id_to_pub_sub_form,
create_sruid,
Expand All @@ -25,12 +26,11 @@
split_service_id,
validate_sruid,
)
from octue.cloud.validation import SERVICE_COMMUNICATION_SCHEMA, raise_if_message_is_invalid
from octue.cloud.validation import raise_if_message_is_invalid
from octue.compatibility import warn_if_incompatible
from octue.utils.decoders import OctueJSONDecoder
from octue.utils.encoders import OctueJSONEncoder
from octue.utils.exceptions import convert_exception_to_primitives
from octue.utils.objects import get_nested_attribute
from octue.utils.threads import RepeatingTimer


Expand Down Expand Up @@ -534,39 +534,32 @@ def _parse_question(self, question):
"""
logger.info("%r received a question.", self)

try:
# Parse question directly from Pub/Sub or Dataflow.
data = json.loads(question.data.decode(), cls=OctueJSONDecoder)

# Acknowledge it if it's directly from Pub/Sub
if hasattr(question, "ack"):
question.ack()
# Acknowledge it if it's directly from Pub/Sub
if hasattr(question, "ack"):
question.ack()

except Exception:
# Parse question from Google Cloud Run.
data = json.loads(base64.b64decode(question["data"]).decode("utf-8").strip(), cls=OctueJSONDecoder)
event, attributes = extract_event_and_attributes_from_pub_sub(question)
event_for_validation = copy.deepcopy(event)

question_uuid = get_nested_attribute(question, "attributes.question_uuid")
forward_logs = bool(int(get_nested_attribute(question, "attributes.forward_logs")))
parent_sdk_version = get_nested_attribute(question, "attributes.octue_sdk_version")

allow_save_diagnostics_data_on_crash = bool(
int(get_nested_attribute(question, "attributes.allow_save_diagnostics_data_on_crash"))
)
# Deserialise input manifest into primitives for validation but leave it serialised for the return value so
# Twine validation still works.
if event.get("input_manifest"):
event_for_validation["input_manifest"] = json.loads(event["input_manifest"], cls=OctueJSONDecoder)

raise_if_message_is_invalid(
message=data,
attributes={
"question_uuid": question_uuid,
"forward_logs": forward_logs,
"parent_sdk_version": parent_sdk_version,
"allow_save_diagnostics_data_on_crash": allow_save_diagnostics_data_on_crash,
},
message=event_for_validation,
attributes=attributes,
receiving_service=self,
parent_sdk_version=parent_sdk_version,
parent_sdk_version=attributes["octue_sdk_version"],
child_sdk_version=importlib.metadata.version("octue"),
schema={"$ref": SERVICE_COMMUNICATION_SCHEMA},
)

logger.info("%r parsed the question successfully.", self)
return data, question_uuid, forward_logs, parent_sdk_version, allow_save_diagnostics_data_on_crash

return (
event,
attributes["question_uuid"],
attributes["forward_logs"],
attributes["octue_sdk_version"],
attributes["allow_save_diagnostics_data_on_crash"],
)
2 changes: 1 addition & 1 deletion octue/cloud/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

logger = logging.getLogger(__name__)

SERVICE_COMMUNICATION_SCHEMA = "https://jsonschema.registry.octue.com/octue/service-communication/0.3.0.json"
SERVICE_COMMUNICATION_SCHEMA = "https://jsonschema.registry.octue.com/octue/service-communication/0.4.0.json"
SERVICE_COMMUNICATION_SCHEMA_INFO_URL = "https://strands.octue.com/octue/service-communication"


Expand Down
7 changes: 6 additions & 1 deletion tests/cloud/pub_sub/test_message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,12 @@ def test_pull_and_enqueue_message(self):
SUBSCRIPTIONS[mock_subscription.name] = [
MockMessage(
data=json.dumps(mock_message).encode(),
attributes={"is_question": False, "message_number": 0},
attributes={
"is_question": False,
"message_number": 0,
"question_uuid": question_uuid,
"octue_sdk_version": "0.50.0",
},
)
]

Expand Down

0 comments on commit 67b0849

Please sign in to comment.