From f68fb9dcded6c93faf463d0425b4d59b7df49b05 Mon Sep 17 00:00:00 2001 From: Robert Yokota Date: Thu, 3 Jul 2025 12:19:18 -0700 Subject: [PATCH] Refactor encryption executor --- .../schema_registry/_async/avro.py | 18 +++- .../schema_registry/_async/json_schema.py | 28 ++++-- .../schema_registry/_async/protobuf.py | 21 +++- .../schema_registry/_async/serde.py | 44 +++++++-- .../schema_registry/_sync/avro.py | 18 +++- .../schema_registry/_sync/json_schema.py | 28 ++++-- .../schema_registry/_sync/protobuf.py | 21 +++- .../schema_registry/_sync/serde.py | 44 +++++++-- .../common/schema_registry_client.py | 30 +++++- .../rules/encryption/encrypt_executor.py | 87 ++++++++++++---- .../_async/test_avro_serdes.py | 98 +++++++++++++++---- .../_async/test_json_serdes.py | 80 +++++++++++++-- .../_async/test_proto_serdes.py | 60 +++++++++++- .../schema_registry/_sync/test_avro_serdes.py | 98 +++++++++++++++---- .../schema_registry/_sync/test_json_serdes.py | 80 +++++++++++++-- .../_sync/test_proto_serdes.py | 60 +++++++++++- 16 files changed, 698 insertions(+), 117 deletions(-) diff --git a/src/confluent_kafka/schema_registry/_async/avro.py b/src/confluent_kafka/schema_registry/_async/avro.py index 91016f1ec..fc8ff0749 100644 --- a/src/confluent_kafka/schema_registry/_async/avro.py +++ b/src/confluent_kafka/schema_registry/_async/avro.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import io from json import loads from typing import Dict, Union, Optional, Callable @@ -29,6 +29,8 @@ AsyncSchemaRegistryClient, prefix_schema_id_serializer, dual_schema_id_deserializer) +from confluent_kafka.schema_registry.common.schema_registry_client import \ + RulePhase from confluent_kafka.serialization import (SerializationError, SerializationContext) from confluent_kafka.schema_registry.rule_registry import RuleRegistry @@ -348,8 +350,14 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 with _ContextStringIO() as fo: # write the record to the rest of the buffer schemaless_writer(fo, parsed_schema, value) + buffer = fo.getvalue() + + if latest_schema is not None: + buffer = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, + None, latest_schema.schema, buffer, None, None) - return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id) + return self._schema_id_serializer(buffer, ctx, self._schema_id) async def _get_parsed_schema(self, schema: Schema) -> AvroSchema: parsed_schema = self._parsed_schemas.get_parsed_schema(schema) @@ -557,6 +565,12 @@ async def __deserialize( if subject is not None: latest_schema = await self._get_reader_schema(subject) + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) + if isinstance(payload, bytes): + payload = io.BytesIO(payload) + if latest_schema is not None: migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema diff --git a/src/confluent_kafka/schema_registry/_async/json_schema.py b/src/confluent_kafka/schema_registry/_async/json_schema.py index 8aa7bc17d..99c1f7d59 100644 --- a/src/confluent_kafka/schema_registry/_async/json_schema.py +++ b/src/confluent_kafka/schema_registry/_async/json_schema.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import io import json from typing import Union, Optional, Tuple, Callable @@ -33,6 +33,8 @@ from confluent_kafka.schema_registry.common.json_schema import ( DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform, _ContextStringIO, JSON_TYPE ) +from confluent_kafka.schema_registry.common.schema_registry_client import \ + RulePhase from confluent_kafka.schema_registry.rule_registry import RuleRegistry from confluent_kafka.schema_registry.serde import AsyncBaseSerializer, AsyncBaseDeserializer, \ ParsedSchemaCache, SchemaId @@ -374,8 +376,14 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 if isinstance(encoded_value, str): encoded_value = encoded_value.encode("utf8") fo.write(encoded_value) + buffer = fo.getvalue() + + if latest_schema is not None: + buffer = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, + None, latest_schema.schema, buffer, None, None) - return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id) + return self._schema_id_serializer(buffer, ctx, self._schema_id) async def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: if schema is None: @@ -552,9 +560,9 @@ async def __init_impl( __init__ = __init_impl def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: - return self.__serialize(data, ctx) + return self.__deserialize(data, ctx) - async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a JSON encoded record with Confluent Schema Registry framing to a dict, or object instance according to from_dict if from_dict is specified. @@ -583,9 +591,6 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N schema_id = SchemaId(JSON_TYPE) payload = self._schema_id_deserializer(data, ctx, schema_id) - # JSON documents are self-describing; no need to query schema - obj_dict = self._json_decode(payload.read()) - if self._registry is not None: writer_schema_raw = await self._get_writer_schema(schema_id, subject) writer_schema, writer_ref_registry = await self._get_parsed_schema(writer_schema_raw) @@ -597,6 +602,15 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N writer_schema_raw = None writer_schema, writer_ref_registry = None, None + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) + if isinstance(payload, bytes): + payload = io.BytesIO(payload) + + # JSON documents are self-describing; no need to query schema + obj_dict = self._json_decode(payload.read()) + if latest_schema is not None: migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema diff --git a/src/confluent_kafka/schema_registry/_async/protobuf.py b/src/confluent_kafka/schema_registry/_async/protobuf.py index 20ecaa57d..f7d1dc8b3 100644 --- a/src/confluent_kafka/schema_registry/_async/protobuf.py +++ b/src/confluent_kafka/schema_registry/_async/protobuf.py @@ -27,6 +27,8 @@ from confluent_kafka.schema_registry import (reference_subject_name_strategy, topic_subject_name_strategy, prefix_schema_id_serializer, dual_schema_id_deserializer) +from confluent_kafka.schema_registry.common.schema_registry_client import \ + RulePhase from confluent_kafka.schema_registry.schema_registry_client import AsyncSchemaRegistryClient from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, \ _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform, _ContextStringIO, PROTOBUF_TYPE @@ -426,7 +428,14 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 with _ContextStringIO() as fo: fo.write(message.SerializeToString()) self._schema_id.message_indexes = self._index_array - return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id) + buffer = fo.getvalue() + + if latest_schema is not None: + buffer = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, + None, latest_schema.schema, buffer, None, None) + + return self._schema_id_serializer(buffer, ctx, self._schema_id) async def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: result = self._parsed_schemas.get_parsed_schema(schema) @@ -549,9 +558,9 @@ async def __init_impl( __init__ = __init_impl def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: - return self.__serialize(data, ctx) + return self.__deserialize(data, ctx) - async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + async def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a serialized protobuf message with Confluent Schema Registry framing. @@ -596,6 +605,12 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N writer_schema_raw = None writer_schema = None + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) + if isinstance(payload, bytes): + payload = io.BytesIO(payload) + if latest_schema is not None: migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema diff --git a/src/confluent_kafka/schema_registry/_async/serde.py b/src/confluent_kafka/schema_registry/_async/serde.py index 792761430..5b208c39c 100644 --- a/src/confluent_kafka/schema_registry/_async/serde.py +++ b/src/confluent_kafka/schema_registry/_async/serde.py @@ -20,6 +20,8 @@ from typing import List, Optional, Set, Dict, Any from confluent_kafka.schema_registry import RegisteredSchema +from confluent_kafka.schema_registry.common.schema_registry_client import \ + RulePhase from confluent_kafka.schema_registry.common.serde import ErrorAction, \ FieldTransformer, Migration, NoneAction, RuleAction, \ RuleConditionError, RuleContext, RuleError, SchemaId @@ -59,6 +61,17 @@ def _execute_rules( source: Optional[Schema], target: Optional[Schema], message: Any, inline_tags: Optional[Dict[str, Set[str]]], field_transformer: Optional[FieldTransformer] + ) -> Any: + return self._execute_rules_with_phase( + ser_ctx, subject, RulePhase.DOMAIN, rule_mode, + source, target, message, inline_tags, field_transformer) + + def _execute_rules_with_phase( + self, ser_ctx: SerializationContext, subject: str, + rule_phase: RulePhase, rule_mode: RuleMode, + source: Optional[Schema], target: Optional[Schema], + message: Any, inline_tags: Optional[Dict[str, Set[str]]], + field_transformer: Optional[FieldTransformer] ) -> Any: if message is None or target is None: return message @@ -73,7 +86,10 @@ def _execute_rules( rules.reverse() else: if target is not None and target.rule_set is not None: - rules = target.rule_set.domain_rules + if rule_phase == RulePhase.ENCODING: + rules = target.rule_set.encoding_rules + else: + rules = target.rule_set.domain_rules if rule_mode == RuleMode.READ: # Execute read rules in reverse order for symmetry rules = rules[:] if rules else [] @@ -197,19 +213,25 @@ async def _get_writer_schema( else: raise SerializationError("Schema ID or GUID is not set") - def _has_rules(self, rule_set: RuleSet, mode: RuleMode) -> bool: + def _has_rules(self, rule_set: RuleSet, phase: RulePhase, mode: RuleMode) -> bool: if rule_set is None: return False + if phase == RulePhase.MIGRATION: + rules = rule_set.migration_rules + elif phase == RulePhase.DOMAIN: + rules = rule_set.domain_rules + elif phase == RulePhase.ENCODING: + rules = rule_set.encoding_rules if mode in (RuleMode.UPGRADE, RuleMode.DOWNGRADE): return any(rule.mode == mode or rule.mode == RuleMode.UPDOWN - for rule in rule_set.migration_rules or []) + for rule in rules or []) elif mode == RuleMode.UPDOWN: - return any(rule.mode == mode for rule in rule_set.migration_rules or []) + return any(rule.mode == mode for rule in rules or []) elif mode in (RuleMode.WRITE, RuleMode.READ): return any(rule.mode == mode or rule.mode == RuleMode.WRITEREAD - for rule in rule_set.domain_rules or []) + for rule in rules or []) elif mode == RuleMode.WRITEREAD: - return any(rule.mode == mode for rule in rule_set.migration_rules or []) + return any(rule.mode == mode for rule in rules or []) return False async def _get_migrations( @@ -235,7 +257,8 @@ async def _get_migrations( if i == 0: previous = version continue - if version.schema.rule_set is not None and self._has_rules(version.schema.rule_set, migration_mode): + if version.schema.rule_set is not None and self._has_rules( + version.schema.rule_set, RulePhase.MIGRATION, migration_mode): if migration_mode == RuleMode.UPGRADE: migration = Migration(migration_mode, previous, version) else: @@ -265,7 +288,8 @@ def _execute_migrations( migrations: List[Migration], message: Any ) -> Any: for migration in migrations: - message = self._execute_rules(ser_ctx, subject, migration.rule_mode, - migration.source.schema, migration.target.schema, - message, None, None) + message = self._execute_rules_with_phase( + ser_ctx, subject, RulePhase.MIGRATION, migration.rule_mode, + migration.source.schema, migration.target.schema, + message, None, None) return message diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py index e1e2ac915..78e7dd8ea 100644 --- a/src/confluent_kafka/schema_registry/_sync/avro.py +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import io from json import loads from typing import Dict, Union, Optional, Callable @@ -29,6 +29,8 @@ SchemaRegistryClient, prefix_schema_id_serializer, dual_schema_id_deserializer) +from confluent_kafka.schema_registry.common.schema_registry_client import \ + RulePhase from confluent_kafka.serialization import (SerializationError, SerializationContext) from confluent_kafka.schema_registry.rule_registry import RuleRegistry @@ -348,8 +350,14 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 with _ContextStringIO() as fo: # write the record to the rest of the buffer schemaless_writer(fo, parsed_schema, value) + buffer = fo.getvalue() + + if latest_schema is not None: + buffer = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, + None, latest_schema.schema, buffer, None, None) - return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id) + return self._schema_id_serializer(buffer, ctx, self._schema_id) def _get_parsed_schema(self, schema: Schema) -> AvroSchema: parsed_schema = self._parsed_schemas.get_parsed_schema(schema) @@ -557,6 +565,12 @@ def __deserialize( if subject is not None: latest_schema = self._get_reader_schema(subject) + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) + if isinstance(payload, bytes): + payload = io.BytesIO(payload) + if latest_schema is not None: migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py index 78ea5efd9..88da6322c 100644 --- a/src/confluent_kafka/schema_registry/_sync/json_schema.py +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import io import json from typing import Union, Optional, Tuple, Callable @@ -33,6 +33,8 @@ from confluent_kafka.schema_registry.common.json_schema import ( DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform, _ContextStringIO, JSON_TYPE ) +from confluent_kafka.schema_registry.common.schema_registry_client import \ + RulePhase from confluent_kafka.schema_registry.rule_registry import RuleRegistry from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, \ ParsedSchemaCache, SchemaId @@ -374,8 +376,14 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 if isinstance(encoded_value, str): encoded_value = encoded_value.encode("utf8") fo.write(encoded_value) + buffer = fo.getvalue() + + if latest_schema is not None: + buffer = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, + None, latest_schema.schema, buffer, None, None) - return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id) + return self._schema_id_serializer(buffer, ctx, self._schema_id) def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: if schema is None: @@ -552,9 +560,9 @@ def __init_impl( __init__ = __init_impl def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: - return self.__serialize(data, ctx) + return self.__deserialize(data, ctx) - def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a JSON encoded record with Confluent Schema Registry framing to a dict, or object instance according to from_dict if from_dict is specified. @@ -583,9 +591,6 @@ def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) - schema_id = SchemaId(JSON_TYPE) payload = self._schema_id_deserializer(data, ctx, schema_id) - # JSON documents are self-describing; no need to query schema - obj_dict = self._json_decode(payload.read()) - if self._registry is not None: writer_schema_raw = self._get_writer_schema(schema_id, subject) writer_schema, writer_ref_registry = self._get_parsed_schema(writer_schema_raw) @@ -597,6 +602,15 @@ def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) - writer_schema_raw = None writer_schema, writer_ref_registry = None, None + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) + if isinstance(payload, bytes): + payload = io.BytesIO(payload) + + # JSON documents are self-describing; no need to query schema + obj_dict = self._json_decode(payload.read()) + if latest_schema is not None: migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py index 08c64bf44..f202f0e99 100644 --- a/src/confluent_kafka/schema_registry/_sync/protobuf.py +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -27,6 +27,8 @@ from confluent_kafka.schema_registry import (reference_subject_name_strategy, topic_subject_name_strategy, prefix_schema_id_serializer, dual_schema_id_deserializer) +from confluent_kafka.schema_registry.common.schema_registry_client import \ + RulePhase from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, \ _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform, _ContextStringIO, PROTOBUF_TYPE @@ -426,7 +428,14 @@ def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 with _ContextStringIO() as fo: fo.write(message.SerializeToString()) self._schema_id.message_indexes = self._index_array - return self._schema_id_serializer(fo.getvalue(), ctx, self._schema_id) + buffer = fo.getvalue() + + if latest_schema is not None: + buffer = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.WRITE, + None, latest_schema.schema, buffer, None, None) + + return self._schema_id_serializer(buffer, ctx, self._schema_id) def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: result = self._parsed_schemas.get_parsed_schema(schema) @@ -549,9 +558,9 @@ def __init_impl( __init__ = __init_impl def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: - return self.__serialize(data, ctx) + return self.__deserialize(data, ctx) - def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: """ Deserialize a serialized protobuf message with Confluent Schema Registry framing. @@ -596,6 +605,12 @@ def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) - writer_schema_raw = None writer_schema = None + payload = self._execute_rules_with_phase( + ctx, subject, RulePhase.ENCODING, RuleMode.READ, + None, writer_schema_raw, payload, None, None) + if isinstance(payload, bytes): + payload = io.BytesIO(payload) + if latest_schema is not None: migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) reader_schema_raw = latest_schema.schema diff --git a/src/confluent_kafka/schema_registry/_sync/serde.py b/src/confluent_kafka/schema_registry/_sync/serde.py index f0512f872..df192ed5e 100644 --- a/src/confluent_kafka/schema_registry/_sync/serde.py +++ b/src/confluent_kafka/schema_registry/_sync/serde.py @@ -20,6 +20,8 @@ from typing import List, Optional, Set, Dict, Any from confluent_kafka.schema_registry import RegisteredSchema +from confluent_kafka.schema_registry.common.schema_registry_client import \ + RulePhase from confluent_kafka.schema_registry.common.serde import ErrorAction, \ FieldTransformer, Migration, NoneAction, RuleAction, \ RuleConditionError, RuleContext, RuleError, SchemaId @@ -59,6 +61,17 @@ def _execute_rules( source: Optional[Schema], target: Optional[Schema], message: Any, inline_tags: Optional[Dict[str, Set[str]]], field_transformer: Optional[FieldTransformer] + ) -> Any: + return self._execute_rules_with_phase( + ser_ctx, subject, RulePhase.DOMAIN, rule_mode, + source, target, message, inline_tags, field_transformer) + + def _execute_rules_with_phase( + self, ser_ctx: SerializationContext, subject: str, + rule_phase: RulePhase, rule_mode: RuleMode, + source: Optional[Schema], target: Optional[Schema], + message: Any, inline_tags: Optional[Dict[str, Set[str]]], + field_transformer: Optional[FieldTransformer] ) -> Any: if message is None or target is None: return message @@ -73,7 +86,10 @@ def _execute_rules( rules.reverse() else: if target is not None and target.rule_set is not None: - rules = target.rule_set.domain_rules + if rule_phase == RulePhase.ENCODING: + rules = target.rule_set.encoding_rules + else: + rules = target.rule_set.domain_rules if rule_mode == RuleMode.READ: # Execute read rules in reverse order for symmetry rules = rules[:] if rules else [] @@ -197,19 +213,25 @@ def _get_writer_schema( else: raise SerializationError("Schema ID or GUID is not set") - def _has_rules(self, rule_set: RuleSet, mode: RuleMode) -> bool: + def _has_rules(self, rule_set: RuleSet, phase: RulePhase, mode: RuleMode) -> bool: if rule_set is None: return False + if phase == RulePhase.MIGRATION: + rules = rule_set.migration_rules + elif phase == RulePhase.DOMAIN: + rules = rule_set.domain_rules + elif phase == RulePhase.ENCODING: + rules = rule_set.encoding_rules if mode in (RuleMode.UPGRADE, RuleMode.DOWNGRADE): return any(rule.mode == mode or rule.mode == RuleMode.UPDOWN - for rule in rule_set.migration_rules or []) + for rule in rules or []) elif mode == RuleMode.UPDOWN: - return any(rule.mode == mode for rule in rule_set.migration_rules or []) + return any(rule.mode == mode for rule in rules or []) elif mode in (RuleMode.WRITE, RuleMode.READ): return any(rule.mode == mode or rule.mode == RuleMode.WRITEREAD - for rule in rule_set.domain_rules or []) + for rule in rules or []) elif mode == RuleMode.WRITEREAD: - return any(rule.mode == mode for rule in rule_set.migration_rules or []) + return any(rule.mode == mode for rule in rules or []) return False def _get_migrations( @@ -235,7 +257,8 @@ def _get_migrations( if i == 0: previous = version continue - if version.schema.rule_set is not None and self._has_rules(version.schema.rule_set, migration_mode): + if version.schema.rule_set is not None and self._has_rules( + version.schema.rule_set, RulePhase.MIGRATION, migration_mode): if migration_mode == RuleMode.UPGRADE: migration = Migration(migration_mode, previous, version) else: @@ -265,7 +288,8 @@ def _execute_migrations( migrations: List[Migration], message: Any ) -> Any: for migration in migrations: - message = self._execute_rules(ser_ctx, subject, migration.rule_mode, - migration.source.schema, migration.target.schema, - message, None, None) + message = self._execute_rules_with_phase( + ser_ctx, subject, RulePhase.MIGRATION, migration.rule_mode, + migration.source.schema, migration.target.schema, + message, None, None) return message diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py index edbd38d02..27f9d946a 100644 --- a/src/confluent_kafka/schema_registry/common/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -312,6 +312,15 @@ def __str__(self) -> str: return str(self.value) +class RulePhase(str, Enum): + MIGRATION = "MIGRATION" + DOMAIN = "DOMAIN" + ENCODING = "ENCODING" + + def __str__(self) -> str: + return str(self.value) + + class RuleMode(str, Enum): UPGRADE = "UPGRADE" DOWNGRADE = "DOWNGRADE" @@ -471,6 +480,7 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: class RuleSet: migration_rules: Optional[List["Rule"]] = _attrs_field(hash=False) domain_rules: Optional[List["Rule"]] = _attrs_field(hash=False) + encoding_rules: Optional[List["Rule"]] = _attrs_field(hash=False, default=None) def to_dict(self) -> Dict[str, Any]: _migration_rules: Optional[List[Dict[str, Any]]] = None @@ -487,12 +497,21 @@ def to_dict(self) -> Dict[str, Any]: domain_rules_item = domain_rules_item_data.to_dict() _domain_rules.append(domain_rules_item) + _encoding_rules: Optional[List[Dict[str, Any]]] = None + if self.encoding_rules is not None: + _encoding_rules = [] + for encoding_rules_item_data in self.encoding_rules: + encoding_rules_item = encoding_rules_item_data.to_dict() + _encoding_rules.append(encoding_rules_item) + field_dict: Dict[str, Any] = {} field_dict.update({}) if _migration_rules is not None: field_dict["migrationRules"] = _migration_rules if _domain_rules is not None: field_dict["domainRules"] = _domain_rules + if _encoding_rules is not None: + field_dict["encodingRules"] = _encoding_rules return field_dict @@ -511,15 +530,24 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: domain_rules_item = Rule.from_dict(domain_rules_item_data) domain_rules.append(domain_rules_item) + encoding_rules = [] + _encoding_rules = d.pop("encodingRules", None) + for encoding_rules_item_data in _encoding_rules or []: + encoding_rules_item = Rule.from_dict(encoding_rules_item_data) + encoding_rules.append(encoding_rules_item) + rule_set = cls( migration_rules=migration_rules, domain_rules=domain_rules, + encoding_rules=encoding_rules, ) return rule_set def __hash__(self): - return hash(frozenset((self.migration_rules or []) + (self.domain_rules or []))) + return hash(frozenset((self.migration_rules or []) + + (self.domain_rules or []) + + (self.encoding_rules or []))) @_attrs_define diff --git a/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py b/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py index cb480f145..8a05e3077 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py @@ -13,6 +13,7 @@ # limitations under the License. import base64 +import io import logging import time from typing import Optional, Tuple, Any @@ -30,8 +31,8 @@ from confluent_kafka.schema_registry.rules.encryption.kms_driver_registry import \ get_kms_driver, KmsDriver from confluent_kafka.schema_registry.serde import RuleContext, \ - FieldRuleExecutor, FieldTransform, RuleError, FieldContext, FieldType - + RuleError, RuleExecutor, FieldType, FieldRuleExecutor, FieldTransform, \ + FieldContext log = logging.getLogger(__name__) @@ -53,7 +54,7 @@ def now(self) -> int: return int(round(time.time() * 1000)) -class FieldEncryptionExecutor(FieldRuleExecutor): +class EncryptionExecutor(RuleExecutor): def __init__(self, clock: Clock = Clock()): self.client = None @@ -81,15 +82,19 @@ def configure(self, client_conf: dict, rule_conf: dict): self.config = rule_conf if rule_conf else {} def type(self) -> str: - return "ENCRYPT" + return "ENCRYPT_PAYLOAD" - def new_transform(self, ctx: RuleContext) -> FieldTransform: + def transform(self, ctx: RuleContext, message: Any) -> Any: + executor = self.new_transform(ctx) + return executor.transform(ctx, FieldType.BYTES, message) + + def new_transform(self, ctx: RuleContext) -> 'EncryptionExecutorTransform': cryptor = self._get_cryptor(ctx) kek_name = self._get_kek_name(ctx) dek_expiry_days = self._get_dek_expiry_days(ctx) - transform = FieldEncryptionExecutorTransform( + transform = EncryptionExecutorTransform( self, cryptor, kek_name, dek_expiry_days) - return transform.transform + return transform def close(self): if self.client is not None: @@ -125,11 +130,11 @@ def _get_dek_expiry_days(self, ctx: RuleContext) -> int: @classmethod def register(cls): - RuleRegistry.register_rule_executor(FieldEncryptionExecutor()) + RuleRegistry.register_rule_executor(EncryptionExecutor()) @classmethod - def register_with_clock(cls, clock: Clock) -> 'FieldEncryptionExecutor': - executor = FieldEncryptionExecutor(clock) + def register_with_clock(cls, clock: Clock) -> 'EncryptionExecutor': + executor = EncryptionExecutor(clock) RuleRegistry.register_rule_executor(executor) return executor @@ -191,9 +196,9 @@ def decrypt(self, dek: bytes, ciphertext: bytes, associated_data: bytes) -> byte return primitive.decrypt(ciphertext, associated_data) -class FieldEncryptionExecutorTransform(object): +class EncryptionExecutorTransform(object): - def __init__(self, executor: FieldEncryptionExecutor, cryptor: Cryptor, kek_name: str, dek_expiry_days: int): + def __init__(self, executor: EncryptionExecutor, cryptor: Cryptor, kek_name: str, dek_expiry_days: int): self._executor = executor self._cryptor = cryptor self._kek_name = kek_name @@ -341,13 +346,13 @@ def _is_expired(self, ctx: RuleContext, dek: Optional[Dek]) -> bool: and dek is not None and (now - dek.ts) / MILLIS_IN_DAY > self._dek_expiry_days) - def transform(self, ctx: RuleContext, field_ctx: FieldContext, field_value: Any) -> Any: + def transform(self, ctx: RuleContext, field_type: FieldType, field_value: Any) -> Any: if field_value is None: return None if ctx.rule_mode == RuleMode.WRITE: - plaintext = self._to_bytes(field_ctx.field_type, field_value) + plaintext = self._to_bytes(field_type, field_value) if plaintext is None: - raise RuleError(f"type {field_ctx.field_type} not supported for encryption") + raise RuleError(f"type {field_type} not supported for encryption") version = None if self._is_dek_rotated(): version = -1 @@ -356,15 +361,15 @@ def transform(self, ctx: RuleContext, field_ctx: FieldContext, field_value: Any) ciphertext = self._cryptor.encrypt(key_material_bytes, plaintext, Cryptor.EMPTY_AAD) if self._is_dek_rotated(): ciphertext = self._prefix_version(dek.version, ciphertext) - if field_ctx.field_type == FieldType.STRING: + if field_type == FieldType.STRING: return base64.b64encode(ciphertext).decode("utf-8") else: - return self._to_object(field_ctx.field_type, ciphertext) + return self._to_object(field_type, ciphertext) elif ctx.rule_mode == RuleMode.READ: - if field_ctx.field_type == FieldType.STRING: + if field_type == FieldType.STRING: ciphertext = base64.b64decode(field_value) else: - ciphertext = self._to_bytes(field_ctx.field_type, field_value) + ciphertext = self._to_bytes(field_type, field_value) if ciphertext is None: return field_value @@ -376,7 +381,7 @@ def transform(self, ctx: RuleContext, field_ctx: FieldContext, field_value: Any) dek = self._get_or_create_dek(ctx, version) key_material_bytes = dek.get_key_material_bytes() plaintext = self._cryptor.decrypt(key_material_bytes, ciphertext, Cryptor.EMPTY_AAD) - return self._to_object(field_ctx.field_type, plaintext) + return self._to_object(field_type, plaintext) else: raise RuleError(f"unsupported rule mode {ctx.rule_mode}") @@ -393,6 +398,8 @@ def _to_bytes(self, field_type: FieldType, value: Any) -> Optional[bytes]: if field_type == FieldType.STRING: return value.encode("utf-8") elif field_type == FieldType.BYTES: + if isinstance(value, io.BytesIO): + return value.read() return value return None @@ -420,3 +427,43 @@ def _register_kms_client(self, kms_driver: KmsDriver, config: dict, kek_url: str kms_client = kms_driver.new_kms_client(config, kek_url) register_kms_client(kms_client) return kms_client + + +class FieldEncryptionExecutor(FieldRuleExecutor): + + def __init__(self, clock: Clock = Clock()): + self.executor = EncryptionExecutor(clock) + + def configure(self, client_conf: dict, rule_conf: dict): + self.executor.configure(client_conf, rule_conf) + + def type(self) -> str: + return "ENCRYPT" + + def new_transform(self, ctx: RuleContext) -> FieldTransform: + executor_transform = self.executor.new_transform(ctx) + transform = FieldEncryptionExecutorTransform(executor_transform) + return transform.transform + + def close(self): + if self.client is not None: + self.client.__exit__() + + @classmethod + def register(cls): + RuleRegistry.register_rule_executor(FieldEncryptionExecutor()) + + @classmethod + def register_with_clock(cls, clock: Clock) -> 'FieldEncryptionExecutor': + executor = FieldEncryptionExecutor(clock) + RuleRegistry.register_rule_executor(executor) + return executor + + +class FieldEncryptionExecutorTransform(object): + + def __init__(self, executor_transform: 'EncryptionExecutorTransform'): + self.executor_transform = executor_transform + + def transform(self, ctx: RuleContext, field_ctx: FieldContext, field_value: Any) -> Any: + return self.executor_transform.transform(ctx, field_ctx.field_type, field_value) diff --git a/tests/schema_registry/_async/test_avro_serdes.py b/tests/schema_registry/_async/test_avro_serdes.py index bc06fcd64..bb95f4098 100644 --- a/tests/schema_registry/_async/test_avro_serdes.py +++ b/tests/schema_registry/_async/test_avro_serdes.py @@ -38,7 +38,7 @@ from confluent_kafka.schema_registry.rules.encryption.dek_registry.dek_registry_client import \ DekRegistryClient, DekAlgorithm from confluent_kafka.schema_registry.rules.encryption.encrypt_executor import \ - FieldEncryptionExecutor, Clock + FieldEncryptionExecutor, Clock, EncryptionExecutor from confluent_kafka.schema_registry.rules.encryption.gcpkms.gcp_driver import \ GcpKmsDriver from confluent_kafka.schema_registry.rules.encryption.hcvault.hcvault_driver import \ @@ -1125,7 +1125,7 @@ async def test_avro_encryption(): 'bytesField': b'foobar', } ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = await ser(obj, ser_ctx) @@ -1134,6 +1134,68 @@ async def test_avro_encryption(): obj['stringField'] = 'hi' obj['bytesField'] = b'foobar' + deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) + executor.executor.client = dek_client + obj2 = await deser(obj_bytes, ser_ctx) + assert obj == obj2 + + +async def test_avro_payload_encryption(): + executor = EncryptionExecutor.register_with_clock(FakeClock()) + + conf = {'url': _BASE_URL} + client = AsyncSchemaRegistryClient.new_client(conf) + ser_conf = {'auto.register.schemas': False, 'use.latest.version': True} + rule_conf = {'secret': 'mysecret'} + schema = { + 'type': 'record', + 'name': 'test', + 'fields': [ + {'name': 'intField', 'type': 'int'}, + {'name': 'doubleField', 'type': 'double'}, + {'name': 'stringField', 'type': 'string', 'confluent:tags': ['PII']}, + {'name': 'booleanField', 'type': 'boolean'}, + {'name': 'bytesField', 'type': 'bytes', 'confluent:tags': ['PII']}, + ] + } + + rule = Rule( + "test-encrypt", + "", + RuleKind.TRANSFORM, + RuleMode.WRITEREAD, + "ENCRYPT_PAYLOAD", + None, + RuleParams({ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey" + }), + None, + None, + "ERROR,NONE", + False + ) + await client.register_schema(_SUBJECT, Schema( + json.dumps(schema), + "AVRO", + [], + None, + RuleSet(None, None, [rule]) + )) + + obj = { + 'intField': 123, + 'doubleField': 45.67, + 'stringField': 'hi', + 'booleanField': True, + 'bytesField': b'foobar', + } + ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) + dek_client = executor.client + ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) + obj_bytes = await ser(obj, ser_ctx) + deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) executor.client = dek_client obj2 = await deser(obj_bytes, ser_ctx) @@ -1193,7 +1255,7 @@ async def test_avro_encryption_deterministic(): 'bytesField': b'foobar', } ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = await ser(obj, ser_ctx) @@ -1203,7 +1265,7 @@ async def test_avro_encryption_deterministic(): obj['bytesField'] = b'foobar' deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = await deser(obj_bytes, ser_ctx) assert obj == obj2 @@ -1273,7 +1335,7 @@ async def test_avro_encryption_cel(): 'bytesField': b'foobar', } ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = await ser(obj, ser_ctx) @@ -1283,7 +1345,7 @@ async def test_avro_encryption_cel(): obj['bytesField'] = b'foobar' deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = await deser(obj_bytes, ser_ctx) assert obj == obj2 @@ -1341,7 +1403,7 @@ async def test_avro_encryption_dek_rotation(): 'bytesField': b'foobar', } ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) - dek_client: DekRegistryClient = executor.client + dek_client: DekRegistryClient = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = await ser(obj, ser_ctx) @@ -1350,17 +1412,17 @@ async def test_avro_encryption_dek_rotation(): obj['stringField'] = 'hi' deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = await deser(obj_bytes, ser_ctx) assert obj == obj2 - dek_client = executor.client + dek_client = executor.executor.client dek = dek_client.get_dek("kek1-rot", _SUBJECT, version=-1) assert dek.version == 1 # advance 2 days now = datetime.now() + timedelta(days=2) - executor.clock.fixed_now = int(round(now.timestamp() * 1000)) + executor.executor.clock.fixed_now = int(round(now.timestamp() * 1000)) obj_bytes = await ser(obj, ser_ctx) @@ -1376,7 +1438,7 @@ async def test_avro_encryption_dek_rotation(): # advance 2 days now = datetime.now() + timedelta(days=2) - executor.clock.fixed_now = int(round(now.timestamp() * 1000)) + executor.executor.clock.fixed_now = int(round(now.timestamp() * 1000)) obj_bytes = await ser(obj, ser_ctx) @@ -1437,7 +1499,7 @@ async def test_avro_encryption_f1_preserialized(): ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) - dek_client: DekRegistryClient = executor.client + dek_client: DekRegistryClient = executor.executor.client dek_client.register_kek("kek1-f1", "local-kms", "mykey") encrypted_dek = "07V2ndh02DA73p+dTybwZFm7DKQSZN1tEwQh+FoX1DZLk4Yj2LLu4omYjp/84tAg3BYlkfGSz+zZacJHIE4=" @@ -1500,7 +1562,7 @@ async def test_avro_encryption_deterministic_f1_preserialized(): ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) - dek_client: DekRegistryClient = executor.client + dek_client: DekRegistryClient = executor.executor.client dek_client.register_kek("kek1-det-f1", "local-kms", "mykey") encrypted_dek = ("YSx3DTlAHrmpoDChquJMifmPntBzxgRVdMzgYL82rgWBKn7aUSnG+WIu9oz" @@ -1562,7 +1624,7 @@ async def test_avro_encryption_dek_rotation_f1_preserialized(): ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) - dek_client: DekRegistryClient = executor.client + dek_client: DekRegistryClient = executor.executor.client dek_client.register_kek("kek1-rot-f1", "local-kms", "mykey") encrypted_dek = "W/v6hOQYq1idVAcs1pPWz9UUONMVZW4IrglTnG88TsWjeCjxmtRQ4VaNe/I5dCfm2zyY9Cu0nqdvqImtUk4=" @@ -1642,7 +1704,7 @@ async def test_avro_encryption_references(): )) ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = await ser(obj, ser_ctx) @@ -1652,7 +1714,7 @@ async def test_avro_encryption_references(): obj['refField']['bytesField'] = b'foobar' deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = await deser(obj_bytes, ser_ctx) assert obj == obj2 @@ -1709,7 +1771,7 @@ async def test_avro_encryption_with_union(): 'bytesField': b'foobar', } ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = await ser(obj, ser_ctx) @@ -1719,7 +1781,7 @@ async def test_avro_encryption_with_union(): obj['bytesField'] = b'foobar' deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = await deser(obj_bytes, ser_ctx) assert obj == obj2 diff --git a/tests/schema_registry/_async/test_json_serdes.py b/tests/schema_registry/_async/test_json_serdes.py index 69e984bb8..31a4aed55 100644 --- a/tests/schema_registry/_async/test_json_serdes.py +++ b/tests/schema_registry/_async/test_json_serdes.py @@ -32,7 +32,7 @@ from confluent_kafka.schema_registry.rules.encryption.azurekms.azure_driver import \ AzureKmsDriver from confluent_kafka.schema_registry.rules.encryption.encrypt_executor import \ - FieldEncryptionExecutor + FieldEncryptionExecutor, EncryptionExecutor from confluent_kafka.schema_registry.rules.encryption.gcpkms.gcp_driver import \ GcpKmsDriver from confluent_kafka.schema_registry.rules.encryption.hcvault.hcvault_driver import \ @@ -980,7 +980,7 @@ async def test_json_encryption(): 'bytesField': base64.b64encode(b'foobar').decode('utf-8'), } ser = await AsyncJSONSerializer(json.dumps(schema), client, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = await ser(obj, ser_ctx) @@ -989,6 +989,74 @@ async def test_json_encryption(): obj['stringField'] = 'hi' obj['bytesField'] = base64.b64encode(b'foobar').decode('utf-8') + deser = await AsyncJSONDeserializer(None, schema_registry_client=client, rule_conf=rule_conf) + executor.executor.client = dek_client + obj2 = await deser(obj_bytes, ser_ctx) + assert obj == obj2 + + +async def test_json_payloadencryption(): + executor = EncryptionExecutor.register_with_clock(FakeClock()) + + conf = {'url': _BASE_URL} + client = AsyncSchemaRegistryClient.new_client(conf) + ser_conf = {'auto.register.schemas': False, 'use.latest.version': True} + rule_conf = {'secret': 'mysecret'} + schema = { + "type": "object", + "properties": { + "intField": {"type": "integer"}, + "doubleField": {"type": "number"}, + "stringField": { + "type": "string", + "confluent:tags": ["PII"] + }, + "booleanField": {"type": "boolean"}, + "bytesField": { + "type": "string", + "contentEncoding": "base64", + "confluent:tags": ["PII"] + } + } + } + + rule = Rule( + "test-encrypt", + "", + RuleKind.TRANSFORM, + RuleMode.WRITEREAD, + "ENCRYPT_PAYLOAD", + None, + RuleParams({ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey" + }), + None, + None, + "ERROR,NONE", + False + ) + await client.register_schema(_SUBJECT, Schema( + json.dumps(schema), + "JSON", + [], + None, + RuleSet(None, None, [rule]) + )) + + obj = { + 'intField': 123, + 'doubleField': 45.67, + 'stringField': 'hi', + 'booleanField': True, + 'bytesField': base64.b64encode(b'foobar').decode('utf-8'), + } + ser = await AsyncJSONSerializer(json.dumps(schema), client, conf=ser_conf, rule_conf=rule_conf) + dek_client = executor.client + ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) + obj_bytes = await ser(obj, ser_ctx) + deser = await AsyncJSONDeserializer(None, schema_registry_client=client, rule_conf=rule_conf) executor.client = dek_client obj2 = await deser(obj_bytes, ser_ctx) @@ -1060,7 +1128,7 @@ async def test_json_encryption_with_union(): 'bytesField': base64.b64encode(b'foobar').decode('utf-8'), } ser = await AsyncJSONSerializer(json.dumps(schema), client, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = await ser(obj, ser_ctx) @@ -1070,7 +1138,7 @@ async def test_json_encryption_with_union(): obj['bytesField'] = base64.b64encode(b'foobar').decode('utf-8') deser = await AsyncJSONDeserializer(None, schema_registry_client=client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = await deser(obj_bytes, ser_ctx) assert obj == obj2 @@ -1151,7 +1219,7 @@ async def test_json_encryption_with_references(): 'otherField': nested } ser = await AsyncJSONSerializer(json.dumps(schema), client, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = await ser(obj, ser_ctx) @@ -1161,7 +1229,7 @@ async def test_json_encryption_with_references(): obj['otherField']['bytesField'] = base64.b64encode(b'foobar').decode('utf-8') deser = await AsyncJSONDeserializer(None, schema_registry_client=client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = await deser(obj_bytes, ser_ctx) assert obj == obj2 diff --git a/tests/schema_registry/_async/test_proto_serdes.py b/tests/schema_registry/_async/test_proto_serdes.py index 053db03ff..5c742ac38 100644 --- a/tests/schema_registry/_async/test_proto_serdes.py +++ b/tests/schema_registry/_async/test_proto_serdes.py @@ -34,7 +34,7 @@ from confluent_kafka.schema_registry.rules.encryption.azurekms.azure_driver import \ AzureKmsDriver from confluent_kafka.schema_registry.rules.encryption.encrypt_executor import \ - FieldEncryptionExecutor, Clock + FieldEncryptionExecutor, Clock, EncryptionExecutor from confluent_kafka.schema_registry.rules.encryption.gcpkms.gcp_driver import \ GcpKmsDriver from confluent_kafka.schema_registry.rules.encryption.hcvault.hcvault_driver import \ @@ -568,7 +568,7 @@ async def test_proto_encryption(): oneof_string='oneof' ) ser = await AsyncProtobufSerializer(example_pb2.Author, client, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = await ser(obj, ser_ctx) @@ -582,6 +582,62 @@ async def test_proto_encryption(): oneof_string='oneof' ) + deser_conf = { + 'use.deprecated.format': False + } + deser = await AsyncProtobufDeserializer(example_pb2.Author, deser_conf, client, rule_conf=rule_conf) + executor.executor.client = dek_client + obj2 = await deser(obj_bytes, ser_ctx) + assert obj == obj2 + + +async def test_proto_payload_encryption(): + executor = EncryptionExecutor.register_with_clock(FakeClock()) + + conf = {'url': _BASE_URL} + client = AsyncSchemaRegistryClient.new_client(conf) + ser_conf = { + 'auto.register.schemas': False, + 'use.latest.version': True, + 'use.deprecated.format': False + } + rule_conf = {'secret': 'mysecret'} + rule = Rule( + "test-encrypt", + "", + RuleKind.TRANSFORM, + RuleMode.WRITEREAD, + "ENCRYPT_PAYLOAD", + None, + RuleParams({ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey" + }), + None, + None, + "ERROR,NONE", + False + ) + await client.register_schema(_SUBJECT, Schema( + _schema_to_str(example_pb2.Author.DESCRIPTOR.file), + "PROTOBUF", + [], + None, + RuleSet(None, None, [rule]) + )) + obj = example_pb2.Author( + name='Kafka', + id=123, + picture=b'foobar', + works=['The Castle', 'TheTrial'], + oneof_string='oneof' + ) + ser = await AsyncProtobufSerializer(example_pb2.Author, client, conf=ser_conf, rule_conf=rule_conf) + dek_client = executor.client + ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) + obj_bytes = await ser(obj, ser_ctx) + deser_conf = { 'use.deprecated.format': False } diff --git a/tests/schema_registry/_sync/test_avro_serdes.py b/tests/schema_registry/_sync/test_avro_serdes.py index dadf530db..735d0f29a 100644 --- a/tests/schema_registry/_sync/test_avro_serdes.py +++ b/tests/schema_registry/_sync/test_avro_serdes.py @@ -38,7 +38,7 @@ from confluent_kafka.schema_registry.rules.encryption.dek_registry.dek_registry_client import \ DekRegistryClient, DekAlgorithm from confluent_kafka.schema_registry.rules.encryption.encrypt_executor import \ - FieldEncryptionExecutor, Clock + FieldEncryptionExecutor, Clock, EncryptionExecutor from confluent_kafka.schema_registry.rules.encryption.gcpkms.gcp_driver import \ GcpKmsDriver from confluent_kafka.schema_registry.rules.encryption.hcvault.hcvault_driver import \ @@ -1125,7 +1125,7 @@ def test_avro_encryption(): 'bytesField': b'foobar', } ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = ser(obj, ser_ctx) @@ -1134,6 +1134,68 @@ def test_avro_encryption(): obj['stringField'] = 'hi' obj['bytesField'] = b'foobar' + deser = AvroDeserializer(client, rule_conf=rule_conf) + executor.executor.client = dek_client + obj2 = deser(obj_bytes, ser_ctx) + assert obj == obj2 + + +def test_avro_payload_encryption(): + executor = EncryptionExecutor.register_with_clock(FakeClock()) + + conf = {'url': _BASE_URL} + client = SchemaRegistryClient.new_client(conf) + ser_conf = {'auto.register.schemas': False, 'use.latest.version': True} + rule_conf = {'secret': 'mysecret'} + schema = { + 'type': 'record', + 'name': 'test', + 'fields': [ + {'name': 'intField', 'type': 'int'}, + {'name': 'doubleField', 'type': 'double'}, + {'name': 'stringField', 'type': 'string', 'confluent:tags': ['PII']}, + {'name': 'booleanField', 'type': 'boolean'}, + {'name': 'bytesField', 'type': 'bytes', 'confluent:tags': ['PII']}, + ] + } + + rule = Rule( + "test-encrypt", + "", + RuleKind.TRANSFORM, + RuleMode.WRITEREAD, + "ENCRYPT_PAYLOAD", + None, + RuleParams({ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey" + }), + None, + None, + "ERROR,NONE", + False + ) + client.register_schema(_SUBJECT, Schema( + json.dumps(schema), + "AVRO", + [], + None, + RuleSet(None, None, [rule]) + )) + + obj = { + 'intField': 123, + 'doubleField': 45.67, + 'stringField': 'hi', + 'booleanField': True, + 'bytesField': b'foobar', + } + ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) + dek_client = executor.client + ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) + obj_bytes = ser(obj, ser_ctx) + deser = AvroDeserializer(client, rule_conf=rule_conf) executor.client = dek_client obj2 = deser(obj_bytes, ser_ctx) @@ -1193,7 +1255,7 @@ def test_avro_encryption_deterministic(): 'bytesField': b'foobar', } ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = ser(obj, ser_ctx) @@ -1203,7 +1265,7 @@ def test_avro_encryption_deterministic(): obj['bytesField'] = b'foobar' deser = AvroDeserializer(client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = deser(obj_bytes, ser_ctx) assert obj == obj2 @@ -1273,7 +1335,7 @@ def test_avro_encryption_cel(): 'bytesField': b'foobar', } ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = ser(obj, ser_ctx) @@ -1283,7 +1345,7 @@ def test_avro_encryption_cel(): obj['bytesField'] = b'foobar' deser = AvroDeserializer(client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = deser(obj_bytes, ser_ctx) assert obj == obj2 @@ -1341,7 +1403,7 @@ def test_avro_encryption_dek_rotation(): 'bytesField': b'foobar', } ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) - dek_client: DekRegistryClient = executor.client + dek_client: DekRegistryClient = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = ser(obj, ser_ctx) @@ -1350,17 +1412,17 @@ def test_avro_encryption_dek_rotation(): obj['stringField'] = 'hi' deser = AvroDeserializer(client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = deser(obj_bytes, ser_ctx) assert obj == obj2 - dek_client = executor.client + dek_client = executor.executor.client dek = dek_client.get_dek("kek1-rot", _SUBJECT, version=-1) assert dek.version == 1 # advance 2 days now = datetime.now() + timedelta(days=2) - executor.clock.fixed_now = int(round(now.timestamp() * 1000)) + executor.executor.clock.fixed_now = int(round(now.timestamp() * 1000)) obj_bytes = ser(obj, ser_ctx) @@ -1376,7 +1438,7 @@ def test_avro_encryption_dek_rotation(): # advance 2 days now = datetime.now() + timedelta(days=2) - executor.clock.fixed_now = int(round(now.timestamp() * 1000)) + executor.executor.clock.fixed_now = int(round(now.timestamp() * 1000)) obj_bytes = ser(obj, ser_ctx) @@ -1437,7 +1499,7 @@ def test_avro_encryption_f1_preserialized(): ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) deser = AvroDeserializer(client, rule_conf=rule_conf) - dek_client: DekRegistryClient = executor.client + dek_client: DekRegistryClient = executor.executor.client dek_client.register_kek("kek1-f1", "local-kms", "mykey") encrypted_dek = "07V2ndh02DA73p+dTybwZFm7DKQSZN1tEwQh+FoX1DZLk4Yj2LLu4omYjp/84tAg3BYlkfGSz+zZacJHIE4=" @@ -1500,7 +1562,7 @@ def test_avro_encryption_deterministic_f1_preserialized(): ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) deser = AvroDeserializer(client, rule_conf=rule_conf) - dek_client: DekRegistryClient = executor.client + dek_client: DekRegistryClient = executor.executor.client dek_client.register_kek("kek1-det-f1", "local-kms", "mykey") encrypted_dek = ("YSx3DTlAHrmpoDChquJMifmPntBzxgRVdMzgYL82rgWBKn7aUSnG+WIu9oz" @@ -1562,7 +1624,7 @@ def test_avro_encryption_dek_rotation_f1_preserialized(): ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) deser = AvroDeserializer(client, rule_conf=rule_conf) - dek_client: DekRegistryClient = executor.client + dek_client: DekRegistryClient = executor.executor.client dek_client.register_kek("kek1-rot-f1", "local-kms", "mykey") encrypted_dek = "W/v6hOQYq1idVAcs1pPWz9UUONMVZW4IrglTnG88TsWjeCjxmtRQ4VaNe/I5dCfm2zyY9Cu0nqdvqImtUk4=" @@ -1642,7 +1704,7 @@ def test_avro_encryption_references(): )) ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = ser(obj, ser_ctx) @@ -1652,7 +1714,7 @@ def test_avro_encryption_references(): obj['refField']['bytesField'] = b'foobar' deser = AvroDeserializer(client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = deser(obj_bytes, ser_ctx) assert obj == obj2 @@ -1709,7 +1771,7 @@ def test_avro_encryption_with_union(): 'bytesField': b'foobar', } ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = ser(obj, ser_ctx) @@ -1719,7 +1781,7 @@ def test_avro_encryption_with_union(): obj['bytesField'] = b'foobar' deser = AvroDeserializer(client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = deser(obj_bytes, ser_ctx) assert obj == obj2 diff --git a/tests/schema_registry/_sync/test_json_serdes.py b/tests/schema_registry/_sync/test_json_serdes.py index 56b3714a2..3e1c7b4b3 100644 --- a/tests/schema_registry/_sync/test_json_serdes.py +++ b/tests/schema_registry/_sync/test_json_serdes.py @@ -32,7 +32,7 @@ from confluent_kafka.schema_registry.rules.encryption.azurekms.azure_driver import \ AzureKmsDriver from confluent_kafka.schema_registry.rules.encryption.encrypt_executor import \ - FieldEncryptionExecutor + FieldEncryptionExecutor, EncryptionExecutor from confluent_kafka.schema_registry.rules.encryption.gcpkms.gcp_driver import \ GcpKmsDriver from confluent_kafka.schema_registry.rules.encryption.hcvault.hcvault_driver import \ @@ -980,7 +980,7 @@ def test_json_encryption(): 'bytesField': base64.b64encode(b'foobar').decode('utf-8'), } ser = JSONSerializer(json.dumps(schema), client, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = ser(obj, ser_ctx) @@ -989,6 +989,74 @@ def test_json_encryption(): obj['stringField'] = 'hi' obj['bytesField'] = base64.b64encode(b'foobar').decode('utf-8') + deser = JSONDeserializer(None, schema_registry_client=client, rule_conf=rule_conf) + executor.executor.client = dek_client + obj2 = deser(obj_bytes, ser_ctx) + assert obj == obj2 + + +def test_json_payloadencryption(): + executor = EncryptionExecutor.register_with_clock(FakeClock()) + + conf = {'url': _BASE_URL} + client = SchemaRegistryClient.new_client(conf) + ser_conf = {'auto.register.schemas': False, 'use.latest.version': True} + rule_conf = {'secret': 'mysecret'} + schema = { + "type": "object", + "properties": { + "intField": {"type": "integer"}, + "doubleField": {"type": "number"}, + "stringField": { + "type": "string", + "confluent:tags": ["PII"] + }, + "booleanField": {"type": "boolean"}, + "bytesField": { + "type": "string", + "contentEncoding": "base64", + "confluent:tags": ["PII"] + } + } + } + + rule = Rule( + "test-encrypt", + "", + RuleKind.TRANSFORM, + RuleMode.WRITEREAD, + "ENCRYPT_PAYLOAD", + None, + RuleParams({ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey" + }), + None, + None, + "ERROR,NONE", + False + ) + client.register_schema(_SUBJECT, Schema( + json.dumps(schema), + "JSON", + [], + None, + RuleSet(None, None, [rule]) + )) + + obj = { + 'intField': 123, + 'doubleField': 45.67, + 'stringField': 'hi', + 'booleanField': True, + 'bytesField': base64.b64encode(b'foobar').decode('utf-8'), + } + ser = JSONSerializer(json.dumps(schema), client, conf=ser_conf, rule_conf=rule_conf) + dek_client = executor.client + ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) + obj_bytes = ser(obj, ser_ctx) + deser = JSONDeserializer(None, schema_registry_client=client, rule_conf=rule_conf) executor.client = dek_client obj2 = deser(obj_bytes, ser_ctx) @@ -1060,7 +1128,7 @@ def test_json_encryption_with_union(): 'bytesField': base64.b64encode(b'foobar').decode('utf-8'), } ser = JSONSerializer(json.dumps(schema), client, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = ser(obj, ser_ctx) @@ -1070,7 +1138,7 @@ def test_json_encryption_with_union(): obj['bytesField'] = base64.b64encode(b'foobar').decode('utf-8') deser = JSONDeserializer(None, schema_registry_client=client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = deser(obj_bytes, ser_ctx) assert obj == obj2 @@ -1151,7 +1219,7 @@ def test_json_encryption_with_references(): 'otherField': nested } ser = JSONSerializer(json.dumps(schema), client, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = ser(obj, ser_ctx) @@ -1161,7 +1229,7 @@ def test_json_encryption_with_references(): obj['otherField']['bytesField'] = base64.b64encode(b'foobar').decode('utf-8') deser = JSONDeserializer(None, schema_registry_client=client, rule_conf=rule_conf) - executor.client = dek_client + executor.executor.client = dek_client obj2 = deser(obj_bytes, ser_ctx) assert obj == obj2 diff --git a/tests/schema_registry/_sync/test_proto_serdes.py b/tests/schema_registry/_sync/test_proto_serdes.py index 5448240c2..1ce23c807 100644 --- a/tests/schema_registry/_sync/test_proto_serdes.py +++ b/tests/schema_registry/_sync/test_proto_serdes.py @@ -34,7 +34,7 @@ from confluent_kafka.schema_registry.rules.encryption.azurekms.azure_driver import \ AzureKmsDriver from confluent_kafka.schema_registry.rules.encryption.encrypt_executor import \ - FieldEncryptionExecutor, Clock + FieldEncryptionExecutor, Clock, EncryptionExecutor from confluent_kafka.schema_registry.rules.encryption.gcpkms.gcp_driver import \ GcpKmsDriver from confluent_kafka.schema_registry.rules.encryption.hcvault.hcvault_driver import \ @@ -568,7 +568,7 @@ def test_proto_encryption(): oneof_string='oneof' ) ser = ProtobufSerializer(example_pb2.Author, client, conf=ser_conf, rule_conf=rule_conf) - dek_client = executor.client + dek_client = executor.executor.client ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) obj_bytes = ser(obj, ser_ctx) @@ -582,6 +582,62 @@ def test_proto_encryption(): oneof_string='oneof' ) + deser_conf = { + 'use.deprecated.format': False + } + deser = ProtobufDeserializer(example_pb2.Author, deser_conf, client, rule_conf=rule_conf) + executor.executor.client = dek_client + obj2 = deser(obj_bytes, ser_ctx) + assert obj == obj2 + + +def test_proto_payload_encryption(): + executor = EncryptionExecutor.register_with_clock(FakeClock()) + + conf = {'url': _BASE_URL} + client = SchemaRegistryClient.new_client(conf) + ser_conf = { + 'auto.register.schemas': False, + 'use.latest.version': True, + 'use.deprecated.format': False + } + rule_conf = {'secret': 'mysecret'} + rule = Rule( + "test-encrypt", + "", + RuleKind.TRANSFORM, + RuleMode.WRITEREAD, + "ENCRYPT_PAYLOAD", + None, + RuleParams({ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey" + }), + None, + None, + "ERROR,NONE", + False + ) + client.register_schema(_SUBJECT, Schema( + _schema_to_str(example_pb2.Author.DESCRIPTOR.file), + "PROTOBUF", + [], + None, + RuleSet(None, None, [rule]) + )) + obj = example_pb2.Author( + name='Kafka', + id=123, + picture=b'foobar', + works=['The Castle', 'TheTrial'], + oneof_string='oneof' + ) + ser = ProtobufSerializer(example_pb2.Author, client, conf=ser_conf, rule_conf=rule_conf) + dek_client = executor.client + ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) + obj_bytes = ser(obj, ser_ctx) + deser_conf = { 'use.deprecated.format': False } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy