Skip to content

Commit b7f8dce

Browse files
authored
Add callback for oauth (@stevenylai, confluentinc#960)
1 parent 8060bd0 commit b7f8dce

File tree

7 files changed

+244
-0
lines changed

7 files changed

+244
-0
lines changed

docs/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,14 @@ The Python bindings also provide some additional configuration properties:
459459
This callback is served upon calling ``client.poll()`` or ``producer.flush()``. See
460460
https://github.com/edenhill/librdkafka/wiki/Statistics" for more information.
461461

462+
* ``oauth_cb(config_str)``: Callback for retrieving OAuth Bearer token.
463+
Function argument ``config_str`` is a str from config: ``sasl.oauthbearer.config``.
464+
Return value of this callback is expected to be ``(token_str, expiry_time)`` tuple
465+
where ``expiry_time`` is the time in seconds since the epoch as a floating point number.
466+
This callback is useful only when ``sasl.mechanisms=OAUTHBEARER`` is set and
467+
is served to get the initial token before a successful broker connection can be made.
468+
The callback can be triggered by calling ``client.poll()`` or ``producer.flush()``.
469+
462470
* ``on_delivery(kafka.KafkaError, kafka.Message)`` (**Producer**): value is a Python function reference
463471
that is called once for each produced message to indicate the final
464472
delivery result (success or failure).

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ The scripts in this directory provide code examples using Confluent's Python cli
1414
* [protobuf_consumer.py](protobuf_consumer.py): DeserializingConsumer with ProtobufDeserializer
1515
* [sasl_producer.py](sasl_producer.py): SerializingProducer with SASL Authentication
1616
* [list_offsets.py](list_offsets.py): List committed offsets and consumer lag for group and topics
17+
* [oauth_producer.py](oauth_producer.py): SerializingProducer with OAuth Authentication (client credentials)
1718

1819
Additional examples for [Confluent Cloud](https://www.confluent.io/confluent-cloud/):
1920

examples/oauth_producer.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright 2020 Confluent Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
#
20+
# This uses OAuth client credentials grant:
21+
# https://www.oauth.com/oauth2-servers/access-tokens/client-credentials/
22+
# where client_id and client_secret are passed as HTTP Authorization header
23+
#
24+
25+
import logging
26+
import functools
27+
import argparse
28+
import time
29+
from confluent_kafka import SerializingProducer
30+
from confluent_kafka.serialization import StringSerializer
31+
import requests
32+
33+
34+
def _get_token(args, config):
35+
"""Note here value of config comes from sasl.oauthbearer.config below.
36+
It is not used in this example but you can put arbitrary values to
37+
configure how you can get the token (e.g. which token URL to use)
38+
"""
39+
payload = {
40+
'grant_type': 'client_credentials',
41+
'scope': ' '.join(args.scopes)
42+
}
43+
resp = requests.post(args.token_url,
44+
auth=(args.client_id, args.client_secret),
45+
data=payload)
46+
token = resp.json()
47+
return token['access_token'], time.time() + float(token['expires_in'])
48+
49+
50+
def producer_config(args):
51+
logger = logging.getLogger(__name__)
52+
return {
53+
'bootstrap.servers': args.bootstrap_servers,
54+
'key.serializer': StringSerializer('utf_8'),
55+
'value.serializer': StringSerializer('utf_8'),
56+
'security.protocol': 'sasl_plaintext',
57+
'sasl.mechanisms': 'OAUTHBEARER',
58+
# sasl.oauthbearer.config can be used to pass argument to your oauth_cb
59+
# It is not used in this example since we are passing all the arguments
60+
# from command line
61+
# 'sasl.oauthbearer.config': 'not-used',
62+
'oauth_cb': functools.partial(_get_token, args),
63+
'logger': logger,
64+
}
65+
66+
67+
def delivery_report(err, msg):
68+
"""
69+
Reports the failure or success of a message delivery.
70+
71+
Args:
72+
err (KafkaError): The error that occurred on None on success.
73+
74+
msg (Message): The message that was produced or failed.
75+
76+
Note:
77+
In the delivery report callback the Message.key() and Message.value()
78+
will be the binary format as encoded by any configured Serializers and
79+
not the same object that was passed to produce().
80+
If you wish to pass the original object(s) for key and value to delivery
81+
report callback we recommend a bound callback or lambda where you pass
82+
the objects along.
83+
84+
"""
85+
if err is not None:
86+
print('Delivery failed for User record {}: {}'.format(msg.key(), err))
87+
return
88+
print('User record {} successfully produced to {} [{}] at offset {}'.format(
89+
msg.key(), msg.topic(), msg.partition(), msg.offset()))
90+
91+
92+
def main(args):
93+
topic = args.topic
94+
delimiter = args.delimiter
95+
96+
producer_conf = producer_config(args)
97+
98+
producer = SerializingProducer(producer_conf)
99+
100+
print('Producing records to topic {}. ^C to exit.'.format(topic))
101+
while True:
102+
# Serve on_delivery callbacks from previous calls to produce()
103+
producer.poll(0.0)
104+
try:
105+
msg_data = input(">")
106+
msg = msg_data.split(delimiter)
107+
if len(msg) == 2:
108+
producer.produce(topic=topic, key=msg[0], value=msg[1],
109+
on_delivery=delivery_report)
110+
else:
111+
producer.produce(topic=topic, value=msg[0],
112+
on_delivery=delivery_report)
113+
except KeyboardInterrupt:
114+
break
115+
116+
print('\nFlushing {} records...'.format(len(producer)))
117+
producer.flush()
118+
119+
120+
if __name__ == '__main__':
121+
parser = argparse.ArgumentParser(description="SerializingProducer OAUTH Example"
122+
" with client credentials grant")
123+
parser.add_argument('-b', dest="bootstrap_servers", required=True,
124+
help="Bootstrap broker(s) (host[:port])")
125+
parser.add_argument('-t', dest="topic", default="example_producer_oauth",
126+
help="Topic name")
127+
parser.add_argument('-d', dest="delimiter", default="|",
128+
help="Key-Value delimiter. Defaults to '|'"),
129+
parser.add_argument('--client', dest="client_id", required=True,
130+
help="Client ID for client credentials flow")
131+
parser.add_argument('--secret', dest="client_secret", required=True,
132+
help="Client secret for client credentials flow.")
133+
parser.add_argument('--token-url', dest="token_url", required=True,
134+
help="Token URL.")
135+
parser.add_argument('--scopes', dest="scopes", required=True, nargs='+',
136+
help="Scopes requested from OAuth server.")
137+
138+
main(parser.parse_args())

examples/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ pyrsistent==0.16.1;python_version<"3.0"
77
pyrsistent;python_version>"3.0"
88
jsonschema
99
protobuf
10+
requests

src/confluent_kafka/src/confluent_kafka.c

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,6 +1521,50 @@ static void log_cb (const rd_kafka_t *rk, int level,
15211521
CallState_resume(cs);
15221522
}
15231523

1524+
static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
1525+
void *opaque) {
1526+
Handle *h = opaque;
1527+
PyObject *eo, *result;
1528+
CallState *cs;
1529+
const char *token;
1530+
double expiry;
1531+
char err_msg[2048];
1532+
rd_kafka_resp_err_t err_code;
1533+
1534+
cs = CallState_get(h);
1535+
1536+
eo = Py_BuildValue("s", oauthbearer_config);
1537+
result = PyObject_CallFunctionObjArgs(h->oauth_cb, eo, NULL);
1538+
Py_DECREF(eo);
1539+
1540+
if (!result) {
1541+
goto err;
1542+
}
1543+
if (!PyArg_ParseTuple(result, "sd", &token, &expiry)) {
1544+
Py_DECREF(result);
1545+
PyErr_Format(PyExc_TypeError,
1546+
"expect returned value from oauth_cb "
1547+
"to be (token_str, expiry_time) tuple");
1548+
goto err;
1549+
}
1550+
err_code = rd_kafka_oauthbearer_set_token(h->rk, token,
1551+
(int64_t)(expiry * 1000),
1552+
"", NULL, 0, err_msg,
1553+
sizeof(err_msg));
1554+
Py_DECREF(result);
1555+
if (err_code) {
1556+
PyErr_Format(PyExc_ValueError, "%s", err_msg);
1557+
goto err;
1558+
}
1559+
goto done;
1560+
1561+
err:
1562+
CallState_crash(cs);
1563+
rd_kafka_yield(h->rk);
1564+
done:
1565+
CallState_resume(cs);
1566+
}
1567+
15241568
/****************************************************************************
15251569
*
15261570
*
@@ -1949,6 +1993,25 @@ rd_kafka_conf_t *common_conf_setup (rd_kafka_type_t ktype,
19491993
Py_XDECREF(ks8);
19501994
Py_DECREF(ks);
19511995
continue;
1996+
} else if (!strcmp(k, "oauth_cb")) {
1997+
if (!PyCallable_Check(vo)) {
1998+
PyErr_SetString(PyExc_TypeError,
1999+
"expected oauth_cb property "
2000+
"as a callable function");
2001+
goto inner_err;
2002+
}
2003+
if (h->oauth_cb) {
2004+
Py_DECREF(h->oauth_cb);
2005+
h->oauth_cb = NULL;
2006+
}
2007+
2008+
if (vo != Py_None) {
2009+
h->oauth_cb = vo;
2010+
Py_INCREF(h->oauth_cb);
2011+
}
2012+
Py_XDECREF(ks8);
2013+
Py_DECREF(ks);
2014+
continue;
19522015
}
19532016

19542017
/* Special handling for certain config keys. */
@@ -2019,6 +2082,9 @@ rd_kafka_conf_t *common_conf_setup (rd_kafka_type_t ktype,
20192082
rd_kafka_conf_set_log_cb(conf, log_cb);
20202083
}
20212084

2085+
if (h->oauth_cb)
2086+
rd_kafka_conf_set_oauthbearer_token_refresh_cb(conf, oauth_cb);
2087+
20222088
rd_kafka_conf_set_opaque(conf, h);
20232089

20242090
#ifdef WITH_PY_TSS

src/confluent_kafka/src/confluent_kafka.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ typedef struct {
236236
rd_kafka_type_t type; /* Producer or consumer */
237237

238238
PyObject *logger;
239+
PyObject *oauth_cb;
239240

240241
union {
241242
/**

tests/test_misc.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,35 @@ def test_throttle_event_types():
130130
assert str(throttle_event) == "broker/0 throttled for 10000 ms"
131131

132132

133+
# global variable for oauth_cb call back function
134+
seen_oauth_cb = False
135+
136+
137+
def test_oauth_cb():
138+
""" Tests oauth_cb. """
139+
140+
def oauth_cb(oauth_config):
141+
global seen_oauth_cb
142+
seen_oauth_cb = True
143+
assert oauth_config == 'oauth_cb'
144+
return 'token', time.time() + 300.0
145+
146+
conf = {'group.id': 'test',
147+
'security.protocol': 'sasl_plaintext',
148+
'sasl.mechanisms': 'OAUTHBEARER',
149+
'socket.timeout.ms': '100',
150+
'session.timeout.ms': 1000, # Avoid close() blocking too long
151+
'sasl.oauthbearer.config': 'oauth_cb',
152+
'oauth_cb': oauth_cb
153+
}
154+
155+
kc = confluent_kafka.Consumer(**conf)
156+
157+
while not seen_oauth_cb:
158+
kc.poll(timeout=1)
159+
kc.close()
160+
161+
133162
def skip_interceptors():
134163
# Run interceptor test if monitoring-interceptor is found
135164
for path in ["/usr/lib", "/usr/local/lib", "staging/libs", "."]:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy