diff --git a/charmcraft.yaml b/charmcraft.yaml index 66da824e9..f8ec195d8 100644 --- a/charmcraft.yaml +++ b/charmcraft.yaml @@ -10,6 +10,13 @@ bases: channel: "22.04" architectures: [arm64] parts: + files: + plugin: dump + source: . + prime: + - charm_version + - workload_version + - scripts charm: override-pull: | craftctl default @@ -20,10 +27,6 @@ parts: fi charm-strict-dependencies: true charm-entrypoint: src/kubernetes_charm.py - prime: - - charm_version - - workload_version - - scripts build-packages: - libffi-dev - libssl-dev diff --git a/lib/charms/data_platform_libs/v0/data_interfaces.py b/lib/charms/data_platform_libs/v0/data_interfaces.py index a2162aa0b..aaed2e528 100644 --- a/lib/charms/data_platform_libs/v0/data_interfaces.py +++ b/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -331,10 +331,14 @@ def _on_topic_requested(self, event: TopicRequestedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 38 +LIBPATCH = 39 PYDEPS = ["ops>=2.0.0"] +# Starting from what LIBPATCH number to apply legacy solutions +# v0.17 was the last version without secrets +LEGACY_SUPPORT_FROM = 17 + logger = logging.getLogger(__name__) Diff = namedtuple("Diff", "added changed deleted") @@ -351,36 +355,16 @@ def _on_topic_requested(self, event: TopicRequestedEvent): GROUP_MAPPING_FIELD = "secret_group_mapping" GROUP_SEPARATOR = "@" +MODEL_ERRORS = { + "not_leader": "this unit is not the leader", + "no_label_and_uri": "ERROR either URI or label should be used for getting an owned secret but not both", + "owner_no_refresh": "ERROR secret owner cannot use --refresh", +} -class SecretGroup(str): - """Secret groups specific type.""" - - -class SecretGroupsAggregate(str): - """Secret groups with option to extend with additional constants.""" - - def __init__(self): - self.USER = SecretGroup("user") - self.TLS = SecretGroup("tls") - self.EXTRA = SecretGroup("extra") - - def __setattr__(self, name, value): - """Setting internal constants.""" - if name in self.__dict__: - raise RuntimeError("Can't set constant!") - else: - super().__setattr__(name, SecretGroup(value)) - - def groups(self) -> list: - """Return the list of stored SecretGroups.""" - return list(self.__dict__.values()) - - def get_group(self, group: str) -> Optional[SecretGroup]: - """If the input str translates to a group name, return that.""" - return SecretGroup(group) if group in self.groups() else None - -SECRET_GROUPS = SecretGroupsAggregate() +############################################################################## +# Exceptions +############################################################################## class DataInterfacesError(Exception): @@ -407,6 +391,15 @@ class IllegalOperationError(DataInterfacesError): """To be used when an operation is not allowed to be performed.""" +############################################################################## +# Global helpers / utilities +############################################################################## + +############################################################################## +# Databag handling and comparison methods +############################################################################## + + def get_encoded_dict( relation: Relation, member: Union[Unit, Application], field: str ) -> Optional[Dict[str, str]]: @@ -482,6 +475,11 @@ def diff(event: RelationChangedEvent, bucket: Optional[Union[Unit, Application]] return Diff(added, changed, deleted) +############################################################################## +# Module decorators +############################################################################## + + def leader_only(f): """Decorator to ensure that only leader can perform given operation.""" @@ -536,6 +534,36 @@ def wrapper(self, *args, **kwargs): return wrapper +def legacy_apply_from_version(version: int) -> Callable: + """Decorator to decide whether to apply a legacy function or not. + + Based on LEGACY_SUPPORT_FROM module variable value, the importer charm may only want + to apply legacy solutions starting from a specific LIBPATCH. + + NOTE: All 'legacy' functions have to be defined and called in a way that they return `None`. + This results in cleaner and more secure execution flows in case the function may be disabled. + This requirement implicitly means that legacy functions change the internal state strictly, + don't return information. + """ + + def decorator(f: Callable[..., None]): + """Signature is ensuring None return value.""" + f.legacy_version = version + + def wrapper(self, *args, **kwargs) -> None: + if version >= LEGACY_SUPPORT_FROM: + return f(self, *args, **kwargs) + + return wrapper + + return decorator + + +############################################################################## +# Helper classes +############################################################################## + + class Scope(Enum): """Peer relations scope.""" @@ -543,9 +571,35 @@ class Scope(Enum): UNIT = "unit" -################################################################################ -# Secrets internal caching -################################################################################ +class SecretGroup(str): + """Secret groups specific type.""" + + +class SecretGroupsAggregate(str): + """Secret groups with option to extend with additional constants.""" + + def __init__(self): + self.USER = SecretGroup("user") + self.TLS = SecretGroup("tls") + self.EXTRA = SecretGroup("extra") + + def __setattr__(self, name, value): + """Setting internal constants.""" + if name in self.__dict__: + raise RuntimeError("Can't set constant!") + else: + super().__setattr__(name, SecretGroup(value)) + + def groups(self) -> list: + """Return the list of stored SecretGroups.""" + return list(self.__dict__.values()) + + def get_group(self, group: str) -> Optional[SecretGroup]: + """If the input str translates to a group name, return that.""" + return SecretGroup(group) if group in self.groups() else None + + +SECRET_GROUPS = SecretGroupsAggregate() class CachedSecret: @@ -554,6 +608,8 @@ class CachedSecret: The data structure is precisely re-using/simulating as in the actual Secret Storage """ + KNOWN_MODEL_ERRORS = [MODEL_ERRORS["no_label_and_uri"], MODEL_ERRORS["owner_no_refresh"]] + def __init__( self, model: Model, @@ -571,6 +627,95 @@ def __init__( self.legacy_labels = legacy_labels self.current_label = None + @property + def meta(self) -> Optional[Secret]: + """Getting cached secret meta-information.""" + if not self._secret_meta: + if not (self._secret_uri or self.label): + return + + try: + self._secret_meta = self._model.get_secret(label=self.label) + except SecretNotFoundError: + # Falling back to seeking for potential legacy labels + self._legacy_compat_find_secret_by_old_label() + + # If still not found, to be checked by URI, to be labelled with the proposed label + if not self._secret_meta and self._secret_uri: + self._secret_meta = self._model.get_secret(id=self._secret_uri, label=self.label) + return self._secret_meta + + ########################################################################## + # Backwards compatibility / Upgrades + ########################################################################## + # These functions are used to keep backwards compatibility on rolling upgrades + # Policy: + # All data is kept intact until the first write operation. (This allows a minimal + # grace period during which rollbacks are fully safe. For more info see the spec.) + # All data involves: + # - databag contents + # - secrets content + # - secret labels (!!!) + # Legacy functions must return None, and leave an equally consistent state whether + # they are executed or skipped (as a high enough versioned execution environment may + # not require so) + + # Compatibility + + @legacy_apply_from_version(34) + def _legacy_compat_find_secret_by_old_label(self) -> None: + """Compatibility function, allowing to find a secret by a legacy label. + + This functionality is typically needed when secret labels changed over an upgrade. + Until the first write operation, we need to maintain data as it was, including keeping + the old secret label. In order to keep track of the old label currently used to access + the secret, and additional 'current_label' field is being defined. + """ + for label in self.legacy_labels: + try: + self._secret_meta = self._model.get_secret(label=label) + except SecretNotFoundError: + pass + else: + if label != self.label: + self.current_label = label + return + + # Migrations + + @legacy_apply_from_version(34) + def _legacy_migration_to_new_label_if_needed(self) -> None: + """Helper function to re-create the secret with a different label. + + Juju does not provide a way to change secret labels. + Thus whenever moving from secrets version that involves secret label changes, + we "re-create" the existing secret, and attach the new label to the new + secret, to be used from then on. + + Note: we replace the old secret with a new one "in place", as we can't + easily switch the containing SecretCache structure to point to a new secret. + Instead we are changing the 'self' (CachedSecret) object to point to the + new instance. + """ + if not self.current_label or not (self.meta and self._secret_meta): + return + + # Create a new secret with the new label + content = self._secret_meta.get_content() + self._secret_uri = None + + # It will be nice to have the possibility to check if we are the owners of the secret... + try: + self._secret_meta = self.add_secret(content, label=self.label) + except ModelError as err: + if MODEL_ERRORS["not_leader"] not in str(err): + raise + self.current_label = None + + ########################################################################## + # Public functions + ########################################################################## + def add_secret( self, content: Dict[str, str], @@ -593,28 +738,6 @@ def add_secret( self._secret_meta = secret return self._secret_meta - @property - def meta(self) -> Optional[Secret]: - """Getting cached secret meta-information.""" - if not self._secret_meta: - if not (self._secret_uri or self.label): - return - - for label in [self.label] + self.legacy_labels: - try: - self._secret_meta = self._model.get_secret(label=label) - except SecretNotFoundError: - pass - else: - if label != self.label: - self.current_label = label - break - - # If still not found, to be checked by URI, to be labelled with the proposed label - if not self._secret_meta and self._secret_uri: - self._secret_meta = self._model.get_secret(id=self._secret_uri, label=self.label) - return self._secret_meta - def get_content(self) -> Dict[str, str]: """Getting cached secret content.""" if not self._secret_content: @@ -624,35 +747,14 @@ def get_content(self) -> Dict[str, str]: except (ValueError, ModelError) as err: # https://bugs.launchpad.net/juju/+bug/2042596 # Only triggered when 'refresh' is set - known_model_errors = [ - "ERROR either URI or label should be used for getting an owned secret but not both", - "ERROR secret owner cannot use --refresh", - ] if isinstance(err, ModelError) and not any( - msg in str(err) for msg in known_model_errors + msg in str(err) for msg in self.KNOWN_MODEL_ERRORS ): raise # Due to: ValueError: Secret owner cannot use refresh=True self._secret_content = self.meta.get_content() return self._secret_content - def _move_to_new_label_if_needed(self): - """Helper function to re-create the secret with a different label.""" - if not self.current_label or not (self.meta and self._secret_meta): - return - - # Create a new secret with the new label - content = self._secret_meta.get_content() - self._secret_uri = None - - # I wish we could just check if we are the owners of the secret... - try: - self._secret_meta = self.add_secret(content, label=self.label) - except ModelError as err: - if "this unit is not the leader" not in str(err): - raise - self.current_label = None - def set_content(self, content: Dict[str, str]) -> None: """Setting cached secret content.""" if not self.meta: @@ -663,7 +765,7 @@ def set_content(self, content: Dict[str, str]) -> None: return if content: - self._move_to_new_label_if_needed() + self._legacy_migration_to_new_label_if_needed() self.meta.set_content(content) self._secret_content = content else: @@ -926,6 +1028,23 @@ def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" raise NotImplementedError + # Optional overrides + + def _legacy_apply_on_fetch(self) -> None: + """This function should provide a list of compatibility functions to be applied when fetching (legacy) data.""" + pass + + def _legacy_apply_on_update(self, fields: List[str]) -> None: + """This function should provide a list of compatibility functions to be applied when writing data. + + Since data may be at a legacy version, migration may be mandatory. + """ + pass + + def _legacy_apply_on_delete(self, fields: List[str]) -> None: + """This function should provide a list of compatibility functions to be applied when deleting (legacy) data.""" + pass + # Internal helper methods @staticmethod @@ -1178,6 +1297,16 @@ def get_relation(self, relation_name, relation_id) -> Relation: return relation + def get_secret_uri(self, relation: Relation, group: SecretGroup) -> Optional[str]: + """Get the secret URI for the corresponding group.""" + secret_field = self._generate_secret_field_name(group) + return relation.data[self.component].get(secret_field) + + def set_secret_uri(self, relation: Relation, group: SecretGroup, secret_uri: str) -> None: + """Set the secret URI for the corresponding group.""" + secret_field = self._generate_secret_field_name(group) + relation.data[self.component][secret_field] = secret_uri + def fetch_relation_data( self, relation_ids: Optional[List[int]] = None, @@ -1194,6 +1323,8 @@ def fetch_relation_data( a dict of the values stored in the relation data bag for all relation instances (indexed by the relation ID). """ + self._legacy_apply_on_fetch() + if not relation_name: relation_name = self.relation_name @@ -1232,6 +1363,8 @@ def fetch_my_relation_data( NOTE: Since only the leader can read the relation's 'this_app'-side Application databag, the functionality is limited to leaders """ + self._legacy_apply_on_fetch() + if not relation_name: relation_name = self.relation_name @@ -1263,6 +1396,8 @@ def fetch_my_relation_field( @leader_only def update_relation_data(self, relation_id: int, data: dict) -> None: """Update the data within the relation.""" + self._legacy_apply_on_update(list(data.keys())) + relation_name = self.relation_name relation = self.get_relation(relation_name, relation_id) return self._update_relation_data(relation, data) @@ -1270,6 +1405,8 @@ def update_relation_data(self, relation_id: int, data: dict) -> None: @leader_only def delete_relation_data(self, relation_id: int, fields: List[str]) -> None: """Remove field from the relation.""" + self._legacy_apply_on_delete(fields) + relation_name = self.relation_name relation = self.get_relation(relation_name, relation_id) return self._delete_relation_data(relation, fields) @@ -1336,8 +1473,7 @@ def _add_relation_secret( uri_to_databag=True, ) -> bool: """Add a new Juju Secret that will be registered in the relation databag.""" - secret_field = self._generate_secret_field_name(group_mapping) - if uri_to_databag and relation.data[self.component].get(secret_field): + if uri_to_databag and self.get_secret_uri(relation, group_mapping): logging.error("Secret for relation %s already exists, not adding again", relation.id) return False @@ -1348,7 +1484,7 @@ def _add_relation_secret( # According to lint we may not have a Secret ID if uri_to_databag and secret.meta and secret.meta.id: - relation.data[self.component][secret_field] = secret.meta.id + self.set_secret_uri(relation, group_mapping, secret.meta.id) # Return the content that was added return True @@ -1449,8 +1585,7 @@ def _get_relation_secret( if not relation: return - secret_field = self._generate_secret_field_name(group_mapping) - if secret_uri := relation.data[self.local_app].get(secret_field): + if secret_uri := self.get_secret_uri(relation, group_mapping): return self.secrets.get(label, secret_uri) def _fetch_specific_relation_data( @@ -1603,11 +1738,10 @@ def _register_secrets_to_relation(self, relation: Relation, params_name_list: Li for group in SECRET_GROUPS.groups(): secret_field = self._generate_secret_field_name(group) - if secret_field in params_name_list: - if secret_uri := relation.data[relation.app].get(secret_field): - self._register_secret_to_relation( - relation.name, relation.id, secret_uri, group - ) + if secret_field in params_name_list and ( + secret_uri := self.get_secret_uri(relation, group) + ): + self._register_secret_to_relation(relation.name, relation.id, secret_uri, group) def _is_resource_created_for_relation(self, relation: Relation) -> bool: if not relation.app: @@ -1618,6 +1752,17 @@ def _is_resource_created_for_relation(self, relation: Relation) -> bool: ) return bool(data.get("username")) and bool(data.get("password")) + # Public functions + + def get_secret_uri(self, relation: Relation, group: SecretGroup) -> Optional[str]: + """Getting relation secret URI for the corresponding Secret Group.""" + secret_field = self._generate_secret_field_name(group) + return relation.data[relation.app].get(secret_field) + + def set_secret_uri(self, relation: Relation, group: SecretGroup, uri: str) -> None: + """Setting relation secret URI is not possible for a Requirer.""" + raise NotImplementedError("Requirer can not change the relation secret URI.") + def is_resource_created(self, relation_id: Optional[int] = None) -> bool: """Check if the resource has been created. @@ -1768,7 +1913,6 @@ def __init__( secret_field_name: Optional[str] = None, deleted_label: Optional[str] = None, ): - """Manager of base client relations.""" RequirerData.__init__( self, model, @@ -1779,6 +1923,11 @@ def __init__( self.secret_field_name = secret_field_name if secret_field_name else self.SECRET_FIELD_NAME self.deleted_label = deleted_label self._secret_label_map = {} + + # Legacy information holders + self._legacy_labels = [] + self._legacy_secret_uri = None + # Secrets that are being dynamically added within the scope of this event handler run self._new_secrets = [] self._additional_secret_group_mapping = additional_secret_group_mapping @@ -1853,10 +2002,12 @@ def set_secret( value: The string value of the secret group_mapping: The name of the "secret group", in case the field is to be added to an existing secret """ + self._legacy_apply_on_update([field]) + full_field = self._field_to_internal_name(field, group_mapping) if self.secrets_enabled and full_field not in self.current_secret_fields: self._new_secrets.append(full_field) - if self._no_group_with_databag(field, full_field): + if self.valid_field_pattern(field, full_field): self.update_relation_data(relation_id, {full_field: value}) # Unlike for set_secret(), there's no harm using this operation with static secrets @@ -1869,6 +2020,8 @@ def get_secret( group_mapping: Optional[SecretGroup] = None, ) -> Optional[str]: """Public interface method to fetch secrets only.""" + self._legacy_apply_on_fetch() + full_field = self._field_to_internal_name(field, group_mapping) if ( self.secrets_enabled @@ -1876,7 +2029,7 @@ def get_secret( and field not in self.current_secret_fields ): return - if self._no_group_with_databag(field, full_field): + if self.valid_field_pattern(field, full_field): return self.fetch_my_relation_field(relation_id, full_field) @dynamic_secrets_only @@ -1887,14 +2040,19 @@ def delete_secret( group_mapping: Optional[SecretGroup] = None, ) -> Optional[str]: """Public interface method to delete secrets only.""" + self._legacy_apply_on_delete([field]) + full_field = self._field_to_internal_name(field, group_mapping) if self.secrets_enabled and full_field not in self.current_secret_fields: logger.warning(f"Secret {field} from group {group_mapping} was not found") return - if self._no_group_with_databag(field, full_field): + + if self.valid_field_pattern(field, full_field): self.delete_relation_data(relation_id, [full_field]) + ########################################################################## # Helpers + ########################################################################## @staticmethod def _field_to_internal_name(field: str, group: Optional[SecretGroup]) -> str: @@ -1936,10 +2094,69 @@ def _content_for_secret_group( if k in self.secret_fields } - # Backwards compatibility + def valid_field_pattern(self, field: str, full_field: str) -> bool: + """Check that no secret group is attempted to be used together without secrets being enabled. + + Secrets groups are impossible to use with versions that are not yet supporting secrets. + """ + if not self.secrets_enabled and full_field != field: + logger.error( + f"Can't access {full_field}: no secrets available (i.e. no secret groups either)." + ) + return False + return True + + ########################################################################## + # Backwards compatibility / Upgrades + ########################################################################## + # These functions are used to keep backwards compatibility on upgrades + # Policy: + # All data is kept intact until the first write operation. (This allows a minimal + # grace period during which rollbacks are fully safe. For more info see spec.) + # All data involves: + # - databag + # - secrets content + # - secret labels (!!!) + # Legacy functions must return None, and leave an equally consistent state whether + # they are executed or skipped (as a high enough versioned execution environment may + # not require so) + + # Full legacy stack for each operation + + def _legacy_apply_on_fetch(self) -> None: + """All legacy functions to be applied on fetch.""" + relation = self._model.relations[self.relation_name][0] + self._legacy_compat_generate_prev_labels() + self._legacy_compat_secret_uri_from_databag(relation) + + def _legacy_apply_on_update(self, fields) -> None: + """All legacy functions to be applied on update.""" + relation = self._model.relations[self.relation_name][0] + self._legacy_compat_generate_prev_labels() + self._legacy_compat_secret_uri_from_databag(relation) + self._legacy_migration_remove_secret_from_databag(relation, fields) + self._legacy_migration_remove_secret_field_name_from_databag(relation) + + def _legacy_apply_on_delete(self, fields) -> None: + """All legacy functions to be applied on delete.""" + relation = self._model.relations[self.relation_name][0] + self._legacy_compat_generate_prev_labels() + self._legacy_compat_secret_uri_from_databag(relation) + self._legacy_compat_check_deleted_label(relation, fields) + + # Compatibility + + @legacy_apply_from_version(18) + def _legacy_compat_check_deleted_label(self, relation, fields) -> None: + """Helper function for legacy behavior. + + As long as https://bugs.launchpad.net/juju/+bug/2028094 wasn't fixed, + we did not delete fields but rather kept them in the secret with a string value + expressing invalidity. This function is maintainnig that behavior when needed. + """ + if not self.deleted_label: + return - def _check_deleted_label(self, relation, fields) -> None: - """Helper function for legacy behavior.""" current_data = self.fetch_my_relation_data([relation.id], fields) if current_data is not None: # Check if the secret we wanna delete actually exists @@ -1952,7 +2169,43 @@ def _check_deleted_label(self, relation, fields) -> None: ", ".join(non_existent), ) - def _remove_secret_from_databag(self, relation, fields: List[str]) -> None: + @legacy_apply_from_version(18) + def _legacy_compat_secret_uri_from_databag(self, relation) -> None: + """Fetching the secret URI from the databag, in case stored there.""" + self._legacy_secret_uri = relation.data[self.component].get( + self._generate_secret_field_name(), None + ) + + @legacy_apply_from_version(34) + def _legacy_compat_generate_prev_labels(self) -> None: + """Generator for legacy secret label names, for backwards compatibility. + + Secret label is part of the data that MUST be maintained across rolling upgrades. + In case there may be a change on a secret label, the old label must be recognized + after upgrades, and left intact until the first write operation -- when we roll over + to the new label. + + This function keeps "memory" of previously used secret labels. + NOTE: Return value takes decorator into account -- all 'legacy' functions may return `None` + + v0.34 (rev69): Fixing issue https://github.com/canonical/data-platform-libs/issues/155 + meant moving from '.' (i.e. 'mysql.app', 'mysql.unit') + to labels '..' (like 'peer.mysql.app') + """ + if self._legacy_labels: + return + + result = [] + members = [self._model.app.name] + if self.scope: + members.append(self.scope.value) + result.append(f"{'.'.join(members)}") + self._legacy_labels = result + + # Migration + + @legacy_apply_from_version(18) + def _legacy_migration_remove_secret_from_databag(self, relation, fields: List[str]) -> None: """For Rolling Upgrades -- when moving from databag to secrets usage. Practically what happens here is to remove stuff from the databag that is @@ -1966,10 +2219,16 @@ def _remove_secret_from_databag(self, relation, fields: List[str]) -> None: if self._fetch_relation_data_without_secrets(self.component, relation, [field]): self._delete_relation_data_without_secrets(self.component, relation, [field]) - def _remove_secret_field_name_from_databag(self, relation) -> None: + @legacy_apply_from_version(18) + def _legacy_migration_remove_secret_field_name_from_databag(self, relation) -> None: """Making sure that the old databag URI is gone. This action should not be executed more than once. + + There was a phase (before moving secrets usage to libs) when charms saved the peer + secret URI to the databag, and used this URI from then on to retrieve their secret. + When upgrading to charm versions using this library, we need to add a label to the + secret and access it via label from than on, and remove the old traces from the databag. """ # Nothing to do if 'internal-secret' is not in the databag if not (relation.data[self.component].get(self._generate_secret_field_name())): @@ -1985,25 +2244,9 @@ def _remove_secret_field_name_from_databag(self, relation) -> None: # Databag reference to the secret URI can be removed, now that it's labelled relation.data[self.component].pop(self._generate_secret_field_name(), None) - def _previous_labels(self) -> List[str]: - """Generator for legacy secret label names, for backwards compatibility.""" - result = [] - members = [self._model.app.name] - if self.scope: - members.append(self.scope.value) - result.append(f"{'.'.join(members)}") - return result - - def _no_group_with_databag(self, field: str, full_field: str) -> bool: - """Check that no secret group is attempted to be used together with databag.""" - if not self.secrets_enabled and full_field != field: - logger.error( - f"Can't access {full_field}: no secrets available (i.e. no secret groups either)." - ) - return False - return True - + ########################################################################## # Event handlers + ########################################################################## def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation has changed.""" @@ -2013,7 +2256,9 @@ def _on_secret_changed_event(self, event: SecretChangedEvent) -> None: """Event emitted when the secret has changed.""" pass + ########################################################################## # Overrides of Relation Data handling functions + ########################################################################## def _generate_secret_label( self, relation_name: str, relation_id: int, group_mapping: SecretGroup @@ -2050,13 +2295,14 @@ def _get_relation_secret( return label = self._generate_secret_label(relation_name, relation_id, group_mapping) - secret_uri = relation.data[self.component].get(self._generate_secret_field_name(), None) # URI or legacy label is only to applied when moving single legacy secret to a (new) label if group_mapping == SECRET_GROUPS.EXTRA: # Fetching the secret with fallback to URI (in case label is not yet known) # Label would we "stuck" on the secret in case it is found - return self.secrets.get(label, secret_uri, legacy_labels=self._previous_labels()) + return self.secrets.get( + label, self._legacy_secret_uri, legacy_labels=self._legacy_labels + ) return self.secrets.get(label) def _get_group_secret_contents( @@ -2086,7 +2332,6 @@ def _fetch_my_specific_relation_data( @either_static_or_dynamic_secrets def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: """Update data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" - self._remove_secret_from_databag(relation, list(data.keys())) _, normal_fields = self._process_secret_fields( relation, self.secret_fields, @@ -2095,7 +2340,6 @@ def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> Non data=data, uri_to_databag=False, ) - self._remove_secret_field_name_from_databag(relation) normal_content = {k: v for k, v in data.items() if k in normal_fields} self._update_relation_data_without_secrets(self.component, relation, normal_content) @@ -2104,8 +2348,6 @@ def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> Non def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" if self.secret_fields and self.deleted_label: - # Legacy, backwards compatibility - self._check_deleted_label(relation, fields) _, normal_fields = self._process_secret_fields( relation, @@ -2141,7 +2383,9 @@ def fetch_relation_field( "fetch_my_relation_data() and fetch_my_relation_field()" ) + ########################################################################## # Public functions -- inherited + ########################################################################## fetch_my_relation_data = Data.fetch_my_relation_data fetch_my_relation_field = Data.fetch_my_relation_field diff --git a/lib/charms/loki_k8s/v1/loki_push_api.py b/lib/charms/loki_k8s/v1/loki_push_api.py index c3c1d086f..7f8372c47 100644 --- a/lib/charms/loki_k8s/v1/loki_push_api.py +++ b/lib/charms/loki_k8s/v1/loki_push_api.py @@ -527,7 +527,7 @@ def _alert_rules_error(self, event): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 11 +LIBPATCH = 12 PYDEPS = ["cosl"] @@ -2562,7 +2562,9 @@ def _update_logging(self, _): return for container in self._charm.unit.containers.values(): - self._update_endpoints(container, loki_endpoints) + if container.can_connect(): + self._update_endpoints(container, loki_endpoints) + # else: `_update_endpoints` will be called on pebble-ready anyway. def _retrieve_endpoints_from_relation(self) -> dict: loki_endpoints = {} diff --git a/lib/charms/tempo_k8s/v1/charm_tracing.py b/lib/charms/tempo_k8s/v1/charm_tracing.py index ebe022e00..fa926539b 100644 --- a/lib/charms/tempo_k8s/v1/charm_tracing.py +++ b/lib/charms/tempo_k8s/v1/charm_tracing.py @@ -172,6 +172,59 @@ def my_tracing_endpoint(self) -> Optional[str]: provide an *absolute* path to the certificate file instead. """ + +def _remove_stale_otel_sdk_packages(): + """Hack to remove stale opentelemetry sdk packages from the charm's python venv. + + See https://github.com/canonical/grafana-agent-operator/issues/146 and + https://bugs.launchpad.net/juju/+bug/2058335 for more context. This patch can be removed after + this juju issue is resolved and sufficient time has passed to expect most users of this library + have migrated to the patched version of juju. When this patch is removed, un-ignore rule E402 for this file in the pyproject.toml (see setting + [tool.ruff.lint.per-file-ignores] in pyproject.toml). + + This only has an effect if executed on an upgrade-charm event. + """ + # all imports are local to keep this function standalone, side-effect-free, and easy to revert later + import os + + if os.getenv("JUJU_DISPATCH_PATH") != "hooks/upgrade-charm": + return + + import logging + import shutil + from collections import defaultdict + + from importlib_metadata import distributions + + otel_logger = logging.getLogger("charm_tracing_otel_patcher") + otel_logger.debug("Applying _remove_stale_otel_sdk_packages patch on charm upgrade") + # group by name all distributions starting with "opentelemetry_" + otel_distributions = defaultdict(list) + for distribution in distributions(): + name = distribution._normalized_name # type: ignore + if name.startswith("opentelemetry_"): + otel_distributions[name].append(distribution) + + otel_logger.debug(f"Found {len(otel_distributions)} opentelemetry distributions") + + # If we have multiple distributions with the same name, remove any that have 0 associated files + for name, distributions_ in otel_distributions.items(): + if len(distributions_) <= 1: + continue + + otel_logger.debug(f"Package {name} has multiple ({len(distributions_)}) distributions.") + for distribution in distributions_: + if not distribution.files: # Not None or empty list + path = distribution._path # type: ignore + otel_logger.info(f"Removing empty distribution of {name} at {path}.") + shutil.rmtree(path) + + otel_logger.debug("Successfully applied _remove_stale_otel_sdk_packages patch. ") + + +_remove_stale_otel_sdk_packages() + + import functools import inspect import logging @@ -197,14 +250,15 @@ def my_tracing_endpoint(self) -> Optional[str]: from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import Span, TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.trace import INVALID_SPAN, Tracer -from opentelemetry.trace import get_current_span as otlp_get_current_span from opentelemetry.trace import ( + INVALID_SPAN, + Tracer, get_tracer, get_tracer_provider, set_span_in_context, set_tracer_provider, ) +from opentelemetry.trace import get_current_span as otlp_get_current_span from ops.charm import CharmBase from ops.framework import Framework @@ -217,7 +271,7 @@ def my_tracing_endpoint(self) -> Optional[str]: # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 11 +LIBPATCH = 14 PYDEPS = ["opentelemetry-exporter-otlp-proto-http==1.21.0"] @@ -391,6 +445,9 @@ def wrap_init(self: CharmBase, framework: Framework, *args, **kwargs): _service_name = service_name or f"{self.app.name}-charm" unit_name = self.unit.name + # apply hacky patch to remove stale opentelemetry sdk packages on upgrade-charm. + # it could be trouble if someone ever decides to implement their own tracer parallel to + # ours and before the charm has inited. We assume they won't. resource = Resource.create( attributes={ "service.name": _service_name, @@ -612,38 +669,58 @@ def trace_type(cls: _T) -> _T: dev_logger.info(f"skipping {method} (dunder)") continue - new_method = trace_method(method) - if isinstance(inspect.getattr_static(cls, method.__name__), staticmethod): + # the span title in the general case should be: + # method call: MyCharmWrappedMethods.b + # if the method has a name (functools.wrapped or regular method), let + # _trace_callable use its default algorithm to determine what name to give the span. + trace_method_name = None + try: + qualname_c0 = method.__qualname__.split(".")[0] + if not hasattr(cls, method.__name__): + # if the callable doesn't have a __name__ (probably a decorated method), + # it probably has a bad qualname too (such as my_decorator..wrapper) which is not + # great for finding out what the trace is about. So we use the method name instead and + # add a reference to the decorator name. Result: + # method call: @my_decorator(MyCharmWrappedMethods.b) + trace_method_name = f"@{qualname_c0}({cls.__name__}.{name})" + except Exception: # noqa: failsafe + pass + + new_method = trace_method(method, name=trace_method_name) + + if isinstance(inspect.getattr_static(cls, name), staticmethod): new_method = staticmethod(new_method) setattr(cls, name, new_method) return cls -def trace_method(method: _F) -> _F: +def trace_method(method: _F, name: Optional[str] = None) -> _F: """Trace this method. A span will be opened when this method is called and closed when it returns. """ - return _trace_callable(method, "method") + return _trace_callable(method, "method", name=name) -def trace_function(function: _F) -> _F: +def trace_function(function: _F, name: Optional[str] = None) -> _F: """Trace this function. A span will be opened when this function is called and closed when it returns. """ - return _trace_callable(function, "function") + return _trace_callable(function, "function", name=name) -def _trace_callable(callable: _F, qualifier: str) -> _F: +def _trace_callable(callable: _F, qualifier: str, name: Optional[str] = None) -> _F: dev_logger.info(f"instrumenting {callable}") # sig = inspect.signature(callable) @functools.wraps(callable) def wrapped_function(*args, **kwargs): # type: ignore - name = getattr(callable, "__qualname__", getattr(callable, "__name__", str(callable))) - with _span(f"{qualifier} call: {name}"): # type: ignore + name_ = name or getattr( + callable, "__qualname__", getattr(callable, "__name__", str(callable)) + ) + with _span(f"{qualifier} call: {name_}"): # type: ignore return callable(*args, **kwargs) # type: ignore # wrapped_function.__signature__ = sig diff --git a/lib/charms/tempo_k8s/v2/tracing.py b/lib/charms/tempo_k8s/v2/tracing.py index 8b9fb4f32..dfb23365f 100644 --- a/lib/charms/tempo_k8s/v2/tracing.py +++ b/lib/charms/tempo_k8s/v2/tracing.py @@ -107,7 +107,7 @@ def __init__(self, *args): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 7 +LIBPATCH = 8 PYDEPS = ["pydantic"] @@ -116,14 +116,13 @@ def __init__(self, *args): DEFAULT_RELATION_NAME = "tracing" RELATION_INTERFACE_NAME = "tracing" +# Supported list rationale https://github.com/canonical/tempo-coordinator-k8s-operator/issues/8 ReceiverProtocol = Literal[ "zipkin", - "kafka", - "opencensus", - "tempo_http", - "tempo_grpc", "otlp_grpc", "otlp_http", + "jaeger_grpc", + "jaeger_thrift_http", ] RawReceiver = Tuple[ReceiverProtocol, str] @@ -141,14 +140,12 @@ class TransportProtocolType(str, enum.Enum): grpc = "grpc" -receiver_protocol_to_transport_protocol = { +receiver_protocol_to_transport_protocol: Dict[ReceiverProtocol, TransportProtocolType] = { "zipkin": TransportProtocolType.http, - "kafka": TransportProtocolType.http, - "opencensus": TransportProtocolType.http, - "tempo_http": TransportProtocolType.http, - "tempo_grpc": TransportProtocolType.grpc, "otlp_grpc": TransportProtocolType.grpc, "otlp_http": TransportProtocolType.http, + "jaeger_thrift_http": TransportProtocolType.http, + "jaeger_grpc": TransportProtocolType.grpc, } """A mapping between telemetry protocols and their corresponding transport protocol. """ diff --git a/lib/charms/tls_certificates_interface/v1/tls_certificates.py b/lib/charms/tls_certificates_interface/v2/tls_certificates.py similarity index 61% rename from lib/charms/tls_certificates_interface/v1/tls_certificates.py rename to lib/charms/tls_certificates_interface/v2/tls_certificates.py index be171d8e9..9f67833ba 100644 --- a/lib/charms/tls_certificates_interface/v1/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v2/tls_certificates.py @@ -11,7 +11,7 @@ From a charm directory, fetch the library using `charmcraft`: ```shell -charmcraft fetch-lib charms.tls_certificates_interface.v1.tls_certificates +charmcraft fetch-lib charms.tls_certificates_interface.v2.tls_certificates ``` Add the following libraries to the charm's `requirements.txt` file: @@ -36,10 +36,10 @@ Example: ```python -from charms.tls_certificates_interface.v1.tls_certificates import ( +from charms.tls_certificates_interface.v2.tls_certificates import ( CertificateCreationRequestEvent, CertificateRevocationRequestEvent, - TLSCertificatesProvidesV1, + TLSCertificatesProvidesV2, generate_private_key, ) from ops.charm import CharmBase, InstallEvent @@ -59,12 +59,14 @@ class ExampleProviderCharm(CharmBase): def __init__(self, *args): super().__init__(*args) - self.certificates = TLSCertificatesProvidesV1(self, "certificates") + self.certificates = TLSCertificatesProvidesV2(self, "certificates") self.framework.observe( - self.certificates.on.certificate_request, self._on_certificate_request + self.certificates.on.certificate_request, + self._on_certificate_request ) self.framework.observe( - self.certificates.on.certificate_revoked, self._on_certificate_revocation_request + self.certificates.on.certificate_revocation_request, + self._on_certificate_revocation_request ) self.framework.observe(self.on.install, self._on_install) @@ -124,17 +126,18 @@ def _on_certificate_revocation_request(self, event: CertificateRevocationRequest Example: ```python -from charms.tls_certificates_interface.v1.tls_certificates import ( +from charms.tls_certificates_interface.v2.tls_certificates import ( CertificateAvailableEvent, CertificateExpiringEvent, CertificateRevokedEvent, - TLSCertificatesRequiresV1, + TLSCertificatesRequiresV2, generate_csr, generate_private_key, ) from ops.charm import CharmBase, RelationJoinedEvent from ops.main import main from ops.model import ActiveStatus, WaitingStatus +from typing import Union class ExampleRequirerCharm(CharmBase): @@ -142,7 +145,7 @@ class ExampleRequirerCharm(CharmBase): def __init__(self, *args): super().__init__(*args) self.cert_subject = "whatever" - self.certificates = TLSCertificatesRequiresV1(self, "certificates") + self.certificates = TLSCertificatesRequiresV2(self, "certificates") self.framework.observe(self.on.install, self._on_install) self.framework.observe( self.on.certificates_relation_joined, self._on_certificates_relation_joined @@ -154,7 +157,11 @@ def __init__(self, *args): self.certificates.on.certificate_expiring, self._on_certificate_expiring ) self.framework.observe( - self.certificates.on.certificate_revoked, self._on_certificate_revoked + self.certificates.on.certificate_invalidated, self._on_certificate_invalidated + ) + self.framework.observe( + self.certificates.on.all_certificates_invalidated, + self._on_all_certificates_invalidated ) def _on_install(self, event) -> None: @@ -196,7 +203,9 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: replicas_relation.data[self.app].update({"chain": event.chain}) self.unit.status = ActiveStatus() - def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None: + def _on_certificate_expiring( + self, event: Union[CertificateExpiringEvent, CertificateInvalidatedEvent] + ) -> None: replicas_relation = self.model.get_relation("replicas") if not replicas_relation: self.unit.status = WaitingStatus("Waiting for peer relation to be created") @@ -216,12 +225,7 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None: ) replicas_relation.data[self.app].update({"csr": new_csr.decode()}) - def _on_certificate_revoked(self, event: CertificateRevokedEvent) -> None: - replicas_relation = self.model.get_relation("replicas") - if not replicas_relation: - self.unit.status = WaitingStatus("Waiting for peer relation to be created") - event.defer() - return + def _certificate_revoked(self) -> None: old_csr = replicas_relation.data[self.app].get("csr") private_key_password = replicas_relation.data[self.app].get("private_key_password") private_key = replicas_relation.data[self.app].get("private_key") @@ -240,44 +244,76 @@ def _on_certificate_revoked(self, event: CertificateRevokedEvent) -> None: replicas_relation.data[self.app].pop("chain") self.unit.status = WaitingStatus("Waiting for new certificate") + def _on_certificate_invalidated(self, event: CertificateInvalidatedEvent) -> None: + replicas_relation = self.model.get_relation("replicas") + if not replicas_relation: + self.unit.status = WaitingStatus("Waiting for peer relation to be created") + event.defer() + return + if event.reason == "revoked": + self._certificate_revoked() + if event.reason == "expired": + self._on_certificate_expiring(event) + + def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEvent) -> None: + # Do what you want with this information, probably remove all certificates. + pass + if __name__ == "__main__": main(ExampleRequirerCharm) ``` + +You can relate both charms by running: + +```bash +juju relate +``` + """ # noqa: D405, D410, D411, D214, D416 import copy import json import logging import uuid -from datetime import datetime, timedelta +from contextlib import suppress +from datetime import datetime, timedelta, timezone from ipaddress import IPv4Address -from typing import Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional, Union from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.serialization import pkcs12 -from cryptography.x509.extensions import Extension, ExtensionNotFound -from jsonschema import exceptions, validate # type: ignore[import] -from ops.charm import CharmBase, CharmEvents, RelationChangedEvent, UpdateStatusEvent +from jsonschema import exceptions, validate +from ops.charm import ( + CharmBase, + CharmEvents, + RelationBrokenEvent, + RelationChangedEvent, + SecretExpiredEvent, + UpdateStatusEvent, +) from ops.framework import EventBase, EventSource, Handle, Object +from ops.jujuversion import JujuVersion +from ops.model import ModelError, Relation, RelationDataContent, SecretNotFoundError # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" # Increment this major API version when introducing breaking changes -LIBAPI = 1 +LIBAPI = 2 # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 12 +LIBPATCH = 28 +PYDEPS = ["cryptography", "jsonschema"] REQUIRER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v1/schemas/requirer.json", # noqa: E501 + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/requirer.json", "type": "object", "title": "`tls_certificates` requirer root schema", "description": "The `tls_certificates` root schema comprises the entire requirer databag for this interface.", # noqa: E501 @@ -298,7 +334,10 @@ def _on_certificate_revoked(self, event: CertificateRevokedEvent) -> None: "type": "array", "items": { "type": "object", - "properties": {"certificate_signing_request": {"type": "string"}}, + "properties": { + "certificate_signing_request": {"type": "string"}, + "ca": {"type": "boolean"}, + }, "required": ["certificate_signing_request"], }, } @@ -309,7 +348,7 @@ def _on_certificate_revoked(self, event: CertificateRevokedEvent) -> None: PROVIDER_JSON_SCHEMA = { "$schema": "http://json-schema.org/draft-04/schema#", - "$id": "https://canonical.github.io/charm-relation-interfaces/tls_certificates/v1/schemas/provider.json", # noqa: E501 + "$id": "https://canonical.github.io/charm-relation-interfaces/interfaces/tls_certificates/v1/schemas/provider.json", "type": "object", "title": "`tls_certificates` provider root schema", "description": "The `tls_certificates` root schema comprises the entire provider databag for this interface.", # noqa: E501 @@ -401,7 +440,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate": self.certificate, "certificate_signing_request": self.certificate_signing_request, @@ -410,7 +449,7 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] @@ -426,7 +465,7 @@ def __init__(self, handle, certificate: str, expiry: str): Args: handle (Handle): Juju framework handle certificate (str): TLS Certificate - expiry (str): Datetime string reprensenting the time at which the certificate + expiry (str): Datetime string representing the time at which the certificate won't be valid anymore. """ super().__init__(handle) @@ -434,88 +473,96 @@ def __init__(self, handle, certificate: str, expiry: str): self.expiry = expiry def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return {"certificate": self.certificate, "expiry": self.expiry} def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.expiry = snapshot["expiry"] -class CertificateExpiredEvent(EventBase): - """Charm Event triggered when a TLS certificate is expired.""" - - def __init__(self, handle: Handle, certificate: str): - super().__init__(handle) - self.certificate = certificate - - def snapshot(self) -> dict: - """Returns snapshot.""" - return {"certificate": self.certificate} - - def restore(self, snapshot: dict): - """Restores snapshot.""" - self.certificate = snapshot["certificate"] - - -class CertificateRevokedEvent(EventBase): - """Charm Event triggered when a TLS certificate is revoked.""" +class CertificateInvalidatedEvent(EventBase): + """Charm Event triggered when a TLS certificate is invalidated.""" def __init__( self, handle: Handle, + reason: Literal["expired", "revoked"], certificate: str, certificate_signing_request: str, ca: str, chain: List[str], - revoked: bool, ): super().__init__(handle) - self.certificate = certificate + self.reason = reason self.certificate_signing_request = certificate_signing_request + self.certificate = certificate self.ca = ca self.chain = chain - self.revoked = revoked def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { - "certificate": self.certificate, + "reason": self.reason, "certificate_signing_request": self.certificate_signing_request, + "certificate": self.certificate, "ca": self.ca, "chain": self.chain, - "revoked": self.revoked, } def restore(self, snapshot: dict): - """Restores snapshot.""" - self.certificate = snapshot["certificate"] + """Restore snapshot.""" + self.reason = snapshot["reason"] self.certificate_signing_request = snapshot["certificate_signing_request"] + self.certificate = snapshot["certificate"] self.ca = snapshot["ca"] self.chain = snapshot["chain"] - self.revoked = snapshot["revoked"] + + +class AllCertificatesInvalidatedEvent(EventBase): + """Charm Event triggered when all TLS certificates are invalidated.""" + + def __init__(self, handle: Handle): + super().__init__(handle) + + def snapshot(self) -> dict: + """Return snapshot.""" + return {} + + def restore(self, snapshot: dict): + """Restore snapshot.""" + pass class CertificateCreationRequestEvent(EventBase): """Charm Event triggered when a TLS certificate is required.""" - def __init__(self, handle: Handle, certificate_signing_request: str, relation_id: int): + def __init__( + self, + handle: Handle, + certificate_signing_request: str, + relation_id: int, + is_ca: bool = False, + ): super().__init__(handle) self.certificate_signing_request = certificate_signing_request self.relation_id = relation_id + self.is_ca = is_ca def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate_signing_request": self.certificate_signing_request, "relation_id": self.relation_id, + "is_ca": self.is_ca, } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate_signing_request = snapshot["certificate_signing_request"] self.relation_id = snapshot["relation_id"] + self.is_ca = snapshot["is_ca"] class CertificateRevocationRequestEvent(EventBase): @@ -536,7 +583,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate": self.certificate, "certificate_signing_request": self.certificate_signing_request, @@ -545,33 +592,72 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] self.chain = snapshot["chain"] -def _load_relation_data(raw_relation_data: dict) -> dict: - """Loads relation data from the relation data bag. +def _load_relation_data(relation_data_content: RelationDataContent) -> dict: + """Load relation data from the relation data bag. Json loads all data. Args: - raw_relation_data: Relation data from the databag + relation_data_content: Relation data from the databag Returns: dict: Relation data in dict format. """ - certificate_data = dict() - for key in raw_relation_data: - try: - certificate_data[key] = json.loads(raw_relation_data[key]) - except (json.decoder.JSONDecodeError, TypeError): - certificate_data[key] = raw_relation_data[key] + certificate_data = {} + try: + for key in relation_data_content: + try: + certificate_data[key] = json.loads(relation_data_content[key]) + except (json.decoder.JSONDecodeError, TypeError): + certificate_data[key] = relation_data_content[key] + except ModelError: + pass return certificate_data +def _get_closest_future_time( + expiry_notification_time: datetime, expiry_time: datetime +) -> datetime: + """Return expiry_notification_time if not in the past, otherwise return expiry_time. + + Args: + expiry_notification_time (datetime): Notification time of impending expiration + expiry_time (datetime): Expiration time + + Returns: + datetime: expiry_notification_time if not in the past, expiry_time otherwise + """ + return ( + expiry_notification_time + if datetime.now(timezone.utc) < expiry_notification_time + else expiry_time + ) + + +def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: + """Extract expiry time from a certificate string. + + Args: + certificate (str): x509 certificate as a string + + Returns: + Optional[datetime]: Expiry datetime or None + """ + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + return certificate_object.not_valid_after_utc + except ValueError: + logger.warning("Could not load certificate.") + return None + + def generate_ca( private_key: bytes, subject: str, @@ -579,11 +665,11 @@ def generate_ca( validity: int = 365, country: str = "US", ) -> bytes: - """Generates a CA Certificate. + """Generate a CA Certificate. Args: private_key (bytes): Private key - subject (str): Certificate subject + subject (str): Common Name that can be an IP or a Full Qualified Domain Name (FQDN). private_key_password (bytes): Private key password validity (int): Certificate validity time (in days) country (str): Certificate Issuing country @@ -594,7 +680,7 @@ def generate_ca( private_key_object = serialization.load_pem_private_key( private_key, password=private_key_password ) - subject = issuer = x509.Name( + subject_name = x509.Name( [ x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country), x509.NameAttribute(x509.NameOID.COMMON_NAME, subject), @@ -604,14 +690,25 @@ def generate_ca( private_key_object.public_key() # type: ignore[arg-type] ) subject_identifier = key_identifier = subject_identifier_object.public_bytes() + key_usage = x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + key_cert_sign=True, + key_agreement=False, + content_commitment=False, + data_encipherment=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ) cert = ( x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) + .subject_name(subject_name) + .issuer_name(subject_name) .public_key(private_key_object.public_key()) # type: ignore[arg-type] .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.utcnow()) - .not_valid_after(datetime.utcnow() + timedelta(days=validity)) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) .add_extension(x509.SubjectKeyIdentifier(digest=subject_identifier), critical=False) .add_extension( x509.AuthorityKeyIdentifier( @@ -621,6 +718,7 @@ def generate_ca( ), critical=False, ) + .add_extension(key_usage, critical=True) .add_extension( x509.BasicConstraints(ca=True, path_length=None), critical=True, @@ -630,6 +728,105 @@ def generate_ca( return cert.public_bytes(serialization.Encoding.PEM) +def get_certificate_extensions( + authority_key_identifier: bytes, + csr: x509.CertificateSigningRequest, + alt_names: Optional[List[str]], + is_ca: bool, +) -> List[x509.Extension]: + """Generate a list of certificate extensions from a CSR and other known information. + + Args: + authority_key_identifier (bytes): Authority key identifier + csr (x509.CertificateSigningRequest): CSR + alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate + + Returns: + List[x509.Extension]: List of extensions + """ + cert_extensions_list: List[x509.Extension] = [ + x509.Extension( + oid=ExtensionOID.AUTHORITY_KEY_IDENTIFIER, + value=x509.AuthorityKeyIdentifier( + key_identifier=authority_key_identifier, + authority_cert_issuer=None, + authority_cert_serial_number=None, + ), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, + value=x509.SubjectKeyIdentifier.from_public_key(csr.public_key()), + critical=False, + ), + x509.Extension( + oid=ExtensionOID.BASIC_CONSTRAINTS, + critical=True, + value=x509.BasicConstraints(ca=is_ca, path_length=None), + ), + ] + + sans: List[x509.GeneralName] = [] + san_alt_names = [x509.DNSName(name) for name in alt_names] if alt_names else [] + sans.extend(san_alt_names) + try: + loaded_san_ext = csr.extensions.get_extension_for_class(x509.SubjectAlternativeName) + sans.extend( + [x509.DNSName(name) for name in loaded_san_ext.value.get_values_for_type(x509.DNSName)] + ) + sans.extend( + [x509.IPAddress(ip) for ip in loaded_san_ext.value.get_values_for_type(x509.IPAddress)] + ) + sans.extend( + [ + x509.RegisteredID(oid) + for oid in loaded_san_ext.value.get_values_for_type(x509.RegisteredID) + ] + ) + except x509.ExtensionNotFound: + pass + + if sans: + cert_extensions_list.append( + x509.Extension( + oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, + critical=False, + value=x509.SubjectAlternativeName(sans), + ) + ) + + if is_ca: + cert_extensions_list.append( + x509.Extension( + ExtensionOID.KEY_USAGE, + critical=True, + value=x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + ) + ) + + existing_oids = {ext.oid for ext in cert_extensions_list} + for extension in csr.extensions: + if extension.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME: + continue + if extension.oid in existing_oids: + logger.warning("Extension %s is managed by the TLS provider, ignoring.", extension.oid) + continue + cert_extensions_list.append(extension) + + return cert_extensions_list + + def generate_certificate( csr: bytes, ca: bytes, @@ -637,8 +834,9 @@ def generate_certificate( ca_key_password: Optional[bytes] = None, validity: int = 365, alt_names: Optional[List[str]] = None, + is_ca: bool = False, ) -> bytes: - """Generates a TLS certificate based on a CSR. + """Generate a TLS certificate based on a CSR. Args: csr (bytes): CSR @@ -647,13 +845,15 @@ def generate_certificate( ca_key_password: CA private key password validity (int): Certificate validity (in days) alt_names (list): List of alt names to put on cert - prefer putting SANs in CSR + is_ca (bool): Whether the certificate is a CA certificate Returns: bytes: Certificate """ csr_object = x509.load_pem_x509_csr(csr) subject = csr_object.subject - issuer = x509.load_pem_x509_certificate(ca).issuer + ca_pem = x509.load_pem_x509_certificate(ca) + issuer = ca_pem.issuer private_key = serialization.load_pem_private_key(ca_key, password=ca_key_password) certificate_builder = ( @@ -662,39 +862,26 @@ def generate_certificate( .issuer_name(issuer) .public_key(csr_object.public_key()) .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.utcnow()) - .not_valid_after(datetime.utcnow() + timedelta(days=validity)) + .not_valid_before(datetime.now(timezone.utc)) + .not_valid_after(datetime.now(timezone.utc) + timedelta(days=validity)) ) - - extensions_list = csr_object.extensions - san_ext: Optional[x509.Extension] = None - if alt_names: - full_sans_dns = alt_names.copy() + extensions = get_certificate_extensions( + authority_key_identifier=ca_pem.extensions.get_extension_for_class( + x509.SubjectKeyIdentifier + ).value.key_identifier, + csr=csr_object, + alt_names=alt_names, + is_ca=is_ca, + ) + for extension in extensions: try: - loaded_san_ext = csr_object.extensions.get_extension_for_class( - x509.SubjectAlternativeName + certificate_builder = certificate_builder.add_extension( + extval=extension.value, + critical=extension.critical, ) - full_sans_dns.extend(loaded_san_ext.value.get_values_for_type(x509.DNSName)) - except ExtensionNotFound: - pass - finally: - san_ext = Extension( - ExtensionOID.SUBJECT_ALTERNATIVE_NAME, - False, - x509.SubjectAlternativeName([x509.DNSName(name) for name in full_sans_dns]), - ) - if not extensions_list: - extensions_list = x509.Extensions([san_ext]) - - for extension in extensions_list: - if extension.value.oid == ExtensionOID.SUBJECT_ALTERNATIVE_NAME and san_ext: - extension = san_ext + except ValueError as e: + logger.warning("Failed to add extension %s: %s", extension.oid, e) - certificate_builder = certificate_builder.add_extension( - extension.value, - critical=extension.critical, - ) - certificate_builder._version = x509.Version.v3 cert = certificate_builder.sign(private_key, hashes.SHA256()) # type: ignore[arg-type] return cert.public_bytes(serialization.Encoding.PEM) @@ -705,7 +892,7 @@ def generate_pfx_package( package_password: str, private_key_password: Optional[bytes] = None, ) -> bytes: - """Generates a PFX package to contain the TLS certificate and private key. + """Generate a PFX package to contain the TLS certificate and private key. Args: certificate (bytes): TLS certificate @@ -736,7 +923,7 @@ def generate_private_key( key_size: int = 2048, public_exponent: int = 65537, ) -> bytes: - """Generates a private key. + """Generate a private key. Args: password (bytes): Password for decrypting the private key @@ -753,14 +940,16 @@ def generate_private_key( key_bytes = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.BestAvailableEncryption(password) - if password - else serialization.NoEncryption(), + encryption_algorithm=( + serialization.BestAvailableEncryption(password) + if password + else serialization.NoEncryption() + ), ) return key_bytes -def generate_csr( +def generate_csr( # noqa: C901 private_key: bytes, subject: str, add_unique_id_to_subject_name: bool = True, @@ -774,11 +963,11 @@ def generate_csr( sans_dns: Optional[List[str]] = None, additional_critical_extensions: Optional[List] = None, ) -> bytes: - """Generates a CSR using private key and subject. + """Generate a CSR using private key and subject. Args: private_key (bytes): Private key - subject (str): CSR Subject. + subject (str): CSR Common Name that can be an IP or a Full Qualified Domain Name (FQDN). add_unique_id_to_subject_name (bool): Whether a unique ID must be added to the CSR's subject name. Always leave to "True" when the CSR is used to request certificates using the tls-certificates relation. @@ -791,7 +980,7 @@ def generate_csr( sans_oid (list): List of registered ID SANs sans_dns (list): List of DNS subject alternative names (similar to the arg: sans) sans_ip (list): List of IP subject alternative names - additional_critical_extensions (list): List if critical additional extension objects. + additional_critical_extensions (list): List of critical additional extension objects. Object must be a x509 ExtensionType. Returns: @@ -832,6 +1021,38 @@ def generate_csr( return signed_certificate.public_bytes(serialization.Encoding.PEM) +def csr_matches_certificate(csr: str, cert: str) -> bool: + """Check if a CSR matches a certificate. + + Args: + csr (str): Certificate Signing Request as a string + cert (str): Certificate as a string + Returns: + bool: True/False depending on whether the CSR matches the certificate. + """ + try: + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): + return False + if ( + csr_object.public_key().public_numbers().n # type: ignore[union-attr] + != cert_object.public_key().public_numbers().n # type: ignore[union-attr] + ): + return False + except ValueError: + logger.warning("Could not load certificate or CSR.") + return False + return True + + class CertificatesProviderCharmEvents(CharmEvents): """List of events that the TLS Certificates provider charm can leverage.""" @@ -844,14 +1065,14 @@ class CertificatesRequirerCharmEvents(CharmEvents): certificate_available = EventSource(CertificateAvailableEvent) certificate_expiring = EventSource(CertificateExpiringEvent) - certificate_expired = EventSource(CertificateExpiredEvent) - certificate_revoked = EventSource(CertificateRevokedEvent) + certificate_invalidated = EventSource(CertificateInvalidatedEvent) + all_certificates_invalidated = EventSource(AllCertificatesInvalidatedEvent) -class TLSCertificatesProvidesV1(Object): +class TLSCertificatesProvidesV2(Object): """TLS certificates provider class to be instantiated by TLS certificates providers.""" - on = CertificatesProviderCharmEvents() + on = CertificatesProviderCharmEvents() # type: ignore[reportAssignmentType] def __init__(self, charm: CharmBase, relationship_name: str): super().__init__(charm, relationship_name) @@ -861,6 +1082,22 @@ def __init__(self, charm: CharmBase, relationship_name: str): self.charm = charm self.relationship_name = relationship_name + def _load_app_relation_data(self, relation: Relation) -> dict: + """Load relation data from the application relation data bag. + + Json loads all data. + + Args: + relation: Relation data from the application databag + + Returns: + dict: Relation data in dict format. + """ + # If unit is not leader, it does not try to reach relation data. + if not self.model.unit.is_leader(): + return {} + return _load_relation_data(relation.data[self.charm.app]) + def _add_certificate( self, relation_id: int, @@ -869,7 +1106,7 @@ def _add_certificate( ca: str, chain: List[str], ) -> None: - """Adds certificate to relation data. + """Add certificate to relation data. Args: relation_id (int): Relation id @@ -895,7 +1132,7 @@ def _add_certificate( "ca": ca, "chain": chain, } - provider_relation_data = _load_relation_data(relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) certificates = copy.deepcopy(provider_certificates) if new_certificate in certificates: @@ -910,7 +1147,7 @@ def _remove_certificate( certificate: Optional[str] = None, certificate_signing_request: Optional[str] = None, ) -> None: - """Removes certificate from a given relation based on user provided certificate or csr. + """Remove certificate from a given relation based on user provided certificate or csr. Args: relation_id (int): Relation id @@ -928,7 +1165,7 @@ def _remove_certificate( raise RuntimeError( f"Relation {self.relationship_name} with relation id {relation_id} does not exist" ) - provider_relation_data = _load_relation_data(relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) certificates = copy.deepcopy(provider_certificates) for certificate_dict in certificates: @@ -943,7 +1180,7 @@ def _remove_certificate( @staticmethod def _relation_data_is_valid(certificates_data: dict) -> bool: - """Uses JSON schema validator to validate relation data content. + """Use JSON schema validator to validate relation data content. Args: certificates_data (dict): Certificate data dictionary as retrieved from relation data. @@ -958,12 +1195,12 @@ def _relation_data_is_valid(certificates_data: dict) -> bool: return False def revoke_all_certificates(self) -> None: - """Revokes all certificates of this provider. + """Revoke all certificates of this provider. This method is meant to be used when the Root CA has changed. """ for relation in self.model.relations[self.relationship_name]: - provider_relation_data = _load_relation_data(relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(relation) provider_certificates = copy.deepcopy(provider_relation_data.get("certificates", [])) for certificate in provider_certificates: certificate["revoked"] = True @@ -977,7 +1214,7 @@ def set_relation_certificate( chain: List[str], relation_id: int, ) -> None: - """Adds certificates to relation data. + """Add certificates to relation data. Args: certificate (str): Certificate @@ -989,6 +1226,8 @@ def set_relation_certificate( Returns: None """ + if not self.model.unit.is_leader(): + return certificates_relation = self.model.get_relation( relation_name=self.relationship_name, relation_id=relation_id ) @@ -1007,7 +1246,7 @@ def set_relation_certificate( ) def remove_certificate(self, certificate: str) -> None: - """Removes a given certificate from relation data. + """Remove a given certificate from relation data. Args: certificate (str): TLS Certificate @@ -1021,8 +1260,45 @@ def remove_certificate(self, certificate: str) -> None: for certificate_relation in certificates_relation: self._remove_certificate(certificate=certificate, relation_id=certificate_relation.id) + def get_issued_certificates( + self, relation_id: Optional[int] = None + ) -> Dict[str, List[Dict[str, str]]]: + """Return a dictionary of issued certificates. + + It returns certificates from all relations if relation_id is not specified. + Certificates are returned per application name and CSR. + + Returns: + dict: Certificates per application name. + """ + certificates: Dict[str, List[Dict[str, str]]] = {} + relations = ( + [ + relation + for relation in self.model.relations[self.relationship_name] + if relation.id == relation_id + ] + if relation_id is not None + else self.model.relations.get(self.relationship_name, []) + ) + for relation in relations: + provider_relation_data = self._load_app_relation_data(relation) + provider_certificates = provider_relation_data.get("certificates", []) + + certificates[relation.app.name] = [] # type: ignore[union-attr] + for certificate in provider_certificates: + if not certificate.get("revoked", False): + certificates[relation.app.name].append( # type: ignore[union-attr] + { + "csr": certificate["certificate_signing_request"], + "certificate": certificate["certificate"], + } + ) + + return certificates + def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handler triggerred on relation changed event. + """Handle relation changed event. Looks at the relation data and either emits: - certificate request event: If the unit relation data contains a CSR for which @@ -1036,13 +1312,15 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: Returns: None """ - assert event.unit is not None + if event.unit is None: + logger.error("Relation_changed event does not have a unit.") + return + if not self.model.unit.is_leader(): + return requirer_relation_data = _load_relation_data(event.relation.data[event.unit]) - provider_relation_data = _load_relation_data(event.relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(event.relation) if not self._relation_data_is_valid(requirer_relation_data): - logger.warning( - f"Relation data did not pass JSON Schema validation: {requirer_relation_data}" - ) + logger.debug("Relation data did not pass JSON Schema validation") return provider_certificates = provider_relation_data.get("certificates", []) requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) @@ -1050,22 +1328,26 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: certificate_creation_request["certificate_signing_request"] for certificate_creation_request in provider_certificates ] - requirer_unit_csrs = [ - certificate_creation_request["certificate_signing_request"] + requirer_unit_certificate_requests = [ + { + "csr": certificate_creation_request["certificate_signing_request"], + "is_ca": certificate_creation_request.get("ca", False), + } for certificate_creation_request in requirer_csrs ] - for certificate_signing_request in requirer_unit_csrs: - if certificate_signing_request not in provider_csrs: + for certificate_request in requirer_unit_certificate_requests: + if certificate_request["csr"] not in provider_csrs: self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_signing_request, + certificate_signing_request=certificate_request["csr"], relation_id=event.relation.id, + is_ca=certificate_request["is_ca"], ) self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: - """Revokes certificates for which no unit has a CSR. + """Revoke certificates for which no unit has a CSR. - Goes through all generated certificates and compare agains the list of CSRS for all units + Goes through all generated certificates and compare against the list of CSRs for all units of a given relationship. Args: @@ -1079,7 +1361,7 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None ) if not certificates_relation: raise RuntimeError(f"Relation {self.relationship_name} does not exist") - provider_relation_data = _load_relation_data(certificates_relation.data[self.charm.app]) + provider_relation_data = self._load_app_relation_data(certificates_relation) list_of_csrs: List[str] = [] for unit in certificates_relation.units: requirer_relation_data = _load_relation_data(certificates_relation.data[unit]) @@ -1096,11 +1378,112 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None ) self.remove_certificate(certificate=certificate["certificate"]) + def get_outstanding_certificate_requests( + self, relation_id: Optional[int] = None + ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: + """Return CSR's for which no certificate has been issued. + + Example return: [ + { + "relation_id": 0, + "application_name": "tls-certificates-requirer", + "unit_name": "tls-certificates-requirer/0", + "unit_csrs": [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "is_ca": false + } + ] + } + ] + + Args: + relation_id (int): Relation id + + Returns: + list: List of dictionaries that contain the unit's csrs + that don't have a certificate issued. + """ + all_unit_csr_mappings = copy.deepcopy(self.get_requirer_csrs(relation_id=relation_id)) + filtered_all_unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] + for unit_csr_mapping in all_unit_csr_mappings: + csrs_without_certs = [] + for csr in unit_csr_mapping["unit_csrs"]: # type: ignore[union-attr] + if not self.certificate_issued_for_csr( + app_name=unit_csr_mapping["application_name"], # type: ignore[arg-type] + csr=csr["certificate_signing_request"], # type: ignore[index] + relation_id=relation_id, + ): + csrs_without_certs.append(csr) + if csrs_without_certs: + unit_csr_mapping["unit_csrs"] = csrs_without_certs # type: ignore[assignment] + filtered_all_unit_csr_mappings.append(unit_csr_mapping) + return filtered_all_unit_csr_mappings + + def get_requirer_csrs( + self, relation_id: Optional[int] = None + ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: + """Return a list of requirers' CSRs grouped by unit. + + It returns CSRs from all relations if relation_id is not specified. + CSRs are returned per relation id, application name and unit name. + + Returns: + list: List of dictionaries that contain the unit's csrs + with the following information + relation_id, application_name and unit_name. + """ + unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] + + relations = ( + [ + relation + for relation in self.model.relations[self.relationship_name] + if relation.id == relation_id + ] + if relation_id is not None + else self.model.relations.get(self.relationship_name, []) + ) -class TLSCertificatesRequiresV1(Object): + for relation in relations: + for unit in relation.units: + requirer_relation_data = _load_relation_data(relation.data[unit]) + unit_csrs_list = requirer_relation_data.get("certificate_signing_requests", []) + unit_csr_mappings.append( + { + "relation_id": relation.id, + "application_name": relation.app.name, # type: ignore[union-attr] + "unit_name": unit.name, + "unit_csrs": unit_csrs_list, + } + ) + return unit_csr_mappings + + def certificate_issued_for_csr( + self, app_name: str, csr: str, relation_id: Optional[int] + ) -> bool: + """Check whether a certificate has been issued for a given CSR. + + Args: + app_name (str): Application name that the CSR belongs to. + csr (str): Certificate Signing Request. + relation_id (Optional[int]): Relation ID + Returns: + bool: True/False depending on whether a certificate has been issued for the given CSR. + """ + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)[ + app_name + ] + for issued_pair in issued_certificates_per_csr: + if "csr" in issued_pair and issued_pair["csr"] == csr: + return csr_matches_certificate(csr, issued_pair["certificate"]) + return False + + +class TLSCertificatesRequiresV2(Object): """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" - on = CertificatesRequirerCharmEvents() + on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] def __init__( self, @@ -1108,7 +1491,7 @@ def __init__( relationship_name: str, expiry_notification_time: int = 168, ): - """Generates/use private key and observes relation changed event. + """Generate/use private key and observes relation changed event. Args: charm: Charm object @@ -1123,11 +1506,26 @@ def __init__( self.framework.observe( charm.on[relationship_name].relation_changed, self._on_relation_changed ) - self.framework.observe(charm.on.update_status, self._on_update_status) + self.framework.observe( + charm.on[relationship_name].relation_broken, self._on_relation_broken + ) + if JujuVersion.from_environ().has_secrets: + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) + else: + self.framework.observe(charm.on.update_status, self._on_update_status) @property - def _requirer_csrs(self) -> List[Dict[str, str]]: - """Returns list of requirer CSR's from relation data.""" + def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: + """Return list of requirer's CSRs from relation unit data. + + Example: + [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "ca": false + } + ] + """ relation = self.model.get_relation(self.relationship_name) if not relation: raise RuntimeError(f"Relation {self.relationship_name} does not exist") @@ -1136,20 +1534,26 @@ def _requirer_csrs(self) -> List[Dict[str, str]]: @property def _provider_certificates(self) -> List[Dict[str, str]]: - """Returns list of provider CSR's from relation data.""" + """Return list of certificates from the provider's relation data.""" relation = self.model.get_relation(self.relationship_name) if not relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") + logger.debug("No relation: %s", self.relationship_name) + return [] if not relation.app: - raise RuntimeError(f"Remote app for relation {self.relationship_name} does not exist") + logger.debug("No remote app in relation: %s", self.relationship_name) + return [] provider_relation_data = _load_relation_data(relation.data[relation.app]) + if not self._relation_data_is_valid(provider_relation_data): + logger.warning("Provider relation data did not pass JSON Schema validation") + return [] return provider_relation_data.get("certificates", []) - def _add_requirer_csr(self, csr: str) -> None: - """Adds CSR to relation data. + def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: + """Add CSR to relation data. Args: csr (str): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate Returns: None @@ -1160,7 +1564,10 @@ def _add_requirer_csr(self, csr: str) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - new_csr_dict = {"certificate_signing_request": csr} + new_csr_dict: Dict[str, Union[bool, str]] = { + "certificate_signing_request": csr, + "ca": is_ca, + } if new_csr_dict in self._requirer_csrs: logger.info("CSR already in relation data - Doing nothing") return @@ -1169,7 +1576,7 @@ def _add_requirer_csr(self, csr: str) -> None: relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) def _remove_requirer_csr(self, csr: str) -> None: - """Removes CSR from relation data. + """Remove CSR from relation data. Args: csr (str): Certificate signing request @@ -1184,35 +1591,37 @@ def _remove_requirer_csr(self, csr: str) -> None: f"The certificate request can't be completed" ) requirer_csrs = copy.deepcopy(self._requirer_csrs) - csr_dict = {"certificate_signing_request": csr} - if csr_dict not in requirer_csrs: - logger.info("CSR not in relation data - Doing nothing") + if not requirer_csrs: + logger.info("No CSRs in relation data - Doing nothing") return - requirer_csrs.remove(csr_dict) + for requirer_csr in requirer_csrs: + if requirer_csr["certificate_signing_request"] == csr: + requirer_csrs.remove(requirer_csr) relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) - def request_certificate_creation(self, certificate_signing_request: bytes) -> None: + def request_certificate_creation( + self, certificate_signing_request: bytes, is_ca: bool = False + ) -> None: """Request TLS certificate to provider charm. Args: certificate_signing_request (bytes): Certificate Signing Request + is_ca (bool): Whether the certificate is a CA certificate Returns: None """ relation = self.model.get_relation(self.relationship_name) if not relation: - message = ( + raise RuntimeError( f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - logger.error(message) - raise RuntimeError(message) - self._add_requirer_csr(certificate_signing_request.decode().strip()) + self._add_requirer_csr(certificate_signing_request.decode().strip(), is_ca=is_ca) logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: - """Removes CSR from relation data. + """Remove CSR from relation data. The provider of this relation is then expected to remove certificates associated to this CSR from the relation data as well and emit a request_certificate_revocation event for the @@ -1230,7 +1639,7 @@ def request_certificate_revocation(self, certificate_signing_request: bytes) -> def request_certificate_renewal( self, old_certificate_signing_request: bytes, new_certificate_signing_request: bytes ) -> None: - """Renews certificate. + """Renew certificate. Removes old CSR from relation data and adds new one. @@ -1252,9 +1661,95 @@ def request_certificate_renewal( ) logger.info("Certificate renewal request completed.") + def get_assigned_certificates(self) -> List[Dict[str, str]]: + """Get a list of certificates that were assigned to this unit. + + Returns: + List of certificates. For example: + [ + { + "ca": "-----BEGIN CERTIFICATE-----...", + "chain": [ + "-----BEGIN CERTIFICATE-----..." + ], + "certificate": "-----BEGIN CERTIFICATE-----...", + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + } + ] + """ + final_list = [] + for csr in self.get_certificate_signing_requests(fulfilled_only=True): + assert isinstance(csr["certificate_signing_request"], str) + if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): + final_list.append(cert) + return final_list + + def get_expiring_certificates(self) -> List[Dict[str, str]]: + """Get a list of certificates that were assigned to this unit that are expiring or expired. + + Returns: + List of certificates. For example: + [ + { + "ca": "-----BEGIN CERTIFICATE-----...", + "chain": [ + "-----BEGIN CERTIFICATE-----..." + ], + "certificate": "-----BEGIN CERTIFICATE-----...", + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + } + ] + """ + final_list = [] + for csr in self.get_certificate_signing_requests(fulfilled_only=True): + assert isinstance(csr["certificate_signing_request"], str) + if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): + expiry_time = _get_certificate_expiry_time(cert["certificate"]) + if not expiry_time: + continue + expiry_notification_time = expiry_time - timedelta( + hours=self.expiry_notification_time + ) + if datetime.now(timezone.utc) > expiry_notification_time: + final_list.append(cert) + return final_list + + def get_certificate_signing_requests( + self, + fulfilled_only: bool = False, + unfulfilled_only: bool = False, + ) -> List[Dict[str, Union[bool, str]]]: + """Get the list of CSR's that were sent to the provider. + + You can choose to get only the CSR's that have a certificate assigned or only the CSR's + that don't. + + Args: + fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. + unfulfilled_only (bool): This option will discard CSRs that have certificates signed. + + Returns: + List of CSR dictionaries. For example: + [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "ca": false + } + ] + """ + final_list = [] + for csr in self._requirer_csrs: + assert isinstance(csr["certificate_signing_request"], str) + cert = self._find_certificate_in_relation_data(csr["certificate_signing_request"]) + if (unfulfilled_only and cert) or (fulfilled_only and not cert): + continue + final_list.append(csr) + + return final_list + @staticmethod def _relation_data_is_valid(certificates_data: dict) -> bool: - """Checks whether relation data is valid based on json schema. + """Check whether relation data is valid based on json schema. Args: certificates_data: Certificate data in dict format. @@ -1269,7 +1764,16 @@ def _relation_data_is_valid(certificates_data: dict) -> bool: return False def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handler triggered on relation changed events. + """Handle relation changed event. + + Goes through all providers certificates that match a requested CSR. + + If the provider certificate is revoked, emit a CertificateInvalidateEvent, + otherwise emit a CertificateAvailableEvent. + + When Juju secrets are available, remove the secret for revoked certificate, + or add a secret with the correct expiry time for new certificates. + Args: event: Juju event @@ -1277,20 +1781,6 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: Returns: None """ - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.warning(f"No relation: {self.relationship_name}") - return - if not relation.app: - logger.warning(f"No remote app in relation: {self.relationship_name}") - return - provider_relation_data = _load_relation_data(relation.data[relation.app]) - if not self._relation_data_is_valid(provider_relation_data): - logger.warning( - f"Provider relation data did not pass JSON Schema validation: " - f"{event.relation.data[relation.app]}" - ) - return requirer_csrs = [ certificate_creation_request["certificate_signing_request"] for certificate_creation_request in self._requirer_csrs @@ -1298,14 +1788,39 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: for certificate in self._provider_certificates: if certificate["certificate_signing_request"] in requirer_csrs: if certificate.get("revoked", False): - self.on.certificate_revoked.emit( - certificate_signing_request=certificate["certificate_signing_request"], + if JujuVersion.from_environ().has_secrets: + with suppress(SecretNotFoundError): + secret = self.model.get_secret( + label=f"{LIBID}-{certificate['certificate_signing_request']}" + ) + secret.remove_all_revisions() + self.on.certificate_invalidated.emit( + reason="revoked", certificate=certificate["certificate"], + certificate_signing_request=certificate["certificate_signing_request"], ca=certificate["ca"], chain=certificate["chain"], - revoked=True, ) else: + if JujuVersion.from_environ().has_secrets: + try: + secret = self.model.get_secret( + label=f"{LIBID}-{certificate['certificate_signing_request']}" + ) + secret.set_content({"certificate": certificate["certificate"]}) + secret.set_info( + expire=self._get_next_secret_expiry_time( + certificate["certificate"] + ), + ) + except SecretNotFoundError: + secret = self.charm.unit.add_secret( + {"certificate": certificate["certificate"]}, + label=f"{LIBID}-{certificate['certificate_signing_request']}", + expire=self._get_next_secret_expiry_time( + certificate["certificate"] + ), + ) self.on.certificate_available.emit( certificate_signing_request=certificate["certificate_signing_request"], certificate=certificate["certificate"], @@ -1313,8 +1828,102 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: chain=certificate["chain"], ) + def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: + """Return the expiry time or expiry notification time. + + Extracts the expiry time from the provided certificate, calculates the + expiry notification time and return the closest of the two, that is in + the future. + + Args: + certificate: x509 certificate + + Returns: + Optional[datetime]: None if the certificate expiry time cannot be read, + next expiry time otherwise. + """ + expiry_time = _get_certificate_expiry_time(certificate) + if not expiry_time: + return None + expiry_notification_time = expiry_time - timedelta(hours=self.expiry_notification_time) + return _get_closest_future_time(expiry_notification_time, expiry_time) + + def _on_relation_broken(self, event: RelationBrokenEvent) -> None: + """Handle relation broken event. + + Emitting `all_certificates_invalidated` from `relation-broken` rather + than `relation-departed` since certs are stored in app data. + + Args: + event: Juju event + + Returns: + None + """ + self.on.all_certificates_invalidated.emit() + + def _on_secret_expired(self, event: SecretExpiredEvent) -> None: + """Handle secret expired event. + + Loads the certificate from the secret, and will emit 1 of 2 + events. + + If the certificate is not yet expired, emits CertificateExpiringEvent + and updates the expiry time of the secret to the exact expiry time on + the certificate. + + If the certificate is expired, emits CertificateInvalidedEvent and + deletes the secret. + + Args: + event (SecretExpiredEvent): Juju event + """ + if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): + return + csr = event.secret.label[len(f"{LIBID}-") :] + certificate_dict = self._find_certificate_in_relation_data(csr) + if not certificate_dict: + # A secret expired but we did not find matching certificate. Cleaning up + event.secret.remove_all_revisions() + return + + expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) + if not expiry_time: + # A secret expired but matching certificate is invalid. Cleaning up + event.secret.remove_all_revisions() + return + + if datetime.now(timezone.utc) < expiry_time: + logger.warning("Certificate almost expired") + self.on.certificate_expiring.emit( + certificate=certificate_dict["certificate"], + expiry=expiry_time.isoformat(), + ) + event.secret.set_info( + expire=_get_certificate_expiry_time(certificate_dict["certificate"]), + ) + else: + logger.warning("Certificate is expired") + self.on.certificate_invalidated.emit( + reason="expired", + certificate=certificate_dict["certificate"], + certificate_signing_request=certificate_dict["certificate_signing_request"], + ca=certificate_dict["ca"], + chain=certificate_dict["chain"], + ) + self.request_certificate_revocation(certificate_dict["certificate"].encode()) + event.secret.remove_all_revisions() + + def _find_certificate_in_relation_data(self, csr: str) -> Optional[Dict[str, Any]]: + """Return the certificate that match the given CSR.""" + for certificate_dict in self._provider_certificates: + if certificate_dict["certificate_signing_request"] != csr: + continue + return certificate_dict + return None + def _on_update_status(self, event: UpdateStatusEvent) -> None: - """Triggered on update status event. + """Handle update status event. Goes through each certificate in the "certificates" relation and checks their expiry date. If they are close to expire (<7 days), emits a CertificateExpiringEvent event and if @@ -1326,35 +1935,25 @@ def _on_update_status(self, event: UpdateStatusEvent) -> None: Returns: None """ - relation = self.model.get_relation(self.relationship_name) - if not relation: - logger.warning(f"No relation: {self.relationship_name}") - return - if not relation.app: - logger.warning(f"No remote app in relation: {self.relationship_name}") - return - provider_relation_data = _load_relation_data(relation.data[relation.app]) - if not self._relation_data_is_valid(provider_relation_data): - logger.warning( - f"Provider relation data did not pass JSON Schema validation: " - f"{relation.data[relation.app]}" - ) - return for certificate_dict in self._provider_certificates: - certificate = certificate_dict["certificate"] - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - except ValueError: - logger.warning("Could not load certificate.") + expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) + if not expiry_time: continue - time_difference = certificate_object.not_valid_after - datetime.utcnow() + time_difference = expiry_time - datetime.now(timezone.utc) if time_difference.total_seconds() < 0: logger.warning("Certificate is expired") - self.on.certificate_expired.emit(certificate=certificate) - self.request_certificate_revocation(certificate.encode()) + self.on.certificate_invalidated.emit( + reason="expired", + certificate=certificate_dict["certificate"], + certificate_signing_request=certificate_dict["certificate_signing_request"], + ca=certificate_dict["ca"], + chain=certificate_dict["chain"], + ) + self.request_certificate_revocation(certificate_dict["certificate"].encode()) continue if time_difference.total_seconds() < (self.expiry_notification_time * 60 * 60): logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( - certificate=certificate, expiry=certificate_object.not_valid_after.isoformat() + certificate=certificate_dict["certificate"], + expiry=expiry_time.isoformat(), ) diff --git a/poetry.lock b/poetry.lock index 7c678401c..0e6869fed 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1838,7 +1838,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, diff --git a/src/relations/tls.py b/src/relations/tls.py index 202fdcdf1..f4a4783f8 100644 --- a/src/relations/tls.py +++ b/src/relations/tls.py @@ -11,7 +11,7 @@ import socket import typing -import charms.tls_certificates_interface.v1.tls_certificates as tls_certificates +import charms.tls_certificates_interface.v2.tls_certificates as tls_certificates import ops import relations.secrets @@ -49,7 +49,7 @@ class _Relation: """Relation to TLS certificate provider""" _charm: "kubernetes_charm.KubernetesRouterCharm" - _interface: tls_certificates.TLSCertificatesRequiresV1 + _interface: tls_certificates.TLSCertificatesRequiresV2 _secrets: relations.secrets.RelationSecrets @property @@ -171,7 +171,7 @@ class RelationEndpoint(ops.Object): def __init__(self, charm_: "kubernetes_charm.KubernetesRouterCharm") -> None: super().__init__(charm_, self.NAME) self._charm = charm_ - self._interface = tls_certificates.TLSCertificatesRequiresV1(self._charm, self.NAME) + self._interface = tls_certificates.TLSCertificatesRequiresV2(self._charm, self.NAME) self._secrets = relations.secrets.RelationSecrets( charm_, self._interface.relationship_name, unit_secret_fields=[_TLS_PRIVATE_KEY] diff --git a/tests/integration/test_upgrade.py b/tests/integration/test_upgrade.py index 650a696ec..328e4eeb6 100644 --- a/tests/integration/test_upgrade.py +++ b/tests/integration/test_upgrade.py @@ -197,7 +197,13 @@ async def test_fail_and_rollback(ops_test: OpsTest, continuous_writes) -> None: logger.info("Re-refresh the charm") await mysql_router_application.refresh(path="./upgrade.charm") - time.sleep(30) + # sleep to ensure that active status from before re-refresh does not affect below check + time.sleep(15) + + await ops_test.model.block_until( + lambda: all(unit.workload_status == "active" for unit in mysql_router_application.units) + and all(unit.agent_status == "idle" for unit in mysql_router_application.units) + ) logger.info("Running resume-upgrade on the mysql router leader unit") mysql_router_leader_unit = await get_leader_unit(ops_test, MYSQL_ROUTER_APP_NAME)