Skip to content

Commit 83f0810

Browse files
committed
Rebase on #7127 and remove _get_item_schema refactoring
1 parent 67549c3 commit 83f0810

File tree

3 files changed

+102
-15
lines changed

3 files changed

+102
-15
lines changed

rest_framework/schemas/openapi.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from enum import Enum
23
from operator import attrgetter
34
from urllib.parse import urljoin
45

@@ -37,16 +38,21 @@ def get_schema(self, request=None, public=False):
3738
Generate a OpenAPI schema.
3839
"""
3940
self._initialise_endpoints()
41+
components_schemas = {}
4042

4143
# Iterate endpoints generating per method path operations.
42-
# TODO: …and reference components.
4344
paths = {}
4445
_, view_endpoints = self._get_paths_and_endpoints(None if public else request)
4546
for path, method, view in view_endpoints:
4647
if not self.has_view_permissions(path, method, view):
4748
continue
4849

4950
operation = view.schema.get_operation(path, method)
51+
component = view.schema.get_components(path, method)
52+
53+
if component is not None:
54+
components_schemas.update(component)
55+
5056
# Normalise path for any provided mount url.
5157
if path.startswith('/'):
5258
path = path[1:]
@@ -59,9 +65,14 @@ def get_schema(self, request=None, public=False):
5965
schema = {
6066
'openapi': '3.0.2',
6167
'info': self.get_info(),
62-
'paths': paths,
68+
'paths': paths
6369
}
6470

71+
if len(components_schemas) > 0:
72+
schema['components'] = {
73+
'schemas': components_schemas
74+
}
75+
6576
return schema
6677

6778
# View Inspectors
@@ -99,6 +110,21 @@ def get_operation(self, path, method):
99110

100111
return operation
101112

113+
def get_components(self, path, method):
114+
serializer = self._get_serializer(path, method)
115+
116+
if not isinstance(serializer, serializers.Serializer):
117+
return None
118+
119+
# If the model has no model, then the serializer will be inlined
120+
if not hasattr(serializer, 'Meta') or not hasattr(serializer.Meta, 'model'):
121+
return None
122+
123+
model_name = serializer.Meta.model.__name__
124+
content = self._map_serializer(serializer)
125+
126+
return {model_name: content}
127+
102128
def _get_operation_id(self, path, method):
103129
"""
104130
Compute an operation ID from the model, serializer or view name.
@@ -470,6 +496,10 @@ def _get_serializer(self, method, path):
470496
.format(view.__class__.__name__, method, path))
471497
return None
472498

499+
def _get_reference(self, serializer):
500+
model_name = serializer.Meta.model.__name__
501+
return {'$ref': '#/components/schemas/{}'.format(model_name)}
502+
473503
def _get_request_body(self, path, method):
474504
if method not in ('PUT', 'PATCH', 'POST'):
475505
return {}
@@ -479,20 +509,30 @@ def _get_request_body(self, path, method):
479509
serializer = self._get_serializer(path, method)
480510

481511
if not isinstance(serializer, serializers.Serializer):
482-
return {}
483-
484-
content = self._map_serializer(serializer)
485-
# No required fields for PATCH
486-
if method == 'PATCH':
487-
content.pop('required', None)
488-
# No read_only fields for request.
489-
for name, schema in content['properties'].copy().items():
490-
if 'readOnly' in schema:
491-
del content['properties'][name]
512+
item_schema = {}
513+
elif hasattr(serializer, 'Meta') and hasattr(serializer.Meta, 'model'):
514+
# If the serializer uses a model, we should use a reference
515+
item_schema = self._get_reference(serializer)
516+
else:
517+
# There is no model, we'll map the serializer's fields
518+
item_schema = self._map_serializer(serializer)
519+
# No required fields for PATCH
520+
if method == 'PATCH':
521+
item_schema.pop('required', None)
522+
# No read_only fields for request.
523+
# No write_only fields for response.
524+
for name, schema in item_schema['properties'].copy().items():
525+
if 'writeOnly' in schema:
526+
del item_schema['properties'][name]
527+
if 'required' in item_schema:
528+
item_schema['required'] = [f for f in item_schema['required'] if f != name]
529+
for name, schema in item_schema['properties'].copy().items():
530+
if 'readOnly' in schema:
531+
del item_schema['properties'][name]
492532

493533
return {
494534
'content': {
495-
ct: {'schema': content}
535+
ct: {'schema': item_schema}
496536
for ct in self.request_media_types
497537
}
498538
}
@@ -508,10 +548,15 @@ def _get_responses(self, path, method):
508548

509549
self.response_media_types = self.map_renderers(path, method)
510550

511-
item_schema = {}
512551
serializer = self._get_serializer(path, method)
513552

514-
if isinstance(serializer, serializers.Serializer):
553+
if not isinstance(serializer, serializers.Serializer):
554+
item_schema = {}
555+
elif hasattr(serializer, 'Meta') and hasattr(serializer.Meta, 'model'):
556+
# If the serializer uses a model, we should use a reference
557+
item_schema = self._get_reference(serializer)
558+
else:
559+
# There is no model, we'll map the serializer's fields
515560
item_schema = self._map_serializer(serializer)
516561
# No write_only fields for response.
517562
for name, schema in item_schema['properties'].copy().items():

tests/schemas/test_openapi.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,3 +742,18 @@ def test_schema_information_empty(self):
742742

743743
assert schema['info']['title'] == ''
744744
assert schema['info']['version'] == ''
745+
746+
def test_serializer_model(self):
747+
"""Construction of the top level dictionary."""
748+
patterns = [
749+
url(r'^example/?$', views.ExampleGenericAPIViewModel.as_view()),
750+
]
751+
generator = SchemaGenerator(patterns=patterns)
752+
753+
request = create_request('/')
754+
schema = generator.get_schema(request=request)
755+
756+
print(schema)
757+
assert 'components' in schema
758+
assert 'schemas' in schema['components']
759+
assert 'OpenAPIExample' in schema['components']['schemas']

tests/schemas/views.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
DecimalValidator, MaxLengthValidator, MaxValueValidator,
55
MinLengthValidator, MinValueValidator, RegexValidator
66
)
7+
from django.db import models
78

89
from rest_framework import generics, permissions, serializers
910
from rest_framework.decorators import action
@@ -137,3 +138,29 @@ def get(self, *args, **kwargs):
137138
url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1',
138139
ip='192.168.1.1')
139140
return Response(serializer.data)
141+
142+
143+
# Serializer with model.
144+
class OpenAPIExample(models.Model):
145+
first_name = models.CharField(max_length=30)
146+
147+
148+
class ExampleSerializerModel(serializers.Serializer):
149+
date = serializers.DateField()
150+
datetime = serializers.DateTimeField()
151+
hstore = serializers.HStoreField()
152+
uuid_field = serializers.UUIDField(default=uuid.uuid4)
153+
154+
class Meta:
155+
model = OpenAPIExample
156+
157+
158+
class ExampleGenericAPIViewModel(generics.GenericAPIView):
159+
serializer_class = ExampleSerializerModel
160+
161+
def get(self, *args, **kwargs):
162+
from datetime import datetime
163+
now = datetime.now()
164+
165+
serializer = self.get_serializer(data=now.date(), datetime=now)
166+
return Response(serializer.data)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy