Skip to content

Commit b50d895

Browse files
authored
Pass request to schema generation (encode#4383)
Pass request to schema generation
1 parent 3698d9e commit b50d895

File tree

2 files changed

+36
-31
lines changed

2 files changed

+36
-31
lines changed

rest_framework/schemas.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -65,44 +65,52 @@ def __init__(self, title=None, url=None, patterns=None, urlconf=None):
6565
urls = import_module(urlconf)
6666
else:
6767
urls = urlconf
68-
patterns = urls.urlpatterns
68+
self.patterns = urls.urlpatterns
6969
elif patterns is None and urlconf is None:
7070
urls = import_module(settings.ROOT_URLCONF)
71-
patterns = urls.urlpatterns
71+
self.patterns = urls.urlpatterns
72+
else:
73+
self.patterns = patterns
7274

7375
if url and not url.endswith('/'):
7476
url += '/'
7577

7678
self.title = title
7779
self.url = url
78-
self.endpoints = self.get_api_endpoints(patterns)
80+
self.endpoints = None
7981

8082
def get_schema(self, request=None):
81-
if request is None:
82-
endpoints = self.endpoints
83-
else:
84-
# Filter the list of endpoints to only include those that
85-
# the user has permission on.
86-
endpoints = []
87-
for key, link, callback in self.endpoints:
88-
method = link.action.upper()
89-
view = callback.cls()
83+
if self.endpoints is None:
84+
self.endpoints = self.get_api_endpoints(self.patterns)
85+
86+
links = []
87+
for key, path, method, callback in self.endpoints:
88+
view = callback.cls()
89+
for attr, val in getattr(callback, 'initkwargs', {}).items():
90+
setattr(view, attr, val)
91+
view.args = ()
92+
view.kwargs = {}
93+
view.format_kwarg = None
94+
95+
if request is not None:
9096
view.request = clone_request(request, method)
91-
view.format_kwarg = None
9297
try:
9398
view.check_permissions(view.request)
9499
except exceptions.APIException:
95-
pass
96-
else:
97-
endpoints.append((key, link, callback))
100+
continue
101+
else:
102+
view.request = None
103+
104+
link = self.get_link(path, method, callback, view)
105+
links.append((key, link))
98106

99-
if not endpoints:
107+
if not link:
100108
return None
101109

102110
# Generate the schema content structure, from the endpoints.
103111
# ('users', 'list'), Link -> {'users': {'list': Link()}}
104112
content = {}
105-
for key, link, callback in endpoints:
113+
for key, link in links:
106114
insert_into(content, key, link)
107115

108116
# Return the schema document.
@@ -122,8 +130,7 @@ def get_api_endpoints(self, patterns, prefix=''):
122130
if self.should_include_endpoint(path, callback):
123131
for method in self.get_allowed_methods(callback):
124132
key = self.get_key(path, method, callback)
125-
link = self.get_link(path, method, callback)
126-
endpoint = (key, link, callback)
133+
endpoint = (key, path, method, callback)
127134
api_endpoints.append(endpoint)
128135

129136
elif isinstance(pattern, RegexURLResolver):
@@ -190,14 +197,10 @@ def get_key(self, path, method, callback):
190197

191198
# Methods for generating each individual `Link` instance...
192199

193-
def get_link(self, path, method, callback):
200+
def get_link(self, path, method, callback, view):
194201
"""
195202
Return a `coreapi.Link` instance for the given endpoint.
196203
"""
197-
view = callback.cls()
198-
for attr, val in getattr(callback, 'initkwargs', {}).items():
199-
setattr(view, attr, val)
200-
201204
fields = self.get_path_fields(path, method, callback, view)
202205
fields += self.get_serializer_fields(path, method, callback, view)
203206
fields += self.get_pagination_fields(path, method, callback, view)
@@ -260,20 +263,18 @@ def get_serializer_fields(self, path, method, callback, view):
260263
if method not in ('PUT', 'PATCH', 'POST'):
261264
return []
262265

263-
if not hasattr(view, 'get_serializer_class'):
266+
if not hasattr(view, 'get_serializer'):
264267
return []
265268

266-
fields = []
267-
268-
serializer_class = view.get_serializer_class()
269-
serializer = serializer_class()
269+
serializer = view.get_serializer()
270270

271271
if isinstance(serializer, serializers.ListSerializer):
272-
return coreapi.Field(name='data', location='body', required=True)
272+
return [coreapi.Field(name='data', location='body', required=True)]
273273

274274
if not isinstance(serializer, serializers.Serializer):
275275
return []
276276

277+
fields = []
277278
for field in serializer.fields.values():
278279
if field.read_only:
279280
continue

tests/test_schemas.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ class ExampleViewSet(ModelViewSet):
4343
def custom_action(self, request, pk):
4444
return super(ExampleSerializer, self).retrieve(self, request)
4545

46+
def get_serializer(self, *args, **kwargs):
47+
assert self.request
48+
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
49+
4650

4751
class ExampleView(APIView):
4852
permission_classes = [permissions.IsAuthenticatedOrReadOnly]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy