Skip to content

Commit c84a217

Browse files
authored
Add unit tests for async SR client and async serdes (confluentinc#1994)
* Add unit tests for async SR client and async serdes * fix flake8 issues * fix test_delivery_report_serialization * run unasync
1 parent 4cc24f3 commit c84a217

32 files changed

+6223
-133
lines changed

src/confluent_kafka/schema_registry/_async/avro.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ async def __deserialize(
558558
latest_schema = await self._get_reader_schema(subject)
559559

560560
if latest_schema is not None:
561-
migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None)
561+
migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None)
562562
reader_schema_raw = latest_schema.schema
563563
reader_schema = await self._get_parsed_schema(latest_schema.schema)
564564
elif self._schema is not None:

src/confluent_kafka/schema_registry/_async/json_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N
598598
writer_schema, writer_ref_registry = None, None
599599

600600
if latest_schema is not None:
601-
migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None)
601+
migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None)
602602
reader_schema_raw = latest_schema.schema
603603
reader_schema, reader_ref_registry = await self._get_parsed_schema(latest_schema.schema)
604604
elif self._schema is not None:
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright 2024 Confluent Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
import uuid
19+
from collections import defaultdict
20+
from threading import Lock
21+
from typing import List, Dict, Optional
22+
23+
from .schema_registry_client import AsyncSchemaRegistryClient
24+
from ..common.schema_registry_client import RegisteredSchema, Schema, ServerConfig
25+
from ..error import SchemaRegistryError
26+
27+
28+
class _SchemaStore(object):
29+
30+
def __init__(self):
31+
self.lock = Lock()
32+
self.max_id = 0
33+
self.schema_id_index = {}
34+
self.schema_guid_index = {}
35+
self.schema_index = {}
36+
self.subject_schemas = defaultdict(set)
37+
38+
def set(self, registered_schema: RegisteredSchema) -> RegisteredSchema:
39+
with self.lock:
40+
self.max_id += 1
41+
rs = RegisteredSchema(
42+
schema_id=self.max_id,
43+
guid=registered_schema.guid,
44+
schema=registered_schema.schema,
45+
subject=registered_schema.subject,
46+
version=registered_schema.version
47+
)
48+
self.schema_id_index[rs.schema_id] = rs
49+
self.schema_guid_index[rs.guid] = rs
50+
self.schema_index[rs.schema] = rs.schema_id
51+
self.subject_schemas[rs.subject].add(rs)
52+
return rs
53+
54+
def get_schema(self, schema_id: int) -> Optional[Schema]:
55+
with self.lock:
56+
rs = self.schema_id_index.get(schema_id, None)
57+
return rs.schema if rs else None
58+
59+
def get_schema_by_guid(self, guid: str) -> Optional[Schema]:
60+
with self.lock:
61+
rs = self.schema_guid_index.get(guid, None)
62+
return rs.schema if rs else None
63+
64+
def get_registered_schema_by_schema(
65+
self,
66+
subject_name: str,
67+
schema: Schema
68+
) -> Optional[RegisteredSchema]:
69+
with self.lock:
70+
if subject_name in self.subject_schemas:
71+
for rs in self.subject_schemas[subject_name]:
72+
if rs.schema == schema:
73+
return rs
74+
return None
75+
76+
def get_version(self, subject_name: str, version: int) -> Optional[RegisteredSchema]:
77+
with self.lock:
78+
if subject_name in self.subject_schemas:
79+
for rs in self.subject_schemas[subject_name]:
80+
if rs.version == version:
81+
return rs
82+
return None
83+
84+
def get_latest_version(self, subject_name: str) -> Optional[RegisteredSchema]:
85+
with self.lock:
86+
if subject_name in self.subject_schemas:
87+
latest_version = 0
88+
latest_schema = None
89+
for rs in self.subject_schemas[subject_name]:
90+
if rs.version > latest_version:
91+
latest_version = rs.version
92+
latest_schema = rs
93+
return latest_schema
94+
return None
95+
96+
def get_latest_with_metadata(
97+
self, subject_name: str,
98+
metadata: Dict[str, str]
99+
) -> Optional[RegisteredSchema]:
100+
with self.lock:
101+
if subject_name in self.subject_schemas:
102+
rs: RegisteredSchema
103+
for rs in self.subject_schemas[subject_name]:
104+
if (rs.schema
105+
and rs.schema.metadata
106+
and rs.schema.metadata.properties
107+
and metadata.items() <= rs.schema.metadata.properties.properties.items()):
108+
return rs
109+
return None
110+
111+
def get_subjects(self) -> List[str]:
112+
with self.lock:
113+
return list(self.subject_schemas.keys())
114+
115+
def get_versions(self, subject_name: str) -> List[int]:
116+
with self.lock:
117+
if subject_name in self.subject_schemas:
118+
return [rs.version for rs in self.subject_schemas[subject_name]]
119+
return []
120+
121+
def remove_by_schema(self, registered_schema: RegisteredSchema):
122+
with self.lock:
123+
subject_name = registered_schema.subject
124+
if subject_name in self.subject_schemas:
125+
self.subject_schemas[subject_name].remove(registered_schema)
126+
127+
def remove_by_subject(self, subject_name: str) -> List[int]:
128+
with self.lock:
129+
versions = []
130+
if subject_name in self.subject_schemas:
131+
for rs in self.subject_schemas[subject_name]:
132+
versions.append(rs.version)
133+
schema_id = self.schema_index.pop(rs.schema, None)
134+
if schema_id is not None:
135+
self.schema_id_index.pop(schema_id, None)
136+
137+
del self.subject_schemas[subject_name]
138+
return versions
139+
140+
def clear(self):
141+
with self.lock:
142+
self.schema_id_index.clear()
143+
self.schema_guid_index.clear()
144+
self.schema_index.clear()
145+
self.subject_schemas.clear()
146+
147+
148+
class AsyncMockSchemaRegistryClient(AsyncSchemaRegistryClient):
149+
150+
def __init__(self, conf: dict):
151+
super().__init__(conf)
152+
self._store = _SchemaStore()
153+
154+
async def register_schema(
155+
self, subject_name: str, schema: 'Schema',
156+
normalize_schemas: bool = False
157+
) -> int:
158+
registered_schema = await self.register_schema_full_response(subject_name, schema, normalize_schemas)
159+
return registered_schema.schema_id
160+
161+
async def register_schema_full_response(
162+
self, subject_name: str, schema: 'Schema',
163+
normalize_schemas: bool = False
164+
) -> 'RegisteredSchema':
165+
registered_schema = self._store.get_registered_schema_by_schema(subject_name, schema)
166+
if registered_schema is not None:
167+
return registered_schema
168+
169+
latest_schema = self._store.get_latest_version(subject_name)
170+
latest_version = 1 if latest_schema is None else latest_schema.version + 1
171+
172+
registered_schema = RegisteredSchema(
173+
schema_id=1,
174+
guid=str(uuid.uuid4()),
175+
schema=schema,
176+
subject=subject_name,
177+
version=latest_version
178+
)
179+
180+
registered_schema = self._store.set(registered_schema)
181+
182+
return registered_schema
183+
184+
async def get_schema(
185+
self, schema_id: int, subject_name: Optional[str] = None,
186+
fmt: Optional[str] = None
187+
) -> 'Schema':
188+
schema = self._store.get_schema(schema_id)
189+
if schema is not None:
190+
return schema
191+
192+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
193+
194+
async def get_schema_by_guid(
195+
self, guid: str, fmt: Optional[str] = None
196+
) -> 'Schema':
197+
schema = self._store.get_schema_by_guid(guid)
198+
if schema is not None:
199+
return schema
200+
201+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
202+
203+
async def lookup_schema(
204+
self, subject_name: str, schema: 'Schema',
205+
normalize_schemas: bool = False, deleted: bool = False
206+
) -> 'RegisteredSchema':
207+
208+
registered_schema = self._store.get_registered_schema_by_schema(subject_name, schema)
209+
if registered_schema is not None:
210+
return registered_schema
211+
212+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
213+
214+
async def get_subjects(self) -> List[str]:
215+
return self._store.get_subjects()
216+
217+
async def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]:
218+
return self._store.remove_by_subject(subject_name)
219+
220+
async def get_latest_version(self, subject_name: str, fmt: Optional[str] = None) -> 'RegisteredSchema':
221+
registered_schema = self._store.get_latest_version(subject_name)
222+
if registered_schema is not None:
223+
return registered_schema
224+
225+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
226+
227+
async def get_latest_with_metadata(
228+
self, subject_name: str, metadata: Dict[str, str],
229+
deleted: bool = False, fmt: Optional[str] = None
230+
) -> 'RegisteredSchema':
231+
registered_schema = self._store.get_latest_with_metadata(subject_name, metadata)
232+
if registered_schema is not None:
233+
return registered_schema
234+
235+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
236+
237+
async def get_version(
238+
self, subject_name: str, version: int,
239+
deleted: bool = False, fmt: Optional[str] = None
240+
) -> 'RegisteredSchema':
241+
registered_schema = self._store.get_version(subject_name, version)
242+
if registered_schema is not None:
243+
return registered_schema
244+
245+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
246+
247+
async def get_versions(self, subject_name: str) -> List[int]:
248+
return self._store.get_versions(subject_name)
249+
250+
async def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int:
251+
registered_schema = self._store.get_version(subject_name, version)
252+
if registered_schema is not None:
253+
self._store.remove_by_schema(registered_schema)
254+
return registered_schema.schema_id
255+
256+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
257+
258+
async def set_config(
259+
self, subject_name: Optional[str] = None, config: 'ServerConfig' = None # noqa F821
260+
) -> 'ServerConfig': # noqa F821
261+
return None
262+
263+
async def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': # noqa F821
264+
return None

src/confluent_kafka/schema_registry/_async/protobuf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N
597597
writer_schema = None
598598

599599
if latest_schema is not None:
600-
migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None)
600+
migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None)
601601
reader_schema_raw = latest_schema.schema
602602
fd_proto, pool = await self._get_parsed_schema(latest_schema.schema)
603603
reader_schema = pool.FindFileByName(fd_proto.name)

src/confluent_kafka/schema_registry/_async/schema_registry_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,8 +1158,8 @@ def clear_caches(self):
11581158

11591159
@staticmethod
11601160
def new_client(conf: dict) -> 'AsyncSchemaRegistryClient':
1161-
from confluent_kafka.schema_registry.mock_schema_registry_client import MockSchemaRegistryClient
1161+
from .mock_schema_registry_client import AsyncMockSchemaRegistryClient
11621162
url = conf.get("url")
11631163
if url.startswith("mock://"):
1164-
return MockSchemaRegistryClient(conf)
1164+
return AsyncMockSchemaRegistryClient(conf)
11651165
return AsyncSchemaRegistryClient(conf)

src/confluent_kafka/schema_registry/_async/serde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ async def _get_migrations(
216216
self, subject: str, source_info: Schema,
217217
target: RegisteredSchema, fmt: Optional[str]
218218
) -> List[Migration]:
219-
source = self._registry.lookup_schema(subject, source_info, False, True)
219+
source = await self._registry.lookup_schema(subject, source_info, False, True)
220220
migrations = []
221221
if source.version < target.version:
222222
migration_mode = RuleMode.UPGRADE

src/confluent_kafka/schema_registry/mock_schema_registry_client.py renamed to src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from threading import Lock
2121
from typing import List, Dict, Optional
2222

23-
from . import SchemaRegistryClient, RegisteredSchema, Schema
24-
from .error import SchemaRegistryError
23+
from .schema_registry_client import SchemaRegistryClient
24+
from ..common.schema_registry_client import RegisteredSchema, Schema, ServerConfig
25+
from ..error import SchemaRegistryError
2526

2627

2728
class _SchemaStore(object):

src/confluent_kafka/schema_registry/_sync/schema_registry_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1158,7 +1158,7 @@ def clear_caches(self):
11581158

11591159
@staticmethod
11601160
def new_client(conf: dict) -> 'SchemaRegistryClient':
1161-
from confluent_kafka.schema_registry.mock_schema_registry_client import MockSchemaRegistryClient
1161+
from .mock_schema_registry_client import MockSchemaRegistryClient
11621162
url = conf.get("url")
11631163
if url.startswith("mock://"):
11641164
return MockSchemaRegistryClient(conf)

tests/integration/schema_registry/_async/test_avro_serializers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@ async def test_delivery_report_serialization(kafka_cluster, load_file, avsc, dat
278278
producer = kafka_cluster.async_producer(value_serializer=value_serializer)
279279

280280
async def assert_cb(err, msg):
281-
actual = value_deserializer(msg.value(),
282-
SerializationContext(topic, MessageField.VALUE, msg.headers()))
281+
actual = await value_deserializer(
282+
msg.value(), SerializationContext(topic, MessageField.VALUE, msg.headers()))
283283

284284
if record_type == "record":
285285
assert [v == actual[k] for k, v in data.items()]

tests/integration/schema_registry/_sync/test_avro_serializers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@ def test_delivery_report_serialization(kafka_cluster, load_file, avsc, data, rec
278278
producer = kafka_cluster.producer(value_serializer=value_serializer)
279279

280280
def assert_cb(err, msg):
281-
actual = value_deserializer(msg.value(),
282-
SerializationContext(topic, MessageField.VALUE, msg.headers()))
281+
actual = value_deserializer(
282+
msg.value(), SerializationContext(topic, MessageField.VALUE, msg.headers()))
283283

284284
if record_type == "record":
285285
assert [v == actual[k] for k, v in data.items()]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy