Skip to content

Commit f8b6468

Browse files
Manicbenemasab
andauthored
Support setting principal and SASL extensions in oauth_cb, handle failures (confluentinc#1402)
* Support setting principal and SASL extensions in oauth_cb and handle token failures * removed global variables Co-authored-by: Emanuele Sabellico <esabellico@confluent.io>
1 parent 9ea3aae commit f8b6468

File tree

2 files changed

+164
-38
lines changed

2 files changed

+164
-38
lines changed

src/confluent_kafka/src/confluent_kafka.c

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,13 +1522,73 @@ static void log_cb (const rd_kafka_t *rk, int level,
15221522
CallState_resume(cs);
15231523
}
15241524

1525+
/**
1526+
* @brief Translate Python \p key and \p value to C types and set on
1527+
* provided \p extensions char* array at the provided index.
1528+
*
1529+
* @returns 1 on success or 0 if an exception was raised.
1530+
*/
1531+
static int py_extensions_to_c (char **extensions, Py_ssize_t idx,
1532+
PyObject *key, PyObject *value) {
1533+
PyObject *ks, *ks8, *vo8 = NULL;
1534+
const char *k;
1535+
const char *v;
1536+
Py_ssize_t ksize = 0;
1537+
Py_ssize_t vsize = 0;
1538+
1539+
if (!(ks = cfl_PyObject_Unistr(key))) {
1540+
PyErr_SetString(PyExc_TypeError,
1541+
"expected extension key to be unicode "
1542+
"string");
1543+
return 0;
1544+
}
1545+
1546+
k = cfl_PyUnistr_AsUTF8(ks, &ks8);
1547+
ksize = (Py_ssize_t)strlen(k);
1548+
1549+
if (cfl_PyUnistr(_Check(value))) {
1550+
/* Unicode string, translate to utf-8. */
1551+
v = cfl_PyUnistr_AsUTF8(value, &vo8);
1552+
if (!v) {
1553+
Py_DECREF(ks);
1554+
Py_XDECREF(ks8);
1555+
return 0;
1556+
}
1557+
vsize = (Py_ssize_t)strlen(v);
1558+
} else {
1559+
PyErr_Format(PyExc_TypeError,
1560+
"expected extension value to be "
1561+
"unicode string, not %s",
1562+
((PyTypeObject *)PyObject_Type(value))->
1563+
tp_name);
1564+
Py_DECREF(ks);
1565+
Py_XDECREF(ks8);
1566+
return 0;
1567+
}
1568+
1569+
extensions[idx] = (char*)malloc(ksize);
1570+
strcpy(extensions[idx], k);
1571+
extensions[idx + 1] = (char*)malloc(vsize);
1572+
strcpy(extensions[idx + 1], v);
1573+
1574+
Py_DECREF(ks);
1575+
Py_XDECREF(ks8);
1576+
Py_XDECREF(vo8);
1577+
1578+
return 1;
1579+
}
1580+
15251581
static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
15261582
void *opaque) {
15271583
Handle *h = opaque;
15281584
PyObject *eo, *result;
15291585
CallState *cs;
15301586
const char *token;
15311587
double expiry;
1588+
const char *principal = "";
1589+
PyObject *extensions = NULL;
1590+
char **rd_extensions = NULL;
1591+
Py_ssize_t rd_extensions_size = 0;
15321592
char err_msg[2048];
15331593
rd_kafka_resp_err_t err_code;
15341594

@@ -1539,26 +1599,57 @@ static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
15391599
Py_DECREF(eo);
15401600

15411601
if (!result) {
1542-
goto err;
1602+
goto fail;
15431603
}
1544-
if (!PyArg_ParseTuple(result, "sd", &token, &expiry)) {
1604+
if (!PyArg_ParseTuple(result, "sd|sO!", &token, &expiry, &principal, &PyDict_Type, &extensions)) {
15451605
Py_DECREF(result);
1546-
PyErr_Format(PyExc_TypeError,
1606+
PyErr_SetString(PyExc_TypeError,
15471607
"expect returned value from oauth_cb "
15481608
"to be (token_str, expiry_time) tuple");
15491609
goto err;
15501610
}
1611+
1612+
if (extensions) {
1613+
int len = (int)PyDict_Size(extensions);
1614+
rd_extensions = (char **)malloc(2 * len * sizeof(char *));
1615+
Py_ssize_t pos = 0;
1616+
PyObject *ko, *vo;
1617+
while (PyDict_Next(extensions, &pos, &ko, &vo)) {
1618+
if (!py_extensions_to_c(rd_extensions, rd_extensions_size, ko, vo)) {
1619+
Py_DECREF(result);
1620+
free(rd_extensions);
1621+
goto err;
1622+
}
1623+
rd_extensions_size = rd_extensions_size + 2;
1624+
}
1625+
}
1626+
15511627
err_code = rd_kafka_oauthbearer_set_token(h->rk, token,
15521628
(int64_t)(expiry * 1000),
1553-
"", NULL, 0, err_msg,
1629+
principal, (const char **)rd_extensions, rd_extensions_size, err_msg,
15541630
sizeof(err_msg));
15551631
Py_DECREF(result);
1556-
if (err_code) {
1632+
if (rd_extensions) {
1633+
for(int i = 0; i < rd_extensions_size; i++) {
1634+
free(rd_extensions[i]);
1635+
}
1636+
free(rd_extensions);
1637+
}
1638+
1639+
if (err_code != RD_KAFKA_RESP_ERR_NO_ERROR) {
15571640
PyErr_Format(PyExc_ValueError, "%s", err_msg);
1558-
goto err;
1641+
goto fail;
15591642
}
15601643
goto done;
15611644

1645+
fail:
1646+
err_code = rd_kafka_oauthbearer_set_token_failure(h->rk, "OAuth callback raised exception");
1647+
if (err_code != RD_KAFKA_RESP_ERR_NO_ERROR) {
1648+
PyErr_SetString(PyExc_ValueError, "Failed to set token failure");
1649+
goto err;
1650+
}
1651+
PyErr_Clear();
1652+
goto done;
15621653
err:
15631654
CallState_crash(cs);
15641655
rd_kafka_yield(h->rk);

tests/test_misc.py

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,49 +24,41 @@ def test_version():
2424
assert confluent_kafka.version()[0] == confluent_kafka.__version__
2525

2626

27-
# global variable for error_cb call back function
28-
seen_error_cb = False
29-
30-
3127
def test_error_cb():
3228
""" Tests error_cb. """
29+
seen_error_cb = False
3330

3431
def error_cb(error_msg):
35-
global seen_error_cb
32+
nonlocal seen_error_cb
3633
seen_error_cb = True
3734
acceptable_error_codes = (confluent_kafka.KafkaError._TRANSPORT, confluent_kafka.KafkaError._ALL_BROKERS_DOWN)
3835
assert error_msg.code() in acceptable_error_codes
3936

4037
conf = {'bootstrap.servers': 'localhost:65531', # Purposely cause connection refused error
4138
'group.id': 'test',
42-
'socket.timeout.ms': '100',
4339
'session.timeout.ms': 1000, # Avoid close() blocking too long
4440
'error_cb': error_cb
4541
}
4642

4743
kc = confluent_kafka.Consumer(**conf)
4844
kc.subscribe(["test"])
4945
while not seen_error_cb:
50-
kc.poll(timeout=1)
46+
kc.poll(timeout=0.1)
5147

5248
kc.close()
5349

5450

55-
# global variable for stats_cb call back function
56-
seen_stats_cb = False
57-
58-
5951
def test_stats_cb():
6052
""" Tests stats_cb. """
53+
seen_stats_cb = False
6154

6255
def stats_cb(stats_json_str):
63-
global seen_stats_cb
56+
nonlocal seen_stats_cb
6457
seen_stats_cb = True
6558
stats_json = json.loads(stats_json_str)
6659
assert len(stats_json['name']) > 0
6760

6861
conf = {'group.id': 'test',
69-
'socket.timeout.ms': '100',
7062
'session.timeout.ms': 1000, # Avoid close() blocking too long
7163
'statistics.interval.ms': 200,
7264
'stats_cb': stats_cb
@@ -76,22 +68,20 @@ def stats_cb(stats_json_str):
7668

7769
kc.subscribe(["test"])
7870
while not seen_stats_cb:
79-
kc.poll(timeout=1)
71+
kc.poll(timeout=0.1)
8072
kc.close()
8173

8274

83-
seen_stats_cb_check_no_brokers = False
84-
85-
8675
def test_conf_none():
8776
""" Issue #133
8877
Test that None can be passed for NULL by setting bootstrap.servers
8978
to None. If None would be converted to a string then a broker would
9079
show up in statistics. Verify that it doesnt. """
80+
seen_stats_cb_check_no_brokers = False
9181

9282
def stats_cb_check_no_brokers(stats_json_str):
9383
""" Make sure no brokers are reported in stats """
94-
global seen_stats_cb_check_no_brokers
84+
nonlocal seen_stats_cb_check_no_brokers
9585
stats = json.loads(stats_json_str)
9686
assert len(stats['brokers']) == 0, "expected no brokers in stats: %s" % stats_json_str
9787
seen_stats_cb_check_no_brokers = True
@@ -101,9 +91,8 @@ def stats_cb_check_no_brokers(stats_json_str):
10191
'stats_cb': stats_cb_check_no_brokers}
10292

10393
p = confluent_kafka.Producer(conf)
104-
p.poll(timeout=1)
94+
p.poll(timeout=0.1)
10595

106-
global seen_stats_cb_check_no_brokers
10796
assert seen_stats_cb_check_no_brokers
10897

10998

@@ -130,23 +119,19 @@ def test_throttle_event_types():
130119
assert str(throttle_event) == "broker/0 throttled for 10000 ms"
131120

132121

133-
# global variable for oauth_cb call back function
134-
seen_oauth_cb = False
135-
136-
137122
def test_oauth_cb():
138123
""" Tests oauth_cb. """
124+
seen_oauth_cb = False
139125

140126
def oauth_cb(oauth_config):
141-
global seen_oauth_cb
127+
nonlocal seen_oauth_cb
142128
seen_oauth_cb = True
143129
assert oauth_config == 'oauth_cb'
144130
return 'token', time.time() + 300.0
145131

146132
conf = {'group.id': 'test',
147133
'security.protocol': 'sasl_plaintext',
148134
'sasl.mechanisms': 'OAUTHBEARER',
149-
'socket.timeout.ms': '100',
150135
'session.timeout.ms': 1000, # Avoid close() blocking too long
151136
'sasl.oauthbearer.config': 'oauth_cb',
152137
'oauth_cb': oauth_cb
@@ -155,7 +140,59 @@ def oauth_cb(oauth_config):
155140
kc = confluent_kafka.Consumer(**conf)
156141

157142
while not seen_oauth_cb:
158-
kc.poll(timeout=1)
143+
kc.poll(timeout=0.1)
144+
kc.close()
145+
146+
147+
def test_oauth_cb_principal_sasl_extensions():
148+
""" Tests oauth_cb. """
149+
seen_oauth_cb = False
150+
151+
def oauth_cb(oauth_config):
152+
nonlocal seen_oauth_cb
153+
seen_oauth_cb = True
154+
assert oauth_config == 'oauth_cb'
155+
return 'token', time.time() + 300.0, oauth_config, {"extone": "extoneval", "exttwo": "exttwoval"}
156+
157+
conf = {'group.id': 'test',
158+
'security.protocol': 'sasl_plaintext',
159+
'sasl.mechanisms': 'OAUTHBEARER',
160+
'session.timeout.ms': 100, # Avoid close() blocking too long
161+
'sasl.oauthbearer.config': 'oauth_cb',
162+
'oauth_cb': oauth_cb
163+
}
164+
165+
kc = confluent_kafka.Consumer(**conf)
166+
167+
while not seen_oauth_cb:
168+
kc.poll(timeout=0.1)
169+
kc.close()
170+
171+
172+
def test_oauth_cb_failure():
173+
""" Tests oauth_cb. """
174+
oauth_cb_count = 0
175+
176+
def oauth_cb(oauth_config):
177+
nonlocal oauth_cb_count
178+
oauth_cb_count += 1
179+
assert oauth_config == 'oauth_cb'
180+
if oauth_cb_count == 2:
181+
return 'token', time.time() + 100.0, oauth_config, {"extthree": "extthreeval"}
182+
raise Exception
183+
184+
conf = {'group.id': 'test',
185+
'security.protocol': 'sasl_plaintext',
186+
'sasl.mechanisms': 'OAUTHBEARER',
187+
'session.timeout.ms': 1000, # Avoid close() blocking too long
188+
'sasl.oauthbearer.config': 'oauth_cb',
189+
'oauth_cb': oauth_cb
190+
}
191+
192+
kc = confluent_kafka.Consumer(**conf)
193+
194+
while oauth_cb_count < 2:
195+
kc.poll(timeout=0.1)
159196
kc.close()
160197

161198

@@ -194,11 +231,9 @@ def test_unordered_dict(init_func):
194231
client.poll(0)
195232

196233

197-
# global variable for on_delivery call back function
198-
seen_delivery_cb = False
199-
200-
201234
def test_topic_config_update():
235+
seen_delivery_cb = False
236+
202237
# *NOTE* default.topic.config has been deprecated.
203238
# This example remains to ensure backward-compatibility until its removal.
204239
confs = [{"message.timeout.ms": 600000, "default.topic.config": {"message.timeout.ms": 1000}},
@@ -207,7 +242,7 @@ def test_topic_config_update():
207242

208243
def on_delivery(err, msg):
209244
# Since there is no broker, produced messages should time out.
210-
global seen_delivery_cb
245+
nonlocal seen_delivery_cb
211246
seen_delivery_cb = True
212247
assert err.code() == confluent_kafka.KafkaError._MSG_TIMED_OUT
213248

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy