Lock poll to prevent race conditions

Add migrations
This commit is contained in:
Joshua Sangmeister 2021-04-12 13:57:24 +02:00 committed by Emanuel Schütze
parent 79d9781a1b
commit ee31c1e633
4 changed files with 97 additions and 9 deletions

View File

@ -0,0 +1,22 @@
# Generated by jsangmeister on 2021-04-15 08:01
from django.db import migrations
from ...poll.migrations.poll_migration_helper import fix_wrongly_calculated_vote_fields
class Migration(migrations.Migration):
dependencies = [
("assignments", "0024_assignmentpoll_entitled_users_remove_duplicates"),
]
operations = [
migrations.RunPython(
fix_wrongly_calculated_vote_fields(
"assignments",
"AssignmentPoll",
lambda poll: not poll.is_pseudoanonymized,
)
),
]

View File

@ -0,0 +1,18 @@
# Generated by jsangmeister on 2021-04-12 13:27
from django.db import migrations
from ...poll.migrations.poll_migration_helper import fix_wrongly_calculated_vote_fields
class Migration(migrations.Migration):
dependencies = [
("motions", "0043_motionpoll_entitled_users_remove_duplicates"),
]
operations = [
migrations.RunPython(
fix_wrongly_calculated_vote_fields("motions", "MotionPoll")
),
]

View File

@ -56,3 +56,43 @@ def remove_entitled_users_duplicates(poll_model_collection, poll_model_name):
poll.save(skip_autoupdate=True) poll.save(skip_autoupdate=True)
return _remove_entitled_users_duplicates return _remove_entitled_users_duplicates
def fix_wrongly_calculated_vote_fields(
poll_model_collection, poll_model_name, filter=None
):
"""
Takes all polls of the given model and corrects votes* fields if:
- vote weight is disabled (checked in config and by asserting that votesvalid==votescast)
- poll type and state must be correct
- calculated value must be bigger than db value (should be the case anyway, but if
it's not, we don't want to break even more things by changing it)
"""
def _fix_wrongly_calculated_vote_fields(apps, schema_editor):
ConfigStore = apps.get_model("core", "ConfigStore")
try:
config = ConfigStore.objects.get(key="users_activate_vote_weight")
value = config.value
except (ConfigStore.DoesNotExist, KeyError):
value = False
if not value:
PollModel = apps.get_model(poll_model_collection, poll_model_name)
for poll in PollModel.objects.all():
if (
poll.type != BasePoll.TYPE_ANALOG
and (not filter or filter(poll))
and poll.state
in (BasePoll.STATE_FINISHED, BasePoll.STATE_PUBLISHED)
and poll.votesvalid == poll.votescast
):
all_vote_tokens = set(
vote.user_token
for option in poll.options.all()
for vote in option.votes.all()
)
if len(all_vote_tokens) > poll.votesvalid:
poll.votesvalid = poll.votescast = len(all_vote_tokens)
poll.save(skip_autoupdate=True)
return _fix_wrongly_calculated_vote_fields

View File

@ -38,6 +38,14 @@ class BasePollViewSet(ModelViewSet):
else: else:
return self.has_manage_permissions() return self.has_manage_permissions()
def get_locked_object(self):
"""
Enhance get_object to make sure to lock the underlying object to prevent
race conditions.
"""
poll = self.get_object()
return self.queryset.select_for_update().get(pk=poll.pk)
@transaction.atomic @transaction.atomic
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data) serializer = self.get_serializer(data=request.data)
@ -66,7 +74,7 @@ class BasePollViewSet(ModelViewSet):
""" """
Customized view endpoint to update a poll. Customized view endpoint to update a poll.
""" """
poll = self.get_object() poll = self.get_locked_object()
partial = kwargs.get("partial", False) partial = kwargs.get("partial", False)
serializer = self.get_serializer(poll, data=request.data, partial=partial) serializer = self.get_serializer(poll, data=request.data, partial=partial)
@ -122,7 +130,7 @@ class BasePollViewSet(ModelViewSet):
@action(detail=True, methods=["POST"]) @action(detail=True, methods=["POST"])
@transaction.atomic @transaction.atomic
def start(self, request, pk): def start(self, request, pk):
poll = self.get_object() poll = self.get_locked_object()
if poll.state != BasePoll.STATE_CREATED: if poll.state != BasePoll.STATE_CREATED:
raise ValidationError({"detail": "Wrong poll state"}) raise ValidationError({"detail": "Wrong poll state"})
poll.state = BasePoll.STATE_STARTED poll.state = BasePoll.STATE_STARTED
@ -135,8 +143,8 @@ class BasePollViewSet(ModelViewSet):
@action(detail=True, methods=["POST"]) @action(detail=True, methods=["POST"])
@transaction.atomic @transaction.atomic
def stop(self, request, pk): def stop(self, request, pk):
poll = self.get_object() poll = self.get_locked_object()
# Analog polls could not be stopped; they are stopped when # Analog polls cannot be stopped; they are stopped when
# the results are entered. # the results are entered.
if poll.type == BasePoll.TYPE_ANALOG: if poll.type == BasePoll.TYPE_ANALOG:
raise ValidationError( raise ValidationError(
@ -155,7 +163,7 @@ class BasePollViewSet(ModelViewSet):
@action(detail=True, methods=["POST"]) @action(detail=True, methods=["POST"])
@transaction.atomic @transaction.atomic
def publish(self, request, pk): def publish(self, request, pk):
poll = self.get_object() poll = self.get_locked_object()
if poll.state != BasePoll.STATE_FINISHED: if poll.state != BasePoll.STATE_FINISHED:
raise ValidationError({"detail": "Wrong poll state"}) raise ValidationError({"detail": "Wrong poll state"})
@ -175,7 +183,7 @@ class BasePollViewSet(ModelViewSet):
@action(detail=True, methods=["POST"]) @action(detail=True, methods=["POST"])
@transaction.atomic @transaction.atomic
def pseudoanonymize(self, request, pk): def pseudoanonymize(self, request, pk):
poll = self.get_object() poll = self.get_locked_object()
if poll.state not in (BasePoll.STATE_FINISHED, BasePoll.STATE_PUBLISHED): if poll.state not in (BasePoll.STATE_FINISHED, BasePoll.STATE_PUBLISHED):
raise ValidationError( raise ValidationError(
@ -191,7 +199,7 @@ class BasePollViewSet(ModelViewSet):
@action(detail=True, methods=["POST"]) @action(detail=True, methods=["POST"])
@transaction.atomic @transaction.atomic
def reset(self, request, pk): def reset(self, request, pk):
poll = self.get_object() poll = self.get_locked_object()
poll.reset() poll.reset()
self.extend_history_information(["Voting reset"]) self.extend_history_information(["Voting reset"])
return Response() return Response()
@ -202,7 +210,7 @@ class BasePollViewSet(ModelViewSet):
""" """
For motion polls: Just "Y", "N" or "A" (if pollmethod is "YNA") For motion polls: Just "Y", "N" or "A" (if pollmethod is "YNA")
""" """
poll = self.get_object() poll = self.get_locked_object()
# Disable history for these requests # Disable history for these requests
disable_history() disable_history()
@ -257,7 +265,7 @@ class BasePollViewSet(ModelViewSet):
@action(detail=True, methods=["POST"]) @action(detail=True, methods=["POST"])
@transaction.atomic @transaction.atomic
def refresh(self, request, pk): def refresh(self, request, pk):
poll = self.get_object() poll = self.get_locked_object()
inform_changed_data(poll) inform_changed_data(poll)
inform_changed_data(poll.get_options()) inform_changed_data(poll.get_options())
inform_changed_data(poll.get_votes()) inform_changed_data(poll.get_votes())