@@ -65,44 +65,52 @@ def __init__(self, title=None, url=None, patterns=None, urlconf=None):
65
65
urls = import_module (urlconf )
66
66
else :
67
67
urls = urlconf
68
- patterns = urls .urlpatterns
68
+ self . patterns = urls .urlpatterns
69
69
elif patterns is None and urlconf is None :
70
70
urls = import_module (settings .ROOT_URLCONF )
71
- patterns = urls .urlpatterns
71
+ self .patterns = urls .urlpatterns
72
+ else :
73
+ self .patterns = patterns
72
74
73
75
if url and not url .endswith ('/' ):
74
76
url += '/'
75
77
76
78
self .title = title
77
79
self .url = url
78
- self .endpoints = self . get_api_endpoints ( patterns )
80
+ self .endpoints = None
79
81
80
82
def get_schema (self , request = None ):
81
- if request is None :
82
- endpoints = self .endpoints
83
- else :
84
- # Filter the list of endpoints to only include those that
85
- # the user has permission on.
86
- endpoints = []
87
- for key , link , callback in self .endpoints :
88
- method = link .action .upper ()
89
- view = callback .cls ()
83
+ if self .endpoints is None :
84
+ self .endpoints = self .get_api_endpoints (self .patterns )
85
+
86
+ links = []
87
+ for key , path , method , callback in self .endpoints :
88
+ view = callback .cls ()
89
+ for attr , val in getattr (callback , 'initkwargs' , {}).items ():
90
+ setattr (view , attr , val )
91
+ view .args = ()
92
+ view .kwargs = {}
93
+ view .format_kwarg = None
94
+
95
+ if request is not None :
90
96
view .request = clone_request (request , method )
91
- view .format_kwarg = None
92
97
try :
93
98
view .check_permissions (view .request )
94
99
except exceptions .APIException :
95
- pass
96
- else :
97
- endpoints .append ((key , link , callback ))
100
+ continue
101
+ else :
102
+ view .request = None
103
+
104
+ link = self .get_link (path , method , callback , view )
105
+ links .append ((key , link ))
98
106
99
- if not endpoints :
107
+ if not link :
100
108
return None
101
109
102
110
# Generate the schema content structure, from the endpoints.
103
111
# ('users', 'list'), Link -> {'users': {'list': Link()}}
104
112
content = {}
105
- for key , link , callback in endpoints :
113
+ for key , link in links :
106
114
insert_into (content , key , link )
107
115
108
116
# Return the schema document.
@@ -122,8 +130,7 @@ def get_api_endpoints(self, patterns, prefix=''):
122
130
if self .should_include_endpoint (path , callback ):
123
131
for method in self .get_allowed_methods (callback ):
124
132
key = self .get_key (path , method , callback )
125
- link = self .get_link (path , method , callback )
126
- endpoint = (key , link , callback )
133
+ endpoint = (key , path , method , callback )
127
134
api_endpoints .append (endpoint )
128
135
129
136
elif isinstance (pattern , RegexURLResolver ):
@@ -190,14 +197,10 @@ def get_key(self, path, method, callback):
190
197
191
198
# Methods for generating each individual `Link` instance...
192
199
193
- def get_link (self , path , method , callback ):
200
+ def get_link (self , path , method , callback , view ):
194
201
"""
195
202
Return a `coreapi.Link` instance for the given endpoint.
196
203
"""
197
- view = callback .cls ()
198
- for attr , val in getattr (callback , 'initkwargs' , {}).items ():
199
- setattr (view , attr , val )
200
-
201
204
fields = self .get_path_fields (path , method , callback , view )
202
205
fields += self .get_serializer_fields (path , method , callback , view )
203
206
fields += self .get_pagination_fields (path , method , callback , view )
@@ -260,20 +263,18 @@ def get_serializer_fields(self, path, method, callback, view):
260
263
if method not in ('PUT' , 'PATCH' , 'POST' ):
261
264
return []
262
265
263
- if not hasattr (view , 'get_serializer_class ' ):
266
+ if not hasattr (view , 'get_serializer ' ):
264
267
return []
265
268
266
- fields = []
267
-
268
- serializer_class = view .get_serializer_class ()
269
- serializer = serializer_class ()
269
+ serializer = view .get_serializer ()
270
270
271
271
if isinstance (serializer , serializers .ListSerializer ):
272
- return coreapi .Field (name = 'data' , location = 'body' , required = True )
272
+ return [ coreapi .Field (name = 'data' , location = 'body' , required = True )]
273
273
274
274
if not isinstance (serializer , serializers .Serializer ):
275
275
return []
276
276
277
+ fields = []
277
278
for field in serializer .fields .values ():
278
279
if field .read_only :
279
280
continue
0 commit comments