diff --git a/openslides/utils/rest_api.py b/openslides/utils/rest_api.py index 19dca25ac..0092ba493 100644 --- a/openslides/utils/rest_api.py +++ b/openslides/utils/rest_api.py @@ -1,4 +1,5 @@ import re +from collections import OrderedDict from urllib.parse import urlparse from rest_framework.decorators import detail_route, list_route # noqa @@ -6,14 +7,16 @@ from rest_framework.metadata import SimpleMetadata # noqa from rest_framework.mixins import DestroyModelMixin, UpdateModelMixin # noqa from rest_framework.response import Response # noqa from rest_framework.routers import DefaultRouter +from rest_framework.serializers import ModelSerializer as _ModelSerializer from rest_framework.serializers import ( # noqa + MANY_RELATION_KWARGS, CharField, DictField, Field, IntegerField, ListField, ListSerializer, - ModelSerializer, + ManyRelatedField, PrimaryKeyRelatedField, RelatedField, SerializerMethodField, @@ -30,6 +33,83 @@ from .exceptions import OpenSlidesError router = DefaultRouter() +class IdManyRelatedField(ManyRelatedField): + """ + ManyRelatedField that appends an suffix to the sub-fields. + + Only works together with the IdPrimaryKeyRelatedField and our + ModelSerializer. + """ + field_name_suffix = '_id' + + def bind(self, field_name, parent): + """ + Called when the field is bound to the serializer. + + See IdPrimaryKeyRelatedField for more informations. + """ + self.source = field_name[:-len(self.field_name_suffix)] + super().bind(field_name, parent) + + +class IdPrimaryKeyRelatedField(PrimaryKeyRelatedField): + """ + Field, that renames the field name to FIELD_NAME_id. + + Only works together the our ModelSerializer. + """ + field_name_suffix = '_id' + + def bind(self, field_name, parent): + """ + Called when the field is bound to the serializer. + + Changes the source so that the original field name is used (removes + the _id suffix). + """ + if field_name: + # field_name is an empty string when the field is created with the + # attribute many=True. In this case the suffix is added with the + # IdManyRelatedField class. + self.source = field_name[:-len(self.field_name_suffix)] + super().bind(field_name, parent) + + @classmethod + def many_init(cls, *args, **kwargs): + """ + Method from rest_framework.relations.RelatedField That uses our + IdManyRelatedField class instead of + rest_framework.relations.ManyRelatedField class. + """ + list_kwargs = {'child_relation': cls(*args, **kwargs)} + for key in kwargs.keys(): + if key in MANY_RELATION_KWARGS: + list_kwargs[key] = kwargs[key] + return IdManyRelatedField(**list_kwargs) + + +class ModelSerializer(_ModelSerializer): + """ + ModelSerializer that changes the field names of related fields to + FIELD_NAME_id. + """ + serializer_related_field = IdPrimaryKeyRelatedField + + def get_fields(self): + """ + Returns all fields of the serializer. + """ + fields = OrderedDict() + + for field_name, field in super().get_fields().items(): + try: + field_name += field.field_name_suffix + except AttributeError: + pass + fields[field_name] = field + return fields + + class PermissionMixin: """ Mixin for subclasses of APIView like GenericViewSet and ModelViewSet. diff --git a/tests/integration/motions/test_viewset.py b/tests/integration/motions/test_viewset.py index 1add5159b..9a9285cf3 100644 --- a/tests/integration/motions/test_viewset.py +++ b/tests/integration/motions/test_viewset.py @@ -52,7 +52,7 @@ class CreateMotion(TestCase): reverse('motion-list'), {'title': 'test_title_Air0bahchaiph1ietoo2', 'text': 'test_text_chaeF9wosh8OowazaiVu', - 'category': category.pk}) + 'category_id': category.pk}) self.assertEqual(response.status_code, status.HTTP_201_CREATED) motion = Motion.objects.get() self.assertEqual(motion.category, category) @@ -69,7 +69,7 @@ class CreateMotion(TestCase): reverse('motion-list'), {'title': 'test_title_pha7moPh7quoth4paina', 'text': 'test_text_YooGhae6tiangung5Rie', - 'submitters': [submitter_1.pk, submitter_2.pk]}) + 'submitters_id': [submitter_1.pk, submitter_2.pk]}) self.assertEqual(response.status_code, status.HTTP_201_CREATED) motion = Motion.objects.get() self.assertEqual(motion.submitters.count(), 2) @@ -82,7 +82,7 @@ class CreateMotion(TestCase): reverse('motion-list'), {'title': 'test_title_Oecee4Da2Mu9EY6Ui4mu', 'text': 'test_text_FbhgnTFgkbjdmvcjbffg', - 'supporters': [supporter.pk]}) + 'supporters_id': [supporter.pk]}) self.assertEqual(response.status_code, status.HTTP_201_CREATED) motion = Motion.objects.get() self.assertEqual(motion.supporters.get().username, 'test_username_ahGhi4Quohyee7ohngie') @@ -93,7 +93,7 @@ class CreateMotion(TestCase): reverse('motion-list'), {'title': 'test_title_Hahke4loos4eiduNiid9', 'text': 'test_text_johcho0Ucaibiehieghe', - 'tags': [tag.pk]}) + 'tags_id': [tag.pk]}) self.assertEqual(response.status_code, status.HTTP_201_CREATED) motion = Motion.objects.get() self.assertEqual(motion.tags.get().name, 'test_tag_iRee3kiecoos4rorohth') @@ -147,7 +147,7 @@ class UpdateMotion(TestCase): password='test_password_XaeTe3aesh8ohg6Cohwo') response = self.client.patch( reverse('motion-detail', args=[self.motion.pk]), - {'supporters': [supporter.pk]}) + {'supporters_id': [supporter.pk]}) self.assertEqual(response.status_code, status.HTTP_200_OK) motion = Motion.objects.get() self.assertEqual(motion.title, 'test_title_aeng7ahChie3waiR8xoh')