Merge pull request #3380 from ostcar/more_typings

More typings
This commit is contained in:
Oskar Hahn 2017-09-03 19:00:34 +02:00 committed by GitHub
commit f1d7f85be9
39 changed files with 387 additions and 324 deletions

View File

@ -3,6 +3,7 @@
import os
import subprocess
import sys
from typing import Dict # noqa
import django
from django.core.management import call_command, execute_from_command_line
@ -88,7 +89,7 @@ def get_parser():
dest='subcommand',
title='Available subcommands',
description="Type '%s <subcommand> --help' for help on a "
"specific subcommand." % parser.prog,
"specific subcommand." % parser.prog, # type: ignore
help='You can choose only one subcommand at once.',
metavar='')
@ -155,8 +156,8 @@ def get_parser():
('runserver', 'Starts the Tornado webserver.'),
)
for django_subcommand, help_text in django_subcommands:
subparsers._choices_actions.append(
subparsers._ChoicesPseudoAction(
subparsers._choices_actions.append( # type: ignore
subparsers._ChoicesPseudoAction( # type: ignore
django_subcommand,
(),
help_text))
@ -248,7 +249,7 @@ def createsettings(args):
"""
settings_path = args.settings_path
local_installation = is_local_installation()
context = {}
context = {} # type: Dict[str, str]
if local_installation:
if settings_path is None:

View File

@ -1,4 +1,9 @@
from ..utils.access_permissions import BaseAccessPermissions
from typing import Iterable # noqa
from ..utils.access_permissions import ( # noqa
BaseAccessPermissions,
RestrictedData,
)
from ..utils.auth import has_perm
from ..utils.collection import Collection
@ -55,17 +60,17 @@ class ItemAccessPermissions(BaseAccessPermissions):
# In hidden case managers and non managers see only some fields
# so that list of speakers is provided regardless.
blocked_keys_hidden_case = full_data[0].keys() - (
blocked_keys_hidden_case = set(full_data[0].keys()) - set((
'id',
'title',
'speakers',
'speaker_list_closed',
'content_object')
'content_object'))
# In non hidden case managers see everything and non managers see
# everything but comments.
if has_perm(user, 'agenda.can_manage'):
blocked_keys_non_hidden_case = []
blocked_keys_non_hidden_case = [] # type: Iterable[str]
else:
blocked_keys_non_hidden_case = ('comment',)
@ -81,7 +86,7 @@ class ItemAccessPermissions(BaseAccessPermissions):
# Reduce result to a single item or None if it was not a collection at
# the beginning of the method.
if isinstance(container, Collection):
restricted_data = data
restricted_data = data # type: RestrictedData
elif data:
restricted_data = data[0]
else:
@ -111,7 +116,7 @@ class ItemAccessPermissions(BaseAccessPermissions):
# Reduce result to a single item or None if it was not a collection at
# the beginning of the method.
if isinstance(container, Collection):
projector_data = data
projector_data = data # type: RestrictedData
elif data:
projector_data = data[0]
else:

View File

@ -1,4 +1,5 @@
from collections import defaultdict
from typing import Dict, List, Set # noqa
from django.conf import settings
from django.contrib.auth.models import AnonymousUser
@ -79,7 +80,7 @@ class ItemManager(models.Manager):
HIDDEN_ITEM and all of their children.
"""
queryset = self.order_by('weight')
item_children = defaultdict(list)
item_children = defaultdict(list) # type: Dict[int, List[Item]]
root_items = []
for item in queryset:
if only_agenda_items and item.type == item.HIDDEN_ITEM:
@ -135,7 +136,7 @@ class ItemManager(models.Manager):
yield (element['id'], parent, weight)
yield from walk_items(element.get('children', []), element['id'])
touched_items = set()
touched_items = set() # type: Set[int]
db_items = dict((item.pk, item) for item in Item.objects.all())
for item_id, parent_id, weight in walk_items(tree):
# Check that the item is only once in the tree to prevent invalid trees
@ -293,7 +294,7 @@ class Item(RESTModelMixin, models.Model):
skip_autoupdate=skip_autoupdate,
name='agenda/list-of-speakers',
id=self.pk)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs) # type: ignore
@property
def title(self):

View File

@ -1,3 +1,5 @@
from typing import Set # noqa
from django.apps import apps
from django.contrib.contenttypes.models import ContentType
@ -62,7 +64,7 @@ def required_users(sender, request_user, **kwargs):
if request_user can see the agenda. This function may return an empty
set.
"""
speakers = set()
speakers = set() # type: Set[int]
if has_perm(request_user, 'agenda.can_see'):
for item_collection_element in Collection(Item.get_collection_string()).element_generator():
full_data = item_collection_element.get_full_data()

View File

@ -1,4 +1,7 @@
from ..utils.access_permissions import BaseAccessPermissions
from ..utils.access_permissions import ( # noqa
BaseAccessPermissions,
RestrictedData,
)
from ..utils.auth import has_perm
from ..utils.collection import Collection
@ -50,7 +53,7 @@ class AssignmentAccessPermissions(BaseAccessPermissions):
# Reduce result to a single item or None if it was not a collection at
# the beginning of the method.
if isinstance(container, Collection):
restricted_data = data
restricted_data = data # type: RestrictedData
elif data:
restricted_data = data[0]
else:
@ -76,7 +79,7 @@ class AssignmentAccessPermissions(BaseAccessPermissions):
# Reduce result to a single item or None if it was not a collection at
# the beginning of the method.
if isinstance(container, Collection):
projector_data = data
projector_data = data # type: RestrictedData
elif data:
projector_data = data[0]
else:

View File

@ -1,4 +1,7 @@
from typing import Dict, List, Union # noqa
from django.apps import AppConfig
from mypy_extensions import TypedDict
from ..utils.collection import Collection
@ -46,9 +49,11 @@ class AssignmentsAppConfig(AppConfig):
def get_angular_constants(self):
assignment = self.get_model('Assignment')
InnerItem = TypedDict('InnerItem', {'value': int, 'display_name': str})
Item = TypedDict('Item', {'name': str, 'value': List[InnerItem]}) # noqa
data = {
'name': 'AssignmentPhases',
'value': []}
'value': []} # type: Item
for phase in assignment.PHASES:
data['value'].append({
'value': phase[0],

View File

@ -1,4 +1,5 @@
from collections import OrderedDict
from typing import Any, Dict, List, Optional # noqa
from django.conf import settings
from django.contrib.contenttypes.fields import GenericRelation
@ -174,7 +175,7 @@ class Assignment(RESTModelMixin, models.Model):
skip_autoupdate=skip_autoupdate,
name='assignments/assignment',
id=self.pk)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs) # type: ignore # TODO fix typing
@property
def candidates(self):
@ -300,14 +301,14 @@ class Assignment(RESTModelMixin, models.Model):
Returns a table represented as a list with all candidates from all
related polls and their vote results.
"""
vote_results_dict = OrderedDict()
vote_results_dict = OrderedDict() # type: Dict[Any, List[AssignmentVote]]
polls = self.polls.all()
if only_published:
polls = polls.filter(published=True)
# All PollOption-Objects related to this assignment
options = []
options = [] # type: List[AssignmentOption]
for poll in polls:
options += poll.get_options()
@ -317,7 +318,7 @@ class Assignment(RESTModelMixin, models.Model):
continue
vote_results_dict[candidate] = []
for poll in polls:
votes = {}
votes = {} # type: Any
try:
# candidate related to this poll
poll_option = poll.get_options().get(candidate=candidate)
@ -429,7 +430,7 @@ class AssignmentPoll(RESTModelMixin, CollectDefaultVotesMixin, # type: ignore
name='assignments/assignment',
id=self.assignment.pk,
poll=self.pk)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs) # type: ignore # TODO: fix typing
def get_assignment(self):
return self.assignment

View File

@ -1,3 +1,5 @@
from typing import Any, Set # noqa
from django.apps import apps
from ..utils.auth import has_perm
@ -22,7 +24,7 @@ def required_users(sender, request_user, **kwargs):
options) in any assignment if request_user can see assignments. This
function may return an empty set.
"""
candidates = set()
candidates = set() # type: Set[Any] # TODO: Replace Any
if has_perm(request_user, 'assignments.can_see'):
for assignment_collection_element in Collection(Assignment.get_collection_string()).element_generator():
full_data = assignment_collection_element.get_full_data()

View File

@ -1,13 +1,4 @@
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
TypeVar,
Union,
)
from typing import Any, Callable, Dict, Iterable, Optional, TypeVar, Union
from django.core.exceptions import ValidationError as DjangoValidationError
from django.utils.translation import ugettext as _
@ -192,9 +183,9 @@ use x = config[...], to set it use config[...] = x.
T = TypeVar('T')
ChoiceType = Optional[List[Dict[str, str]]]
ChoiceType = Optional[Iterable[Dict[str, str]]]
ChoiceCallableType = Union[ChoiceType, Callable[[], ChoiceType]]
ValidatorsType = List[Callable[[T], None]]
ValidatorsType = Iterable[Callable[[T], None]]
OnChangeType = Callable[[], None]
ConfigVariableDict = TypedDict('ConfigVariableDict', {
'key': str,

View File

@ -51,7 +51,7 @@ class Command(BaseCommand):
response = urlopen(self.get_geiss_url()).read()
releases = json.loads(response.decode())
for release in releases:
version = distutils.version.StrictVersion(release['tag_name'])
version = distutils.version.StrictVersion(release['tag_name']) # type: ignore
if version < self.FIRST_NOT_SUPPORTED_VERSION:
break
else:

View File

@ -111,7 +111,7 @@ class Projector(RESTModelMixin, models.Model):
"""
# Get all elements from all apps.
elements = {}
for element in ProjectorElement.get_all():
for element in ProjectorElement.get_all(): # type: ignore
elements[element.name] = element
# Parse result
@ -138,7 +138,7 @@ class Projector(RESTModelMixin, models.Model):
"""
# Get all elements from all apps.
elements = {}
for element in ProjectorElement.get_all():
for element in ProjectorElement.get_all(): # type: ignore
elements[element.name] = element
# Generator
@ -169,7 +169,7 @@ class Projector(RESTModelMixin, models.Model):
elements = {}
# Build projector elements.
for element in ProjectorElement.get_all():
for element in ProjectorElement.get_all(): # type: ignore
elements[element.name] = element
# Iterate over all active projector elements.
@ -341,7 +341,7 @@ class ProjectorMessage(RESTModelMixin, models.Model):
skip_autoupdate=skip_autoupdate,
name='core/projector-message',
id=self.pk)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs) # type: ignore
class Countdown(RESTModelMixin, models.Model):
@ -370,7 +370,7 @@ class Countdown(RESTModelMixin, models.Model):
skip_autoupdate=skip_autoupdate,
name='core/countdown',
id=self.pk)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs) # type: ignore
def control(self, action):
if action not in ('start', 'stop', 'reset'):

View File

@ -3,6 +3,7 @@ import uuid
from collections import OrderedDict
from operator import attrgetter
from textwrap import dedent
from typing import Any, Dict, List # noqa
from django.apps import apps
from django.conf import settings
@ -11,6 +12,7 @@ from django.db.models import F
from django.http import Http404, HttpResponse
from django.utils.timezone import now
from django.utils.translation import ugettext as _
from mypy_extensions import TypedDict
from .. import __version__ as version
from ..utils import views as utils_views
@ -105,7 +107,7 @@ class WebclientJavaScriptView(utils_views.View):
"""
def get(self, *args, **kwargs):
angular_modules = []
js_files = []
js_files = [] # type: List[str]
realm = kwargs.get('realm') # Result is 'site' or 'projector'
for app_config in apps.get_app_configs():
# Add the angular app if the module has one.
@ -582,7 +584,7 @@ class ConfigMetadata(SimpleMetadata):
"""
def determine_metadata(self, request, view):
# Build tree.
config_groups = []
config_groups = [] # type: List[Any] # TODO: Replace Any by correct type
for config_variable in sorted(config.config_variables.values(), key=attrgetter('weight')):
if config_variable.is_hidden():
# Skip hidden config variables. Do not even check groups and subgroups.
@ -787,7 +789,8 @@ class VersionView(utils_views.APIView):
http_method_names = ['get']
def get_context_data(self, **context):
result = dict(openslides_version=version, plugins=[])
Result = TypedDict('Result', {'openslides_version': str, 'plugins': List[Dict[str, str]]}) # noqa
result = dict(openslides_version=version, plugins=[]) # type: Result
# Versions of plugins.
for plugin in settings.INSTALLED_PLUGINS:
result['plugins'].append({

View File

@ -1,4 +1,7 @@
from ..utils.access_permissions import BaseAccessPermissions
from ..utils.access_permissions import ( # noqa
BaseAccessPermissions,
RestrictedData,
)
from ..utils.auth import has_perm
from ..utils.collection import Collection
@ -41,7 +44,7 @@ class MediafileAccessPermissions(BaseAccessPermissions):
# Reduce result to a single item or None if it was not a collection at
# the beginning of the method.
if isinstance(container, Collection):
restricted_data = data
restricted_data = data # type: RestrictedData
elif data:
restricted_data = data[0]
else:

View File

@ -73,7 +73,7 @@ class Mediafile(RESTModelMixin, models.Model):
skip_autoupdate=skip_autoupdate,
name='mediafiles/mediafile',
id=self.pk)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs) # type: ignore
def get_filesize(self):
"""

View File

@ -1,7 +1,10 @@
from copy import deepcopy
from ..core.config import config
from ..utils.access_permissions import BaseAccessPermissions
from ..utils.access_permissions import ( # noqa
BaseAccessPermissions,
RestrictedData,
)
from ..utils.auth import has_perm
from ..utils.collection import Collection, CollectionElement
@ -78,7 +81,7 @@ class MotionAccessPermissions(BaseAccessPermissions):
# Reduce result to a single item or None if it was not a collection at
# the beginning of the method.
if isinstance(container, Collection):
restricted_data = data
restricted_data = data # type: RestrictedData
elif data:
restricted_data = data[0]
else:
@ -114,7 +117,7 @@ class MotionAccessPermissions(BaseAccessPermissions):
# Reduce result to a single item or None if it was not a collection at
# the beginning of the method.
if isinstance(container, Collection):
projector_data = data
projector_data = data # type: RestrictedData
elif data:
projector_data = data[0]
else:

View File

@ -231,7 +231,7 @@ class Motion(RESTModelMixin, models.Model):
try:
# Always skip autoupdate. Maybe we run it later in this method.
with transaction.atomic():
super(Motion, self).save(skip_autoupdate=True, *args, **kwargs)
super(Motion, self).save(skip_autoupdate=True, *args, **kwargs) # type: ignore
except IntegrityError:
# Identifier is already used.
if hasattr(self, '_identifier_prefix'):
@ -309,7 +309,7 @@ class Motion(RESTModelMixin, models.Model):
skip_autoupdate=skip_autoupdate,
name='motions/motion',
id=self.pk)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs) # type: ignore
def version_data_changed(self, version):
"""
@ -879,7 +879,7 @@ class MotionBlock(RESTModelMixin, models.Model):
skip_autoupdate=skip_autoupdate,
name='motions/motion-block',
id=self.pk)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs) # type: ignore
@property
def agenda_item(self):

View File

@ -1,3 +1,5 @@
from typing import Dict # noqa
from django.db import transaction
from django.utils.translation import ugettext as _
@ -157,7 +159,7 @@ class MotionPollSerializer(ModelSerializer):
def __init__(self, *args, **kwargs):
# The following dictionary is just a cache for several votes.
self._votes_dicts = {}
self._votes_dicts = {} # type: Dict[int, Dict[int, int]]
return super().__init__(*args, **kwargs)
def get_yes(self, obj):

View File

@ -1,3 +1,5 @@
from typing import Set # noqa
from django.apps import apps
from django.utils.translation import ugettext_noop
@ -124,7 +126,7 @@ def required_users(sender, request_user, **kwargs):
any motion if request_user can see motions. This function may return an
empty set.
"""
submitters_supporters = set()
submitters_supporters = set() # type: Set[int]
if has_perm(request_user, 'motions.can_see'):
for motion_collection_element in Collection(Motion.get_collection_string()).element_generator():
full_data = motion_collection_element.get_full_data()

View File

@ -1,5 +1,6 @@
import base64
import re
from typing import Optional # noqa
from django.conf import settings
from django.contrib.staticfiles import finders
@ -92,7 +93,7 @@ class MotionViewSet(ModelViewSet):
try:
parent_motion = CollectionElement.from_values(
Motion.get_collection_string(),
request.data['parent_id'])
request.data['parent_id']) # type: Optional[CollectionElement]
except Motion.DoesNotExist:
raise ValidationError({'detail': _('The parent motion does not exist.')})
else:

View File

@ -52,7 +52,7 @@ class Topic(RESTModelMixin, models.Model):
skip_autoupdate=skip_autoupdate,
name='topics/topic',
id=self.pk)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs) # type: ignore
@property
def agenda_item(self):

View File

@ -1,7 +1,12 @@
from typing import Any, Dict, List # noqa
from django.contrib.auth.models import AnonymousUser
from ..core.signals import user_data_required
from ..utils.access_permissions import BaseAccessPermissions
from ..utils.access_permissions import ( # noqa
BaseAccessPermissions,
RestrictedData,
)
from ..utils.auth import anonymous_is_enabled, has_perm
from ..utils.collection import Collection
@ -94,7 +99,7 @@ class UserAccessPermissions(BaseAccessPermissions):
# Reduce result to a single item or None if it was not a collection at
# the beginning of the method.
if isinstance(container, Collection):
restricted_data = data
restricted_data = data # type: RestrictedData
elif data:
restricted_data = data[0]
else:
@ -127,7 +132,7 @@ class UserAccessPermissions(BaseAccessPermissions):
# Reduce result to a single item or None if it was not a collection at
# the beginning of the method.
if isinstance(container, Collection):
projector_data = data
projector_data = data # type: RestrictedData
elif data:
projector_data = data[0]
else:
@ -187,7 +192,7 @@ class PersonalNoteAccessPermissions(BaseAccessPermissions):
# Parse data.
if user is None:
data = []
data = [] # type: List[Dict[str, Any]]
else:
for full in full_data:
if full['user_id'] == user.id:
@ -199,7 +204,7 @@ class PersonalNoteAccessPermissions(BaseAccessPermissions):
# Reduce result to a single item or None if it was not a collection at
# the beginning of the method.
if isinstance(container, Collection):
restricted_data = data
restricted_data = data # type: RestrictedData
elif data:
restricted_data = data[0]
else:

View File

@ -220,7 +220,7 @@ class User(RESTModelMixin, PermissionsMixin, AbstractBaseUser):
skip_autoupdate=skip_autoupdate,
name='users/user',
id=self.pk)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs)
return super().delete(skip_autoupdate=skip_autoupdate, *args, **kwargs) # type: ignore
def has_perm(self, perm):
"""

View File

@ -1,46 +1,29 @@
from django.dispatch import Signal
from typing import Any, Dict, List, Optional, Union
from .collection import Collection
from .dispatch import SignalConnectMetaClass
from django.db.models import Model
from rest_framework.serializers import Serializer
from .collection import Collection, CollectionElement
Container = Union[CollectionElement, Collection]
RestrictedData = Union[List[Dict[str, Any]], Dict[str, Any], None]
class BaseAccessPermissions(object, metaclass=SignalConnectMetaClass):
class BaseAccessPermissions:
"""
Base access permissions container.
Every app which has autoupdate models has to create classes subclassing
from this base class for every autoupdate root model. Each subclass has
to have a globally unique name. The metaclass (SignalConnectMetaClass)
does the rest of the magic.
from this base class for every autoupdate root model.
"""
signal = Signal()
def __init__(self, **kwargs):
"""
Initializes the access permission instance. This is done when the
signal is sent.
Because of Django's signal API, we have to take wildcard keyword
arguments. But they are not used here.
"""
pass
@classmethod
def get_dispatch_uid(cls):
"""
Returns the classname as a unique string for each class. Returns None
for the base class so it will not be connected to the signal.
"""
if not cls.__name__ == 'BaseAccessPermissions':
return cls.__name__
def check_permissions(self, user):
def check_permissions(self, user: Optional[CollectionElement]) -> bool:
"""
Returns True if the user has read access to model instances.
"""
return False
def get_serializer_class(self, user=None):
def get_serializer_class(self, user: CollectionElement=None) -> Serializer:
"""
Returns different serializer classes according to users permissions.
@ -51,13 +34,13 @@ class BaseAccessPermissions(object, metaclass=SignalConnectMetaClass):
"You have to add the method 'get_serializer_class' to your "
"access permissions class.".format(self))
def get_full_data(self, instance):
def get_full_data(self, instance: Model) -> Dict[str, Any]:
"""
Returns all possible serialized data for the given instance.
"""
return self.get_serializer_class(user=None)(instance).data
def get_restricted_data(self, container, user):
def get_restricted_data(self, container: Container, user: Optional[CollectionElement]) -> RestrictedData:
"""
Returns the restricted serialized data for the instance prepared
for the user.
@ -82,7 +65,7 @@ class BaseAccessPermissions(object, metaclass=SignalConnectMetaClass):
data = None
return data
def get_projector_data(self, container):
def get_projector_data(self, container: Container) -> RestrictedData:
"""
Returns the serialized data for the projector. Returns None if the
user has no access to this specific data. Returns reduced data if

View File

@ -16,6 +16,7 @@ def has_perm(user: Optional[CollectionElement], perm: str) -> bool:
group_collection_string = 'users/group' # This is the hard coded collection string for openslides.users.models.Group
# Convert user to right type
# TODO: Remove this and make use, that user has always the right type
user = user_to_collection_user(user)
if user is None and not anonymous_is_enabled():
has_perm = False

View File

@ -2,12 +2,14 @@ import json
import time
import warnings
from collections import Iterable, defaultdict
from typing import Any, Dict, Iterable, List, cast # noqa
from channels import Channel, Group
from channels.asgi import get_channel_layer
from channels.auth import channel_session_user, channel_session_user_from_http
from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction
from django.db.models import Model
from ..core.config import config
from ..core.models import Projector
@ -16,7 +18,7 @@ from .cache import startup_cache, websocket_user_cache
from .collection import Collection, CollectionElement, CollectionElementList
def send_or_wait(send_func, *args, **kwargs):
def send_or_wait(send_func: Any, *args: Any, **kwargs: Any) -> None:
"""
Wrapper for channels' send() method.
@ -41,7 +43,7 @@ def send_or_wait(send_func, *args, **kwargs):
)
def format_for_autoupdate(collection_string, id, action, data=None):
def format_for_autoupdate(collection_string: str, id: int, action: str, data: Dict[str, Any]=None) -> Dict[str, Any]:
"""
Returns a dict that can be used for autoupdate.
"""
@ -64,7 +66,7 @@ def format_for_autoupdate(collection_string, id, action, data=None):
@channel_session_user_from_http
def ws_add_site(message):
def ws_add_site(message: Any) -> None:
"""
Adds the websocket connection to a group specific to the connecting user.
@ -92,6 +94,9 @@ def ws_add_site(message):
access_permissions = collection.get_access_permissions()
restricted_data = access_permissions.get_restricted_data(collection, user)
# At this point restricted_data has to be a list. So we have to tell it mypy
restricted_data = cast(List[Dict[str, Any]], restricted_data)
for data in restricted_data:
if data is None:
# We do not want to send 'deleted' objects on startup.
@ -100,7 +105,7 @@ def ws_add_site(message):
output.append(
format_for_autoupdate(
collection_string=collection.collection_string,
id=data['id'],
id=int(data['id']),
action='changed',
data=data))
@ -110,7 +115,7 @@ def ws_add_site(message):
@channel_session_user
def ws_disconnect_site(message):
def ws_disconnect_site(message: Any) -> None:
"""
This function is called, when a client on the site disconnects.
"""
@ -119,7 +124,7 @@ def ws_disconnect_site(message):
@channel_session_user
def ws_receive_site(message):
def ws_receive_site(message: Any) -> None:
"""
This function is called if a message from a client comes in. The message
should be a list. Every item is broadcasted to the given users (or all
@ -137,8 +142,8 @@ def ws_receive_site(message):
else:
if isinstance(incomming, list):
# Parse all items
receivers_users = defaultdict(list)
receivers_reply_channels = defaultdict(list)
receivers_users = defaultdict(list) # type: Dict[int, List[Any]]
receivers_reply_channels = defaultdict(list) # type: Dict[str, List[Any]]
items_for_all = []
for item in incomming:
if item.get('collection') == 'notify':
@ -184,12 +189,12 @@ def ws_receive_site(message):
@channel_session_user_from_http
def ws_add_projector(message, projector_id):
def ws_add_projector(message: Any, projector_id: int) -> None:
"""
Adds the websocket connection to a group specific to the projector with the given id.
Also sends all data that are shown on the projector.
"""
user = message.user.id
user = user_to_collection_user(message.user.id)
if not has_perm(user, 'core.can_see_projector'):
send_or_wait(message.reply_channel.send, {'text': 'No permissions to see this projector.'})
@ -230,14 +235,14 @@ def ws_add_projector(message, projector_id):
send_or_wait(message.reply_channel.send, {'text': json.dumps(output)})
def ws_disconnect_projector(message, projector_id):
def ws_disconnect_projector(message: Any, projector_id: int) -> None:
"""
This function is called, when a client on the projector disconnects.
"""
Group('projector-{}'.format(projector_id)).discard(message.reply_channel)
def send_data(message):
def send_data(message: Any) -> None:
"""
Informs all site users and projector clients about changed data.
"""
@ -285,7 +290,7 @@ def send_data(message):
{'text': json.dumps(output)})
def inform_changed_data(instances, information=None):
def inform_changed_data(instances: Iterable[Model], information: Dict[str, Any]=None) -> None:
"""
Informs the autoupdate system and the caching system about the creation or
update of an element.
@ -317,7 +322,8 @@ def inform_changed_data(instances, information=None):
transaction.on_commit(lambda: send_autoupdate(collection_elements))
def inform_deleted_data(*args, information=None):
# TODO: Change the input argument to tuples
def inform_deleted_data(*args: Any, information: Dict[str, Any]=None) -> None:
"""
Informs the autoupdate system and the caching system about the deletion of
elements.
@ -351,7 +357,8 @@ def inform_deleted_data(*args, information=None):
transaction.on_commit(lambda: send_autoupdate(collection_elements))
def inform_data_collection_element_list(collection_elements, information=None):
def inform_data_collection_element_list(collection_elements: CollectionElementList,
information: Dict[str, Any]=None) -> None:
"""
Informs the autoupdate system about some collection elements. This is
used just to send some data to all users.
@ -363,7 +370,7 @@ def inform_data_collection_element_list(collection_elements, information=None):
transaction.on_commit(lambda: send_autoupdate(collection_elements))
def send_autoupdate(collection_elements):
def send_autoupdate(collection_elements: CollectionElementList) -> None:
"""
Helper function, that sends collection_elements through a channel to the
autoupdate system.

View File

@ -1,10 +1,27 @@
from collections import defaultdict
from typing import ( # noqa
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
)
from channels import Group
from channels.sessions import session_for_reply_channel
from django.apps import apps
from django.core.cache import cache, caches
if TYPE_CHECKING:
# Dummy import Collection for mypy
from .collection import Collection # noqa
UserCacheDataType = Dict[int, Set[str]]
class BaseWebsocketUserCache:
"""
@ -15,36 +32,36 @@ class BaseWebsocketUserCache:
"""
cache_key = 'current_websocket_users'
def add(self, user_id, channel_name):
def add(self, user_id: int, channel_name: str) -> None:
"""
Adds a channel name to an user id.
"""
raise NotImplementedError()
def remove(self, user_id, channel_name):
def remove(self, user_id: int, channel_name: str) -> None:
"""
Removes one channel name from the cache.
"""
raise NotImplementedError()
def get_all(self):
def get_all(self) -> UserCacheDataType:
"""
Returns all data using a dict where the key is a user id and the value
is a set of channel_names.
"""
raise NotImplementedError()
def save_data(self, data):
def save_data(self, data: UserCacheDataType) -> None:
"""
Saves the full data set (like created with build_data) to the cache.
"""
raise NotImplementedError()
def build_data(self):
def build_data(self) -> UserCacheDataType:
"""
Creates all the data, saves it to the cache and returns it.
"""
websocket_user_ids = defaultdict(set)
websocket_user_ids = defaultdict(set) # type: UserCacheDataType
for channel_name in Group('site').channel_layer.group_channels('site'):
session = session_for_reply_channel(channel_name)
user_id = session.get('user_id', None)
@ -52,7 +69,7 @@ class BaseWebsocketUserCache:
self.save_data(websocket_user_ids)
return websocket_user_ids
def get_cache_key(self):
def get_cache_key(self) -> str:
"""
Returns the cache key.
"""
@ -67,7 +84,7 @@ class RedisWebsocketUserCache(BaseWebsocketUserCache):
for each user another set to save the channel names.
"""
def add(self, user_id, channel_name):
def add(self, user_id: int, channel_name: str) -> None:
"""
Adds a channel name to an user id.
"""
@ -77,35 +94,35 @@ class RedisWebsocketUserCache(BaseWebsocketUserCache):
pipe.sadd(self.get_user_cache_key(user_id), channel_name)
pipe.execute()
def remove(self, user_id, channel_name):
def remove(self, user_id: int, channel_name: str) -> None:
"""
Removes one channel name from the cache.
"""
redis = get_redis_connection()
redis.srem(self.get_user_cache_key(user_id), channel_name)
def get_all(self):
def get_all(self) -> UserCacheDataType:
"""
Returns all data using a dict where the key is a user id and the value
is a set of channel_names.
"""
redis = get_redis_connection()
user_ids = redis.smembers(self.get_cache_key())
user_ids = redis.smembers(self.get_cache_key()) # type: Optional[List[str]]
if user_ids is None:
websocket_user_ids = self.build_data()
else:
websocket_user_ids = dict()
for user_id in user_ids:
for redis_user_id in user_ids:
# Redis returns the id as string. So we have to convert it
user_id = int(user_id)
channel_names = redis.smembers(self.get_user_cache_key(user_id))
user_id = int(redis_user_id)
channel_names = redis.smembers(self.get_user_cache_key(user_id)) # type: Optional[List[str]]
if channel_names is not None:
# If channel name is empty, then we can assume, that the user
# has no active connection.
websocket_user_ids[user_id] = set(channel_names)
return websocket_user_ids
def save_data(self, data):
def save_data(self, data: UserCacheDataType) -> None:
"""
Saves the full data set (like created with the method build_data()) to
the cache.
@ -122,13 +139,13 @@ class RedisWebsocketUserCache(BaseWebsocketUserCache):
pipe.sadd(self.get_user_cache_key(user_id), *channel_names)
pipe.execute()
def get_cache_key(self):
def get_cache_key(self) -> str:
"""
Returns the cache key.
"""
return cache.make_key(self.cache_key)
def get_user_cache_key(self, user_id):
def get_user_cache_key(self, user_id: int) -> str:
"""
Returns a cache key to save the channel names for a specific user.
"""
@ -146,7 +163,7 @@ class DjangoCacheWebsocketUserCache(BaseWebsocketUserCache):
the value is a set of channel names.
"""
def add(self, user_id, channel_name):
def add(self, user_id: int, channel_name: str) -> None:
"""
Adds a channel name for a user using the django cache.
"""
@ -160,7 +177,7 @@ class DjangoCacheWebsocketUserCache(BaseWebsocketUserCache):
websocket_user_ids[user_id] = set([channel_name])
cache.set(self.get_cache_key(), websocket_user_ids)
def remove(self, user_id, channel_name):
def remove(self, user_id: int, channel_name: str) -> None:
"""
Removes one channel name from the django cache.
"""
@ -169,7 +186,7 @@ class DjangoCacheWebsocketUserCache(BaseWebsocketUserCache):
websocket_user_ids[user_id].discard(channel_name)
cache.set(self.get_cache_key(), websocket_user_ids)
def get_all(self):
def get_all(self) -> UserCacheDataType:
"""
Returns the data using the django cache.
"""
@ -178,7 +195,7 @@ class DjangoCacheWebsocketUserCache(BaseWebsocketUserCache):
return self.build_data()
return websocket_user_ids
def save_data(self, data):
def save_data(self, data: UserCacheDataType) -> None:
"""
Saves the data using the django cache.
"""
@ -191,18 +208,18 @@ class StartupCache:
"""
cache_key = "full_data_startup_cache"
def build(self):
def build(self) -> Dict[str, List[str]]:
"""
Generate the cache by going through all apps. Returns a dict where the
key is the collection string and the value a list of the full_data from
the collection elements.
"""
cache_data = {}
cache_data = {} # type: Dict[str, List[str]]
for app in apps.get_app_configs():
try:
# Get the method get_startup_elements() from an app.
# This method has to return an iterable of Collection objects.
get_startup_elements = app.get_startup_elements
get_startup_elements = app.get_startup_elements # type: Callable[[], Iterable[Collection]]
except AttributeError:
# Skip apps that do not implement get_startup_elements.
continue
@ -216,20 +233,20 @@ class StartupCache:
cache.set(self.cache_key, cache_data, 86400)
return cache_data
def clear(self):
def clear(self) -> None:
"""
Clears the cache.
"""
cache.delete(self.cache_key)
def get_collections(self):
def get_collections(self) -> Generator['Collection', None, None]:
"""
Generator that returns all cached Collections.
The data is read from the cache if it exists. It builds the cache if it
does not exists.
"""
from .collection import Collection
from .collection import Collection # noqa
data = cache.get(self.cache_key)
if data is None:
# The cache does not exist.
@ -241,7 +258,7 @@ class StartupCache:
startup_cache = StartupCache()
def use_redis_cache():
def use_redis_cache() -> bool:
"""
Returns True if Redis is used als caching backend.
"""
@ -252,7 +269,7 @@ def use_redis_cache():
return isinstance(caches['default'], RedisCache)
def get_redis_connection():
def get_redis_connection() -> Any:
"""
Returns an object that can be used to talk directly to redis.
"""

View File

@ -1,38 +1,33 @@
from typing import Mapping # noqa
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
List,
Optional,
Set,
Tuple,
Type,
Union,
)
from django.apps import apps
from django.core.cache import cache
from django.db.models import Model
from .cache import get_redis_connection, use_redis_cache
if TYPE_CHECKING:
from .access_permissions import BaseAccessPermissions # noqa
# TODO: Try to import this type from access_permission
RestrictedData = Union[List[Dict[str, Any]], Dict[str, Any], None]
class CollectionElement:
@classmethod
def from_instance(cls, instance, deleted=False, information=None):
"""
Returns a collection element from a database instance.
This will also update the instance in the cache.
If deleted is set to True, the element is deleted from the cache.
"""
return cls(instance=instance, deleted=deleted, information=information)
@classmethod
def from_values(cls, collection_string, id, deleted=False, full_data=None, information=None):
"""
Returns a collection element from a collection_string and an id.
If deleted is set to True, the element is deleted from the cache.
With the argument full_data, the content of the CollectionElement can be set.
It has to be a dict in the format that is used be access_permission.get_full_data().
"""
return cls(collection_string=collection_string, id=id, deleted=deleted,
full_data=full_data, information=information)
def __init__(self, instance=None, deleted=False, collection_string=None, id=None,
full_data=None, information=None):
def __init__(self, instance: Model=None, deleted: bool=False, collection_string: str=None,
id: int=None, full_data: Dict[str, Any]=None, information: Dict[str, Any]=None) -> None:
"""
Do not use this. Use the methods from_instance() or from_values().
"""
@ -47,7 +42,7 @@ class CollectionElement:
elif collection_string is not None and id is not None:
# Collection element is created via values
self.collection_string = collection_string
self.id = int(id)
self.id = id
else:
raise RuntimeError(
'Invalid state. Use CollectionElement.from_instance() or '
@ -65,7 +60,32 @@ class CollectionElement:
# neither exist in the cache nor in the database.
self.get_full_data()
def __eq__(self, collection_element):
@classmethod
def from_instance(cls, instance: Model, deleted: bool=False, information: Dict[str, Any]=None) -> 'CollectionElement':
"""
Returns a collection element from a database instance.
This will also update the instance in the cache.
If deleted is set to True, the element is deleted from the cache.
"""
return cls(instance=instance, deleted=deleted, information=information)
@classmethod
def from_values(cls, collection_string: str, id: int, deleted: bool=False,
full_data: Dict[str, Any]=None, information: Dict[str, Any]=None) -> 'CollectionElement':
"""
Returns a collection element from a collection_string and an id.
If deleted is set to True, the element is deleted from the cache.
With the argument full_data, the content of the CollectionElement can be set.
It has to be a dict in the format that is used be access_permission.get_full_data().
"""
return cls(collection_string=collection_string, id=id, deleted=deleted,
full_data=full_data, information=information)
def __eq__(self, collection_element: 'CollectionElement') -> bool: # type: ignore
"""
Compares two collection_elements.
@ -75,7 +95,7 @@ class CollectionElement:
return (self.collection_string == collection_element.collection_string and
self.id == collection_element.id)
def as_channels_message(self):
def as_channels_message(self) -> Dict[str, Any]:
"""
Returns a dictonary that can be used to send the object through the
channels system.
@ -92,7 +112,7 @@ class CollectionElement:
channel_message['full_data'] = self.full_data
return channel_message
def as_autoupdate(self, method, *args):
def as_autoupdate(self, method: str, *args: Any) -> Dict[str, Any]:
"""
Only for internal use. Do not use it directly. Use as_autoupdate_for_user()
or as_autoupdate_for_projector().
@ -112,19 +132,16 @@ class CollectionElement:
action='deleted' if self.is_deleted() else 'changed',
data=data)
def as_autoupdate_for_user(self, user):
def as_autoupdate_for_user(self, user: Optional['CollectionElement']) -> Dict[str, Any]:
"""
Returns a dict that can be sent through the autoupdate system for a site
user.
The argument `user` can be anything, that is allowd as argument for
utils.auth.has_perm().
"""
return self.as_autoupdate(
'get_restricted_data',
user)
def as_autoupdate_for_projector(self):
def as_autoupdate_for_projector(self) -> Dict[str, Any]:
"""
Returns a dict that can be sent through the autoupdate system for the
projector.
@ -132,22 +149,19 @@ class CollectionElement:
return self.as_autoupdate(
'get_projector_data')
def as_dict_for_user(self, user):
def as_dict_for_user(self, user: Optional['CollectionElement']) -> 'RestrictedData':
"""
Returns a dict with the data for a user. Can be used for the rest api.
The argument `user` can be anything, that is allowd as argument for
utils.auth.has_perm().
"""
return self.get_access_permissions().get_restricted_data(self, user)
def get_model(self):
def get_model(self) -> Type[Model]:
"""
Returns the django model that is used for this collection.
"""
return get_model_from_collection_string(self.collection_string)
def get_instance(self):
def get_instance(self) -> Model:
"""
Returns the instance as django object.
@ -165,13 +179,13 @@ class CollectionElement:
self.instance = query.get(pk=self.id)
return self.instance
def get_access_permissions(self):
def get_access_permissions(self) -> 'BaseAccessPermissions':
"""
Returns the get_access_permissions object for the this collection element.
"""
return self.get_model().get_access_permissions()
def get_full_data(self):
def get_full_data(self) -> Any:
"""
Returns the full_data of this collection_element from with all other
dics can be generated.
@ -194,21 +208,21 @@ class CollectionElement:
self.save_to_cache()
return self.full_data
def is_deleted(self):
def is_deleted(self) -> bool:
"""
Returns Ture if the item is marked as deleted.
"""
return self.deleted
def get_cache_key(self):
def get_cache_key(self) -> str:
"""
Returns a string that is used as cache key for a single instance.
"""
return get_single_element_cache_key(self.collection_string, self.id)
def delete_from_cache(self):
def delete_from_cache(self) -> None:
"""
Delets an element from the cache.
Delets the element from the cache.
Does nothing if the element is not in the cache.
"""
@ -218,7 +232,7 @@ class CollectionElement:
# Delete the id of the instance of the instance list
Collection(self.collection_string).delete_id_from_cache(self.id)
def save_to_cache(self):
def save_to_cache(self) -> None:
"""
Add or update the element to the cache.
"""
@ -238,7 +252,7 @@ class CollectionElementList(list):
"""
@classmethod
def from_channels_message(cls, message):
def from_channels_message(cls, message: Dict[str, Any]) -> 'CollectionElementList':
"""
Creates a collection element list from a channel message.
"""
@ -247,16 +261,16 @@ class CollectionElementList(list):
self.append(CollectionElement.from_values(**values))
return self
def as_channels_message(self):
def as_channels_message(self) -> Dict[str, Any]:
"""
Returns a list of dicts that can be send through the channel system.
"""
message = {'elements': []}
message = {'elements': []} # type: Dict[str, Any]
for element in self:
message['elements'].append(element.as_channels_message())
return message
def as_autoupdate_for_user(self, user):
def as_autoupdate_for_user(self, user: Optional[CollectionElement]) -> List[Dict[str, Any]]:
"""
Returns a list of dicts, that can be send though the websocket to a user.
@ -274,7 +288,7 @@ class Collection:
Represents all elements of one collection.
"""
def __init__(self, collection_string, full_data=None):
def __init__(self, collection_string: str, full_data: List[Dict[str, Any]]=None) -> None:
"""
Initiates a Collection. A collection_string has to be given. If
full_data (a list of dictionaries) is not given the method
@ -284,7 +298,7 @@ class Collection:
self.collection_string = collection_string
self.full_data = full_data
def get_cache_key(self, raw=False):
def get_cache_key(self, raw: bool=False) -> str:
"""
Returns a string that is used as cache key for a collection.
@ -296,19 +310,19 @@ class Collection:
key = cache.make_key(key)
return key
def get_model(self):
def get_model(self) -> Type[Model]:
"""
Returns the django model that is used for this collection.
"""
return get_model_from_collection_string(self.collection_string)
def get_access_permissions(self):
def get_access_permissions(self) -> 'BaseAccessPermissions':
"""
Returns the get_access_permissions object for the this collection.
"""
return self.get_model().get_access_permissions()
def element_generator(self):
def element_generator(self) -> Generator[CollectionElement, None, None]:
"""
Generator that yields all collection_elements of this collection.
"""
@ -329,8 +343,10 @@ class Collection:
# Generate collection elements that where in the cache.
for cache_key, cached_full_data in cached_full_data_dict.items():
collection_string, id = get_collection_id_from_cache_key(cache_key)
yield CollectionElement.from_values(
*get_collection_id_from_cache_key(cache_key),
collection_string,
id,
full_data=cached_full_data)
# Generate collection element that where not in the cache.
@ -343,7 +359,7 @@ class Collection:
for instance in query.filter(pk__in=missing_ids):
yield CollectionElement.from_instance(instance)
def get_full_data(self):
def get_full_data(self) -> List[Dict[str, Any]]:
"""
Returns a list of dictionaries with full_data of all collection
elements.
@ -355,7 +371,7 @@ class Collection:
in self.element_generator()]
return self.full_data
def as_autoupdate_for_projector(self):
def as_autoupdate_for_projector(self) -> List[Dict[str, Any]]:
"""
Returns a list of dictonaries to send them to the projector.
"""
@ -368,12 +384,9 @@ class Collection:
output.append(content)
return output
def as_autoupdate_for_user(self, user):
def as_autoupdate_for_user(self, user: Optional[CollectionElement]) -> List[Dict[str, Any]]:
"""
Returns a list of dicts, that can be send though the websocket to a user.
The argument `user` can be anything, that is allowd as argument for
utils.auth.has_perm().
"""
# TODO: This method is not used. Remove it.
output = []
@ -383,22 +396,19 @@ class Collection:
output.append(content)
return output
def as_list_for_user(self, user):
def as_list_for_user(self, user: Optional[CollectionElement]) -> List['RestrictedData']:
"""
Returns a list of dictonaries to send them to a user, for example over
the rest api.
The argument `user` can be anything, that is allowd as argument for
utils.auth.has_perm().
"""
output = []
output = [] # type: List[RestrictedData]
for collection_element in self.element_generator():
content = collection_element.as_dict_for_user(user)
content = collection_element.as_dict_for_user(user) # type: RestrictedData
if content is not None:
output.append(content)
return output
def get_all_ids(self):
def get_all_ids(self) -> Set[int]:
"""
Returns a set of all ids of instances in this collection.
"""
@ -408,7 +418,7 @@ class Collection:
ids = self.get_all_ids_other()
return ids
def get_all_ids_redis(self):
def get_all_ids_redis(self) -> Set[int]:
redis = get_redis_connection()
ids = redis.smembers(self.get_cache_key(raw=True))
if not ids:
@ -419,7 +429,7 @@ class Collection:
ids = set(int(id) for id in ids)
return ids
def get_all_ids_other(self):
def get_all_ids_other(self) -> Set[int]:
ids = cache.get(self.get_cache_key())
if ids is None:
# If it is not in the cache then get it from the database.
@ -427,7 +437,7 @@ class Collection:
cache.set(self.get_cache_key(), ids)
return ids
def delete_id_from_cache(self, id):
def delete_id_from_cache(self, id: int) -> None:
"""
Delets a id from the cache.
"""
@ -436,11 +446,11 @@ class Collection:
else:
self.delete_id_from_cache_other(id)
def delete_id_from_cache_redis(self, id):
def delete_id_from_cache_redis(self, id: int) -> None:
redis = get_redis_connection()
redis.srem(self.get_cache_key(raw=True), id)
def delete_id_from_cache_other(self, id):
def delete_id_from_cache_other(self, id: int) -> None:
ids = cache.get(self.get_cache_key())
if ids is not None:
ids = set(ids)
@ -456,7 +466,7 @@ class Collection:
# Delete the key, if there are not ids left
cache.delete(self.get_cache_key())
def add_id_to_cache(self, id):
def add_id_to_cache(self, id: int) -> None:
"""
Adds a collection id to the list of collection ids in the cache.
"""
@ -465,13 +475,13 @@ class Collection:
else:
self.add_id_to_cache_other(id)
def add_id_to_cache_redis(self, id):
def add_id_to_cache_redis(self, id: int) -> None:
redis = get_redis_connection()
if redis.exists(self.get_cache_key(raw=True)):
# Only add the value if it is in the cache.
redis.sadd(self.get_cache_key(raw=True), id)
def add_id_to_cache_other(self, id):
def add_id_to_cache_other(self, id: int) -> None:
ids = cache.get(self.get_cache_key())
if ids is not None:
# Only change the value if it is in the cache.
@ -480,14 +490,14 @@ class Collection:
cache.set(self.get_cache_key(), ids)
_models_to_collection_string = {} # type: Mapping[str, object]
_models_to_collection_string = {} # type: Dict[str, Type[Model]]
def get_model_from_collection_string(collection_string):
def get_model_from_collection_string(collection_string: str) -> Type[Model]:
"""
Returns a model class which belongs to the argument collection_string.
"""
def model_generator():
def model_generator() -> Generator[Type[Model], None, None]:
"""
Yields all models of all apps.
"""
@ -512,7 +522,7 @@ def get_model_from_collection_string(collection_string):
return model
def get_single_element_cache_key(collection_string, id):
def get_single_element_cache_key(collection_string: str, id: int) -> str:
"""
Returns a string that is used as cache key for a single instance.
"""
@ -521,7 +531,7 @@ def get_single_element_cache_key(collection_string, id):
id=id)
def get_single_element_cache_key_prefix(collection_string):
def get_single_element_cache_key_prefix(collection_string: str) -> str:
"""
Returns the first part of the cache key for single elements, which is the
same for all cache keys of the same collection.
@ -529,14 +539,14 @@ def get_single_element_cache_key_prefix(collection_string):
return "{collection}:".format(collection=collection_string)
def get_element_list_cache_key(collection_string):
def get_element_list_cache_key(collection_string: str) -> str:
"""
Returns a string that is used as cache key for a collection.
"""
return "{collection}".format(collection=collection_string)
def get_collection_id_from_cache_key(cache_key):
def get_collection_id_from_cache_key(cache_key: str) -> Tuple[str, int]:
"""
Returns a tuble of the collection string and the id from a cache_key
created with get_instance_cache_key.

View File

@ -49,7 +49,7 @@ class SignalConnectMetaClass(type):
default attributes and methods.
"""
class_attributes['get_all'] = get_all
new_class = super(SignalConnectMetaClass, metaclass).__new__(
new_class = super().__new__(
metaclass, class_name, class_parents, class_attributes)
try:
dispatch_uid = new_class.get_dispatch_uid()

View File

@ -6,10 +6,12 @@ import tempfile
import threading
import time
import webbrowser
from typing import Dict, Optional
from django.conf import ENVIRONMENT_VARIABLE
from django.core.exceptions import ImproperlyConfigured
from django.utils.crypto import get_random_string
from mypy_extensions import NoReturn
DEVELOPMENT_VERSION = 'Development Version'
UNIX_VERSION = 'Unix Version'
@ -34,11 +36,11 @@ class UnknownCommand(Exception):
class ExceptionArgumentParser(argparse.ArgumentParser):
def error(self, message):
def error(self, message: str) -> NoReturn:
raise UnknownCommand(message)
def detect_openslides_type():
def detect_openslides_type() -> str:
"""
Returns the type of this OpenSlides version.
"""
@ -58,7 +60,7 @@ def detect_openslides_type():
return openslides_type
def get_default_settings_path(openslides_type=None):
def get_default_settings_path(openslides_type: str=None) -> str:
"""
Returns the default settings path according to the OpenSlides type.
@ -80,7 +82,7 @@ def get_default_settings_path(openslides_type=None):
return os.path.join(parent_directory, 'openslides', 'settings.py')
def get_local_settings_path():
def get_local_settings_path() -> str:
"""
Returns the path to a local settings.
@ -89,7 +91,7 @@ def get_local_settings_path():
return os.path.join('personal_data', 'var', 'settings.py')
def setup_django_settings_module(settings_path=None, local_installation=None):
def setup_django_settings_module(settings_path: str =None, local_installation: bool=False) -> None:
"""
Sets the environment variable ENVIRONMENT_VARIABLE, that means
'DJANGO_SETTINGS_MODULE', to the given settings.
@ -100,7 +102,7 @@ def setup_django_settings_module(settings_path=None, local_installation=None):
If the argument settings_path is set, then the environment variable is
always overwritten.
"""
if settings_path is None and os.environ.get(ENVIRONMENT_VARIABLE, None):
if settings_path is None and os.environ.get(ENVIRONMENT_VARIABLE, ""):
return
if settings_path is None:
@ -128,7 +130,7 @@ def setup_django_settings_module(settings_path=None, local_installation=None):
os.environ[ENVIRONMENT_VARIABLE] = settings_module_name
def get_default_settings_context(user_data_path=None):
def get_default_settings_context(user_data_path: str=None) -> Dict[str, str]:
"""
Returns the default context values for the settings template:
'openslides_user_data_path', 'import_function' and 'debug'.
@ -154,7 +156,7 @@ def get_default_settings_context(user_data_path=None):
return default_context
def get_default_user_data_path(openslides_type):
def get_default_user_data_path(openslides_type: str) -> str:
"""
Returns the default path for user specific data according to the OpenSlides
type.
@ -174,7 +176,7 @@ def get_default_user_data_path(openslides_type):
return default_user_data_path
def get_win32_app_data_path():
def get_win32_app_data_path() -> str:
"""
Returns the path to Windows' AppData directory.
"""
@ -197,7 +199,7 @@ def get_win32_app_data_path():
return buf.value
def get_win32_portable_path():
def get_win32_portable_path() -> str:
"""
Returns the path to the Windows portable version.
"""
@ -217,14 +219,14 @@ def get_win32_portable_path():
return portable_path
def get_win32_portable_user_data_path():
def get_win32_portable_user_data_path() -> str:
"""
Returns the user data path to the Windows portable version.
"""
return os.path.join(get_win32_portable_path(), 'openslides')
def write_settings(settings_path=None, template=None, **context):
def write_settings(settings_path: str=None, template: str=None, **context: str) -> str:
"""
Creates the settings file at the given path using the given values for the
file template.
@ -259,7 +261,7 @@ def write_settings(settings_path=None, template=None, **context):
return os.path.realpath(settings_path)
def open_browser(host, port):
def open_browser(host: str, port: int) -> None:
"""
Launches the default web browser at the given host and port and opens
the webinterface. Uses start_browser internally.
@ -271,7 +273,7 @@ def open_browser(host, port):
start_browser('http://%s:%s' % (host, port))
def start_browser(browser_url):
def start_browser(browser_url: str) -> None:
"""
Launches the default web browser at the given url and opens the
webinterface.
@ -282,7 +284,7 @@ def start_browser(browser_url):
print('Could not locate runnable browser: Skipping start')
else:
def function():
def function() -> None:
# TODO: Use a nonblocking sleep event here. Tornado has such features.
time.sleep(1)
browser.open(browser_url)
@ -291,7 +293,7 @@ def start_browser(browser_url):
thread.start()
def get_database_path_from_settings():
def get_database_path_from_settings() -> Optional[str]:
"""
Retrieves the database path out of the settings file. Returns None,
if it is not a SQLite3 database.
@ -313,7 +315,7 @@ def get_database_path_from_settings():
return database_path
def is_local_installation():
def is_local_installation() -> bool:
"""
Returns True if the command is called for a local installation
@ -322,7 +324,7 @@ def is_local_installation():
return True if '--local-installation' in sys.argv or 'manage.py' in sys.argv[0] else False
def get_geiss_path():
def get_geiss_path() -> str:
"""
Returns the path and file to the Geiss binary.
"""
@ -332,7 +334,7 @@ def get_geiss_path():
return os.path.join(download_path, bin_name)
def is_windows():
def is_windows() -> bool:
"""
Returns True if the current system is Windows. Returns False otherwise.
"""

View File

@ -1,3 +1,6 @@
from typing import Any, Dict
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from .access_permissions import BaseAccessPermissions # noqa
@ -9,11 +12,11 @@ class MinMaxIntegerField(models.IntegerField):
IntegerField with options to set a min- and a max-value.
"""
def __init__(self, min_value=None, max_value=None, *args, **kwargs):
def __init__(self, min_value: int=None, max_value: int=None, *args: Any, **kwargs: Any) -> None:
self.min_value, self.max_value = min_value, max_value
super(MinMaxIntegerField, self).__init__(*args, **kwargs)
def formfield(self, **kwargs):
def formfield(self, **kwargs: Any) -> Any:
defaults = {'min_value': self.min_value, 'max_value': self.max_value}
defaults.update(kwargs)
return super(MinMaxIntegerField, self).formfield(**defaults)
@ -26,7 +29,7 @@ class RESTModelMixin:
access_permissions = None # type: BaseAccessPermissions
def get_root_rest_element(self):
def get_root_rest_element(self) -> models.Model:
"""
Returns the root rest instance.
@ -35,32 +38,36 @@ class RESTModelMixin:
return self
@classmethod
def get_access_permissions(cls):
def get_access_permissions(cls) -> BaseAccessPermissions:
"""
Returns a container to handle access permissions for this model and
its corresponding viewset.
"""
if cls.access_permissions is None:
raise ImproperlyConfigured("A RESTModel needs to have an access_permission.")
return cls.access_permissions
@classmethod
def get_collection_string(cls):
def get_collection_string(cls) -> str:
"""
Returns the string representing the name of the collection. Returns
None if this is not a so called root rest instance.
"""
# TODO Check if this is a root rest element class and return None if not.
app_label = cls._meta.app_label # type: ignore
object_name = cls._meta.object_name # type: ignore
return '/'.join(
(convert_camel_case_to_pseudo_snake_case(cls._meta.app_label),
convert_camel_case_to_pseudo_snake_case(cls._meta.object_name)))
(convert_camel_case_to_pseudo_snake_case(app_label),
convert_camel_case_to_pseudo_snake_case(object_name)))
def get_rest_pk(self):
def get_rest_pk(self) -> int:
"""
Returns the primary key used in the REST API. By default this is
the database pk.
"""
return self.pk
return self.pk # type: ignore
def save(self, skip_autoupdate=False, information=None, *args, **kwargs):
def save(self, skip_autoupdate: bool=False, information: Dict[str, str]=None, *args: Any, **kwargs: Any) -> Any:
"""
Calls Django's save() method and afterwards hits the autoupdate system.
@ -76,12 +83,12 @@ class RESTModelMixin:
"""
# We don't know how to fix this circular import
from .autoupdate import inform_changed_data
return_value = super().save(*args, **kwargs)
return_value = super().save(*args, **kwargs) # type: ignore
if not skip_autoupdate:
inform_changed_data(self.get_root_rest_element(), information=information)
return return_value
def delete(self, skip_autoupdate=False, information=None, *args, **kwargs):
def delete(self, skip_autoupdate: bool=False, information: Dict[str, str]=None, *args: Any, **kwargs: Any) -> Any:
"""
Calls Django's delete() method and afterwards hits the autoupdate system.
@ -101,8 +108,8 @@ class RESTModelMixin:
"""
# We don't know how to fix this circular import
from .autoupdate import inform_changed_data, inform_deleted_data
instance_pk = self.pk
return_value = super().delete(*args, **kwargs)
instance_pk = self.pk # type: ignore
return_value = super().delete(*args, **kwargs) # type: ignore
if not skip_autoupdate:
if self != self.get_root_rest_element():
# The deletion of a included element is a change of the root element.

View File

@ -1,6 +1,7 @@
import os
import pkgutil
import sys
from typing import Any, List, Tuple
from django.apps import apps
from django.conf import settings
@ -15,7 +16,7 @@ from openslides.utils.main import (
# Methods to collect plugins.
def collect_plugins_from_entry_points():
def collect_plugins_from_entry_points() -> Tuple[str, ...]:
"""
Collects all entry points in the group openslides_plugins from all
distributions in the default working set and returns their module names as
@ -24,7 +25,7 @@ def collect_plugins_from_entry_points():
return tuple(entry_point.module_name for entry_point in iter_entry_points('openslides_plugins'))
def collect_plugins_from_path(path):
def collect_plugins_from_path(path: str) -> Tuple[str, ...]:
"""
Collects all modules/packages in the given `path` and returns a tuple
of their names.
@ -32,7 +33,7 @@ def collect_plugins_from_path(path):
return tuple(x[1] for x in pkgutil.iter_modules([path]))
def collect_plugins():
def collect_plugins() -> Tuple[str, ...]:
"""
Collect all plugins that can be automatically discovered.
"""
@ -52,7 +53,7 @@ def collect_plugins():
# Methods to retrieve plugin metadata and urlpatterns.
def get_plugin_verbose_name(plugin):
def get_plugin_verbose_name(plugin: str) -> str:
"""
Returns the verbose name of a plugin. The plugin argument must be a python
dotted module path.
@ -60,7 +61,7 @@ def get_plugin_verbose_name(plugin):
return apps.get_app_config(plugin).verbose_name
def get_plugin_description(plugin):
def get_plugin_description(plugin: str) -> str:
"""
Returns the short descrption of a plugin. The plugin argument must be a
python dotted module path.
@ -76,7 +77,7 @@ def get_plugin_description(plugin):
return description
def get_plugin_version(plugin):
def get_plugin_version(plugin: str) -> str:
"""
Returns the version string of a plugin. The plugin argument must be a
python dotted module path.
@ -92,7 +93,7 @@ def get_plugin_version(plugin):
return version
def get_plugin_urlpatterns(plugin):
def get_plugin_urlpatterns(plugin: str) -> Any:
"""
Returns the urlpatterns object for a plugin. The plugin argument must be
a python dotted module path.
@ -108,12 +109,12 @@ def get_plugin_urlpatterns(plugin):
return urlpatterns
def get_all_plugin_urlpatterns():
def get_all_plugin_urlpatterns() -> List[Any]:
"""
Helper function to return all urlpatterns of all plugins listed in
settings.INSTALLED_PLUGINS.
"""
urlpatterns = []
urlpatterns = [] # type: List[Any]
for plugin in settings.INSTALLED_PLUGINS:
plugin_urlpatterns = get_plugin_urlpatterns(plugin)
if plugin_urlpatterns:

View File

@ -1,4 +1,4 @@
from typing import Optional # noqa
from typing import Any, Dict, Iterable, List, Optional # noqa
from django.dispatch import Signal
@ -18,7 +18,7 @@ class ProjectorElement(object, metaclass=SignalConnectMetaClass):
signal = Signal()
name = None # type: Optional[str]
def __init__(self, **kwargs):
def __init__(self, **kwargs: str) -> None:
"""
Initializes the projector element instance. This is done when the
signal is sent.
@ -29,15 +29,16 @@ class ProjectorElement(object, metaclass=SignalConnectMetaClass):
pass
@classmethod
def get_dispatch_uid(cls):
def get_dispatch_uid(cls) -> Optional[str]:
"""
Returns the classname as a unique string for each class. Returns None
for the base class so it will not be connected to the signal.
"""
if not cls.__name__ == 'ProjectorElement':
return cls.__name__
return None
def check_and_update_data(self, projector_object, config_entry):
def check_and_update_data(self, projector_object: Any, config_entry: Any) -> Any:
"""
Checks projector element data via self.check_data() and updates
them via self.update_data(). The projector object and the config
@ -50,7 +51,7 @@ class ProjectorElement(object, metaclass=SignalConnectMetaClass):
self.check_data()
return self.update_data() or {}
def check_data(self):
def check_data(self) -> None:
"""
Method can be overridden to validate projector element data. This
may raise ProjectorException in case of an error.
@ -59,7 +60,7 @@ class ProjectorElement(object, metaclass=SignalConnectMetaClass):
"""
pass
def update_data(self):
def update_data(self) -> Dict[Any, Any]:
"""
Method can be overridden to update the projector element data
output. This should return a dictonary. Use this for server
@ -69,21 +70,23 @@ class ProjectorElement(object, metaclass=SignalConnectMetaClass):
"""
pass
def get_requirements(self, config_entry):
def get_requirements(self, config_entry: Any) -> Iterable[Any]:
"""
Returns an iterable of instances that are required for this projector
element. The config_entry has to be given.
"""
return ()
def get_requirements_as_collection_elements(self, config_entry):
def get_requirements_as_collection_elements(self, config_entry: Any) -> Iterable[CollectionElement]:
"""
Returns an iterable of collection elements that are required for this
projector element. The config_entry has to be given.
"""
return (CollectionElement.from_instance(instance) for instance in self.get_requirements(config_entry))
def get_collection_elements_required_for_this(self, collection_element, config_entry):
def get_collection_elements_required_for_this(
self, collection_element: CollectionElement,
config_entry: Any) -> List[CollectionElement]:
"""
Returns a list of CollectionElements that have to be sent to every
projector that shows this projector element according to the given

View File

@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Optional # noqa
from typing import Any, Dict, Iterable, Optional, Type # noqa
from django.http import Http404
from rest_framework import status # noqa
@ -24,6 +24,7 @@ from rest_framework.serializers import ( # noqa
ManyRelatedField,
PrimaryKeyRelatedField,
RelatedField,
Serializer,
SerializerMethodField,
ValidationError,
)
@ -31,7 +32,7 @@ from rest_framework.viewsets import GenericViewSet as _GenericViewSet # noqa
from rest_framework.viewsets import ModelViewSet as _ModelViewSet # noqa
from rest_framework.viewsets import ViewSet as _ViewSet # noqa
from .access_permissions import BaseAccessPermissions # noqa
from .access_permissions import BaseAccessPermissions, RestrictedData # noqa
from .auth import user_to_collection_user
from .collection import Collection, CollectionElement
@ -47,7 +48,7 @@ class IdManyRelatedField(ManyRelatedField):
"""
field_name_suffix = '_id'
def bind(self, field_name, parent):
def bind(self, field_name: str, parent: Any) -> None:
"""
Called when the field is bound to the serializer.
@ -65,7 +66,7 @@ class IdPrimaryKeyRelatedField(PrimaryKeyRelatedField):
"""
field_name_suffix = '_id'
def bind(self, field_name, parent):
def bind(self, field_name: str, parent: Any) -> None:
"""
Called when the field is bound to the serializer.
@ -80,7 +81,7 @@ class IdPrimaryKeyRelatedField(PrimaryKeyRelatedField):
super().bind(field_name, parent)
@classmethod
def many_init(cls, *args, **kwargs):
def many_init(cls, *args: Any, **kwargs: Any) -> IdManyRelatedField:
"""
Method from rest_framework.relations.RelatedField That uses our
IdManyRelatedField class instead of
@ -106,7 +107,7 @@ class PermissionMixin:
"""
access_permissions = None # type: Optional[BaseAccessPermissions]
def get_permissions(self):
def get_permissions(self) -> Iterable[str]:
"""
Overridden method to check view permissions. Returns an empty
iterable so Django REST framework won't do any other permission
@ -114,10 +115,10 @@ class PermissionMixin:
and the request passes.
"""
if not self.check_view_permissions():
self.permission_denied(self.request)
self.permission_denied(self.request) # type: ignore
return ()
def check_view_permissions(self):
def check_view_permissions(self) -> bool:
"""
Override this and return True if the requesting user should be able to
get access to your view.
@ -127,22 +128,22 @@ class PermissionMixin:
"""
return False
def get_access_permissions(self):
def get_access_permissions(self) -> BaseAccessPermissions:
"""
Returns a container to handle access permissions for this viewset and
its corresponding model.
"""
return self.access_permissions
return self.access_permissions # type: ignore
def get_serializer_class(self):
def get_serializer_class(self) -> Type[Serializer]:
"""
Overridden method to return the serializer class given by the
access permissions container.
"""
if self.get_access_permissions() is not None:
serializer_class = self.get_access_permissions().get_serializer_class(self.request.user)
serializer_class = self.get_access_permissions().get_serializer_class(self.request.user) # type: ignore
else:
serializer_class = super().get_serializer_class()
serializer_class = super().get_serializer_class() # type: ignore
return serializer_class
@ -153,11 +154,11 @@ class ModelSerializer(_ModelSerializer):
"""
serializer_related_field = IdPrimaryKeyRelatedField
def get_fields(self):
def get_fields(self) -> Any:
"""
Returns all fields of the serializer.
"""
fields = OrderedDict()
fields = OrderedDict() # type: Dict[str, Field]
for field_name, field in super().get_fields().items():
try:
@ -177,7 +178,7 @@ class ListModelMixin(_ListModelMixin):
queryset = Model.objects.all()
"""
def list(self, request, *args, **kwargs):
def list(self, request: Any, *args: Any, **kwargs: Any) -> Response:
model = self.get_queryset().model
try:
collection_string = model.get_collection_string()
@ -200,7 +201,7 @@ class RetrieveModelMixin(_RetrieveModelMixin):
queryset = Model.objects.all()
"""
def retrieve(self, request, *args, **kwargs):
def retrieve(self, request: Any, *args: Any, **kwargs: Any) -> Response:
model = self.get_queryset().model
try:
collection_string = model.get_collection_string()
@ -213,7 +214,7 @@ class RetrieveModelMixin(_RetrieveModelMixin):
collection_string, self.kwargs[lookup_url_kwarg])
user = user_to_collection_user(request.user)
try:
content = collection_element.as_dict_for_user(user)
content = collection_element.as_dict_for_user(user) # type: RestrictedData
except collection_element.get_model().DoesNotExist:
raise Http404
if content is None:

View File

@ -1,4 +1,5 @@
from contextlib import ContextDecorator
from typing import Any
from unittest.mock import patch
from django.core.cache import caches
@ -9,7 +10,7 @@ from ..core.config import config
class OpenSlidesDiscoverRunner(DiscoverRunner):
def run_tests(self, test_labels, extra_tests=None, **kwargs):
def run_tests(self, test_labels, extra_tests=None, **kwargs): # type: ignore
"""
Test Runner which does not create a database, if only unittest are run.
"""
@ -35,7 +36,7 @@ class TestCase(_TestCase):
Resets the config object after each test.
"""
def tearDown(self):
def tearDown(self) -> None:
config.key_to_id = {}
@ -48,11 +49,11 @@ class use_cache(ContextDecorator):
The code inside the contextmananger starts with an empty cache.
"""
def __enter__(self):
def __enter__(self) -> None:
cache = caches['locmem']
cache.clear()
self.patch = patch('openslides.utils.collection.cache', cache)
self.patch.start()
def __exit__(self, *exc):
def __exit__(self, *exc: Any) -> None:
self.patch.stop()

View File

@ -6,7 +6,7 @@ CAMEL_CASE_TO_PSEUDO_SNAKE_CASE_CONVERSION_REGEX_1 = re.compile('(.)([A-Z][a-z]+
CAMEL_CASE_TO_PSEUDO_SNAKE_CASE_CONVERSION_REGEX_2 = re.compile('([a-z0-9])([A-Z])')
def convert_camel_case_to_pseudo_snake_case(text):
def convert_camel_case_to_pseudo_snake_case(text: str) -> str:
"""
Converts camel case to pseudo snake case using hyphen instead of
underscore.
@ -19,12 +19,13 @@ def convert_camel_case_to_pseudo_snake_case(text):
return CAMEL_CASE_TO_PSEUDO_SNAKE_CASE_CONVERSION_REGEX_2.sub(r'\1-\2', s1).lower()
def to_roman(number):
def to_roman(number: int) -> str:
"""
Converts an arabic number within range from 1 to 4999 to the
corresponding roman number. Returns None on error conditions.
corresponding roman number. Returns the input converted as string on error
conditions or higher numbers.
"""
try:
return roman.toRoman(number)
except (roman.NotIntegerError, roman.OutOfRangeError):
return None
return str(number)

View File

@ -20,16 +20,13 @@ allowed_styles = [
]
def validate_html(html):
def validate_html(html: str) -> str:
"""
This method takes a string and escapes all non-whitelisted html entries.
Every field of a model that is loaded trusted in the DOM should be validated.
"""
if isinstance(html, str):
return bleach.clean(
html,
tags=allowed_tags,
attributes=allowed_attributes,
styles=allowed_styles)
else:
return html
return bleach.clean(
html,
tags=allowed_tags,
attributes=allowed_attributes,
styles=allowed_styles)

View File

@ -1,12 +1,10 @@
from typing import List # noqa
from typing import Any, Dict, List # noqa
from django.views import generic as django_views
from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic.base import View
from rest_framework.response import Response
from rest_framework.views import APIView as _APIView
View = django_views.View
class CSRFMixin:
"""
@ -14,8 +12,8 @@ class CSRFMixin:
"""
@classmethod
def as_view(cls, *args, **kwargs):
view = super().as_view(*args, **kwargs)
def as_view(cls, *args: Any, **kwargs: Any) -> View:
view = super().as_view(*args, **kwargs) # type: ignore
return ensure_csrf_cookie(view)
@ -32,13 +30,13 @@ class APIView(_APIView):
http_method_names = ['get', 'post', 'put', 'patch', 'delete', 'head', 'options', 'trace']
"""
def get_context_data(self, **context):
def get_context_data(self, **context: Any) -> Dict[str, Any]:
"""
Returns the context for the response.
"""
return context
def method_call(self, request, *args, **kwargs):
def method_call(self, request: Any, *args: Any, **kwargs: Any) -> Any:
"""
Http method that returns the response object with the context data.
"""

View File

@ -17,8 +17,12 @@ multi_line_output = 3
[mypy]
ignore_missing_imports = true
strict_optional = true
check_untyped_defs = true
[mypy-openslides.utils.auth]
[mypy-openslides.utils.dispatch]
ignore_errors = true
[mypy-openslides.utils.*]
disallow_any = unannotated
[mypy-openslides.core.config]

View File

@ -8,4 +8,4 @@ class ToRomanTest(TestCase):
self.assertEqual(utils.to_roman(3), 'III')
def test_to_roman_none(self):
self.assertTrue(utils.to_roman(-3) is None)
self.assertEqual(utils.to_roman(-3), '-3')