From e6bbae30a8f931b0ec8d881d03c916902777a506 Mon Sep 17 00:00:00 2001 From: Alan Crosswell Date: Thu, 3 Sep 2020 20:38:25 -0400 Subject: [PATCH] add OAS securitySchemes and security objects --- docs/api-guide/schemas.md | 20 +++++++ rest_framework/authentication.py | 93 +++++++++++++++++++++++++++++++ rest_framework/schemas/openapi.py | 58 +++++++++++++++++++ tests/schemas/test_openapi.py | 49 +++++++++++++++- 4 files changed, 219 insertions(+), 1 deletion(-) diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index b9de6745fe..937eb4d425 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -389,6 +389,26 @@ differentiate between request and response objects. By default returns `get_serializer()` but can be overridden to differentiate between request and response objects. +#### `get_security_schemes()` + +Generates the OpenAPI `securitySchemes` components based on: +- Your default `authentication_classes` (`settings.DEFAULT_AUTHENTICATION_CLASSES`) +- Per-view non-default `authentication_classes` + +These are generated using the authentication classes' `openapi_security_scheme()` class method. If you +extend `BaseAuthentication` with your own authentication class, you can add this class method to return +the appropriate security scheme object. + +#### `get_security_requirements()` + +Root-level security requirements (the top-level `security` object) are generated based on the +default authentication classes. Operation-level security requirements are generated only if the given view's +`authentication_classes` differ from the defaults. + +These are generated using the authentication classes' `openapi_security_requirement()` class +method. If you extended `BaseAuthentication` with your own authentication class, you can add this +class method to return the appropriate list of security requirements objects. + ### `AutoSchema.__init__()` kwargs `AutoSchema` provides a number of `__init__()` kwargs that can be used for diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 382abf1580..ffff1f4567 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -49,6 +49,32 @@ def authenticate_header(self, request): """ pass + #: Name of openapi security scheme. Override if you want to customize it. + openapi_security_scheme_name = None + + @classmethod + def openapi_security_scheme(cls): + """ + Override this to return an Open API Specification `securityScheme object + `_ + """ + return {} + + @classmethod + def openapi_security_requirement(cls, view, method): + """ + Override this to return an Open API Specification `security requirement object + `_ + + :param view: used to find view attributes used by a permission class or None for root-level + :param method: used to distinguish among method-specific permissions or None for root-level + :return:list: [security requirement objects] + """ + # At this point, none of the built-in DRF authentication classes fill in the + # requirement list: OAuth2/OIDC are the only security types that currently uses the list + # (for scopes). See http://spec.openapis.org/oas/v3.0.3#patterned-fields-2. + return [{}] + class BasicAuthentication(BaseAuthentication): """ @@ -108,6 +134,22 @@ def authenticate_credentials(self, userid, password, request=None): def authenticate_header(self, request): return 'Basic realm="%s"' % self.www_authenticate_realm + openapi_security_scheme_name = 'basicAuth' + + @classmethod + def openapi_security_scheme(cls): + return { + cls.openapi_security_scheme_name: { + 'type': 'http', + 'scheme': 'basic', + 'description': 'Basic Authentication' + } + } + + @classmethod + def openapi_security_requirement(cls, view, method): + return [{cls.openapi_security_scheme_name: []}] + class SessionAuthentication(BaseAuthentication): """ @@ -147,6 +189,23 @@ def dummy_get_response(request): # pragma: no cover # CSRF failed, bail with explicit error message raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) + openapi_security_scheme_name = 'sessionAuth' + + @classmethod + def openapi_security_scheme(cls): + return { + cls.openapi_security_scheme_name: { + 'type': 'apiKey', + 'in': 'cookie', + 'name': 'JSESSIONID', + 'description': 'Session authentication' + } + } + + @classmethod + def openapi_security_requirement(cls, view, method): + return [{cls.openapi_security_scheme_name: []}] + class TokenAuthentication(BaseAuthentication): """ @@ -210,6 +269,23 @@ def authenticate_credentials(self, key): def authenticate_header(self, request): return self.keyword + openapi_security_scheme_name = 'tokenAuth' + + @classmethod + def openapi_security_scheme(cls): + return { + cls.openapi_security_scheme_name: { + 'type': 'http', + 'in': 'header', + 'name': 'Authorization', # Authorization: token ... + 'description': 'Token authentication' + } + } + + @classmethod + def openapi_security_requirement(cls, view, method): + return [{cls.openapi_security_scheme_name: []}] + class RemoteUserAuthentication(BaseAuthentication): """ @@ -230,3 +306,20 @@ def authenticate(self, request): user = authenticate(request=request, remote_user=request.META.get(self.header)) if user and user.is_active: return (user, None) + + openapi_security_scheme_name = 'remoteUserAuth' + + @classmethod + def openapi_security_scheme(cls): + return { + cls.openapi_security_scheme_name: { + 'type': 'http', + 'in': 'header', + 'name': 'REMOTE_USER', + 'description': 'Remote User authentication' + } + } + + @classmethod + def openapi_security_requirement(cls, view, method): + return [{cls.openapi_security_scheme_name: []}] diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 5e9d59f8bf..481fe4c42b 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -70,6 +70,14 @@ def get_schema(self, request=None, public=False): """ self._initialise_endpoints() components_schemas = {} + security_schemes_schemas = {} + root_security_requirements = [] + + if api_settings.DEFAULT_AUTHENTICATION_CLASSES: + for auth_class in api_settings.DEFAULT_AUTHENTICATION_CLASSES: + req = auth_class.openapi_security_requirement(None, None) + if req: + root_security_requirements += req # Iterate endpoints generating per method path operations. paths = {} @@ -80,6 +88,7 @@ def get_schema(self, request=None, public=False): operation = view.schema.get_operation(path, method) components = view.schema.get_components(path, method) + for k in components.keys(): if k not in components_schemas: continue @@ -89,6 +98,16 @@ def get_schema(self, request=None, public=False): components_schemas.update(components) + security_schemes = view.schema.get_security_schemes(path, method) + for k in security_schemes.keys(): + if k not in security_schemes_schemas: + continue + if security_schemes_schemas[k] == security_schemes[k]: + continue + warnings.warn('Security scheme component "{}" has been overriden with a different ' + 'value.'.format(k)) + security_schemes_schemas.update(security_schemes) + # Normalise path for any provided mount url. if path.startswith('/'): path = path[1:] @@ -111,6 +130,14 @@ def get_schema(self, request=None, public=False): 'schemas': components_schemas } + if len(security_schemes_schemas) > 0: + if 'components' not in schema: + schema['components'] = {} + schema['components']['securitySchemes'] = security_schemes_schemas + + if len(root_security_requirements) > 0: + schema['security'] = root_security_requirements + return schema # View Inspectors @@ -146,6 +173,9 @@ def get_operation(self, path, method): operation['operationId'] = self.get_operation_id(path, method) operation['description'] = self.get_description(path, method) + security = self.get_security_requirements(path, method) + if security is not None: + operation['security'] = security parameters = [] parameters += self.get_path_parameters(path, method) @@ -713,6 +743,34 @@ def get_tags(self, path, method): return [path.split('/')[0].replace('_', '-')] + def get_security_schemes(self, path, method): + """ + Get components.schemas.securitySchemes required by this path. + returns dict of securitySchemes. + """ + schemes = {} + for auth_class in self.view.authentication_classes: + if hasattr(auth_class, 'openapi_security_scheme'): + schemes.update(auth_class.openapi_security_scheme()) + return schemes + + def get_security_requirements(self, path, method): + """ + Get Security Requirement Object list for this operation. + Returns a list of security requirement objects based on the view's authentication classes + unless this view's authentication classes are the same as the root-level defaults. + """ + # references the securityScheme names described above in get_security_schemes() + security = [] + if self.view.authentication_classes == api_settings.DEFAULT_AUTHENTICATION_CLASSES: + return None + for auth_class in self.view.authentication_classes: + if hasattr(auth_class, 'openapi_security_requirement'): + req = auth_class.openapi_security_requirement(self.view, method) + if req: + security += req + return security + def _get_path_parameters(self, path, method): warnings.warn( "Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. " diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index daa035a3f3..c62ae1a081 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -8,6 +8,7 @@ from django.utils.translation import gettext_lazy as _ from rest_framework import filters, generics, pagination, routers, serializers +from rest_framework.authentication import TokenAuthentication from rest_framework.authtoken.views import obtain_auth_token from rest_framework.compat import uritemplate from rest_framework.parsers import JSONParser, MultiPartParser @@ -1235,5 +1236,51 @@ class ExampleView(generics.DestroyAPIView): ] generator = SchemaGenerator(patterns=url_patterns) schema = generator.get_schema(request=create_request('/')) - assert 'components' not in schema + assert 'schemas' not in schema['components'] assert 'content' not in schema['paths']['/example/']['delete']['responses']['204'] + + def test_default_root_security_schemes(self): + patterns = [ + path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()), + ] + + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + assert 'security' in schema + assert {'sessionAuth': []} in schema['security'] + assert {'basicAuth': []} in schema['security'] + assert 'security' not in schema['paths']['/example/']['get'] + + @override_settings(REST_FRAMEWORK={'DEFAULT_AUTHENTICATION_CLASSES': None}) + def test_no_default_root_security_schemes(self): + patterns = [ + path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()), + ] + + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + assert 'security' not in schema + + def test_operation_security_schemes(self): + class MyExample(views.ExampleAutoSchemaComponentName): + authentication_classes = [TokenAuthentication] + + patterns = [ + path('^example/?$', MyExample.as_view()), + ] + + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + assert 'security' in schema + assert {'sessionAuth': []} in schema['security'] + assert {'basicAuth': []} in schema['security'] + get_operation = schema['paths']['/example/']['get'] + assert 'security' in get_operation + assert {'tokenAuth': []} in get_operation['security'] + assert len(get_operation['security']) == 1 pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy