Appended an _id suffix to all related field names in the rest api

Fixes #1597
This commit is contained in:
Oskar Hahn 2015-07-07 10:09:35 +02:00
parent 66f45ecd1f
commit 89a6d5b451
2 changed files with 86 additions and 6 deletions

View File

@ -1,4 +1,5 @@
import re
from collections import OrderedDict
from urllib.parse import urlparse
from rest_framework.decorators import detail_route, list_route # noqa
@ -6,14 +7,16 @@ from rest_framework.metadata import SimpleMetadata # noqa
from rest_framework.mixins import DestroyModelMixin, UpdateModelMixin # noqa
from rest_framework.response import Response # noqa
from rest_framework.routers import DefaultRouter
from rest_framework.serializers import ModelSerializer as _ModelSerializer
from rest_framework.serializers import ( # noqa
MANY_RELATION_KWARGS,
CharField,
DictField,
Field,
IntegerField,
ListField,
ListSerializer,
ModelSerializer,
ManyRelatedField,
PrimaryKeyRelatedField,
RelatedField,
SerializerMethodField,
@ -30,6 +33,83 @@ from .exceptions import OpenSlidesError
router = DefaultRouter()
class IdManyRelatedField(ManyRelatedField):
"""
ManyRelatedField that appends an suffix to the sub-fields.
Only works together with the IdPrimaryKeyRelatedField and our
ModelSerializer.
"""
field_name_suffix = '_id'
def bind(self, field_name, parent):
"""
Called when the field is bound to the serializer.
See IdPrimaryKeyRelatedField for more informations.
"""
self.source = field_name[:-len(self.field_name_suffix)]
super().bind(field_name, parent)
class IdPrimaryKeyRelatedField(PrimaryKeyRelatedField):
"""
Field, that renames the field name to FIELD_NAME_id.
Only works together the our ModelSerializer.
"""
field_name_suffix = '_id'
def bind(self, field_name, parent):
"""
Called when the field is bound to the serializer.
Changes the source so that the original field name is used (removes
the _id suffix).
"""
if field_name:
# field_name is an empty string when the field is created with the
# attribute many=True. In this case the suffix is added with the
# IdManyRelatedField class.
self.source = field_name[:-len(self.field_name_suffix)]
super().bind(field_name, parent)
@classmethod
def many_init(cls, *args, **kwargs):
"""
Method from rest_framework.relations.RelatedField That uses our
IdManyRelatedField class instead of
rest_framework.relations.ManyRelatedField class.
"""
list_kwargs = {'child_relation': cls(*args, **kwargs)}
for key in kwargs.keys():
if key in MANY_RELATION_KWARGS:
list_kwargs[key] = kwargs[key]
return IdManyRelatedField(**list_kwargs)
class ModelSerializer(_ModelSerializer):
"""
ModelSerializer that changes the field names of related fields to
FIELD_NAME_id.
"""
serializer_related_field = IdPrimaryKeyRelatedField
def get_fields(self):
"""
Returns all fields of the serializer.
"""
fields = OrderedDict()
for field_name, field in super().get_fields().items():
try:
field_name += field.field_name_suffix
except AttributeError:
pass
fields[field_name] = field
return fields
class PermissionMixin:
"""
Mixin for subclasses of APIView like GenericViewSet and ModelViewSet.

View File

@ -52,7 +52,7 @@ class CreateMotion(TestCase):
reverse('motion-list'),
{'title': 'test_title_Air0bahchaiph1ietoo2',
'text': 'test_text_chaeF9wosh8OowazaiVu',
'category': category.pk})
'category_id': category.pk})
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
motion = Motion.objects.get()
self.assertEqual(motion.category, category)
@ -69,7 +69,7 @@ class CreateMotion(TestCase):
reverse('motion-list'),
{'title': 'test_title_pha7moPh7quoth4paina',
'text': 'test_text_YooGhae6tiangung5Rie',
'submitters': [submitter_1.pk, submitter_2.pk]})
'submitters_id': [submitter_1.pk, submitter_2.pk]})
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
motion = Motion.objects.get()
self.assertEqual(motion.submitters.count(), 2)
@ -82,7 +82,7 @@ class CreateMotion(TestCase):
reverse('motion-list'),
{'title': 'test_title_Oecee4Da2Mu9EY6Ui4mu',
'text': 'test_text_FbhgnTFgkbjdmvcjbffg',
'supporters': [supporter.pk]})
'supporters_id': [supporter.pk]})
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
motion = Motion.objects.get()
self.assertEqual(motion.supporters.get().username, 'test_username_ahGhi4Quohyee7ohngie')
@ -93,7 +93,7 @@ class CreateMotion(TestCase):
reverse('motion-list'),
{'title': 'test_title_Hahke4loos4eiduNiid9',
'text': 'test_text_johcho0Ucaibiehieghe',
'tags': [tag.pk]})
'tags_id': [tag.pk]})
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
motion = Motion.objects.get()
self.assertEqual(motion.tags.get().name, 'test_tag_iRee3kiecoos4rorohth')
@ -147,7 +147,7 @@ class UpdateMotion(TestCase):
password='test_password_XaeTe3aesh8ohg6Cohwo')
response = self.client.patch(
reverse('motion-detail', args=[self.motion.pk]),
{'supporters': [supporter.pk]})
{'supporters_id': [supporter.pk]})
self.assertEqual(response.status_code, status.HTTP_200_OK)
motion = Motion.objects.get()
self.assertEqual(motion.title, 'test_title_aeng7ahChie3waiR8xoh')