diff --git a/lib/charms/data_platform_libs/v0/data_interfaces.py b/lib/charms/data_platform_libs/v0/data_interfaces.py index 86d7521..0bb594e 100644 --- a/lib/charms/data_platform_libs/v0/data_interfaces.py +++ b/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Library to manage the relation for the data-platform products. +r"""Library to manage the relation for the data-platform products. This library contains the Requires and Provides classes for handling the relation between an application and multiple managed application supported by the data-team: @@ -291,22 +291,26 @@ def _on_topic_requested(self, event: TopicRequestedEvent): exchanged in the relation databag. """ +import copy import json import logging from abc import ABC, abstractmethod from collections import namedtuple from datetime import datetime -from typing import List, Optional +from enum import Enum +from typing import Callable, Dict, List, Optional, Set, Tuple, Union +from ops import JujuVersion, Secret, SecretInfo, SecretNotFoundError from ops.charm import ( CharmBase, CharmEvents, RelationChangedEvent, + RelationCreatedEvent, RelationEvent, - RelationJoinedEvent, + SecretChangedEvent, ) from ops.framework import EventSource, Object -from ops.model import Relation +from ops.model import Application, ModelError, Relation, Unit # The unique Charmhub library identifier, never change it LIBID = "6c3e6b6680d64e9c89e611d1a15f65be" @@ -316,7 +320,7 @@ def _on_topic_requested(self, event: TopicRequestedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 12 +LIBPATCH = 21 PYDEPS = ["ops>=2.0.0"] @@ -331,7 +335,79 @@ def _on_topic_requested(self, event: TopicRequestedEvent): deleted - key that were deleted""" -def diff(event: RelationChangedEvent, bucket: str) -> Diff: +PROV_SECRET_PREFIX = "secret-" +REQ_SECRET_FIELDS = "requested-secrets" + + +class SecretGroup(Enum): + """Secret groups as constants.""" + + USER = "user" + TLS = "tls" + EXTRA = "extra" + + +# Local map to associate mappings with secrets potentially as a group +SECRET_LABEL_MAP = { + "username": SecretGroup.USER, + "password": SecretGroup.USER, + "uris": SecretGroup.USER, + "tls": SecretGroup.TLS, + "tls-ca": SecretGroup.TLS, +} + + +class DataInterfacesError(Exception): + """Common ancestor for DataInterfaces related exceptions.""" + + +class SecretError(Exception): + """Common ancestor for Secrets related exceptions.""" + + +class SecretAlreadyExistsError(SecretError): + """A secret that was to be added already exists.""" + + +class SecretsUnavailableError(SecretError): + """Secrets aren't yet available for Juju version used.""" + + +class SecretsIllegalUpdateError(SecretError): + """Secrets aren't yet available for Juju version used.""" + + +def get_encoded_dict( + relation: Relation, member: Union[Unit, Application], field: str +) -> Optional[Dict[str, str]]: + """Retrieve and decode an encoded field from relation data.""" + data = json.loads(relation.data[member].get(field, "{}")) + if isinstance(data, dict): + return data + logger.error("Unexpected datatype for %s instead of dict.", str(data)) + + +def get_encoded_list( + relation: Relation, member: Union[Unit, Application], field: str +) -> Optional[List[str]]: + """Retrieve and decode an encoded field from relation data.""" + data = json.loads(relation.data[member].get(field, "[]")) + if isinstance(data, list): + return data + logger.error("Unexpected datatype for %s instead of list.", str(data)) + + +def set_encoded_field( + relation: Relation, + member: Union[Unit, Application], + field: str, + value: Union[str, list, Dict[str, str]], +) -> None: + """Set an encoded field from relation data.""" + relation.data[member].update({field: json.dumps(value)}) + + +def diff(event: RelationChangedEvent, bucket: Union[Unit, Application]) -> Diff: """Retrieves the diff of the data in the relation changed databag. Args: @@ -343,31 +419,164 @@ def diff(event: RelationChangedEvent, bucket: str) -> Diff: keys from the event relation databag. """ # Retrieve the old data from the data key in the application relation databag. - old_data = json.loads(event.relation.data[bucket].get("data", "{}")) + old_data = get_encoded_dict(event.relation, bucket, "data") + + if not old_data: + old_data = {} + # Retrieve the new data from the event relation databag. - new_data = { - key: value for key, value in event.relation.data[event.app].items() if key != "data" - } + new_data = ( + {key: value for key, value in event.relation.data[event.app].items() if key != "data"} + if event.app + else {} + ) # These are the keys that were added to the databag and triggered this event. - added = new_data.keys() - old_data.keys() + added = new_data.keys() - old_data.keys() # pyright: ignore [reportGeneralTypeIssues] # These are the keys that were removed from the databag and triggered this event. - deleted = old_data.keys() - new_data.keys() + deleted = old_data.keys() - new_data.keys() # pyright: ignore [reportGeneralTypeIssues] # These are the keys that already existed in the databag, # but had their values changed. - changed = {key for key in old_data.keys() & new_data.keys() if old_data[key] != new_data[key]} + changed = { + key + for key in old_data.keys() & new_data.keys() # pyright: ignore [reportGeneralTypeIssues] + if old_data[key] != new_data[key] # pyright: ignore [reportGeneralTypeIssues] + } # Convert the new_data to a serializable format and save it for a next diff check. - event.relation.data[bucket].update({"data": json.dumps(new_data)}) + set_encoded_field(event.relation, bucket, "data", new_data) # Return the diff with all possible changes. return Diff(added, changed, deleted) -# Base DataProvides and DataRequires +def leader_only(f): + """Decorator to ensure that only leader can perform given operation.""" + def wrapper(self, *args, **kwargs): + if not self.local_unit.is_leader(): + logger.error( + "This operation (%s()) can only be performed by the leader unit", f.__name__ + ) + return + return f(self, *args, **kwargs) + + return wrapper + + +def juju_secrets_only(f): + """Decorator to ensure that certain operations would be only executed on Juju3.""" + + def wrapper(self, *args, **kwargs): + if not self.secrets_enabled: + raise SecretsUnavailableError("Secrets unavailable on current Juju version") + return f(self, *args, **kwargs) + + return wrapper + + +class Scope(Enum): + """Peer relations scope.""" + + APP = "app" + UNIT = "unit" + + +class CachedSecret: + """Locally cache a secret. + + The data structure is precisely re-using/simulating as in the actual Secret Storage + """ + + def __init__(self, charm: CharmBase, label: str, secret_uri: Optional[str] = None): + self._secret_meta = None + self._secret_content = {} + self._secret_uri = secret_uri + self.label = label + self.charm = charm + + def add_secret(self, content: Dict[str, str], relation: Relation) -> Secret: + """Create a new secret.""" + if self._secret_uri: + raise SecretAlreadyExistsError( + "Secret is already defined with uri %s", self._secret_uri + ) + + secret = self.charm.app.add_secret(content, label=self.label) + secret.grant(relation) + self._secret_uri = secret.id + self._secret_meta = secret + return self._secret_meta + + @property + def meta(self) -> Optional[Secret]: + """Getting cached secret meta-information.""" + if not self._secret_meta: + if not (self._secret_uri or self.label): + return + try: + self._secret_meta = self.charm.model.get_secret(label=self.label) + except SecretNotFoundError: + if self._secret_uri: + self._secret_meta = self.charm.model.get_secret( + id=self._secret_uri, label=self.label + ) + return self._secret_meta + + def get_content(self) -> Dict[str, str]: + """Getting cached secret content.""" + if not self._secret_content: + if self.meta: + self._secret_content = self.meta.get_content() + return self._secret_content + + def set_content(self, content: Dict[str, str]) -> None: + """Setting cached secret content.""" + if not self.meta: + return + + if content: + self.meta.set_content(content) + self._secret_content = content + else: + self.meta.remove_all_revisions() + + def get_info(self) -> Optional[SecretInfo]: + """Wrapper function to apply the corresponding call on the Secret object within CachedSecret if any.""" + if self.meta: + return self.meta.get_info() + + +class SecretCache: + """A data structure storing CachedSecret objects.""" + + def __init__(self, charm): + self.charm = charm + self._secrets: Dict[str, CachedSecret] = {} + + def get(self, label: str, uri: Optional[str] = None) -> Optional[CachedSecret]: + """Getting a secret from Juju Secret store or cache.""" + if not self._secrets.get(label): + secret = CachedSecret(self.charm, label, uri) + if secret.meta: + self._secrets[label] = secret + return self._secrets.get(label) + + def add(self, label: str, content: Dict[str, str], relation: Relation) -> CachedSecret: + """Adding a secret to Juju Secret.""" + if self._secrets.get(label): + raise SecretAlreadyExistsError(f"Secret {label} already exists") + + secret = CachedSecret(self.charm, label) + secret.add_secret(content, relation) + self._secrets[label] = secret + return self._secrets[label] + + +# Base DataRelation -class DataProvides(Object, ABC): - """Base provides-side of the data products relation.""" + +class DataRelation(Object, ABC): + """Base relation data mainpulation (abstract) class.""" def __init__(self, charm: CharmBase, relation_name: str) -> None: super().__init__(charm, relation_name) @@ -377,62 +586,562 @@ def __init__(self, charm: CharmBase, relation_name: str) -> None: self.relation_name = relation_name self.framework.observe( charm.on[relation_name].relation_changed, - self._on_relation_changed, + self._on_relation_changed_event, ) + self._jujuversion = None + self.secrets = SecretCache(self.charm) - def _diff(self, event: RelationChangedEvent) -> Diff: - """Retrieves the diff of the data in the relation changed databag. + @property + def relations(self) -> List[Relation]: + """The list of Relation instances associated with this relation_name.""" + return [ + relation + for relation in self.charm.model.relations[self.relation_name] + if self._is_relation_active(relation) + ] - Args: - event: relation changed event. + @property + def secrets_enabled(self): + """Is this Juju version allowing for Secrets usage?""" + if not self._jujuversion: + self._jujuversion = JujuVersion.from_environ() + return self._jujuversion.has_secrets - Returns: - a Diff instance containing the added, deleted and changed - keys from the event relation databag. - """ - return diff(event, self.local_app) + # Mandatory overrides for internal/helper methods @abstractmethod - def _on_relation_changed(self, event: RelationChangedEvent) -> None: + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation data has changed.""" raise NotImplementedError - def fetch_relation_data(self) -> dict: + @abstractmethod + def _get_relation_secret( + self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret that's been stored in the relation databag.""" + raise NotImplementedError + + @abstractmethod + def _fetch_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetch data available (directily or indirectly -- i.e. secrets) from the relation.""" + raise NotImplementedError + + @abstractmethod + def _fetch_my_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetch data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + raise NotImplementedError + + @abstractmethod + def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: + """Update data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + raise NotImplementedError + + @abstractmethod + def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: + """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + raise NotImplementedError + + # Internal helper methods + + @staticmethod + def _is_relation_active(relation: Relation): + """Whether the relation is active based on contained data.""" + try: + _ = repr(relation.data) + return True + except (RuntimeError, ModelError): + return False + + @staticmethod + def _is_secret_field(field: str) -> bool: + """Is the field in question a secret reference (URI) field or not?""" + return field.startswith(PROV_SECRET_PREFIX) + + @staticmethod + def _generate_secret_label( + relation_name: str, relation_id: int, group_mapping: SecretGroup + ) -> str: + """Generate unique group_mappings for secrets within a relation context.""" + return f"{relation_name}.{relation_id}.{group_mapping.value}.secret" + + @staticmethod + def _generate_secret_field_name(group_mapping: SecretGroup) -> str: + """Generate unique group_mappings for secrets within a relation context.""" + return f"{PROV_SECRET_PREFIX}{group_mapping.value}" + + def _relation_from_secret_label(self, secret_label: str) -> Optional[Relation]: + """Retrieve the relation that belongs to a secret label.""" + contents = secret_label.split(".") + + if not (contents and len(contents) >= 3): + return + + contents.pop() # ".secret" at the end + contents.pop() # Group mapping + relation_id = contents.pop() + try: + relation_id = int(relation_id) + except ValueError: + return + + # In case '.' character appeared in relation name + relation_name = ".".join(contents) + + try: + return self.get_relation(relation_name, relation_id) + except ModelError: + return + + @staticmethod + def _group_secret_fields(secret_fields: List[str]) -> Dict[SecretGroup, List[str]]: + """Helper function to arrange secret mappings under their group. + + NOTE: All unrecognized items end up in the 'extra' secret bucket. + Make sure only secret fields are passed! + """ + secret_fieldnames_grouped = {} + for key in secret_fields: + if group := SECRET_LABEL_MAP.get(key): + secret_fieldnames_grouped.setdefault(group, []).append(key) + else: + secret_fieldnames_grouped.setdefault(SecretGroup.EXTRA, []).append(key) + return secret_fieldnames_grouped + + def _get_group_secret_contents( + self, + relation: Relation, + group: SecretGroup, + secret_fields: Optional[Union[Set[str], List[str]]] = None, + ) -> Dict[str, str]: + """Helper function to retrieve collective, requested contents of a secret.""" + if not secret_fields: + secret_fields = [] + + if (secret := self._get_relation_secret(relation.id, group)) and ( + secret_data := secret.get_content() + ): + return {k: v for k, v in secret_data.items() if k in secret_fields} + return {} + + @staticmethod + def _content_for_secret_group( + content: Dict[str, str], secret_fields: Set[str], group_mapping: SecretGroup + ) -> Dict[str, str]: + """Select : pairs from input, that belong to this particular Secret group.""" + if group_mapping == SecretGroup.EXTRA: + return { + k: v + for k, v in content.items() + if k in secret_fields and k not in SECRET_LABEL_MAP.keys() + } + + return { + k: v + for k, v in content.items() + if k in secret_fields and SECRET_LABEL_MAP.get(k) == group_mapping + } + + @juju_secrets_only + def _get_relation_secret_data( + self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[Dict[str, str]]: + """Retrieve contents of a Juju Secret that's been stored in the relation databag.""" + secret = self._get_relation_secret(relation_id, group_mapping, relation_name) + if secret: + return secret.get_content() + + # Core operations on Relation Fields manipulations (regardless whether the field is in the databag or in a secret) + # Internal functions to be called directly from transparent public interface functions (+closely related helpers) + + def _process_secret_fields( + self, + relation: Relation, + req_secret_fields: Optional[List[str]], + impacted_rel_fields: List[str], + operation: Callable, + second_chance_as_normal_field: bool = True, + *args, + **kwargs, + ) -> Tuple[Dict[str, str], Set[str]]: + """Isolate target secret fields of manipulation, and execute requested operation by Secret Group.""" + result = {} + normal_fields = set(impacted_rel_fields) + if req_secret_fields and self.secrets_enabled: + normal_fields = normal_fields - set(req_secret_fields) + secret_fields = set(impacted_rel_fields) - set(normal_fields) + + secret_fieldnames_grouped = self._group_secret_fields(list(secret_fields)) + + for group in secret_fieldnames_grouped: + # operation() should return nothing when all goes well + if group_result := operation(relation, group, secret_fields, *args, **kwargs): + result.update(group_result) + elif second_chance_as_normal_field: + # If it wasn't found as a secret, let's give it a 2nd chance as "normal" field + # Needed when Juju3 Requires meets Juju2 Provider + normal_fields |= set(secret_fieldnames_grouped[group]) + return (result, normal_fields) + + def _fetch_relation_data_without_secrets( + self, app: Application, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetching databag contents when no secrets are involved. + + Since the Provider's databag is the only one holding secrest, we can apply + a simplified workflow to read the Require's side's databag. + This is used typically when the Provides side wants to read the Requires side's data, + or when the Requires side may want to read its own data. + """ + if fields: + return {k: relation.data[app][k] for k in fields if k in relation.data[app]} + else: + return dict(relation.data[app]) + + def _fetch_relation_data_with_secrets( + self, + app: Application, + req_secret_fields: Optional[List[str]], + relation: Relation, + fields: Optional[List[str]] = None, + ) -> Dict[str, str]: + """Fetching databag contents when secrets may be involved. + + This function has internal logic to resolve if a requested field may be "hidden" + within a Relation Secret, or directly available as a databag field. Typically + used to read the Provides side's databag (eigher by the Requires side, or by + Provides side itself). + """ + result = {} + normal_fields = [] + + if not fields: + all_fields = list(relation.data[app].keys()) + normal_fields = [field for field in all_fields if not self._is_secret_field(field)] + + # There must have been secrets there + if all_fields != normal_fields and req_secret_fields: + # So we assemble the full fields list (without 'secret-' fields) + fields = normal_fields + req_secret_fields + + if fields: + result, normal_fields = self._process_secret_fields( + relation, req_secret_fields, fields, self._get_group_secret_contents + ) + + # Processing "normal" fields. May include leftover from what we couldn't retrieve as a secret. + # (Typically when Juju3 Requires meets Juju2 Provides) + if normal_fields: + result.update( + self._fetch_relation_data_without_secrets(app, relation, list(normal_fields)) + ) + return result + + def _update_relation_data_without_secrets( + self, app: Application, relation: Relation, data: Dict[str, str] + ): + """Updating databag contents when no secrets are involved.""" + if any(self._is_secret_field(key) for key in data.keys()): + raise SecretsIllegalUpdateError("Can't update secret {key}.") + + if relation: + relation.data[app].update(data) + + def _delete_relation_data_without_secrets( + self, app: Application, relation: Relation, fields: List[str] + ) -> None: + """Remove databag fields 'fields' from Relation.""" + for field in fields: + relation.data[app].pop(field) + + # Public interface methods + # Handling Relation Fields seamlessly, regardless if in databag or a Juju Secret + + def get_relation(self, relation_name, relation_id) -> Relation: + """Safe way of retrieving a relation.""" + relation = self.charm.model.get_relation(relation_name, relation_id) + + if not relation: + raise DataInterfacesError( + "Relation %s %s couldn't be retrieved", relation_name, relation_id + ) + + if not relation.app: + raise DataInterfacesError("Relation's application missing") + + return relation + + def fetch_relation_data( + self, + relation_ids: Optional[List[int]] = None, + fields: Optional[List[str]] = None, + relation_name: Optional[str] = None, + ) -> Dict[int, Dict[str, str]]: """Retrieves data from relation. This function can be used to retrieve data from a relation in the charm code when outside an event callback. + Function cannot be used in `*-relation-broken` events and will raise an exception. Returns: a dict of the values stored in the relation data bag - for all relation instances (indexed by the relation id). + for all relation instances (indexed by the relation ID). """ + if not relation_name: + relation_name = self.relation_name + + relations = [] + if relation_ids: + relations = [ + self.get_relation(relation_name, relation_id) for relation_id in relation_ids + ] + else: + relations = self.relations + data = {} - for relation in self.relations: - data[relation.id] = { - key: value for key, value in relation.data[relation.app].items() if key != "data" - } + for relation in relations: + if not relation_ids or (relation_ids and relation.id in relation_ids): + data[relation.id] = self._fetch_specific_relation_data(relation, fields) return data - def _update_relation_data(self, relation_id: int, data: dict) -> None: - """Updates a set of key-value pairs in the relation. + def fetch_relation_field( + self, relation_id: int, field: str, relation_name: Optional[str] = None + ) -> Optional[str]: + """Get a single field from the relation data.""" + return ( + self.fetch_relation_data([relation_id], [field], relation_name) + .get(relation_id, {}) + .get(field) + ) - This function writes in the application data bag, therefore, - only the leader unit can call it. + @leader_only + def fetch_my_relation_data( + self, + relation_ids: Optional[List[int]] = None, + fields: Optional[List[str]] = None, + relation_name: Optional[str] = None, + ) -> Optional[Dict[int, Dict[str, str]]]: + """Fetch data of the 'owner' (or 'this app') side of the relation. + + NOTE: Since only the leader can read the relation's 'this_app'-side + Application databag, the functionality is limited to leaders + """ + if not relation_name: + relation_name = self.relation_name + + relations = [] + if relation_ids: + relations = [ + self.get_relation(relation_name, relation_id) for relation_id in relation_ids + ] + else: + relations = self.relations + + data = {} + for relation in relations: + if not relation_ids or relation.id in relation_ids: + data[relation.id] = self._fetch_my_specific_relation_data(relation, fields) + return data + + @leader_only + def fetch_my_relation_field( + self, relation_id: int, field: str, relation_name: Optional[str] = None + ) -> Optional[str]: + """Get a single field from the relation data -- owner side. + + NOTE: Since only the leader can read the relation's 'this_app'-side + Application databag, the functionality is limited to leaders + """ + if relation_data := self.fetch_my_relation_data([relation_id], [field], relation_name): + return relation_data.get(relation_id, {}).get(field) + + @leader_only + def update_relation_data(self, relation_id: int, data: dict) -> None: + """Update the data within the relation.""" + relation_name = self.relation_name + relation = self.get_relation(relation_name, relation_id) + return self._update_relation_data(relation, data) + + @leader_only + def delete_relation_data(self, relation_id: int, fields: List[str]) -> None: + """Remove field from the relation.""" + relation_name = self.relation_name + relation = self.get_relation(relation_name, relation_id) + return self._delete_relation_data(relation, fields) + + +# Base DataProvides and DataRequires + + +class DataProvides(DataRelation): + """Base provides-side of the data products relation.""" + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + super().__init__(charm, relation_name) + + def _diff(self, event: RelationChangedEvent) -> Diff: + """Retrieves the diff of the data in the relation changed databag. Args: - relation_id: the identifier for a particular relation. - data: dict containing the key-value pairs - that should be updated in the relation. + event: relation changed event. + + Returns: + a Diff instance containing the added, deleted and changed + keys from the event relation databag. """ - if self.local_unit.is_leader(): - relation = self.charm.model.get_relation(self.relation_name, relation_id) - relation.data[self.local_app].update(data) + return diff(event, self.local_app) - @property - def relations(self) -> List[Relation]: - """The list of Relation instances associated with this relation_name.""" - return list(self.charm.model.relations[self.relation_name]) + # Private methods handling secrets + + @juju_secrets_only + def _add_relation_secret( + self, relation: Relation, content: Dict[str, str], group_mapping: SecretGroup + ) -> Optional[Secret]: + """Add a new Juju Secret that will be registered in the relation databag.""" + secret_field = self._generate_secret_field_name(group_mapping) + if relation.data[self.local_app].get(secret_field): + logging.error("Secret for relation %s already exists, not adding again", relation.id) + return + + label = self._generate_secret_label(self.relation_name, relation.id, group_mapping) + secret = self.secrets.add(label, content, relation) + + # According to lint we may not have a Secret ID + if secret.meta and secret.meta.id: + relation.data[self.local_app][secret_field] = secret.meta.id + + @juju_secrets_only + def _update_relation_secret( + self, relation: Relation, content: Dict[str, str], group_mapping: SecretGroup + ): + """Update the contents of an existing Juju Secret, referred in the relation databag.""" + secret = self._get_relation_secret(relation.id, group_mapping) + + if not secret: + logging.error("Can't update secret for relation %s", relation.id) + return + + old_content = secret.get_content() + full_content = copy.deepcopy(old_content) + full_content.update(content) + secret.set_content(full_content) + + def _add_or_update_relation_secrets( + self, + relation: Relation, + group: SecretGroup, + secret_fields: Set[str], + data: Dict[str, str], + ) -> None: + """Update contents for Secret group. If the Secret doesn't exist, create it.""" + secret_content = self._content_for_secret_group(data, secret_fields, group) + if self._get_relation_secret(relation.id, group): + self._update_relation_secret(relation, secret_content, group) + else: + self._add_relation_secret(relation, secret_content, group) + + @juju_secrets_only + def _delete_relation_secret( + self, relation: Relation, group: SecretGroup, secret_fields: List[str], fields: List[str] + ): + """Update the contents of an existing Juju Secret, referred in the relation databag.""" + secret = self._get_relation_secret(relation.id, group) + + if not secret: + logging.error("Can't update secret for relation %s", relation.id) + return + + old_content = secret.get_content() + new_content = copy.deepcopy(old_content) + for field in fields: + new_content.pop(field) + secret.set_content(new_content) + + if not new_content: + field = self._generate_secret_field_name(group) + relation.data[self.local_app].pop(field) + + # Mandatory internal overrides + + @juju_secrets_only + def _get_relation_secret( + self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret that's been stored in the relation databag.""" + if not relation_name: + relation_name = self.relation_name + + label = self._generate_secret_label(relation_name, relation_id, group_mapping) + if secret := self.secrets.get(label): + return secret + + relation = self.charm.model.get_relation(relation_name, relation_id) + if not relation: + return + + secret_field = self._generate_secret_field_name(group_mapping) + if secret_uri := relation.data[self.local_app].get(secret_field): + return self.secrets.get(label, secret_uri) + + def _fetch_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> Dict[str, str]: + """Fetching relation data for Provides. + + NOTE: Since all secret fields are in the Provides side of the databag, we don't need to worry about that + """ + if not relation.app: + return {} + + return self._fetch_relation_data_without_secrets(relation.app, relation, fields) + + def _fetch_my_specific_relation_data( + self, relation: Relation, fields: Optional[List[str]] + ) -> dict: + """Fetching our own relation data.""" + secret_fields = None + if relation.app: + secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) + + return self._fetch_relation_data_with_secrets( + self.local_app, + secret_fields, + relation, + fields, + ) + + def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: + """Set values for fields not caring whether it's a secret or not.""" + req_secret_fields = [] + if relation.app: + req_secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) + + _, normal_fields = self._process_secret_fields( + relation, + req_secret_fields, + list(data), + self._add_or_update_relation_secrets, + data=data, + ) + + normal_content = {k: v for k, v in data.items() if k in normal_fields} + self._update_relation_data_without_secrets(self.local_app, relation, normal_content) + + def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: + """Delete fields from the Relation not caring whether it's a secret or not.""" + req_secret_fields = [] + if relation.app: + req_secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) + + _, normal_fields = self._process_secret_fields( + relation, req_secret_fields, fields, self._delete_relation_secret, fields=fields + ) + self._delete_relation_data_without_secrets(self.local_app, relation, list(normal_fields)) + + # Public methods - "native" def set_credentials(self, relation_id: int, username: str, password: str) -> None: """Set credentials. @@ -445,13 +1154,7 @@ def set_credentials(self, relation_id: int, username: str, password: str) -> Non username: user that was created. password: password of the created user. """ - self._update_relation_data( - relation_id, - { - "username": username, - "password": password, - }, - ) + self.update_relation_data(relation_id, {"username": username, "password": password}) def set_tls(self, relation_id: int, tls: str) -> None: """Set whether TLS is enabled. @@ -460,7 +1163,7 @@ def set_tls(self, relation_id: int, tls: str) -> None: relation_id: the identifier for a particular relation. tls: whether tls is enabled (True or False). """ - self._update_relation_data(relation_id, {"tls": tls}) + self.update_relation_data(relation_id, {"tls": tls}) def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: """Set the TLS CA in the application relation databag. @@ -469,108 +1172,98 @@ def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: relation_id: the identifier for a particular relation. tls_ca: TLS certification authority. """ - self._update_relation_data(relation_id, {"tls-ca": tls_ca}) + self.update_relation_data(relation_id, {"tls-ca": tls_ca}) -class DataRequires(Object, ABC): +class DataRequires(DataRelation): """Requires-side of the relation.""" + SECRET_FIELDS = ["username", "password", "tls", "tls-ca", "uris"] + def __init__( self, charm, relation_name: str, - extra_user_roles: str = None, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], ): """Manager of base client relations.""" super().__init__(charm, relation_name) - self.charm = charm self.extra_user_roles = extra_user_roles - self.local_app = self.charm.model.app - self.local_unit = self.charm.unit - self.relation_name = relation_name + self._secret_fields = list(self.SECRET_FIELDS) + if additional_secret_fields: + self._secret_fields += additional_secret_fields + self.framework.observe( - self.charm.on[relation_name].relation_joined, self._on_relation_joined_event + self.charm.on[relation_name].relation_created, self._on_relation_created_event ) self.framework.observe( - self.charm.on[relation_name].relation_changed, self._on_relation_changed_event + charm.on.secret_changed, + self._on_secret_changed_event, ) - @abstractmethod - def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: - """Event emitted when the application joins the relation.""" - raise NotImplementedError - - @abstractmethod - def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: - raise NotImplementedError + @property + def secret_fields(self) -> Optional[List[str]]: + """Local access to secrets field, in case they are being used.""" + if self.secrets_enabled: + return self._secret_fields - def fetch_relation_data(self) -> dict: - """Retrieves data from relation. + def _diff(self, event: RelationChangedEvent) -> Diff: + """Retrieves the diff of the data in the relation changed databag. - This function can be used to retrieve data from a relation - in the charm code when outside an event callback. - Function cannot be used in `*-relation-broken` events and will raise an exception. + Args: + event: relation changed event. Returns: - a dict of the values stored in the relation data bag - for all relation instances (indexed by the relation ID). + a Diff instance containing the added, deleted and changed + keys from the event relation databag. """ - data = {} - for relation in self.relations: - data[relation.id] = { - key: value for key, value in relation.data[relation.app].items() if key != "data" - } - return data + return diff(event, self.local_unit) - def _update_relation_data(self, relation_id: int, data: dict) -> None: - """Updates a set of key-value pairs in the relation. + # Internal helper functions - This function writes in the application data bag, therefore, - only the leader unit can call it. + def _register_secret_to_relation( + self, relation_name: str, relation_id: int, secret_id: str, group: SecretGroup + ): + """Fetch secrets and apply local label on them. - Args: - relation_id: the identifier for a particular relation. - data: dict containing the key-value pairs - that should be updated in the relation. + [MAGIC HERE] + If we fetch a secret using get_secret(id=, label=), + then will be "stuck" on the Secret object, whenever it may + appear (i.e. as an event attribute, or fetched manually) on future occasions. + + This will allow us to uniquely identify the secret on Provides side (typically on + 'secret-changed' events), and map it to the corresponding relation. """ - if self.local_unit.is_leader(): - relation = self.charm.model.get_relation(self.relation_name, relation_id) - relation.data[self.local_app].update(data) + label = self._generate_secret_label(relation_name, relation_id, group) - def _diff(self, event: RelationChangedEvent) -> Diff: - """Retrieves the diff of the data in the relation changed databag. + # Fetchin the Secret's meta information ensuring that it's locally getting registered with + CachedSecret(self.charm, label, secret_id).meta - Args: - event: relation changed event. + def _register_secrets_to_relation(self, relation: Relation, params_name_list: List[str]): + """Make sure that secrets of the provided list are locally 'registered' from the databag. - Returns: - a Diff instance containing the added, deleted and changed - keys from the event relation databag. + More on 'locally registered' magic is described in _register_secret_to_relation() method """ - return diff(event, self.local_unit) + if not relation.app: + return - @property - def relations(self) -> List[Relation]: - """The list of Relation instances associated with this relation_name.""" - return [ - relation - for relation in self.charm.model.relations[self.relation_name] - if self._is_relation_active(relation) - ] + for group in SecretGroup: + secret_field = self._generate_secret_field_name(group) + if secret_field in params_name_list: + if secret_uri := relation.data[relation.app].get(secret_field): + self._register_secret_to_relation( + relation.name, relation.id, secret_uri, group + ) - @staticmethod - def _is_relation_active(relation: Relation): - try: - _ = repr(relation.data) - return True - except RuntimeError: + def _is_resource_created_for_relation(self, relation: Relation) -> bool: + if not relation.app: return False - @staticmethod - def _is_resource_created_for_relation(relation: Relation): - return ( - "username" in relation.data[relation.app] and "password" in relation.data[relation.app] + data = self.fetch_relation_data([relation.id], ["username", "password"]).get( + relation.id, {} ) + return bool(data.get("username")) and bool(data.get("password")) def is_resource_created(self, relation_id: Optional[int] = None) -> bool: """Check if the resource has been created. @@ -599,15 +1292,81 @@ def is_resource_created(self, relation_id: Optional[int] = None) -> bool: else: return ( all( - [ - self._is_resource_created_for_relation(relation) - for relation in self.relations - ] + self._is_resource_created_for_relation(relation) for relation in self.relations ) if self.relations else False ) + # Event handlers + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the relation is created.""" + if not self.local_unit.is_leader(): + return + + if self.secret_fields: + set_encoded_field( + event.relation, self.charm.app, REQ_SECRET_FIELDS, self.secret_fields + ) + + @abstractmethod + def _on_secret_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation data has changed.""" + raise NotImplementedError + + # Mandatory internal overrides + + @juju_secrets_only + def _get_relation_secret( + self, relation_id: int, group: SecretGroup, relation_name: Optional[str] = None + ) -> Optional[CachedSecret]: + """Retrieve a Juju Secret that's been stored in the relation databag.""" + if not relation_name: + relation_name = self.relation_name + + label = self._generate_secret_label(relation_name, relation_id, group) + return self.secrets.get(label) + + def _fetch_specific_relation_data( + self, relation, fields: Optional[List[str]] = None + ) -> Dict[str, str]: + """Fetching Requires data -- that may include secrets.""" + if not relation.app: + return {} + return self._fetch_relation_data_with_secrets( + relation.app, self.secret_fields, relation, fields + ) + + def _fetch_my_specific_relation_data(self, relation, fields: Optional[List[str]]) -> dict: + """Fetching our own relation data.""" + return self._fetch_relation_data_without_secrets(self.local_app, relation, fields) + + def _update_relation_data(self, relation: Relation, data: dict) -> None: + """Updates a set of key-value pairs in the relation. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation: the particular relation. + data: dict containing the key-value pairs + that should be updated in the relation. + """ + return self._update_relation_data_without_secrets(self.local_app, relation, data) + + def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: + """Deletes a set of fields from the relation. + + This function writes in the application data bag, therefore, + only the leader unit can call it. + + Args: + relation: the particular relation. + fields: list containing the field names that should be removed from the relation. + """ + return self._delete_relation_data_without_secrets(self.local_app, relation, fields) + # General events @@ -618,30 +1377,108 @@ class ExtraRoleEvent(RelationEvent): @property def extra_user_roles(self) -> Optional[str]: """Returns the extra user roles that were requested.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("extra-user-roles") class AuthenticationEvent(RelationEvent): - """Base class for authentication fields for events.""" + """Base class for authentication fields for events. + + The amount of logic added here is not ideal -- but this was the only way to preserve + the interface when moving to Juju Secrets + """ + + @property + def _secrets(self) -> dict: + """Caching secrets to avoid fetching them each time a field is referrd. + + DON'T USE the encapsulated helper variable outside of this function + """ + if not hasattr(self, "_cached_secrets"): + self._cached_secrets = {} + return self._cached_secrets + + @property + def _jujuversion(self) -> JujuVersion: + """Caching jujuversion to avoid a Juju call on each field evaluation. + + DON'T USE the encapsulated helper variable outside of this function + """ + if not hasattr(self, "_cached_jujuversion"): + self._cached_jujuversion = None + if not self._cached_jujuversion: + self._cached_jujuversion = JujuVersion.from_environ() + return self._cached_jujuversion + + def _get_secret(self, group) -> Optional[Dict[str, str]]: + """Retrieveing secrets.""" + if not self.app: + return + if not self._secrets.get(group): + self._secrets[group] = None + secret_field = f"{PROV_SECRET_PREFIX}{group}" + if secret_uri := self.relation.data[self.app].get(secret_field): + secret = self.framework.model.get_secret(id=secret_uri) + self._secrets[group] = secret.get_content() + return self._secrets[group] + + @property + def secrets_enabled(self): + """Is this Juju version allowing for Secrets usage?""" + return self._jujuversion.has_secrets @property def username(self) -> Optional[str]: """Returns the created username.""" + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("user") + if secret: + return secret.get("username") + return self.relation.data[self.relation.app].get("username") @property def password(self) -> Optional[str]: """Returns the password for the created user.""" + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("user") + if secret: + return secret.get("password") + return self.relation.data[self.relation.app].get("password") @property def tls(self) -> Optional[str]: """Returns whether TLS is configured.""" + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("tls") + if secret: + return secret.get("tls") + return self.relation.data[self.relation.app].get("tls") @property def tls_ca(self) -> Optional[str]: """Returns TLS CA.""" + if not self.relation.app: + return None + + if self.secrets_enabled: + secret = self._get_secret("tls") + if secret: + return secret.get("tls-ca") + return self.relation.data[self.relation.app].get("tls-ca") @@ -654,6 +1491,9 @@ class DatabaseProvidesEvent(RelationEvent): @property def database(self) -> Optional[str]: """Returns the database that was requested.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("database") @@ -676,6 +1516,9 @@ class DatabaseRequiresEvent(RelationEvent): @property def database(self) -> Optional[str]: """Returns the database name.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("database") @property @@ -685,6 +1528,9 @@ def endpoints(self) -> Optional[str]: In VM charms, this is the primary's address. In kubernetes charms, this is the service to the primary pod. """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("endpoints") @property @@ -694,6 +1540,9 @@ def read_only_endpoints(self) -> Optional[str]: In VM charms, this is the address of all the secondary instances. In kubernetes charms, this is the service to all replica pod instances. """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("read-only-endpoints") @property @@ -702,6 +1551,9 @@ def replset(self) -> Optional[str]: MongoDB only. """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("replset") @property @@ -710,6 +1562,9 @@ def uris(self) -> Optional[str]: MongoDB, Redis, OpenSearch. """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("uris") @property @@ -718,6 +1573,9 @@ def version(self) -> Optional[str]: Version as informed by the database daemon. """ + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("version") @@ -750,24 +1608,25 @@ class DatabaseRequiresEvents(CharmEvents): class DatabaseProvides(DataProvides): """Provider-side of the database relations.""" - on = DatabaseProvidesEvents() + on = DatabaseProvidesEvents() # pyright: ignore [reportGeneralTypeIssues] def __init__(self, charm: CharmBase, relation_name: str) -> None: super().__init__(charm, relation_name) - def _on_relation_changed(self, event: RelationChangedEvent) -> None: + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation has changed.""" - # Only the leader should handle this event. + # Leader only if not self.local_unit.is_leader(): return - # Check which data has changed to emit customs events. diff = self._diff(event) # Emit a database requested event if the setup key (database name and optional # extra user roles) was added to the relation databag by the application. if "database" in diff.added: - self.on.database_requested.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "database_requested").emit( + event.relation, app=event.app, unit=event.unit + ) def set_database(self, relation_id: int, database_name: str) -> None: """Set database name. @@ -779,7 +1638,7 @@ def set_database(self, relation_id: int, database_name: str) -> None: relation_id: the identifier for a particular relation. database_name: database name. """ - self._update_relation_data(relation_id, {"database": database_name}) + self.update_relation_data(relation_id, {"database": database_name}) def set_endpoints(self, relation_id: int, connection_strings: str) -> None: """Set database primary connections. @@ -795,7 +1654,7 @@ def set_endpoints(self, relation_id: int, connection_strings: str) -> None: relation_id: the identifier for a particular relation. connection_strings: database hosts and ports comma separated list. """ - self._update_relation_data(relation_id, {"endpoints": connection_strings}) + self.update_relation_data(relation_id, {"endpoints": connection_strings}) def set_read_only_endpoints(self, relation_id: int, connection_strings: str) -> None: """Set database replicas connection strings. @@ -807,7 +1666,7 @@ def set_read_only_endpoints(self, relation_id: int, connection_strings: str) -> relation_id: the identifier for a particular relation. connection_strings: database hosts and ports comma separated list. """ - self._update_relation_data(relation_id, {"read-only-endpoints": connection_strings}) + self.update_relation_data(relation_id, {"read-only-endpoints": connection_strings}) def set_replset(self, relation_id: int, replset: str) -> None: """Set replica set name in the application relation databag. @@ -818,7 +1677,7 @@ def set_replset(self, relation_id: int, replset: str) -> None: relation_id: the identifier for a particular relation. replset: replica set name. """ - self._update_relation_data(relation_id, {"replset": replset}) + self.update_relation_data(relation_id, {"replset": replset}) def set_uris(self, relation_id: int, uris: str) -> None: """Set the database connection URIs in the application relation databag. @@ -829,7 +1688,7 @@ def set_uris(self, relation_id: int, uris: str) -> None: relation_id: the identifier for a particular relation. uris: connection URIs. """ - self._update_relation_data(relation_id, {"uris": uris}) + self.update_relation_data(relation_id, {"uris": uris}) def set_version(self, relation_id: int, version: str) -> None: """Set the database version in the application relation databag. @@ -838,24 +1697,25 @@ def set_version(self, relation_id: int, version: str) -> None: relation_id: the identifier for a particular relation. version: database version. """ - self._update_relation_data(relation_id, {"version": version}) + self.update_relation_data(relation_id, {"version": version}) class DatabaseRequires(DataRequires): """Requires-side of the database relation.""" - on = DatabaseRequiresEvents() + on = DatabaseRequiresEvents() # pyright: ignore [reportGeneralTypeIssues] def __init__( self, charm, relation_name: str, database_name: str, - extra_user_roles: str = None, - relations_aliases: List[str] = None, + extra_user_roles: Optional[str] = None, + relations_aliases: Optional[List[str]] = None, + additional_secret_fields: Optional[List[str]] = [], ): """Manager of database client relations.""" - super().__init__(charm, relation_name, extra_user_roles) + super().__init__(charm, relation_name, extra_user_roles, additional_secret_fields) self.database = database_name self.relations_aliases = relations_aliases @@ -880,6 +1740,10 @@ def __init__( DatabaseReadOnlyEndpointsChangedEvent, ) + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + pass + def _assign_relation_alias(self, relation_id: int) -> None: """Assigns an alias to a relation. @@ -894,11 +1758,8 @@ def _assign_relation_alias(self, relation_id: int) -> None: # Return if an alias was already assigned to this relation # (like when there are more than one unit joining the relation). - if ( - self.charm.model.get_relation(self.relation_name, relation_id) - .data[self.local_unit] - .get("alias") - ): + relation = self.charm.model.get_relation(self.relation_name, relation_id) + if relation and relation.data[self.local_unit].get("alias"): return # Retrieve the available aliases (the ones that weren't assigned to any relation). @@ -911,7 +1772,12 @@ def _assign_relation_alias(self, relation_id: int) -> None: # Set the alias in the unit relation databag of the specific relation. relation = self.charm.model.get_relation(self.relation_name, relation_id) - relation.data[self.local_unit].update({"alias": available_aliases[0]}) + if relation: + relation.data[self.local_unit].update({"alias": available_aliases[0]}) + + # We need to set relation alias also on the application level so, + # it will be accessible in show-unit juju command, executed for a consumer application unit + self.update_relation_data(relation_id, {"alias": available_aliases[0]}) def _emit_aliased_event(self, event: RelationChangedEvent, event_name: str) -> None: """Emit an aliased event to a particular relation if it has an alias. @@ -958,23 +1824,30 @@ def is_postgresql_plugin_enabled(self, plugin: str, relation_index: int = 0) -> if len(self.relations) == 0: return False - relation_data = self.fetch_relation_data()[self.relations[relation_index].id] - host = relation_data.get("endpoints") + relation_id = self.relations[relation_index].id + host = self.fetch_relation_field(relation_id, "endpoints") # Return False if there is no endpoint available. if host is None: return False host = host.split(":")[0] - user = relation_data.get("username") - password = relation_data.get("password") + + content = self.fetch_relation_data([relation_id], ["username", "password"]).get( + relation_id, {} + ) + user = content.get("username") + password = content.get("password") + connection_string = ( f"host='{host}' dbname='{self.database}' user='{user}' password='{password}'" ) try: with psycopg.connect(connection_string) as connection: with connection.cursor() as cursor: - cursor.execute(f"SELECT TRUE FROM pg_extension WHERE extname='{plugin}';") + cursor.execute( + "SELECT TRUE FROM pg_extension WHERE extname=%s::text;", (plugin,) + ) return cursor.fetchone() is not None except psycopg.Error as e: logger.exception( @@ -982,15 +1855,17 @@ def is_postgresql_plugin_enabled(self, plugin: str, relation_index: int = 0) -> ) return False - def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: - """Event emitted when the application joins the database relation.""" + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the database relation is created.""" + super()._on_relation_created_event(event) + # If relations aliases were provided, assign one to the relation. self._assign_relation_alias(event.relation.id) # Sets both database and extra user roles in the relation # if the roles are provided. Otherwise, sets only the database. if self.extra_user_roles: - self._update_relation_data( + self.update_relation_data( event.relation.id, { "database": self.database, @@ -998,19 +1873,28 @@ def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: }, ) else: - self._update_relation_data(event.relation.id, {"database": self.database}) + self.update_relation_data(event.relation.id, {"database": self.database}) def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the database relation has changed.""" # Check which data has changed to emit customs events. diff = self._diff(event) + # Register all new secrets with their labels + if any(newval for newval in diff.added if self._is_secret_field(newval)): + self._register_secrets_to_relation(event.relation, diff.added) + # Check if the database is created # (the database charm shared the credentials). - if "username" in diff.added and "password" in diff.added: + secret_field_user = self._generate_secret_field_name(SecretGroup.USER) + if ( + "username" in diff.added and "password" in diff.added + ) or secret_field_user in diff.added: # Emit the default event (the one without an alias). logger.info("database created at %s", datetime.now()) - self.on.database_created.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "database_created").emit( + event.relation, app=event.app, unit=event.unit + ) # Emit the aliased event (if any). self._emit_aliased_event(event, "database_created") @@ -1024,7 +1908,9 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: if "endpoints" in diff.added or "endpoints" in diff.changed: # Emit the default event (the one without an alias). logger.info("endpoints changed on %s", datetime.now()) - self.on.endpoints_changed.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "endpoints_changed").emit( + event.relation, app=event.app, unit=event.unit + ) # Emit the aliased event (if any). self._emit_aliased_event(event, "endpoints_changed") @@ -1038,7 +1924,7 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: if "read-only-endpoints" in diff.added or "read-only-endpoints" in diff.changed: # Emit the default event (the one without an alias). logger.info("read-only-endpoints changed on %s", datetime.now()) - self.on.read_only_endpoints_changed.emit( + getattr(self.on, "read_only_endpoints_changed").emit( event.relation, app=event.app, unit=event.unit ) @@ -1055,11 +1941,17 @@ class KafkaProvidesEvent(RelationEvent): @property def topic(self) -> Optional[str]: """Returns the topic that was requested.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("topic") @property def consumer_group_prefix(self) -> Optional[str]: """Returns the consumer-group-prefix that was requested.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("consumer-group-prefix") @@ -1082,21 +1974,33 @@ class KafkaRequiresEvent(RelationEvent): @property def topic(self) -> Optional[str]: """Returns the topic.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("topic") @property def bootstrap_server(self) -> Optional[str]: """Returns a comma-separated list of broker uris.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("endpoints") @property def consumer_group_prefix(self) -> Optional[str]: """Returns the consumer-group-prefix.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("consumer-group-prefix") @property def zookeeper_uris(self) -> Optional[str]: """Returns a comma separated list of Zookeeper uris.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("zookeeper-uris") @@ -1124,14 +2028,14 @@ class KafkaRequiresEvents(CharmEvents): class KafkaProvides(DataProvides): """Provider-side of the Kafka relation.""" - on = KafkaProvidesEvents() + on = KafkaProvidesEvents() # pyright: ignore [reportGeneralTypeIssues] def __init__(self, charm: CharmBase, relation_name: str) -> None: super().__init__(charm, relation_name) - def _on_relation_changed(self, event: RelationChangedEvent) -> None: + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation has changed.""" - # Only the leader should handle this event. + # Leader only if not self.local_unit.is_leader(): return @@ -1141,7 +2045,9 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: # Emit a topic requested event if the setup key (topic name and optional # extra user roles) was added to the relation databag by the application. if "topic" in diff.added: - self.on.topic_requested.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "topic_requested").emit( + event.relation, app=event.app, unit=event.unit + ) def set_topic(self, relation_id: int, topic: str) -> None: """Set topic name in the application relation databag. @@ -1150,7 +2056,7 @@ def set_topic(self, relation_id: int, topic: str) -> None: relation_id: the identifier for a particular relation. topic: the topic name. """ - self._update_relation_data(relation_id, {"topic": topic}) + self.update_relation_data(relation_id, {"topic": topic}) def set_bootstrap_server(self, relation_id: int, bootstrap_server: str) -> None: """Set the bootstrap server in the application relation databag. @@ -1159,7 +2065,7 @@ def set_bootstrap_server(self, relation_id: int, bootstrap_server: str) -> None: relation_id: the identifier for a particular relation. bootstrap_server: the bootstrap server address. """ - self._update_relation_data(relation_id, {"endpoints": bootstrap_server}) + self.update_relation_data(relation_id, {"endpoints": bootstrap_server}) def set_consumer_group_prefix(self, relation_id: int, consumer_group_prefix: str) -> None: """Set the consumer group prefix in the application relation databag. @@ -1168,7 +2074,7 @@ def set_consumer_group_prefix(self, relation_id: int, consumer_group_prefix: str relation_id: the identifier for a particular relation. consumer_group_prefix: the consumer group prefix string. """ - self._update_relation_data(relation_id, {"consumer-group-prefix": consumer_group_prefix}) + self.update_relation_data(relation_id, {"consumer-group-prefix": consumer_group_prefix}) def set_zookeeper_uris(self, relation_id: int, zookeeper_uris: str) -> None: """Set the zookeeper uris in the application relation databag. @@ -1177,13 +2083,13 @@ def set_zookeeper_uris(self, relation_id: int, zookeeper_uris: str) -> None: relation_id: the identifier for a particular relation. zookeeper_uris: comma-separated list of ZooKeeper server uris. """ - self._update_relation_data(relation_id, {"zookeeper-uris": zookeeper_uris}) + self.update_relation_data(relation_id, {"zookeeper-uris": zookeeper_uris}) class KafkaRequires(DataRequires): """Requires-side of the Kafka relation.""" - on = KafkaRequiresEvents() + on = KafkaRequiresEvents() # pyright: ignore [reportGeneralTypeIssues] def __init__( self, @@ -1192,23 +2098,42 @@ def __init__( topic: str, extra_user_roles: Optional[str] = None, consumer_group_prefix: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], ): """Manager of Kafka client relations.""" # super().__init__(charm, relation_name) - super().__init__(charm, relation_name, extra_user_roles) + super().__init__(charm, relation_name, extra_user_roles, additional_secret_fields) self.charm = charm self.topic = topic self.consumer_group_prefix = consumer_group_prefix or "" - def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: - """Event emitted when the application joins the Kafka relation.""" + @property + def topic(self): + """Topic to use in Kafka.""" + return self._topic + + @topic.setter + def topic(self, value): + # Avoid wildcards + if value == "*": + raise ValueError(f"Error on topic '{value}', cannot be a wildcard.") + self._topic = value + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the Kafka relation is created.""" + super()._on_relation_created_event(event) + # Sets topic, extra user roles, and "consumer-group-prefix" in the relation relation_data = { f: getattr(self, f.replace("-", "_"), "") for f in ["consumer-group-prefix", "extra-user-roles", "topic"] } - self._update_relation_data(event.relation.id, relation_data) + self.update_relation_data(event.relation.id, relation_data) + + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + pass def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the Kafka relation has changed.""" @@ -1217,10 +2142,18 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: # Check if the topic is created # (the Kafka charm shared the credentials). - if "username" in diff.added and "password" in diff.added: + + # Register all new secrets with their labels + if any(newval for newval in diff.added if self._is_secret_field(newval)): + self._register_secrets_to_relation(event.relation, diff.added) + + secret_field_user = self._generate_secret_field_name(SecretGroup.USER) + if ( + "username" in diff.added and "password" in diff.added + ) or secret_field_user in diff.added: # Emit the default event (the one without an alias). logger.info("topic created at %s", datetime.now()) - self.on.topic_created.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "topic_created").emit(event.relation, app=event.app, unit=event.unit) # To avoid unnecessary application restarts do not trigger # “endpoints_changed“ event if “topic_created“ is triggered. @@ -1231,7 +2164,7 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: if "endpoints" in diff.added or "endpoints" in diff.changed: # Emit the default event (the one without an alias). logger.info("endpoints changed on %s", datetime.now()) - self.on.bootstrap_server_changed.emit( + getattr(self.on, "bootstrap_server_changed").emit( event.relation, app=event.app, unit=event.unit ) # here check if this is the right design return @@ -1246,6 +2179,9 @@ class OpenSearchProvidesEvent(RelationEvent): @property def index(self) -> Optional[str]: """Returns the index that was requested.""" + if not self.relation.app: + return None + return self.relation.data[self.relation.app].get("index") @@ -1287,24 +2223,25 @@ class OpenSearchRequiresEvents(CharmEvents): class OpenSearchProvides(DataProvides): """Provider-side of the OpenSearch relation.""" - on = OpenSearchProvidesEvents() + on = OpenSearchProvidesEvents() # pyright: ignore[reportGeneralTypeIssues] def __init__(self, charm: CharmBase, relation_name: str) -> None: super().__init__(charm, relation_name) - def _on_relation_changed(self, event: RelationChangedEvent) -> None: + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation has changed.""" - # Only the leader should handle this event. + # Leader only if not self.local_unit.is_leader(): return - # Check which data has changed to emit customs events. diff = self._diff(event) # Emit an index requested event if the setup key (index name and optional extra user roles) # have been added to the relation databag by the application. if "index" in diff.added: - self.on.index_requested.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "index_requested").emit( + event.relation, app=event.app, unit=event.unit + ) def set_index(self, relation_id: int, index: str) -> None: """Set the index in the application relation databag. @@ -1315,7 +2252,7 @@ def set_index(self, relation_id: int, index: str) -> None: requested index, and can be used to present a different index name if, for example, the requested index is invalid. """ - self._update_relation_data(relation_id, {"index": index}) + self.update_relation_data(relation_id, {"index": index}) def set_endpoints(self, relation_id: int, endpoints: str) -> None: """Set the endpoints in the application relation databag. @@ -1324,7 +2261,7 @@ def set_endpoints(self, relation_id: int, endpoints: str) -> None: relation_id: the identifier for a particular relation. endpoints: the endpoint addresses for opensearch nodes. """ - self._update_relation_data(relation_id, {"endpoints": endpoints}) + self.update_relation_data(relation_id, {"endpoints": endpoints}) def set_version(self, relation_id: int, version: str) -> None: """Set the opensearch version in the application relation databag. @@ -1333,31 +2270,63 @@ def set_version(self, relation_id: int, version: str) -> None: relation_id: the identifier for a particular relation. version: database version. """ - self._update_relation_data(relation_id, {"version": version}) + self.update_relation_data(relation_id, {"version": version}) class OpenSearchRequires(DataRequires): """Requires-side of the OpenSearch relation.""" - on = OpenSearchRequiresEvents() + on = OpenSearchRequiresEvents() # pyright: ignore[reportGeneralTypeIssues] def __init__( - self, charm, relation_name: str, index: str, extra_user_roles: Optional[str] = None + self, + charm, + relation_name: str, + index: str, + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], ): """Manager of OpenSearch client relations.""" - super().__init__(charm, relation_name, extra_user_roles) + super().__init__(charm, relation_name, extra_user_roles, additional_secret_fields) self.charm = charm self.index = index - def _on_relation_joined_event(self, event: RelationJoinedEvent) -> None: - """Event emitted when the application joins the OpenSearch relation.""" + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the OpenSearch relation is created.""" + super()._on_relation_created_event(event) + # Sets both index and extra user roles in the relation if the roles are provided. # Otherwise, sets only the index. data = {"index": self.index} if self.extra_user_roles: data["extra-user-roles"] = self.extra_user_roles - self._update_relation_data(event.relation.id, data) + self.update_relation_data(event.relation.id, data) + + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + if not event.secret.label: + return + + relation = self._relation_from_secret_label(event.secret.label) + if not relation: + logging.info( + f"Received secret {event.secret.label} but couldn't parse, seems irrelevant" + ) + return + + if relation.app == self.charm.app: + logging.info("Secret changed event ignored for Secret Owner") + + remote_unit = None + for unit in relation.units: + if unit.app != self.charm.app: + remote_unit = unit + + logger.info("authentication updated") + getattr(self.on, "authentication_updated").emit( + relation, app=relation.app, unit=remote_unit + ) def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the OpenSearch relation has changed. @@ -1367,18 +2336,27 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: # Check which data has changed to emit customs events. diff = self._diff(event) - # Check if authentication has updated, emit event if so - updates = {"username", "password", "tls", "tls-ca"} + # Register all new secrets with their labels + if any(newval for newval in diff.added if self._is_secret_field(newval)): + self._register_secrets_to_relation(event.relation, diff.added) + + secret_field_user = self._generate_secret_field_name(SecretGroup.USER) + secret_field_tls = self._generate_secret_field_name(SecretGroup.TLS) + updates = {"username", "password", "tls", "tls-ca", secret_field_user, secret_field_tls} if len(set(diff._asdict().keys()) - updates) < len(diff): logger.info("authentication updated at: %s", datetime.now()) - self.on.authentication_updated.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "authentication_updated").emit( + event.relation, app=event.app, unit=event.unit + ) # Check if the index is created # (the OpenSearch charm shares the credentials). - if "username" in diff.added and "password" in diff.added: + if ( + "username" in diff.added and "password" in diff.added + ) or secret_field_user in diff.added: # Emit the default event (the one without an alias). logger.info("index created at: %s", datetime.now()) - self.on.index_created.emit(event.relation, app=event.app, unit=event.unit) + getattr(self.on, "index_created").emit(event.relation, app=event.app, unit=event.unit) # To avoid unnecessary application restarts do not trigger # “endpoints_changed“ event if “index_created“ is triggered. @@ -1389,7 +2367,7 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: if "endpoints" in diff.added or "endpoints" in diff.changed: # Emit the default event (the one without an alias). logger.info("endpoints changed on %s", datetime.now()) - self.on.endpoints_changed.emit( + getattr(self.on, "endpoints_changed").emit( event.relation, app=event.app, unit=event.unit ) # here check if this is the right design return diff --git a/lib/charms/tls_certificates_interface/v2/tls_certificates.py b/lib/charms/tls_certificates_interface/v2/tls_certificates.py index e6358c5..1cbd3c0 100644 --- a/lib/charms/tls_certificates_interface/v2/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v2/tls_certificates.py @@ -276,9 +276,10 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven import json import logging import uuid +from contextlib import suppress from datetime import datetime, timedelta from ipaddress import IPv4Address -from typing import Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID @@ -286,15 +287,18 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.serialization import pkcs12 from cryptography.x509.extensions import Extension, ExtensionNotFound -from jsonschema import exceptions, validate # type: ignore[import] +from jsonschema import exceptions, validate # type: ignore[import-untyped] from ops.charm import ( CharmBase, CharmEvents, RelationBrokenEvent, RelationChangedEvent, + SecretExpiredEvent, UpdateStatusEvent, ) from ops.framework import EventBase, EventSource, Handle, Object +from ops.jujuversion import JujuVersion +from ops.model import Relation, SecretNotFoundError # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" @@ -304,7 +308,9 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 5 +LIBPATCH = 18 + +PYDEPS = ["cryptography", "jsonschema"] REQUIRER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", @@ -634,6 +640,17 @@ def generate_ca( private_key_object.public_key() # type: ignore[arg-type] ) subject_identifier = key_identifier = subject_identifier_object.public_bytes() + key_usage = x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + key_cert_sign=True, + key_agreement=False, + content_commitment=False, + data_encipherment=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ) cert = ( x509.CertificateBuilder() .subject_name(subject) @@ -651,6 +668,7 @@ def generate_ca( ), critical=False, ) + .add_extension(key_usage, critical=True) .add_extension( x509.BasicConstraints(ca=True, path_length=None), critical=True, @@ -683,7 +701,8 @@ def generate_certificate( """ csr_object = x509.load_pem_x509_csr(csr) subject = csr_object.subject - issuer = x509.load_pem_x509_certificate(ca).issuer + ca_pem = x509.load_pem_x509_certificate(ca) + issuer = ca_pem.issuer private_key = serialization.load_pem_private_key(ca_key, password=ca_key_password) certificate_builder = ( @@ -694,6 +713,20 @@ def generate_certificate( .serial_number(x509.random_serial_number()) .not_valid_before(datetime.utcnow()) .not_valid_after(datetime.utcnow() + timedelta(days=validity)) + .add_extension( + x509.AuthorityKeyIdentifier( + key_identifier=ca_pem.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ) + .add_extension( + x509.SubjectKeyIdentifier.from_public_key(csr_object.public_key()), critical=False + ) + .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=False) ) extensions_list = csr_object.extensions @@ -724,6 +757,7 @@ def generate_certificate( extension.value, critical=extension.critical, ) + certificate_builder._version = x509.Version.v3 cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] return cert.public_bytes(serialization.Encoding.PEM) @@ -821,7 +855,7 @@ def generate_csr( sans_oid (list): List of registered ID SANs sans_dns (list): List of DNS subject alternative names (similar to the arg: sans) sans_ip (list): List of IP subject alternative names - additional_critical_extensions (list): List if critical additional extension objects. + additional_critical_extensions (list): List of critical additional extension objects. Object must be a x509 ExtensionType. Returns: @@ -891,6 +925,22 @@ def __init__(self, charm: CharmBase, relationship_name: str): self.charm = charm self.relationship_name = relationship_name + def _load_app_relation_data(self, relation: Relation) -> dict: + """Loads relation data from the application relation data bag. + + Json loads all data. + + Args: + relation_object: Relation data from the application databag + + Returns: + dict: Relation data in dict format. + """ + # If unit is not leader, it does not try to reach relation data. + if not self.model.unit.is_leader(): + return {} + return _load_relation_data(relation.data[self.charm.app]) + def _add_certificate( self, relation_id: int, @@ -925,7 +975,7 @@ def _add_certificate( "ca": ca, "chain": chain, } - provider_relation_data = _load_relation_data(relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) certificates = copy.deepcopy(provider_certificates) if new_certificate in certificates: @@ -958,7 +1008,7 @@ def _remove_certificate( raise RuntimeError( f"Relation {self.relationship_name} with relation id {relation_id} does not exist" ) - provider_relation_data = _load_relation_data(relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) certificates = copy.deepcopy(provider_certificates) for certificate_dict in certificates: @@ -993,7 +1043,7 @@ def revoke_all_certificates(self) -> None: This method is meant to be used when the Root CA has changed. """ for relation in self.model.relations[self.relationship_name]: - provider_relation_data = _load_relation_data(relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(relation) provider_certificates = copy.deepcopy(provider_relation_data.get("certificates", [])) for certificate in provider_certificates: certificate["revoked"] = True @@ -1053,6 +1103,43 @@ def remove_certificate(self, certificate: str) -> None: for certificate_relation in certificates_relation: self._remove_certificate(certificate=certificate, relation_id=certificate_relation.id) + def get_issued_certificates( + self, relation_id: Optional[int] = None + ) -> Dict[str, List[Dict[str, str]]]: + """Returns a dictionary of issued certificates. + + It returns certificates from all relations if relation_id is not specified. + Certificates are returned per application name and CSR. + + Returns: + dict: Certificates per application name. + """ + certificates: Dict[str, List[Dict[str, str]]] = {} + relations = ( + [ + relation + for relation in self.model.relations[self.relationship_name] + if relation.id == relation_id + ] + if relation_id is not None + else self.model.relations.get(self.relationship_name, []) + ) + for relation in relations: + provider_relation_data = self._load_app_relation_data(relation) + provider_certificates = provider_relation_data.get("certificates", []) + + certificates[relation.app.name] = [] # type: ignore[union-attr] + for certificate in provider_certificates: + if not certificate.get("revoked", False): + certificates[relation.app.name].append( # type: ignore[union-attr] + { + "csr": certificate["certificate_signing_request"], + "certificate": certificate["certificate"], + } + ) + + return certificates + def _on_relation_changed(self, event: RelationChangedEvent) -> None: """Handler triggered on relation changed event. @@ -1068,9 +1155,13 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: Returns: None """ - assert event.unit is not None + if event.unit is None: + logger.error("Relation_changed event does not have a unit.") + return + if not self.model.unit.is_leader(): + return requirer_relation_data = _load_relation_data(event.relation.data[event.unit]) - provider_relation_data = _load_relation_data(event.relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(event.relation) if not self._relation_data_is_valid(requirer_relation_data): logger.debug("Relation data did not pass JSON Schema validation") return @@ -1109,7 +1200,7 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None ) if not certificates_relation: raise RuntimeError(f"Relation {self.relationship_name} does not exist") - provider_relation_data = _load_relation_data(certificates_relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(certificates_relation) list_of_csrs: List[str] = [] for unit in certificates_relation.units: requirer_relation_data = _load_relation_data(certificates_relation.data[unit]) @@ -1126,6 +1217,90 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None ) self.remove_certificate(certificate=certificate["certificate"]) + def get_requirer_csrs_with_no_certs( + self, relation_id: Optional[int] = None + ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: + """Filters the requirer's units csrs. + + Keeps the ones for which no certificate was provided. + + Args: + relation_id (int): Relation id + + Returns: + list: List of dictionaries that contain the unit's csrs + that don't have a certificate issued. + """ + all_unit_csr_mappings = copy.deepcopy(self.get_requirer_csrs(relation_id=relation_id)) + filtered_all_unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] + for unit_csr_mapping in all_unit_csr_mappings: + csrs_without_certs = [] + for csr in unit_csr_mapping["unit_csrs"]: # type: ignore[union-attr] + if not self.certificate_issued_for_csr( + app_name=unit_csr_mapping["application_name"], # type: ignore[arg-type] + csr=csr["certificate_signing_request"], # type: ignore[index] + ): + csrs_without_certs.append(csr) + if csrs_without_certs: + unit_csr_mapping["unit_csrs"] = csrs_without_certs # type: ignore[assignment] + filtered_all_unit_csr_mappings.append(unit_csr_mapping) + return filtered_all_unit_csr_mappings + + def get_requirer_csrs( + self, relation_id: Optional[int] = None + ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: + """Returns a list of requirers' CSRs grouped by unit. + + It returns CSRs from all relations if relation_id is not specified. + CSRs are returned per relation id, application name and unit name. + + Returns: + list: List of dictionaries that contain the unit's csrs + with the following information + relation_id, application_name and unit_name. + """ + unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] + + relations = ( + [ + relation + for relation in self.model.relations[self.relationship_name] + if relation.id == relation_id + ] + if relation_id is not None + else self.model.relations.get(self.relationship_name, []) + ) + + for relation in relations: + for unit in relation.units: + requirer_relation_data = _load_relation_data(relation.data[unit]) + unit_csrs_list = requirer_relation_data.get("certificate_signing_requests", []) + unit_csr_mappings.append( + { + "relation_id": relation.id, + "application_name": relation.app.name, # type: ignore[union-attr] + "unit_name": unit.name, + "unit_csrs": unit_csrs_list, + } + ) + return unit_csr_mappings + + def certificate_issued_for_csr(self, app_name: str, csr: str) -> bool: + """Checks whether a certificate has been issued for a given CSR. + + Args: + app_name (str): Application name that the CSR belongs to. + csr (str): Certificate Signing Request. + + Returns: + bool: True/False depending on whether a certificate has been issued for the given CSR. + """ + issued_certificates_per_csr = self.get_issued_certificates()[app_name] + for issued_pair in issued_certificates_per_csr: + if "csr" in issued_pair and issued_pair["csr"] == csr: + return csr_matches_certificate(csr, issued_pair["certificate"]) + return False + class TLSCertificatesRequiresV2(Object): """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" @@ -1156,11 +1331,14 @@ def __init__( self.framework.observe( charm.on[relationship_name].relation_broken, self._on_relation_broken ) - self.framework.observe(charm.on.update_status, self._on_update_status) + if JujuVersion.from_environ().has_secrets: + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) + else: + self.framework.observe(charm.on.update_status, self._on_update_status) @property def _requirer_csrs(self) -> List[Dict[str, str]]: - """Returns list of requirer CSR's from relation data.""" + """Returns list of requirer's CSRs from relation data.""" relation = self.model.get_relation(self.relationship_name) if not relation: raise RuntimeError(f"Relation {self.relationship_name} does not exist") @@ -1169,13 +1347,18 @@ def _requirer_csrs(self) -> List[Dict[str, str]]: @property def _provider_certificates(self) -> List[Dict[str, str]]: - """Returns list of provider CSRs from relation data.""" + """Returns list of certificates from the provider's relation data.""" relation = self.model.get_relation(self.relationship_name) if not relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") + logger.debug("No relation: %s", self.relationship_name) + return [] if not relation.app: - raise RuntimeError(f"Remote app for relation {self.relationship_name} does not exist") + logger.debug("No remote app in relation: %s", self.relationship_name) + return [] provider_relation_data = _load_relation_data(relation.data[relation.app]) + if not self._relation_data_is_valid(provider_relation_data): + logger.warning("Provider relation data did not pass JSON Schema validation") + return [] return provider_relation_data.get("certificates", []) def _add_requirer_csr(self, csr: str) -> None: @@ -1302,23 +1485,21 @@ def _relation_data_is_valid(certificates_data: dict) -> bool: def _on_relation_changed(self, event: RelationChangedEvent) -> None: """Handler triggered on relation changed events. + Goes through all providers certificates that match a requested CSR. + + If the provider certificate is revoked, emit a CertificateInvalidateEvent, + otherwise emit a CertificateAvailableEvent. + + When Juju secrets are available, remove the secret for revoked certificate, + or add a secret with the correct expiry time for new certificates. + + Args: event: Juju event Returns: None """ - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.warning("No relation: %s", self.relationship_name) - return - if not relation.app: - logger.warning("No remote app in relation: %s", self.relationship_name) - return - provider_relation_data = _load_relation_data(relation.data[relation.app]) - if not self._relation_data_is_valid(provider_relation_data): - logger.debug("Provider relation data did not pass JSON Schema validation") - return requirer_csrs = [ certificate_creation_request["certificate_signing_request"] for certificate_creation_request in self._requirer_csrs @@ -1326,6 +1507,12 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: for certificate in self._provider_certificates: if certificate["certificate_signing_request"] in requirer_csrs: if certificate.get("revoked", False): + if JujuVersion.from_environ().has_secrets: + with suppress(SecretNotFoundError): + secret = self.model.get_secret( + label=f"{LIBID}-{certificate['certificate_signing_request']}" + ) + secret.remove_all_revisions() self.on.certificate_invalidated.emit( reason="revoked", certificate=certificate["certificate"], @@ -1334,6 +1521,25 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: chain=certificate["chain"], ) else: + if JujuVersion.from_environ().has_secrets: + try: + secret = self.model.get_secret( + label=f"{LIBID}-{certificate['certificate_signing_request']}" + ) + secret.set_content({"certificate": certificate["certificate"]}) + secret.set_info( + expire=self._get_next_secret_expiry_time( + certificate["certificate"] + ), + ) + except SecretNotFoundError: + secret = self.charm.unit.add_secret( + {"certificate": certificate["certificate"]}, + label=f"{LIBID}-{certificate['certificate_signing_request']}", + expire=self._get_next_secret_expiry_time( + certificate["certificate"] + ), + ) self.on.certificate_available.emit( certificate_signing_request=certificate["certificate_signing_request"], certificate=certificate["certificate"], @@ -1341,6 +1547,26 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: chain=certificate["chain"], ) + def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: + """Return the expiry time or expiry notification time. + + Extracts the expiry time from the provided certificate, calculates the + expiry notification time and return the closest of the two, that is in + the future. + + Args: + certificate: x509 certificate + + Returns: + Optional[datetime]: None if the certificate expiry time cannot be read, + next expiry time otherwise. + """ + expiry_time = _get_certificate_expiry_time(certificate) + if not expiry_time: + return None + expiry_notification_time = expiry_time - timedelta(hours=self.expiry_notification_time) + return _get_closest_future_time(expiry_notification_time, expiry_time) + def _on_relation_broken(self, event: RelationBrokenEvent) -> None: """Handler triggered on relation broken event. @@ -1353,12 +1579,67 @@ def _on_relation_broken(self, event: RelationBrokenEvent) -> None: Returns: None """ - relation = self.model.get_relation(self.relationship_name) - if not relation: + self.on.all_certificates_invalidated.emit() + + def _on_secret_expired(self, event: SecretExpiredEvent) -> None: + """Triggered when a certificate is set to expire. + + Loads the certificate from the secret, and will emit 1 of 2 + events. + + If the certificate is not yet expired, emits CertificateExpiringEvent + and updates the expiry time of the secret to the exact expiry time on + the certificate. + + If the certificate is expired, emits CertificateInvalidedEvent and + deletes the secret. + + Args: + event (SecretExpiredEvent): Juju event + """ + if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): return - if not relation.app or not relation.app.name: + csr = event.secret.label[len(f"{LIBID}-") :] + certificate_dict = self._find_certificate_in_relation_data(csr) + if not certificate_dict: + # A secret expired but we did not find matching certificate. Cleaning up + event.secret.remove_all_revisions() + return + + expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) + if not expiry_time: + # A secret expired but matching certificate is invalid. Cleaning up + event.secret.remove_all_revisions() return - self.on.all_certificates_invalidated.emit() + + if datetime.utcnow() < expiry_time: + logger.warning("Certificate almost expired") + self.on.certificate_expiring.emit( + certificate=certificate_dict["certificate"], + expiry=expiry_time.isoformat(), + ) + event.secret.set_info( + expire=_get_certificate_expiry_time(certificate_dict["certificate"]), + ) + else: + logger.warning("Certificate is expired") + self.on.certificate_invalidated.emit( + reason="expired", + certificate=certificate_dict["certificate"], + certificate_signing_request=certificate_dict["certificate_signing_request"], + ca=certificate_dict["ca"], + chain=certificate_dict["chain"], + ) + self.request_certificate_revocation(certificate_dict["certificate"].encode()) + event.secret.remove_all_revisions() + + def _find_certificate_in_relation_data(self, csr: str) -> Optional[Dict[str, Any]]: + """Returns the certificate that match the given CSR.""" + for certificate_dict in self._provider_certificates: + if certificate_dict["certificate_signing_request"] != csr: + continue + return certificate_dict + return None def _on_update_status(self, event: UpdateStatusEvent) -> None: """Triggered on update status event. @@ -1373,26 +1654,11 @@ def _on_update_status(self, event: UpdateStatusEvent) -> None: Returns: None """ - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.debug("No relation: %s", self.relationship_name) - return - if not relation.app: - logger.debug("No remote app in relation: %s", self.relationship_name) - return - provider_relation_data = _load_relation_data(relation.data[relation.app]) - if not self._relation_data_is_valid(provider_relation_data): - logger.debug("Provider relation data did not pass JSON Schema validation") - return for certificate_dict in self._provider_certificates: - try: - certificate_object = x509.load_pem_x509_certificate( - data=certificate_dict["certificate"].encode() - ) - except ValueError: - logger.warning("Could not load certificate.") + expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) + if not expiry_time: continue - time_difference = certificate_object.not_valid_after - datetime.utcnow() + time_difference = expiry_time - datetime.utcnow() if time_difference.total_seconds() < 0: logger.warning("Certificate is expired") self.on.certificate_invalidated.emit( @@ -1408,5 +1674,73 @@ def _on_update_status(self, event: UpdateStatusEvent) -> None: logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( certificate=certificate_dict["certificate"], - expiry=certificate_object.not_valid_after.isoformat(), + expiry=expiry_time.isoformat(), ) + + +def csr_matches_certificate(csr: str, cert: str) -> bool: + """Check if a CSR matches a certificate. + + expects to get the original string representations. + + Args: + csr (str): Certificate Signing Request + cert (str): Certificate + Returns: + bool: True/False depending on whether the CSR matches the certificate. + """ + try: + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): + return False + if ( + csr_object.public_key().public_numbers().n # type: ignore[union-attr] + != cert_object.public_key().public_numbers().n # type: ignore[union-attr] + ): + return False + except ValueError: + logger.warning("Could not load certificate or CSR.") + return False + return True + + +def _get_closest_future_time( + expiry_notification_time: datetime, expiry_time: datetime +) -> datetime: + """Return expiry_notification_time if not in the past, otherwise return expiry_time. + + Args: + expiry_notification_time (datetime): Notification time of impending expiration + expiry_time (datetime): Expiration time + + Returns: + datetime: expiry_notification_time if not in the past, expiry_time otherwise + """ + return ( + expiry_notification_time if datetime.utcnow() < expiry_notification_time else expiry_time + ) + + +def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: + """Extract expiry time from a certificate string. + + Args: + certificate (str): x509 certificate as a string + + Returns: + Optional[datetime]: Expiry datetime or None + """ + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + return certificate_object.not_valid_after + except ValueError: + logger.warning("Could not load certificate.") + return None