๐Ÿ“ฆ encode / django-rest-framework

๐Ÿ“„ test_decorators.py ยท 457 lines
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457import sys

import pytest
from django.test import TestCase

from rest_framework import status
from rest_framework.authentication import BasicAuthentication
from rest_framework.decorators import (
    action, api_view, authentication_classes, content_negotiation_class,
    metadata_class, parser_classes, permission_classes, renderer_classes,
    schema, throttle_classes, versioning_class
)
from rest_framework.negotiation import BaseContentNegotiation
from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated
from rest_framework.renderers import JSONRenderer
from rest_framework.response import Response
from rest_framework.schemas import AutoSchema
from rest_framework.test import APIRequestFactory
from rest_framework.throttling import UserRateThrottle
from rest_framework.versioning import QueryParameterVersioning
from rest_framework.views import APIView


class DecoratorTestCase(TestCase):

    def setUp(self):
        self.factory = APIRequestFactory()

    def test_api_view_incorrect(self):
        """
        If @api_view is not applied correct, we should raise an assertion.
        """

        @api_view
        def view(request):
            return Response()

        request = self.factory.get('/')
        self.assertRaises(AssertionError, view, request)

    def test_api_view_incorrect_arguments(self):
        """
        If @api_view is missing arguments, we should raise an assertion.
        """

        with self.assertRaises(AssertionError):
            @api_view('GET')
            def view(request):
                return Response()

    def test_calling_method(self):

        @api_view(['GET'])
        def view(request):
            return Response({})

        request = self.factory.get('/')
        response = view(request)
        assert response.status_code == status.HTTP_200_OK

        request = self.factory.post('/')
        response = view(request)
        assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED

    def test_calling_put_method(self):

        @api_view(['GET', 'PUT'])
        def view(request):
            return Response({})

        request = self.factory.put('/')
        response = view(request)
        assert response.status_code == status.HTTP_200_OK

        request = self.factory.post('/')
        response = view(request)
        assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED

    def test_calling_patch_method(self):

        @api_view(['GET', 'PATCH'])
        def view(request):
            return Response({})

        request = self.factory.patch('/')
        response = view(request)
        assert response.status_code == status.HTTP_200_OK

        request = self.factory.post('/')
        response = view(request)
        assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED

    def test_renderer_classes(self):

        @api_view(['GET'])
        @renderer_classes([JSONRenderer])
        def view(request):
            return Response({})

        request = self.factory.get('/')
        response = view(request)
        assert isinstance(response.accepted_renderer, JSONRenderer)

    def test_parser_classes(self):

        @api_view(['GET'])
        @parser_classes([JSONParser])
        def view(request):
            assert len(request.parsers) == 1
            assert isinstance(request.parsers[0], JSONParser)
            return Response({})

        request = self.factory.get('/')
        view(request)

    def test_authentication_classes(self):

        @api_view(['GET'])
        @authentication_classes([BasicAuthentication])
        def view(request):
            assert len(request.authenticators) == 1
            assert isinstance(request.authenticators[0], BasicAuthentication)
            return Response({})

        request = self.factory.get('/')
        view(request)

    def test_permission_classes(self):

        @api_view(['GET'])
        @permission_classes([IsAuthenticated])
        def view(request):
            return Response({})

        request = self.factory.get('/')
        response = view(request)
        assert response.status_code == status.HTTP_403_FORBIDDEN

    def test_throttle_classes(self):
        class OncePerDayUserThrottle(UserRateThrottle):
            rate = '1/day'

        @api_view(['GET'])
        @throttle_classes([OncePerDayUserThrottle])
        def view(request):
            return Response({})

        request = self.factory.get('/')
        response = view(request)
        assert response.status_code == status.HTTP_200_OK

        response = view(request)
        assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS

    def test_versioning_class(self):
        @api_view(["GET"])
        @versioning_class(QueryParameterVersioning)
        def view(request):
            return Response({"version": request.version})

        request = self.factory.get("/?version=1.2.3")
        response = view(request)
        assert response.data == {"version": "1.2.3"}

    def test_metadata_class(self):
        # From TestMetadata.test_none_metadata()
        @api_view()
        @metadata_class(None)
        def view(request):
            return Response({})

        request = self.factory.options('/')
        response = view(request)
        assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
        assert response.data == {'detail': 'Method "OPTIONS" not allowed.'}

    def test_content_negotiation(self):
        class CustomContentNegotiation(BaseContentNegotiation):
            def select_renderer(self, request, renderers, format_suffix):
                assert request.META['HTTP_ACCEPT'] == 'custom/type'
                return (renderers[0], renderers[0].media_type)

        @api_view(["GET"])
        @content_negotiation_class(CustomContentNegotiation)
        def view(request):
            return Response({})

        request = self.factory.get('/', HTTP_ACCEPT='custom/type')
        response = view(request)
        assert response.status_code == status.HTTP_200_OK

    def test_schema(self):
        """
        Checks CustomSchema class is set on view
        """
        class CustomSchema(AutoSchema):
            pass

        @api_view(['GET'])
        @schema(CustomSchema())
        def view(request):
            return Response({})

        assert isinstance(view.cls.schema, CustomSchema)

    def test_incorrect_decorator_order_permission_classes(self):
        """
        If @permission_classes is applied after @api_view, we should raise a TypeError.
        """
        with self.assertRaises(TypeError) as cm:
            @permission_classes([IsAuthenticated])
            @api_view(['GET'])
            def view(request):
                return Response({})

        assert '@permission_classes must come after (below) the @api_view decorator' in str(cm.exception)

    def test_incorrect_decorator_order_renderer_classes(self):
        """
        If @renderer_classes is applied after @api_view, we should raise a TypeError.
        """
        with self.assertRaises(TypeError) as cm:
            @renderer_classes([JSONRenderer])
            @api_view(['GET'])
            def view(request):
                return Response({})

        assert '@renderer_classes must come after (below) the @api_view decorator' in str(cm.exception)

    def test_incorrect_decorator_order_parser_classes(self):
        """
        If @parser_classes is applied after @api_view, we should raise a TypeError.
        """
        with self.assertRaises(TypeError) as cm:
            @parser_classes([JSONParser])
            @api_view(['GET'])
            def view(request):
                return Response({})

        assert '@parser_classes must come after (below) the @api_view decorator' in str(cm.exception)

    def test_incorrect_decorator_order_authentication_classes(self):
        """
        If @authentication_classes is applied after @api_view, we should raise a TypeError.
        """
        with self.assertRaises(TypeError) as cm:
            @authentication_classes([BasicAuthentication])
            @api_view(['GET'])
            def view(request):
                return Response({})

        assert '@authentication_classes must come after (below) the @api_view decorator' in str(cm.exception)

    def test_incorrect_decorator_order_throttle_classes(self):
        """
        If @throttle_classes is applied after @api_view, we should raise a TypeError.
        """
        class OncePerDayUserThrottle(UserRateThrottle):
            rate = '1/day'

        with self.assertRaises(TypeError) as cm:
            @throttle_classes([OncePerDayUserThrottle])
            @api_view(['GET'])
            def view(request):
                return Response({})

        assert '@throttle_classes must come after (below) the @api_view decorator' in str(cm.exception)

    def test_incorrect_decorator_order_versioning_class(self):
        """
        If @versioning_class is applied after @api_view, we should raise a TypeError.
        """
        with self.assertRaises(TypeError) as cm:
            @versioning_class(QueryParameterVersioning)
            @api_view(['GET'])
            def view(request):
                return Response({})

        assert '@versioning_class must come after (below) the @api_view decorator' in str(cm.exception)

    def test_incorrect_decorator_order_metadata_class(self):
        """
        If @metadata_class is applied after @api_view, we should raise a TypeError.
        """
        with self.assertRaises(TypeError) as cm:
            @metadata_class(None)
            @api_view(['GET'])
            def view(request):
                return Response({})

        assert '@metadata_class must come after (below) the @api_view decorator' in str(cm.exception)

    def test_incorrect_decorator_order_content_negotiation_class(self):
        """
        If @content_negotiation_class is applied after @api_view, we should raise a TypeError.
        """
        class CustomContentNegotiation(BaseContentNegotiation):
            def select_renderer(self, request, renderers, format_suffix):
                return (renderers[0], renderers[0].media_type)

        with self.assertRaises(TypeError) as cm:
            @content_negotiation_class(CustomContentNegotiation)
            @api_view(['GET'])
            def view(request):
                return Response({})

        assert '@content_negotiation_class must come after (below) the @api_view decorator' in str(cm.exception)

    def test_incorrect_decorator_order_schema(self):
        """
        If @schema is applied after @api_view, we should raise a TypeError.
        """
        class CustomSchema(AutoSchema):
            pass

        with self.assertRaises(TypeError) as cm:
            @schema(CustomSchema())
            @api_view(['GET'])
            def view(request):
                return Response({})

        assert '@schema must come after (below) the @api_view decorator' in str(cm.exception)


class ActionDecoratorTestCase(TestCase):

    def test_defaults(self):
        @action(detail=True)
        def test_action(request):
            """Description"""

        assert test_action.mapping == {'get': 'test_action'}
        assert test_action.detail is True
        assert test_action.url_path == 'test_action'
        assert test_action.url_name == 'test-action'
        assert test_action.kwargs == {
            'name': 'Test action',
            'description': 'Description',
        }

    def test_detail_required(self):
        with pytest.raises(AssertionError) as excinfo:
            @action()
            def test_action(request):
                raise NotImplementedError

        assert str(excinfo.value) == "@action() missing required argument: 'detail'"

    @pytest.mark.skipif(sys.version_info < (3, 11), reason="HTTPMethod was added in Python 3.11")
    def test_method_mapping_http_method(self):
        from http import HTTPMethod

        method_names = [getattr(HTTPMethod, name.upper()) for name in APIView.http_method_names]

        @action(detail=False, methods=method_names)
        def test_action():
            raise NotImplementedError

        expected_mapping = {name: test_action.__name__ for name in APIView.http_method_names}

        assert test_action.mapping == expected_mapping

    def test_method_mapping_http_methods(self):
        # All HTTP methods should be mappable
        @action(detail=False, methods=[])
        def test_action():
            raise NotImplementedError

        for name in APIView.http_method_names:
            def method():
                raise NotImplementedError

            method.__name__ = name
            getattr(test_action.mapping, name)(method)

        # ensure the mapping returns the correct method name
        for name in APIView.http_method_names:
            assert test_action.mapping[name] == name

    def test_view_name_kwargs(self):
        """
        'name' and 'suffix' are mutually exclusive kwargs used for generating
        a view's display name.
        """
        # by default, generate name from method
        @action(detail=True)
        def test_action(request):
            raise NotImplementedError

        assert test_action.kwargs == {
            'description': None,
            'name': 'Test action',
        }

        # name kwarg supersedes name generation
        @action(detail=True, name='test name')
        def test_action(request):
            raise NotImplementedError

        assert test_action.kwargs == {
            'description': None,
            'name': 'test name',
        }

        # suffix kwarg supersedes name generation
        @action(detail=True, suffix='Suffix')
        def test_action(request):
            raise NotImplementedError

        assert test_action.kwargs == {
            'description': None,
            'suffix': 'Suffix',
        }

        # name + suffix is a conflict.
        with pytest.raises(TypeError) as excinfo:
            action(detail=True, name='test name', suffix='Suffix')

        assert str(excinfo.value) == "`name` and `suffix` are mutually exclusive arguments."

    def test_method_mapping(self):
        @action(detail=False)
        def test_action(request):
            raise NotImplementedError

        @test_action.mapping.post
        def test_action_post(request):
            raise NotImplementedError

        # The secondary handler methods should not have the action attributes
        for name in ['mapping', 'detail', 'url_path', 'url_name', 'kwargs']:
            assert hasattr(test_action, name) and not hasattr(test_action_post, name)

    def test_method_mapping_already_mapped(self):
        @action(detail=True)
        def test_action(request):
            raise NotImplementedError

        msg = "Method 'get' has already been mapped to '.test_action'."
        with self.assertRaisesMessage(AssertionError, msg):
            @test_action.mapping.get
            def test_action_get(request):
                raise NotImplementedError

    def test_method_mapping_overwrite(self):
        @action(detail=True)
        def test_action():
            raise NotImplementedError

        msg = ("Method mapping does not behave like the property decorator. You "
               "cannot use the same method name for each mapping declaration.")
        with self.assertRaisesMessage(AssertionError, msg):
            @test_action.mapping.post
            def test_action():
                raise NotImplementedError