OpenSlides/openslides/utils/cache.py
2017-09-03 18:34:55 +02:00

284 lines
9.2 KiB
Python

from collections import defaultdict
from typing import ( # noqa
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
)
from channels import Group
from channels.sessions import session_for_reply_channel
from django.apps import apps
from django.core.cache import cache, caches
if TYPE_CHECKING:
# Dummy import Collection for mypy
from .collection import Collection # noqa
UserCacheDataType = Dict[int, Set[str]]
class BaseWebsocketUserCache:
"""
Caches the reply channel names of all open websocket connections. The id of
the user that that opened the connection is used as reference.
This is the Base cache that has to be overriden.
"""
cache_key = 'current_websocket_users'
def add(self, user_id: int, channel_name: str) -> None:
"""
Adds a channel name to an user id.
"""
raise NotImplementedError()
def remove(self, user_id: int, channel_name: str) -> None:
"""
Removes one channel name from the cache.
"""
raise NotImplementedError()
def get_all(self) -> UserCacheDataType:
"""
Returns all data using a dict where the key is a user id and the value
is a set of channel_names.
"""
raise NotImplementedError()
def save_data(self, data: UserCacheDataType) -> None:
"""
Saves the full data set (like created with build_data) to the cache.
"""
raise NotImplementedError()
def build_data(self) -> UserCacheDataType:
"""
Creates all the data, saves it to the cache and returns it.
"""
websocket_user_ids = defaultdict(set) # type: UserCacheDataType
for channel_name in Group('site').channel_layer.group_channels('site'):
session = session_for_reply_channel(channel_name)
user_id = session.get('user_id', None)
websocket_user_ids[user_id or 0].add(channel_name)
self.save_data(websocket_user_ids)
return websocket_user_ids
def get_cache_key(self) -> str:
"""
Returns the cache key.
"""
return self.cache_key
class RedisWebsocketUserCache(BaseWebsocketUserCache):
"""
Implementation of the WebsocketUserCache that uses redis.
This uses one cache key to store all connected user ids in a set and
for each user another set to save the channel names.
"""
def add(self, user_id: int, channel_name: str) -> None:
"""
Adds a channel name to an user id.
"""
redis = get_redis_connection()
pipe = redis.pipeline()
pipe.sadd(self.get_cache_key(), user_id)
pipe.sadd(self.get_user_cache_key(user_id), channel_name)
pipe.execute()
def remove(self, user_id: int, channel_name: str) -> None:
"""
Removes one channel name from the cache.
"""
redis = get_redis_connection()
redis.srem(self.get_user_cache_key(user_id), channel_name)
def get_all(self) -> UserCacheDataType:
"""
Returns all data using a dict where the key is a user id and the value
is a set of channel_names.
"""
redis = get_redis_connection()
user_ids = redis.smembers(self.get_cache_key()) # type: Optional[List[str]]
if user_ids is None:
websocket_user_ids = self.build_data()
else:
websocket_user_ids = dict()
for redis_user_id in user_ids:
# Redis returns the id as string. So we have to convert it
user_id = int(redis_user_id)
channel_names = redis.smembers(self.get_user_cache_key(user_id)) # type: Optional[List[str]]
if channel_names is not None:
# If channel name is empty, then we can assume, that the user
# has no active connection.
websocket_user_ids[user_id] = set(channel_names)
return websocket_user_ids
def save_data(self, data: UserCacheDataType) -> None:
"""
Saves the full data set (like created with the method build_data()) to
the cache.
"""
redis = get_redis_connection()
pipe = redis.pipeline()
# Save all user ids
pipe.delete(self.get_cache_key())
pipe.sadd(self.get_cache_key(), *data.keys())
for user_id, channel_names in data.items():
pipe.delete(self.get_user_cache_key(user_id))
pipe.sadd(self.get_user_cache_key(user_id), *channel_names)
pipe.execute()
def get_cache_key(self) -> str:
"""
Returns the cache key.
"""
return cache.make_key(self.cache_key)
def get_user_cache_key(self, user_id: int) -> str:
"""
Returns a cache key to save the channel names for a specific user.
"""
return cache.make_key('{}:{}'.format(self.cache_key, user_id))
class DjangoCacheWebsocketUserCache(BaseWebsocketUserCache):
"""
Implementation of the WebsocketUserCache that uses the django cache.
If you use this with the inmemory cache, then you should only use one
worker.
This uses only one cache key to save a dict where the key is the user id and
the value is a set of channel names.
"""
def add(self, user_id: int, channel_name: str) -> None:
"""
Adds a channel name for a user using the django cache.
"""
websocket_user_ids = cache.get(self.get_cache_key())
if websocket_user_ids is None:
websocket_user_ids = dict()
if user_id in websocket_user_ids:
websocket_user_ids[user_id].add(channel_name)
else:
websocket_user_ids[user_id] = set([channel_name])
cache.set(self.get_cache_key(), websocket_user_ids)
def remove(self, user_id: int, channel_name: str) -> None:
"""
Removes one channel name from the django cache.
"""
websocket_user_ids = cache.get(self.get_cache_key())
if websocket_user_ids is not None and user_id in websocket_user_ids:
websocket_user_ids[user_id].discard(channel_name)
cache.set(self.get_cache_key(), websocket_user_ids)
def get_all(self) -> UserCacheDataType:
"""
Returns the data using the django cache.
"""
websocket_user_ids = cache.get(self.get_cache_key())
if websocket_user_ids is None:
return self.build_data()
return websocket_user_ids
def save_data(self, data: UserCacheDataType) -> None:
"""
Saves the data using the django cache.
"""
cache.set(self.get_cache_key(), data)
class StartupCache:
"""
Cache of all data that are required when a client connects via websocket.
"""
cache_key = "full_data_startup_cache"
def build(self) -> Dict[str, List[str]]:
"""
Generate the cache by going through all apps. Returns a dict where the
key is the collection string and the value a list of the full_data from
the collection elements.
"""
cache_data = {} # type: Dict[str, List[str]]
for app in apps.get_app_configs():
try:
# Get the method get_startup_elements() from an app.
# This method has to return an iterable of Collection objects.
get_startup_elements = app.get_startup_elements # type: Callable[[], Iterable[Collection]]
except AttributeError:
# Skip apps that do not implement get_startup_elements.
continue
for collection in get_startup_elements():
cache_data[collection.collection_string] = [
collection_element.get_full_data()
for collection_element
in collection.element_generator()]
cache.set(self.cache_key, cache_data, 86400)
return cache_data
def clear(self) -> None:
"""
Clears the cache.
"""
cache.delete(self.cache_key)
def get_collections(self) -> Generator['Collection', None, None]:
"""
Generator that returns all cached Collections.
The data is read from the cache if it exists. It builds the cache if it
does not exists.
"""
from .collection import Collection # noqa
data = cache.get(self.cache_key)
if data is None:
# The cache does not exist.
data = self.build()
for collection_string, full_data in data.items():
yield Collection(collection_string, full_data)
startup_cache = StartupCache()
def use_redis_cache() -> bool:
"""
Returns True if Redis is used als caching backend.
"""
try:
from django_redis.cache import RedisCache
except ImportError:
return False
return isinstance(caches['default'], RedisCache)
def get_redis_connection() -> Any:
"""
Returns an object that can be used to talk directly to redis.
"""
from django_redis import get_redis_connection
return get_redis_connection("default")
if use_redis_cache():
websocket_user_cache = RedisWebsocketUserCache() # type: BaseWebsocketUserCache
else:
websocket_user_cache = DjangoCacheWebsocketUserCache()