Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle JSON in Zone POST requests #538

Merged
merged 4 commits into from
May 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions mreg/api/v1/views_zones.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
import django.core.exceptions

from django.db import transaction
Expand All @@ -7,6 +8,7 @@
from rest_framework import (generics, renderers, status)
from rest_framework.decorators import (api_view, renderer_classes)
from rest_framework.exceptions import ParseError
from rest_framework.request import Request
from rest_framework.response import Response

from mreg.models.base import NameServer
Expand Down Expand Up @@ -55,6 +57,13 @@ def _validate_nameservers(names):
done.add(name)


def _get_request_nameservers(request: Request, field: str = "primary_ns") -> List[str]:
"""Extract nameservers from the request data."""
if request.content_type == "application/json":
return request.data.get(field, [])
return request.data.getlist(field, [])


class ZoneList(generics.ListCreateAPIView):
"""
get:
Expand All @@ -72,13 +81,13 @@ def get_queryset(self):
qs = super().get_queryset()
return self.filterset(data=self.request.GET, queryset=qs).qs

def post(self, request, *args, **kwargs):
def post(self, request: Request, *args, **kwargs):
qs = self.get_queryset()
if qs.filter(name=request.data["name"]).exists():
content = {'ERROR': 'Zone name already in use'}
return Response(content, status=status.HTTP_409_CONFLICT)
# A copy is required since the original is immutable
nameservers = request.data.getlist('primary_ns')
nameservers = _get_request_nameservers(request)
_validate_nameservers(nameservers)
data = request.data.copy()
data['primary_ns'] = nameservers[0]
Expand Down Expand Up @@ -121,13 +130,12 @@ def get_queryset(self):
self.queryset = self.parentzone.delegations.all().order_by('id')
return self.filterset(data=self.request.GET, queryset=self.queryset).qs

def post(self, request, *args, **kwargs):
def post(self, request: Request, *args, **kwargs):
qs = self.get_queryset()
if qs.filter(name=request.data[self.lookup_field]).exists():
content = {'ERROR': 'Zone name already in use'}
return Response(content, status=status.HTTP_409_CONFLICT)

nameservers = request.data.getlist('nameservers')
nameservers = _get_request_nameservers(request, "nameservers")
_validate_nameservers(nameservers)
data = request.data.copy()
data['zone'] = self.parentzone.pk
Expand Down Expand Up @@ -292,14 +300,14 @@ def get(self, request, *args, **kwargs):
zone = self.get_object()
return Response([ns.name for ns in zone.nameservers.all()], status=status.HTTP_200_OK)

def patch(self, request, *args, **kwargs):
def patch(self, request: Request, *args, **kwargs):
if 'primary_ns' not in request.data:
return Response({'ERROR': 'No nameserver found in body'}, status=status.HTTP_400_BAD_REQUEST)
zone = self.get_object()
nameservers = request.data.getlist('primary_ns')
nameservers = _get_request_nameservers(request)
_validate_nameservers(nameservers)
zone.update_nameservers(nameservers)
zone.primary_ns = request.data.getlist('primary_ns')[0]
zone.primary_ns = nameservers[0]
zone.updated = True
self.perform_update(zone)
return Response(status=status.HTTP_204_NO_CONTENT, headers={'Location': request.path})
Expand Down
Loading