Extension Blueprints

Blueprints are a collection of schema fixes for Django and REST Framework apps. Some libraries/apps do not play well with drf-spectacular’s automatic introspection. With extensions you can manually provide the necessary information to generate a better schema.

There is no blueprint for the app you are looking for? No problem, you can easily write extensions yourself. Take the blueprints here as examples and have a look at Workflow & schema customization. Feel free to contribute new ones or fixes with a PR. Blueprint files can be found here.

Note

Simply copy&paste the snippets into your codebase. The extensions register themselves automatically. Just be sure that the python interpreter sees them at least once. To that end, we suggest creating a PROJECT/schema.py file and importing it in your PROJECT/__init__.py (same directory as settings.py and urls.py) with import PROJECT.schema. Now you are all set.

dj-stripe

Stripe Models for Django: dj-stripe

from djstripe.contrib.rest_framework.serializers import (
    CreateSubscriptionSerializer, SubscriptionSerializer
)

from drf_spectacular.extensions import OpenApiViewExtension
from drf_spectacular.utils import extend_schema


class FixDjstripeSubscriptionRestView(OpenApiViewExtension):
    target_class = 'djstripe.contrib.rest_framework.views.SubscriptionRestView'

    def view_replacement(self):
        class Fixed(self.target_class):
            serializer_class = SubscriptionSerializer

            @extend_schema(
                request=CreateSubscriptionSerializer,
                responses=CreateSubscriptionSerializer
            )
            def post(self, request, *args, **kwargs):
                pass

        return Fixed

django-oscar-api

RESTful API for django-oscar: django-oscar-api

from rest_framework import serializers

from drf_spectacular.extensions import (
    OpenApiSerializerExtension, OpenApiSerializerFieldExtension, OpenApiViewExtension
)
from drf_spectacular.plumbing import build_basic_type
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_field


class Fix1(OpenApiViewExtension):
    target_class = 'oscarapi.views.root.api_root'

    def view_replacement(self):
        return extend_schema(responses=OpenApiTypes.OBJECT)(self.target_class)


class Fix2(OpenApiViewExtension):
    target_class = 'oscarapi.views.product.ProductAvailability'

    def view_replacement(self):
        from oscarapi.serializers.product import AvailabilitySerializer

        class Fixed(self.target_class):
            serializer_class = AvailabilitySerializer
        return Fixed


class Fix3(OpenApiViewExtension):
    target_class = 'oscarapi.views.product.ProductPrice'

    def view_replacement(self):
        from oscarapi.serializers.checkout import PriceSerializer

        class Fixed(self.target_class):
            serializer_class = PriceSerializer
        return Fixed


class Fix4(OpenApiViewExtension):
    target_class = 'oscarapi.views.checkout.UserAddressDetail'

    def view_replacement(self):
        from oscar.apps.address.models import UserAddress

        class Fixed(self.target_class):
            queryset = UserAddress.objects.none()
        return Fixed


class Fix5(OpenApiViewExtension):
    target_class = 'oscarapi.views.product.CategoryList'

    def view_replacement(self):
        class Fixed(self.target_class):
            @extend_schema(parameters=[
                OpenApiParameter(name='breadcrumbs', type=OpenApiTypes.STR, location=OpenApiParameter.PATH)
            ])
            def get(self, request, *args, **kwargs):
                pass

        return Fixed


class Fix6(OpenApiSerializerExtension):
    target_class = 'oscarapi.serializers.checkout.OrderSerializer'

    def map_serializer(self, auto_schema, direction):
        from oscarapi.serializers.checkout import OrderOfferDiscountSerializer, OrderVoucherOfferSerializer

        class Fixed(self.target_class):
            @extend_schema_field(OrderOfferDiscountSerializer(many=True))
            def get_offer_discounts(self):
                pass

            @extend_schema_field(OpenApiTypes.URI)
            def get_payment_url(self):
                pass

            @extend_schema_field(OrderVoucherOfferSerializer(many=True))
            def get_voucher_discounts(self):
                pass

        return auto_schema._map_serializer(Fixed, direction)


class Fix7(OpenApiSerializerFieldExtension):
    target_class = 'oscarapi.serializers.fields.CategoryField'

    def map_serializer_field(self, auto_schema, direction):
        return build_basic_type(OpenApiTypes.STR)


class Fix8(OpenApiSerializerFieldExtension):
    target_class = 'oscarapi.serializers.fields.AttributeValueField'

    def map_serializer_field(self, auto_schema, direction):
        return {
            'oneOf': [
                build_basic_type(OpenApiTypes.STR),
            ]
        }


class Fix9(OpenApiSerializerExtension):
    target_class = 'oscarapi.serializers.basket.BasketSerializer'

    def map_serializer(self, auto_schema, direction):
        class Fixed(self.target_class):
            is_tax_known = serializers.SerializerMethodField()

            def get_is_tax_known(self) -> bool:
                pass

        return auto_schema._map_serializer(Fixed, direction)


class Fix10(Fix9):
    target_class = 'oscarapi.serializers.basket.BasketLineSerializer'

djangorestframework-api-key

Since djangorestframework-api-key has no entry in authentication_classes, drf-spectacular cannot pick up this library. To alleviate this shortcoming, you can manually add the appropriate security scheme.

Note

Usage of the SECURITY setting is discouraged, unless there are special circumstances like here for example. For almost all cases OpenApiAuthenticationExtension is strongly preferred, because SECURITY will get appended to every endpoint in the schema regardless of effectiveness.

SPECTACULAR_SETTINGS = {
    "APPEND_COMPONENTS": {
        "securitySchemes": {
            "ApiKeyAuth": {
                "type": "apiKey",
                "in": "header",
                "name": "Authorization"
            }
        }
    },
    "SECURITY": [{"ApiKeyAuth": [], }],
     ...
}

Polymorphic models

Using polymorphic models/serializers unfortunately yields flat serializers due to the way the serializers are constructed. This means the polymorphic serializers have no inheritance hierarchy that represents common functionality. These extensions retroactively build a hierarchy by rolling up the “common denominator” fields into the base components, and importing those into the sub-components via allOf. This results in components that better represent the structure of the underlying serializers/models from which they originated.

The components work perfectly fine without this extension, but in some cases generated client code has a hard time with the disjunctive nature of the unmodified components. This blueprint is designed to fix that issue.

from drf_spectacular.contrib.rest_polymorphic import PolymorphicSerializerExtension
from drf_spectacular.plumbing import ResolvedComponent
from drf_spectacular.serializers import PolymorphicProxySerializerExtension
from drf_spectacular.settings import spectacular_settings


class RollupMixin:
    """
    This is a schema helper that pulls the "common denominator" fields from child
    components into their parent component. It only applies to PolymorphicSerializer
    as well as PolymorphicProxySerializer, where there is an (implicit) inheritance hierarchy.

    The actual functionality is realized via extensions defined below.
    """
    def map_serializer(self, auto_schema, direction):
        schema = super().map_serializer(auto_schema, direction)

        if isinstance(self, PolymorphicProxySerializerExtension):
            sub_serializers = self.target.serializers
        else:
            sub_serializers = [
                self.target._get_serializer_from_model_or_instance(sub_model)
                for sub_model in self.target.model_serializer_mapping
            ]

        resolved_sub_serializers = [
            auto_schema.resolve_serializer(sub, direction) for sub in sub_serializers
        ]
        # this will only be generated on return of map_serializer so mock it for now
        mocked_component = ResolvedComponent(
            name=auto_schema._get_serializer_name(self.target, direction),
            type=ResolvedComponent.SCHEMA,
            object=self.target,
            schema=schema
        )

        # hack for recursive models. at the time of extension execution, not all sub
        # serializer schema have been generated, so no rollup is possible.
        # by registering a local variable scoped postproc hook, we delay this
        # execution to the end where all schemas are present.
        def postprocessing_rollup_hook(generator, result, **kwargs):
            rollup_properties(mocked_component, resolved_sub_serializers)
            result['components'] = generator.registry.build({})
            return result

        # register postproc hook. must run before enum postproc due to rebuilding the registry
        spectacular_settings.POSTPROCESSING_HOOKS.insert(0, postprocessing_rollup_hook)
        # and do nothing for now
        return schema


def rollup_properties(component, resolved_sub_serializers):
    # rollup already happened (spectacular bug and normally not needed)
    if any('allOf' in r.schema for r in resolved_sub_serializers):
        return

    all_field_sets = [
        set(list(r.schema['properties'])) for r in resolved_sub_serializers
    ]
    common_fields = all_field_sets[0].intersection(*all_field_sets[1:])
    common_schema = {
        'properties': {},
        'required': set(),
    }

    # substitute sub serializers' common fields with base class
    for r in resolved_sub_serializers:
        for cf in sorted(common_fields):
            if cf in r.schema['properties']:
                common_schema['properties'][cf] = r.schema['properties'][cf]
                del r.schema['properties'][cf]
                if cf in r.schema.get('required', []):
                    common_schema['required'].add(cf)
        r.schema = {'allOf': [component.ref, r.schema]}

    # modify regular schema for field rollup
    del component.schema['oneOf']
    component.schema['properties'] = common_schema['properties']
    if common_schema['required']:
        component.schema['required'] = sorted(common_schema['required'])


class PolymorphicRollupSerializerExtension(RollupMixin, PolymorphicSerializerExtension):
    priority = 1


class PolymorphicProxyRollupSerializerExtension(RollupMixin, PolymorphicProxySerializerExtension):
    priority = 1

RapiDoc

RapiDoc is documentation tool that can be used as an alternate to Redoc or Swagger UI.

from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.views import APIView

from drf_spectacular.plumbing import get_relative_url, set_query_parameters
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.utils import extend_schema
from drf_spectacular.views import AUTHENTICATION_CLASSES


class SpectacularRapiDocView(APIView):
    renderer_classes = [TemplateHTMLRenderer]
    permission_classes = spectacular_settings.SERVE_PERMISSIONS
    authentication_classes = AUTHENTICATION_CLASSES
    url_name = 'schema'
    url = None
    template_name = 'rapidoc.html'
    title = spectacular_settings.TITLE

    @extend_schema(exclude=True)
    def get(self, request, *args, **kwargs):
        schema_url = self.url or get_relative_url(reverse(self.url_name, request=request))
        schema_url = set_query_parameters(schema_url, lang=request.GET.get('lang'))
        return Response(
            data={
                'title': self.title,
                'dist': 'https://cdn.jsdelivr.net/npm/rapidoc@latest',
                'schema_url': schema_url,
            },
            template_name=self.template_name,
        )
<!DOCTYPE html>
<html>
  <head>
    <title>{{ title|default:"RapiDoc" }}</title>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <script type="module" src="{{ dist }}/dist/rapidoc-min.js"></script>
  </head>
  <body>
    <rapi-doc spec-url="{{ schema_url }}"></rapi-doc>
  </body>
</html>

drf-rw-serializers

drf-rw-serializers provides generic views, viewsets and mixins that extend the Django REST Framework ones adding separated serializers for read and write operations.

drf-spectacular requires just a small AutoSchema augmentation to make it aware of drf-rw-serializers. Remember to replace the AutoSchema in DEFAULT_SCHEMA_CLASS.

from drf_rw_serializers.generics import GenericAPIView as RWGenericAPIView

from drf_spectacular.openapi import AutoSchema


class CustomAutoSchema(AutoSchema):
    """ Utilize custom drf_rw_serializers methods for directional serializers """

    def get_request_serializer(self):
        if isinstance(self.view, RWGenericAPIView):
            return self.view.get_write_serializer_class()()
        return self._get_serializer()

    def get_response_serializers(self):
        if isinstance(self.view, RWGenericAPIView):
            return self.view.get_read_serializer_class()()
        return self._get_serializer()

drf-extra-fields Base64FileField

drf-extra-fields provides a Base64FileField and Base64ImageField that automatically represent binary files as base64 encoded strings. This is a useful way to embed files within a larger JSON API and keep all data within the same tree and served with a single request or response.

Because requests to these fields require a base64 encoded string and responses can be either a URI or base64 contents (if represent_as_base64=True) custom schema generation logic is required as this differs from the default DRF FileField.

from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.plumbing import build_basic_type
from drf_spectacular.types import OpenApiTypes


class Base64FileFieldSchema(OpenApiSerializerFieldExtension):
    target_class = "drf_extra_fields.fields.Base64FileField"

    def map_serializer_field(self, auto_schema, direction):
        if direction == "request":
            return build_basic_type(OpenApiTypes.BYTE)
        elif direction == "response":
            if self.target.represent_in_base64:
                return build_basic_type(OpenApiTypes.BYTE)
            else:
                return build_basic_type(OpenApiTypes.URI)


class Base64ImageFieldSchema(Base64FileFieldSchema):
    target_class = "drf_extra_fields.fields.Base64ImageField"