279c2ba796
Prevents corrupt cache if restricted data cache is cleared while runtime.
473 lines
16 KiB
Python
473 lines
16 KiB
Python
import json
|
|
from collections import defaultdict
|
|
from typing import ( # noqa
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Generator,
|
|
Iterable,
|
|
List,
|
|
Optional,
|
|
Set,
|
|
Type,
|
|
Union,
|
|
)
|
|
|
|
from channels import Group
|
|
from channels.sessions import session_for_reply_channel
|
|
from django.conf import settings
|
|
from django.core.cache import cache, caches
|
|
|
|
if TYPE_CHECKING:
|
|
# Dummy import Collection for mypy
|
|
from .collection import Collection # noqa
|
|
|
|
UserCacheDataType = Dict[int, Set[str]]
|
|
|
|
|
|
class BaseWebsocketUserCache:
|
|
"""
|
|
Caches the reply channel names of all open websocket connections. The id of
|
|
the user that that opened the connection is used as reference.
|
|
|
|
This is the Base cache that has to be overriden.
|
|
"""
|
|
cache_key = 'current_websocket_users'
|
|
|
|
def add(self, user_id: int, channel_name: str) -> None:
|
|
"""
|
|
Adds a channel name to an user id.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def remove(self, user_id: int, channel_name: str) -> None:
|
|
"""
|
|
Removes one channel name from the cache.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def get_all(self) -> UserCacheDataType:
|
|
"""
|
|
Returns all data using a dict where the key is a user id and the value
|
|
is a set of channel_names.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def save_data(self, data: UserCacheDataType) -> None:
|
|
"""
|
|
Saves the full data set (like created with build_data) to the cache.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def build_data(self) -> UserCacheDataType:
|
|
"""
|
|
Creates all the data, saves it to the cache and returns it.
|
|
"""
|
|
websocket_user_ids = defaultdict(set) # type: UserCacheDataType
|
|
for channel_name in Group('site').channel_layer.group_channels('site'):
|
|
session = session_for_reply_channel(channel_name)
|
|
user_id = session.get('user_id', None)
|
|
websocket_user_ids[user_id or 0].add(channel_name)
|
|
self.save_data(websocket_user_ids)
|
|
return websocket_user_ids
|
|
|
|
def get_cache_key(self) -> str:
|
|
"""
|
|
Returns the cache key.
|
|
"""
|
|
return self.cache_key
|
|
|
|
|
|
class RedisWebsocketUserCache(BaseWebsocketUserCache):
|
|
"""
|
|
Implementation of the WebsocketUserCache that uses redis.
|
|
|
|
This uses one cache key to store all connected user ids in a set and
|
|
for each user another set to save the channel names.
|
|
"""
|
|
|
|
def add(self, user_id: int, channel_name: str) -> None:
|
|
"""
|
|
Adds a channel name to an user id.
|
|
"""
|
|
redis = get_redis_connection()
|
|
pipe = redis.pipeline()
|
|
pipe.sadd(self.get_cache_key(), user_id)
|
|
pipe.sadd(self.get_user_cache_key(user_id), channel_name)
|
|
pipe.execute()
|
|
|
|
def remove(self, user_id: int, channel_name: str) -> None:
|
|
"""
|
|
Removes one channel name from the cache.
|
|
"""
|
|
redis = get_redis_connection()
|
|
redis.srem(self.get_user_cache_key(user_id), channel_name)
|
|
|
|
def get_all(self) -> UserCacheDataType:
|
|
"""
|
|
Returns all data using a dict where the key is a user id and the value
|
|
is a set of channel_names.
|
|
"""
|
|
redis = get_redis_connection()
|
|
user_ids = redis.smembers(self.get_cache_key()) # type: Optional[List[str]]
|
|
if user_ids is None:
|
|
websocket_user_ids = self.build_data()
|
|
else:
|
|
websocket_user_ids = dict()
|
|
for redis_user_id in user_ids:
|
|
# Redis returns the id as string. So we have to convert it
|
|
user_id = int(redis_user_id)
|
|
channel_names = redis.smembers(self.get_user_cache_key(user_id)) # type: Optional[List[str]]
|
|
if channel_names is not None:
|
|
# If channel name is empty, then we can assume, that the user
|
|
# has no active connection.
|
|
websocket_user_ids[user_id] = set(channel_names)
|
|
return websocket_user_ids
|
|
|
|
def save_data(self, data: UserCacheDataType) -> None:
|
|
"""
|
|
Saves the full data set (like created with the method build_data()) to
|
|
the cache.
|
|
"""
|
|
redis = get_redis_connection()
|
|
pipe = redis.pipeline()
|
|
|
|
# Save all user ids
|
|
pipe.delete(self.get_cache_key())
|
|
pipe.sadd(self.get_cache_key(), *data.keys())
|
|
|
|
for user_id, channel_names in data.items():
|
|
pipe.delete(self.get_user_cache_key(user_id))
|
|
pipe.sadd(self.get_user_cache_key(user_id), *channel_names)
|
|
pipe.execute()
|
|
|
|
def get_cache_key(self) -> str:
|
|
"""
|
|
Returns the cache key.
|
|
"""
|
|
return cache.make_key(self.cache_key)
|
|
|
|
def get_user_cache_key(self, user_id: int) -> str:
|
|
"""
|
|
Returns a cache key to save the channel names for a specific user.
|
|
"""
|
|
return cache.make_key('{}:{}'.format(self.cache_key, user_id))
|
|
|
|
|
|
class DjangoCacheWebsocketUserCache(BaseWebsocketUserCache):
|
|
"""
|
|
Implementation of the WebsocketUserCache that uses the django cache.
|
|
|
|
If you use this with the inmemory cache, then you should only use one
|
|
worker.
|
|
|
|
This uses only one cache key to save a dict where the key is the user id and
|
|
the value is a set of channel names.
|
|
"""
|
|
|
|
def add(self, user_id: int, channel_name: str) -> None:
|
|
"""
|
|
Adds a channel name for a user using the django cache.
|
|
"""
|
|
websocket_user_ids = cache.get(self.get_cache_key())
|
|
if websocket_user_ids is None:
|
|
websocket_user_ids = dict()
|
|
|
|
if user_id in websocket_user_ids:
|
|
websocket_user_ids[user_id].add(channel_name)
|
|
else:
|
|
websocket_user_ids[user_id] = set([channel_name])
|
|
cache.set(self.get_cache_key(), websocket_user_ids)
|
|
|
|
def remove(self, user_id: int, channel_name: str) -> None:
|
|
"""
|
|
Removes one channel name from the django cache.
|
|
"""
|
|
websocket_user_ids = cache.get(self.get_cache_key())
|
|
if websocket_user_ids is not None and user_id in websocket_user_ids:
|
|
websocket_user_ids[user_id].discard(channel_name)
|
|
cache.set(self.get_cache_key(), websocket_user_ids)
|
|
|
|
def get_all(self) -> UserCacheDataType:
|
|
"""
|
|
Returns the data using the django cache.
|
|
"""
|
|
websocket_user_ids = cache.get(self.get_cache_key())
|
|
if websocket_user_ids is None:
|
|
return self.build_data()
|
|
return websocket_user_ids
|
|
|
|
def save_data(self, data: UserCacheDataType) -> None:
|
|
"""
|
|
Saves the data using the django cache.
|
|
"""
|
|
cache.set(self.get_cache_key(), data)
|
|
|
|
|
|
class FullDataCache:
|
|
"""
|
|
Caches all data as full data.
|
|
|
|
Helps to get all data from one collection.
|
|
"""
|
|
|
|
base_cache_key = 'full_data_cache'
|
|
|
|
def build_for_collection(self, collection_string: str) -> None:
|
|
"""
|
|
Build the cache for collection from a django model.
|
|
|
|
Rebuilds the cache for that collection, if it already exists.
|
|
"""
|
|
redis = get_redis_connection()
|
|
pipe = redis.pipeline()
|
|
|
|
# Clear the cache for collection
|
|
pipe.delete(self.get_cache_key(collection_string))
|
|
|
|
# Save all elements
|
|
from .collection import get_model_from_collection_string
|
|
model = get_model_from_collection_string(collection_string)
|
|
try:
|
|
query = model.objects.get_full_queryset()
|
|
except AttributeError:
|
|
# If the model des not have to method get_full_queryset(), then use
|
|
# the default queryset from django.
|
|
query = model.objects
|
|
|
|
# Build a dict from the instance id to the full_data
|
|
mapping = {instance.pk: json.dumps(model.get_access_permissions().get_full_data(instance))
|
|
for instance in query.all()}
|
|
|
|
if mapping:
|
|
# Save the dict into a redis map, if there is at least one value
|
|
pipe.hmset(
|
|
self.get_cache_key(collection_string),
|
|
mapping)
|
|
|
|
pipe.execute()
|
|
|
|
def add_element(self, collection_string: str, id: int, data: Dict[str, Any]) -> None:
|
|
"""
|
|
Adds one element to the cache. If the cache does not exists for the collection,
|
|
it is created.
|
|
"""
|
|
redis = get_redis_connection()
|
|
|
|
# If the cache does not exist for the collection, then create it first.
|
|
if not self.exists_for_collection(collection_string):
|
|
self.build_for_collection(collection_string)
|
|
|
|
redis.hset(
|
|
self.get_cache_key(collection_string),
|
|
id,
|
|
json.dumps(data))
|
|
|
|
def del_element(self, collection_string: str, id: int) -> None:
|
|
"""
|
|
Removes one element from the cache.
|
|
|
|
Does nothing if the cache does not exist.
|
|
"""
|
|
redis = get_redis_connection()
|
|
redis.hdel(
|
|
self.get_cache_key(collection_string),
|
|
id)
|
|
|
|
def exists_for_collection(self, collection_string: str) -> bool:
|
|
"""
|
|
Returns True if the cache for the collection exists, else False.
|
|
"""
|
|
redis = get_redis_connection()
|
|
return redis.exists(self.get_cache_key(collection_string))
|
|
|
|
def get_data(self, collection_string: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Returns all data for the collection.
|
|
"""
|
|
redis = get_redis_connection()
|
|
return [json.loads(element.decode()) for element in redis.hvals(self.get_cache_key(collection_string))]
|
|
|
|
def get_element(self, collection_string: str, id: int) -> Dict[str, Any]:
|
|
"""
|
|
Returns one element from the collection.
|
|
|
|
Raises model.DoesNotExist if the element is not in the cache.
|
|
"""
|
|
redis = get_redis_connection()
|
|
element = redis.hget(self.get_cache_key(collection_string), id)
|
|
if element is None:
|
|
from .collection import get_model_from_collection_string
|
|
model = get_model_from_collection_string(collection_string)
|
|
raise model.DoesNotExist(collection_string, id)
|
|
return json.loads(element.decode())
|
|
|
|
def get_cache_key(self, collection_string: str) -> str:
|
|
"""
|
|
Returns the cache key for a collection.
|
|
"""
|
|
return cache.make_key('{}:{}'.format(self.base_cache_key, collection_string))
|
|
|
|
|
|
class DummyFullDataCache:
|
|
"""
|
|
Dummy FullDataCache that does nothing.
|
|
"""
|
|
def build_for_collection(self, collection_string: str) -> None:
|
|
pass
|
|
|
|
def add_element(self, collection_string: str, id: int, data: Dict[str, Any]) -> None:
|
|
pass
|
|
|
|
def del_element(self, collection_string: str, id: int) -> None:
|
|
pass
|
|
|
|
def exists_for_collection(self, collection_string: str) -> bool:
|
|
return False
|
|
|
|
def get_data(self, collection_string: str) -> List[Dict[str, Any]]:
|
|
from .collection import get_model_from_collection_string
|
|
model = get_model_from_collection_string(collection_string)
|
|
try:
|
|
query = model.objects.get_full_queryset()
|
|
except AttributeError:
|
|
# If the model des not have to method get_full_queryset(), then use
|
|
# the default queryset from django.
|
|
query = model.objects
|
|
|
|
return [model.get_access_permissions().get_full_data(instance)
|
|
for instance in query.all()]
|
|
|
|
def get_element(self, collection_string: str, id: int) -> Dict[str, Any]:
|
|
from .collection import get_model_from_collection_string
|
|
model = get_model_from_collection_string(collection_string)
|
|
try:
|
|
query = model.objects.get_full_queryset()
|
|
except AttributeError:
|
|
# If the model des not have to method get_full_queryset(), then use
|
|
# the default queryset from django.
|
|
query = model.objects
|
|
|
|
return model.get_access_permissions().get_full_data(query.get(pk=id))
|
|
|
|
|
|
class RestrictedDataCache:
|
|
"""
|
|
Caches all data for a specific users.
|
|
|
|
Helps to get all data from all collections for a specific user.
|
|
|
|
The cached values are expected to be formatted for outout via websocket.
|
|
"""
|
|
|
|
base_cache_key = 'restricted_user_cache'
|
|
|
|
def update_element(self, user_id: int, collection_string: str, id: int, data: object) -> None:
|
|
"""
|
|
Adds on element to the cache only if the cache exists for the user.
|
|
|
|
Note: This method is not atomic. So in very rare cases it is possible
|
|
that the restricted date cache can become corrupt. The best solution would be to
|
|
use a lua script instead. See also #3427.
|
|
"""
|
|
if self.exists_for_user(user_id):
|
|
self.add_element(user_id, collection_string, id, data)
|
|
|
|
def add_element(self, user_id: int, collection_string: str, id: int, data: object) -> None:
|
|
"""
|
|
Adds one element to the cache. If the cache does not exists for the user,
|
|
it is created.
|
|
"""
|
|
redis = get_redis_connection()
|
|
redis.hset(
|
|
self.get_cache_key(user_id),
|
|
"{}/{}".format(collection_string, id),
|
|
json.dumps(data))
|
|
|
|
def del_element(self, user_id: int, collection_string: str, id: int) -> None:
|
|
"""
|
|
Removes one element from the cache.
|
|
|
|
Does nothing if the cache does not exist.
|
|
"""
|
|
redis = get_redis_connection()
|
|
redis.hdel(
|
|
self.get_cache_key(user_id),
|
|
"{}/{}".format(collection_string, id))
|
|
|
|
def exists_for_user(self, user_id: int) -> bool:
|
|
"""
|
|
Returns True if the cache for the user exists, else False.
|
|
"""
|
|
redis = get_redis_connection()
|
|
return redis.exists(self.get_cache_key(user_id))
|
|
|
|
def get_data(self, user_id: int) -> List[object]:
|
|
"""
|
|
Returns all data for the user.
|
|
|
|
The returned value is a list of the elements.
|
|
"""
|
|
redis = get_redis_connection()
|
|
return [json.loads(element.decode()) for element in redis.hvals(self.get_cache_key(user_id))]
|
|
|
|
def get_cache_key(self, user_id: int) -> str:
|
|
"""
|
|
Returns the cache key for a user.
|
|
"""
|
|
return cache.make_key('{}:{}'.format(self.base_cache_key, user_id))
|
|
|
|
|
|
class DummyRestrictedDataCache:
|
|
"""
|
|
Dummy RestrictedDataCache that does nothing.
|
|
"""
|
|
|
|
def update_element(self, user_id: int, collection_string: str, id: int, data: object) -> None:
|
|
pass
|
|
|
|
def add_element(self, user_id: int, collection_string: str, id: int, data: object) -> None:
|
|
pass
|
|
|
|
def del_element(self, user_id: int, collection_string: str, id: int) -> None:
|
|
pass
|
|
|
|
def exists_for_user(self, user_id: int) -> bool:
|
|
return False
|
|
|
|
def get_data(self, user_id: int) -> List[object]:
|
|
pass
|
|
|
|
|
|
def use_redis_cache() -> bool:
|
|
"""
|
|
Returns True if Redis is used als caching backend.
|
|
"""
|
|
try:
|
|
from django_redis.cache import RedisCache
|
|
except ImportError:
|
|
return False
|
|
return isinstance(caches['default'], RedisCache)
|
|
|
|
|
|
def get_redis_connection() -> Any:
|
|
"""
|
|
Returns an object that can be used to talk directly to redis.
|
|
"""
|
|
from django_redis import get_redis_connection
|
|
return get_redis_connection("default")
|
|
|
|
|
|
if use_redis_cache():
|
|
websocket_user_cache = RedisWebsocketUserCache() # type: BaseWebsocketUserCache
|
|
if settings.DISABLE_USER_CACHE:
|
|
restricted_data_cache = DummyRestrictedDataCache() # type: Union[RestrictedDataCache, DummyRestrictedDataCache]
|
|
else:
|
|
restricted_data_cache = RestrictedDataCache()
|
|
full_data_cache = FullDataCache() # type: Union[FullDataCache, DummyFullDataCache]
|
|
else:
|
|
websocket_user_cache = DjangoCacheWebsocketUserCache()
|
|
restricted_data_cache = DummyRestrictedDataCache()
|
|
full_data_cache = DummyFullDataCache()
|