import json import logging from collections import defaultdict from datetime import datetime from time import sleep from typing import Any, Callable, Dict, List, Optional, Tuple, Type from asgiref.sync import async_to_sync from django.apps import apps from .cache_providers import ( Cachable, ElementCacheProvider, MemoryCacheProvider, RedisCacheProvider, ) from .redis import use_redis from .schema_version import SchemaVersion, schema_version_handler from .utils import get_element_id, split_element_id logger = logging.getLogger(__name__) def get_all_cachables() -> List[Cachable]: """ Returns all element of OpenSlides. """ out: List[Cachable] = [] for app in apps.get_app_configs(): try: # Get the method get_startup_elements() from an app. # This method has to return an iterable of Cachable objects. get_startup_elements = app.get_startup_elements except AttributeError: # Skip apps that do not implement get_startup_elements. continue out.extend(get_startup_elements()) return out class ElementCache: """ Cache for the elements. Saves the full_data There is one redis Hash (simular to python dict) for the full_data The key of the Hashes is COLLECTIONSTRING:ID where COLLECTIONSTRING is the collection_string of a collection and id the id of an element. There is an sorted set in redis with the change id as score. The values are COLLETIONSTRING:ID for the elements that have been changed with that change id. With this key it is possible, to get all elements as full_data that are newer then a specific change id. All method of this class are async. You either have to call them with await in an async environment or use asgiref.sync.async_to_sync(). """ def __init__( self, cache_provider_class: Type[ElementCacheProvider] = RedisCacheProvider, cachable_provider: Callable[[], List[Cachable]] = get_all_cachables, default_change_id: Optional[int] = None, ) -> None: """ Initializes the cache. """ self.cache_provider = cache_provider_class(self.async_ensure_cache) self.cachable_provider = cachable_provider self._cachables: Optional[Dict[str, Cachable]] = None self.default_change_id: Optional[int] = default_change_id @property def cachables(self) -> Dict[str, Cachable]: """ Returns all cachables as a dict where the key is the collection_string of the cachable. """ # This method is neccessary to lazy load the cachables if self._cachables is None: self._cachables = { cachable.get_collection_string(): cachable for cachable in self.cachable_provider() } return self._cachables def ensure_cache( self, reset: bool = False, default_change_id: Optional[int] = None ) -> None: """ Ensures the existance of the cache; see async_ensure_cache for more info. """ async_to_sync(self.async_ensure_cache)(reset, default_change_id) async def async_ensure_cache( self, reset: bool = False, default_change_id: Optional[int] = None ) -> None: """ Makes sure that the cache exist. Builds the cache if not or reset is given as True. """ cache_exists = await self.cache_provider.data_exists() if reset or not cache_exists: await self.build_cache(default_change_id) def ensure_schema_version(self) -> None: async_to_sync(self.async_ensure_schema_version)() async def async_ensure_schema_version(self) -> None: cache_schema_version = await self.cache_provider.get_schema_version() schema_changed = not schema_version_handler.compare(cache_schema_version) schema_version_handler.log_current() cache_exists = await self.cache_provider.data_exists() if schema_changed or not cache_exists: await self.build_cache(schema_version=schema_version_handler.get()) async def build_cache( self, default_change_id: Optional[int] = None, schema_version: Optional[SchemaVersion] = None, ) -> None: lock_name = "build_cache" # Set a lock so only one process builds the cache if await self.cache_provider.set_lock(lock_name): logger.info("Building up the cache data...") try: mapping = {} for collection_string, cachable in self.cachables.items(): for element in cachable.get_elements(): mapping.update( { get_element_id( collection_string, element["id"] ): json.dumps(element) } ) logger.info("Done building the cache data.") logger.info("Saving cache data into the cache...") if default_change_id is None: if self.default_change_id is not None: default_change_id = self.default_change_id else: # Use the miliseconds (rounded) since the 2016-02-29. default_change_id = int( (datetime.utcnow() - datetime(2016, 2, 29)).total_seconds() ) default_change_id *= 1000 await self.cache_provider.reset_full_cache(mapping, default_change_id) if schema_version: await self.cache_provider.set_schema_version(schema_version) logger.info("Done saving the cache data.") finally: await self.cache_provider.del_lock(lock_name) else: logger.info("Wait for another process to build up the cache...") while await self.cache_provider.get_lock(lock_name): sleep(0.01) logger.info("Cache is ready (built by another process).") async def change_elements( self, elements: Dict[str, Optional[Dict[str, Any]]] ) -> int: """ Changes elements in the cache. elements is a dict with element_id <-> changed element. When the value is None, it is interpreded as deleted. Returns the new generated change_id. """ # Split elements into changed and deleted. deleted_elements = [] changed_elements = [] for element_id, data in elements.items(): if data: # The arguments for redis.hset is pairs of key value changed_elements.append(element_id) changed_elements.append(json.dumps(data)) else: deleted_elements.append(element_id) return await self.cache_provider.add_changed_elements( changed_elements, deleted_elements ) async def get_all_data_list( self, user_id: Optional[int] = None ) -> Dict[str, List[Dict[str, Any]]]: """ Returns all data with a list per collection: { : [, , ...] } If the user id is given the data will be restricted for this user. """ all_data: Dict[str, List[Dict[str, Any]]] = defaultdict(list) for element_id, data in (await self.cache_provider.get_all_data()).items(): collection_string, _ = split_element_id(element_id) all_data[collection_string].append(json.loads(data.decode())) if user_id is not None: for collection_string in all_data.keys(): restricter = self.cachables[collection_string].restrict_elements all_data[collection_string] = await restricter( user_id, all_data[collection_string] ) return dict(all_data) async def get_all_data_dict(self) -> Dict[str, Dict[int, Dict[str, Any]]]: """ Returns all data with a dict (id <-> element) per collection: { : { : } } """ all_data: Dict[str, Dict[int, Dict[str, Any]]] = defaultdict(dict) for element_id, data in (await self.cache_provider.get_all_data()).items(): collection_string, id = split_element_id(element_id) all_data[collection_string][id] = json.loads(data.decode()) return dict(all_data) async def get_collection_data( self, collection_string: str ) -> Dict[int, Dict[str, Any]]: """ Returns the data for one collection as dict: {id: } """ encoded_collection_data = await self.cache_provider.get_collection_data( collection_string ) collection_data = {} for id in encoded_collection_data.keys(): collection_data[id] = json.loads(encoded_collection_data[id].decode()) return collection_data async def get_element_data( self, collection_string: str, id: int, user_id: Optional[int] = None ) -> Optional[Dict[str, Any]]: """ Returns one element or None, if the element does not exist. If the user id is given the data will be restricted for this user. """ encoded_element = await self.cache_provider.get_element_data( get_element_id(collection_string, id) ) if encoded_element is None: return None element = json.loads(encoded_element.decode()) # type: ignore if user_id is not None: restricter = self.cachables[collection_string].restrict_elements restricted_elements = await restricter(user_id, [element]) element = restricted_elements[0] if restricted_elements else None return element async def get_data_since( self, user_id: Optional[int] = None, change_id: int = 0, max_change_id: int = -1 ) -> Tuple[Dict[str, List[Dict[str, Any]]], List[str]]: """ Returns all data since change_id until max_change_id (included). max_change_id -1 means the highest change_id. If the user id is given the data will be restricted for this user. Returns two values inside a tuple. The first value is a dict where the key is the collection_string and the value is a list of data. The second is a list of element_ids with deleted elements. Only returns elements with the change_id or newer. When change_id is 0, all elements are returned. Raises a RuntimeError when the lowest change_id in redis is higher then the requested change_id. In this case the method has to be rerun with change_id=0. This is importend because there could be deleted elements that the cache does not know about. """ if change_id == 0: return (await self.get_all_data_list(user_id), []) # This raises a Runtime Exception, if there is no change_id lowest_change_id = await self.get_lowest_change_id() if change_id < lowest_change_id: # When change_id is lower then the lowest change_id in redis, we can # not inform the user about deleted elements. raise RuntimeError( f"change_id {change_id} is lower then the lowest change_id in redis {lowest_change_id}. " "Catch this exception and rerun the method with change_id=0." ) raw_changed_elements, deleted_elements = await self.cache_provider.get_data_since( change_id, max_change_id=max_change_id ) changed_elements = { collection_string: [json.loads(value.decode()) for value in value_list] for collection_string, value_list in raw_changed_elements.items() } if user_id is not None: for collection_string, elements in changed_elements.items(): restricter = self.cachables[collection_string].restrict_elements restricted_elements = await restricter(user_id, elements) # Add removed objects (through restricter) to deleted elements. element_ids = set([element["id"] for element in elements]) restricted_element_ids = set( [element["id"] for element in restricted_elements] ) for id in element_ids - restricted_element_ids: deleted_elements.append(get_element_id(collection_string, id)) if not restricted_elements: del changed_elements[collection_string] else: changed_elements[collection_string] = restricted_elements return (changed_elements, deleted_elements) async def get_current_change_id(self) -> int: """ Returns the current change id. Returns default_change_id if there is no change id yet. """ return await self.cache_provider.get_current_change_id() async def get_lowest_change_id(self) -> int: """ Returns the lowest change id. Raises a RuntimeError if there is no change_id. """ return await self.cache_provider.get_lowest_change_id() def load_element_cache() -> ElementCache: """ Generates an element cache instance. """ if use_redis: cache_provider_class: Type[ElementCacheProvider] = RedisCacheProvider else: cache_provider_class = MemoryCacheProvider return ElementCache(cache_provider_class=cache_provider_class) # Set the element_cache element_cache = load_element_cache()