Skip to content

Instantly share code, notes, and snippets.

@alyahmady
Created June 3, 2024 09:14
Show Gist options
  • Select an option

  • Save alyahmady/5496fd0f2843085306eef4f4aaf6d329 to your computer and use it in GitHub Desktop.

Select an option

Save alyahmady/5496fd0f2843085306eef4f4aaf6d329 to your computer and use it in GitHub Desktop.
OPTIONS method and metadata in DRF Spectacular
# apps/core/metadata.py
from collections import OrderedDict
from django.utils.encoding import force_str
from rest_framework import serializers
from rest_framework.metadata import SimpleMetadata
from rest_framework.request import clone_request
class CustomMetadata(SimpleMetadata):
def determine_metadata(self, request, view) -> OrderedDict:
metadata: OrderedDict = super().determine_metadata(request, view)
if hasattr(view, "search_fields"):
metadata["search_fields"] = view.search_fields
if hasattr(view, "ordering_fields"):
metadata["ordering_fields"] = view.ordering_fields
return metadata
def determine_actions(self, request, view):
actions = {}
for method in {"PUT", "POST", "GET", "PATCH", "DELETE"} & set(view.allowed_methods):
view.request = clone_request(request, method)
if method == "DELETE":
serializer_info = None
else:
serializer = view.get_serializer()
serializer_info = self.get_serializer_info(serializer)
actions[method] = serializer_info
view.request = request
return actions
def get_field_info(self, field):
field_info = OrderedDict()
field_info["type"] = self.label_lookup[field]
field_info["required"] = getattr(field, "required", False)
attrs = ["read_only", "label", "help_text", "min_length", "max_length", "min_value", "max_value"]
for attr in attrs:
value = getattr(field, attr, None)
if value:
field_info[attr] = force_str(value, strings_only=True)
if getattr(field, "child", None):
field_info["child"] = self.get_field_info(field.child)
elif getattr(field, "fields", None):
field_info["children"] = self.get_serializer_info(field)
if not isinstance(field, serializers.RelatedField | serializers.ManyRelatedField) and hasattr(field, "choices"):
field_info["choices"] = [
{"value": choice_value, "display_name": force_str(choice_name, strings_only=True)}
for choice_value, choice_name in field.choices.items()
]
return field_info
# PROJECT_NAME/schema.py
from drf_spectacular.contrib.rest_framework_simplejwt import SimpleJWTScheme
from drf_spectacular.generators import EndpointEnumerator, SchemaGenerator
from drf_spectacular.openapi import AutoSchema
class CustomEndpointEnumerator(EndpointEnumerator):
def get_allowed_methods(self, callback):
if hasattr(callback, "actions"):
actions = set(callback.actions)
http_method_names = set(callback.cls.http_method_names)
methods = [method.upper() for method in actions & http_method_names]
else:
methods = callback.cls().allowed_methods
return [method for method in methods if method not in {"HEAD", "TRACE", "CONNECT"}]
class CustomSchemaGenerator(SchemaGenerator):
endpoint_inspector_cls = CustomEndpointEnumerator
class CustomSchema(AutoSchema):
method_mapping = {
"get": "retrieve",
"post": "create",
"put": "update",
"patch": "partial_update",
"delete": "destroy",
"options": "metadata",
}
# PROJECT_NAME/settings.py
REST_FRAMEWORK = {
...
"DEFAULT_METADATA_CLASS": "apps.core.metadata.CustomMetadata",
"DEFAULT_SCHEMA_CLASS": "PROJECT_NAME.schema.CustomSchema",
...
}
SPECTACULAR_SETTINGS = {
...
"DEFAULT_GENERATOR_CLASS": "PROJECT_NAME.schema.CustomSchemaGenerator",
...
}
# apps/core/utils.py
from typing import Literal
from rest_framework.generics import GenericAPIView
from rest_framework.viewsets import GenericViewSet
def get_view_action(
view: GenericViewSet | GenericAPIView,
) -> Literal["metadata", "list", "retrieve", "create", "update", "partial_update", "destroy"]:
try:
method = view.request.method.lower()
assert method
except (AttributeError, AssertionError):
view_allowed_methods = getattr(view, "allowed_methods", None) or [""]
method = view_allowed_methods[0].lower()
try:
action = view.action
assert action
assert action != "metadata" or method == "options"
except (AttributeError, AssertionError):
view_action_map = getattr(view, "action_map", None) or {}
action = view_action_map.get(method)
return action or ("metadata" if method == "options" else None)
# apps/core/views.py
from typing import TYPE_CHECKING
from rest_framework.decorators import action
from rest_framework.permissions import AllowAny
if TYPE_CHECKING:
from rest_framework.generics import GenericAPIView
from rest_framework.viewsets import GenericViewSet
from apps.core.utils import get_view_action
class AllowedOptionsMixin:
# To be used by APIView and GenericViewSet
# To expose OPTIONS method for all actions (specially in Swagger)
def get_authenticators(self):
self: GenericViewSet | GenericAPIView | AllowedOptionsMixin
action = get_view_action(self)
if action == "metadata":
return ()
return super().get_authenticators()
def get_permissions(self):
self: GenericViewSet | GenericAPIView | AllowedOptionsMixin
action = get_view_action(self)
if action == "metadata":
return (AllowAny(),)
return super().get_permissions()
@action(detail=False, methods=["OPTIONS"])
def metadata(self, request, *args, **kwargs):
return super().options(request, *args, **kwargs)
# This is an example of ViewSet
class SampleViewSet(
AllowedOptionsMixin,
# UpdateModelMixin,
GenericViewSet
):
# Consider to include `"options": "metadata"` in "action_map" of `as_view` method,
# when you define ViewSet in URL patterns
...
# Consider to include "options" in "http_method_names",
# if you want to expose OPTIONS method, and "http_method_names" is defined
http_method_names = (
# "put",
# "patch",
"options"
)
# This is an example of ViewSet in URL patterns
urlpatterns = [
path(
"...",
SampleViewSet.as_view({
# "put": "update",
# "patch": "partial_update",
"options": "metadata"
}),
name="...",
),
# This is an example of APIView
class SampleAPIView(AllowedOptionsMixin, GenericAPIView):
...
# Consider to include "options" in "http_method_names",
# if you want to expose OPTIONS method, and "http_method_names" is defined
http_method_names = ("get", "options")
def get(self, request, *args, **kwargs):
...
return Response()
# This is an example of APIView in URL patterns
urlpatterns = [
path(
"...",
SampleViewSet.as_view({
# "put": "update",
# "patch": "partial_update",
"options": "metadata"
}),
name="...",
),
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment