Merge pull request #4896 from FinnStutzenstein/cache

Major cache rewrite
This commit is contained in:
Finn Stutzenstein 2019-08-08 13:21:49 +02:00 committed by GitHub
commit 91238e83bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 906 additions and 1173 deletions

View File

@ -17,7 +17,7 @@ matrix:
- pip freeze - pip freeze
script: script:
- mypy openslides/ tests/ - mypy openslides/ tests/
- pytest --cov --cov-fail-under=76 - pytest --cov --cov-fail-under=75
- language: python - language: python
name: "Server: Tests Python 3.7" name: "Server: Tests Python 3.7"
@ -36,7 +36,7 @@ matrix:
- isort --check-only --diff --recursive openslides tests - isort --check-only --diff --recursive openslides tests
- black --check --diff --target-version py36 openslides tests - black --check --diff --target-version py36 openslides tests
- mypy openslides/ tests/ - mypy openslides/ tests/
- pytest --cov --cov-fail-under=76 - pytest --cov --cov-fail-under=75
- language: python - language: python
name: "Server: Tests Startup Routine Python 3.7" name: "Server: Tests Startup Routine Python 3.7"

View File

@ -24,8 +24,8 @@
"po2json-tempfix": "./node_modules/.bin/po2json -f mf src/assets/i18n/de.po /dev/stdout | sed -f sed_replacements > src/assets/i18n/de.json && ./node_modules/.bin/po2json -f mf src/assets/i18n/cs.po /dev/stdout | sed -f sed_replacements > src/assets/i18n/cs.json", "po2json-tempfix": "./node_modules/.bin/po2json -f mf src/assets/i18n/de.po /dev/stdout | sed -f sed_replacements > src/assets/i18n/de.json && ./node_modules/.bin/po2json -f mf src/assets/i18n/cs.po /dev/stdout | sed -f sed_replacements > src/assets/i18n/cs.json",
"prettify-check": "prettier --config ./.prettierrc --list-different \"src/{app,environments}/**/*{.ts,.js,.json,.css,.scss}\"", "prettify-check": "prettier --config ./.prettierrc --list-different \"src/{app,environments}/**/*{.ts,.js,.json,.css,.scss}\"",
"prettify-write": "prettier --config ./.prettierrc --write \"src/{app,environments}/**/*{.ts,.js,.json,.css,.scss}\"", "prettify-write": "prettier --config ./.prettierrc --write \"src/{app,environments}/**/*{.ts,.js,.json,.css,.scss}\"",
"cleanup": "npm run lint-write; npm run prettify-write", "cleanup": "npm run prettify-write; npm run lint-write",
"cleanup-win": "npm run lint-write & npm run prettify-write" "cleanup-win": "npm run prettify-write & npm run lint-write"
}, },
"dependencies": { "dependencies": {
"@angular/animations": "^8.0.3", "@angular/animations": "^8.0.3",

View File

@ -96,7 +96,10 @@ export class AutoupdateService {
Object.keys(autoupdate.changed).forEach(collection => { Object.keys(autoupdate.changed).forEach(collection => {
elements = elements.concat(this.mapObjectsToBaseModels(collection, autoupdate.changed[collection])); elements = elements.concat(this.mapObjectsToBaseModels(collection, autoupdate.changed[collection]));
}); });
const updateSlot = await this.DSUpdateManager.getNewUpdateSlot(this.DS);
await this.DS.set(elements, autoupdate.to_change_id); await this.DS.set(elements, autoupdate.to_change_id);
this.DSUpdateManager.commit(updateSlot);
} }
/** /**
@ -107,7 +110,7 @@ export class AutoupdateService {
const maxChangeId = this.DS.maxChangeId; const maxChangeId = this.DS.maxChangeId;
if (autoupdate.from_change_id <= maxChangeId && autoupdate.to_change_id <= maxChangeId) { if (autoupdate.from_change_id <= maxChangeId && autoupdate.to_change_id <= maxChangeId) {
console.log('ignore'); console.log(`Ignore. Clients change id: ${maxChangeId}`);
return; // Ignore autoupdates, that lay full behind our changeid. return; // Ignore autoupdates, that lay full behind our changeid.
} }

View File

@ -1,6 +1,7 @@
import { Injectable } from '@angular/core'; import { Injectable } from '@angular/core';
import { Observable, of, Subject } from 'rxjs'; import { BehaviorSubject, Observable } from 'rxjs';
import { filter } from 'rxjs/operators';
import { WebsocketService } from './websocket.service'; import { WebsocketService } from './websocket.service';
@ -28,17 +29,12 @@ export class ConstantsService {
/** /**
* The constants * The constants
*/ */
private constants: Constants; private constants: Constants = {};
/**
* Flag, if constants are requested, but the server hasn't send them yet.
*/
private pending = false;
/** /**
* Pending requests will be notified by these subjects, one per key. * Pending requests will be notified by these subjects, one per key.
*/ */
private pendingSubject: { [key: string]: Subject<any> } = {}; private subjects: { [key: string]: BehaviorSubject<any> } = {};
/** /**
* @param websocketService * @param websocketService
@ -47,29 +43,16 @@ export class ConstantsService {
// The hook for recieving constants. // The hook for recieving constants.
websocketService.getOberservable<Constants>('constants').subscribe(constants => { websocketService.getOberservable<Constants>('constants').subscribe(constants => {
this.constants = constants; this.constants = constants;
if (this.pending) { Object.keys(this.subjects).forEach(key => {
// send constants to subscribers that await constants. this.subjects[key].next(this.constants[key]);
this.pending = false; });
this.informSubjects();
}
}); });
// We can request constants, if the websocket connection opens. // We can request constants, if the websocket connection opens.
// On retries, the `refresh()` method is called by the OpenSlidesService, so // On retries, the `refresh()` method is called by the OpenSlidesService, so
// here we do not need to take care about this. // here we do not need to take care about this.
websocketService.noRetryConnectEvent.subscribe(() => { websocketService.noRetryConnectEvent.subscribe(() => {
if (this.pending) { this.refresh();
this.websocketService.send('constants', {});
}
});
}
/**
* Inform subjects about changes.
*/
private informSubjects(): void {
Object.keys(this.pendingSubject).forEach(key => {
this.pendingSubject[key].next(this.constants[key]);
}); });
} }
@ -78,32 +61,19 @@ export class ConstantsService {
* @param key The constant to get. * @param key The constant to get.
*/ */
public get<T>(key: string): Observable<T> { public get<T>(key: string): Observable<T> {
if (this.constants) { if (!this.subjects[key]) {
return of(this.constants[key]); this.subjects[key] = new BehaviorSubject<any>(this.constants[key]);
} else {
// we have to request constants.
if (!this.pending) {
this.pending = true;
// if the connection is open, we directly can send the request.
if (this.websocketService.isConnected) {
this.websocketService.send('constants', {});
}
}
if (!this.pendingSubject[key]) {
this.pendingSubject[key] = new Subject<any>();
}
return this.pendingSubject[key].asObservable() as Observable<T>;
} }
return this.subjects[key].asObservable().pipe(filter(x => !!x));
} }
/** /**
* Refreshed the constants * Refreshed the constants
*/ */
public async refresh(): Promise<void> { public refresh(): Promise<void> {
if (!this.websocketService.isConnected) { if (!this.websocketService.isConnected) {
return; return;
} }
this.constants = await this.websocketService.sendAndGetResponse('constants', {}); this.websocketService.send('constants', {});
this.informSubjects();
} }
} }

View File

@ -1,12 +1,23 @@
import { Injectable } from '@angular/core'; import { Injectable } from '@angular/core';
import { take } from 'rxjs/operators';
import { AutoupdateService } from './autoupdate.service'; import { AutoupdateService } from './autoupdate.service';
import { ConstantsService } from './constants.service'; import { ConstantsService } from './constants.service';
import { StorageService } from './storage.service'; import { StorageService } from './storage.service';
const DB_SCHEMA_VERSION = 'DbSchemaVersion'; interface SchemaVersion {
db: string;
config: number;
migration: number;
}
function isSchemaVersion(obj: any): obj is SchemaVersion {
if (!obj || typeof obj !== 'object') {
return false;
}
return obj.db !== undefined && obj.config !== undefined && obj.migration !== undefined;
}
const SCHEMA_VERSION = 'SchemaVersion';
/** /**
* Manages upgrading the DataStore, if the migration version from the server is higher than the current one. * Manages upgrading the DataStore, if the migration version from the server is higher than the current one.
@ -25,24 +36,47 @@ export class DataStoreUpgradeService {
private constantsService: ConstantsService, private constantsService: ConstantsService,
private storageService: StorageService private storageService: StorageService
) { ) {
this.checkForUpgrade(); // Prevent the schema version to be cleard. This is important
// after a reset from OpenSlides, because the complete data is
// queried from the server and we do not want also to trigger a reload
// by changing the schema from null -> <schema>.
this.storageService.addNoClearKey(SCHEMA_VERSION);
this.constantsService
.get<SchemaVersion>(SCHEMA_VERSION)
.subscribe(serverVersion => this.checkForUpgrade(serverVersion));
} }
public async checkForUpgrade(): Promise<boolean> { public async checkForUpgrade(serverVersion: SchemaVersion): Promise<boolean> {
const version = await this.constantsService console.log('Server schema version:', serverVersion);
.get<string | number>(DB_SCHEMA_VERSION) const clientVersion = await this.storageService.get<SchemaVersion>(SCHEMA_VERSION);
.pipe(take(1)) await this.storageService.set(SCHEMA_VERSION, serverVersion);
.toPromise();
console.log('DB schema version:', version);
const currentVersion = await this.storageService.get<string>(DB_SCHEMA_VERSION);
await this.storageService.set(DB_SCHEMA_VERSION, version);
const doUpgrade = version !== currentVersion;
if (doUpgrade) { let doUpgrade = false;
console.log(`DB schema version changed from ${currentVersion} to ${version}`); if (isSchemaVersion(clientVersion)) {
await this.autoupdateService.doFullUpdate(); if (clientVersion.db !== serverVersion.db) {
console.log(`\tDB id changed from ${clientVersion.db} to ${serverVersion.db}`);
doUpgrade = true;
}
if (clientVersion.config !== serverVersion.config) {
console.log(`\tConfig changed from ${clientVersion.config} to ${serverVersion.config}`);
doUpgrade = true;
}
if (clientVersion.migration !== serverVersion.migration) {
console.log(`\tMigration changed from ${clientVersion.migration} to ${serverVersion.migration}`);
doUpgrade = true;
}
} else {
console.log('\tNo client schema version.');
doUpgrade = true;
} }
if (doUpgrade) {
console.log('\t-> In result of a schema version change: Do full update.');
await this.autoupdateService.doFullUpdate();
} else {
console.log('\t-> No upgrade needed.');
}
return doUpgrade; return doUpgrade;
} }
} }

View File

@ -3,7 +3,6 @@ import { Router } from '@angular/router';
import { AutoupdateService } from './autoupdate.service'; import { AutoupdateService } from './autoupdate.service';
import { ConstantsService } from './constants.service'; import { ConstantsService } from './constants.service';
import { DataStoreUpgradeService } from './data-store-upgrade.service';
import { DataStoreService } from './data-store.service'; import { DataStoreService } from './data-store.service';
import { OperatorService } from './operator.service'; import { OperatorService } from './operator.service';
import { StorageService } from './storage.service'; import { StorageService } from './storage.service';
@ -47,8 +46,7 @@ export class OpenSlidesService {
private router: Router, private router: Router,
private autoupdateService: AutoupdateService, private autoupdateService: AutoupdateService,
private DS: DataStoreService, private DS: DataStoreService,
private constantsService: ConstantsService, private constantsService: ConstantsService
private dataStoreUpgradeService: DataStoreUpgradeService
) { ) {
// Handler that gets called, if the websocket connection reconnects after a disconnection. // Handler that gets called, if the websocket connection reconnects after a disconnection.
// There might have changed something on the server, so we check the operator, if he changed. // There might have changed something on the server, so we check the operator, if he changed.
@ -162,6 +160,7 @@ export class OpenSlidesService {
const response = await this.operator.whoAmI(); const response = await this.operator.whoAmI();
// User logged off. // User logged off.
if (!response.user && !response.guest_enabled) { if (!response.user && !response.guest_enabled) {
this.websocketService.cancelReconnectenRetry();
await this.shutdown(); await this.shutdown();
this.redirectToLoginIfNotSubpage(); this.redirectToLoginIfNotSubpage();
} else { } else {
@ -174,24 +173,9 @@ export class OpenSlidesService {
await this.reboot(); await this.reboot();
} else if (requestChanges) { } else if (requestChanges) {
// User is still the same, but check for missed autoupdates. // User is still the same, but check for missed autoupdates.
await this.recoverAfterReconnect(); this.autoupdateService.requestChanges();
this.constantsService.refresh();
} }
} }
} }
/**
* The cache-refresh strategy, if there was an reconnect and the user didn't changed.
*/
private async recoverAfterReconnect(): Promise<void> {
// Reload constants to get either new one (in general) and especially
// the "DbSchemaVersion" one, to check, if the DB has changed (e.g. due
// to an update)
await this.constantsService.refresh();
// If the DB schema version didn't change, request normal changes.
// If so, then a full update is implicit triggered, so we do not need to to anything.
if (!(await this.dataStoreUpgradeService.checkForUpgrade())) {
this.autoupdateService.requestChanges();
}
}
} }

View File

@ -13,12 +13,18 @@ import { OpenSlidesStatusService } from './openslides-status.service';
providedIn: 'root' providedIn: 'root'
}) })
export class StorageService { export class StorageService {
private noClearKeys: string[] = [];
/** /**
* Constructor to create the StorageService. Needs the localStorage service. * Constructor to create the StorageService. Needs the localStorage service.
* @param localStorage * @param localStorage
*/ */
public constructor(private localStorage: LocalStorage, private OSStatus: OpenSlidesStatusService) {} public constructor(private localStorage: LocalStorage, private OSStatus: OpenSlidesStatusService) {}
public addNoClearKey(key: string): void {
this.noClearKeys.push(key);
}
/** /**
* Sets the item into the store asynchronously. * Sets the item into the store asynchronously.
* @param key * @param key
@ -57,13 +63,20 @@ export class StorageService {
} }
/** /**
* Clear the whole cache * Clear the whole cache except for keys given in `addNoClearKey`.
*/ */
public async clear(): Promise<void> { public async clear(): Promise<void> {
this.assertNotHistoryMode(); this.assertNotHistoryMode();
const savedData: { [key: string]: any } = {};
for (const key of this.noClearKeys) {
savedData[key] = await this.get(key);
}
if (!(await this.localStorage.clear().toPromise())) { if (!(await this.localStorage.clear().toPromise())) {
throw new Error('Could not clear the storage.'); throw new Error('Could not clear the storage.');
} }
for (const key of this.noClearKeys) {
await this.set(key, savedData[key]);
}
} }
/** /**

View File

@ -182,6 +182,11 @@ export class WebsocketService {
*/ */
private retryCounter = 0; private retryCounter = 0;
/**
* The timeout in the onClose-handler for the next reconnect retry.
*/
private retryTimeout: any = null;
/** /**
* Constructor that handles the router * Constructor that handles the router
* @param matSnackBar * @param matSnackBar
@ -385,12 +390,20 @@ export class WebsocketService {
// A random retry timeout between 2000 and 5000 ms. // A random retry timeout between 2000 and 5000 ms.
const timeout = Math.floor(Math.random() * 3000 + 2000); const timeout = Math.floor(Math.random() * 3000 + 2000);
setTimeout(() => { this.retryTimeout = setTimeout(() => {
this.retryTimeout = null;
this.connect({ enableAutoupdates: true }, true); this.connect({ enableAutoupdates: true }, true);
}, timeout); }, timeout);
} }
} }
public cancelReconnectenRetry(): void {
if (this.retryTimeout) {
clearTimeout(this.retryTimeout);
this.retryTimeout = null;
}
}
private dismissConnectionErrorNotice(): void { private dismissConnectionErrorNotice(): void {
if (this.connectionErrorNotice) { if (this.connectionErrorNotice) {
this.connectionErrorNotice.dismiss(); this.connectionErrorNotice.dismiss();

View File

@ -1,4 +1,3 @@
import hashlib
import logging import logging
import os import os
import sys import sys
@ -8,11 +7,9 @@ from typing import Any, Dict, List
from django.apps import AppConfig from django.apps import AppConfig
from django.conf import settings from django.conf import settings
from django.db.models import Max
from django.db.models.signals import post_migrate, pre_delete from django.db.models.signals import post_migrate, pre_delete
from openslides.utils.schema_version import schema_version_handler
logger = logging.getLogger("openslides.core")
class CoreAppConfig(AppConfig): class CoreAppConfig(AppConfig):
@ -179,18 +176,7 @@ class CoreAppConfig(AppConfig):
config_groups[-1]["subgroups"][-1]["items"].append(config_variable.data) config_groups[-1]["subgroups"][-1]["items"].append(config_variable.data)
constants["ConfigVariables"] = config_groups constants["ConfigVariables"] = config_groups
# get max migration id -> the "version" of the DB constants["SchemaVersion"] = schema_version_handler.get()
from django.db.migrations.recorder import MigrationRecorder
migration_version = MigrationRecorder.Migration.objects.aggregate(Max("id"))[
"id__max"
]
config_version = config["config_version"]
hash = hashlib.sha1(
f"{migration_version}#{config_version}".encode()
).hexdigest()
constants["DbSchemaVersion"] = hash
logger.info(f"DbSchemaVersion={hash}")
return constants return constants
@ -209,6 +195,7 @@ def manage_config(**kwargs):
altered = config.cleanup_old_config_values() or altered altered = config.cleanup_old_config_values() or altered
if altered: if altered:
config.increment_version() config.increment_version()
logging.getLogger(__name__).info("Updated config variables")
def startup(): def startup():
@ -224,6 +211,6 @@ def startup():
from openslides.utils.cache import element_cache from openslides.utils.cache import element_cache
from openslides.core.models import History from openslides.core.models import History
element_cache.ensure_cache() element_cache.ensure_schema_version()
set_constants(get_constants_from_apps()) set_constants(get_constants_from_apps())
History.objects.build_history() History.objects.build_history()

View File

@ -49,7 +49,7 @@ class ConfigHandler:
if not self.exists(key): if not self.exists(key):
raise ConfigNotFound(f"The config variable {key} was not found.") raise ConfigNotFound(f"The config variable {key} was not found.")
return async_to_sync(element_cache.get_element_full_data)( return async_to_sync(element_cache.get_element_data)(
self.get_collection_string(), self.get_key_to_id()[key] self.get_collection_string(), self.get_key_to_id()[key]
)["value"] )["value"]
@ -85,7 +85,7 @@ class ConfigHandler:
if self.key_to_id is not None: if self.key_to_id is not None:
return return
config_full_data = await element_cache.get_collection_full_data( config_full_data = await element_cache.get_collection_data(
self.get_collection_string() self.get_collection_string()
) )
elements = config_full_data.values() elements = config_full_data.values()

View File

@ -1,3 +1,5 @@
import uuid
from django.core.validators import MaxLengthValidator from django.core.validators import MaxLengthValidator
from openslides.core.config import ConfigVariable from openslides.core.config import ConfigVariable
@ -394,7 +396,7 @@ def get_config_variables():
group="Custom translations", group="Custom translations",
) )
# Config version # Config version and DB id
yield ConfigVariable( yield ConfigVariable(
name="config_version", name="config_version",
input_type="integer", input_type="integer",
@ -402,3 +404,10 @@ def get_config_variables():
group="Version", group="Version",
hidden=True, hidden=True,
) )
yield ConfigVariable(
name="db_id",
input_type="string",
default_value=uuid.uuid4().hex,
group="Version",
hidden=True,
)

View File

@ -287,7 +287,7 @@ class HistoryManager(models.Manager):
instances = None instances = None
if self.all().count() == 0: if self.all().count() == 0:
elements = [] elements = []
all_full_data = async_to_sync(element_cache.get_all_full_data)() all_full_data = async_to_sync(element_cache.get_all_data_list)()
for collection_string, data in all_full_data.items(): for collection_string, data in all_full_data.items():
for full_data in data: for full_data in data:
elements.append( elements.append(

View File

@ -608,7 +608,7 @@ class HistoryDataView(utils_views.APIView):
) )
missing_keys = all_current_config_keys - all_old_config_keys missing_keys = all_current_config_keys - all_old_config_keys
if missing_keys: if missing_keys:
config_full_data = async_to_sync(element_cache.get_collection_full_data)( config_full_data = async_to_sync(element_cache.get_collection_data)(
"core/config" "core/config"
) )
key_to_id = config.get_key_to_id() key_to_id = config.get_key_to_id()

View File

@ -5,11 +5,9 @@ from ..utils.constants import get_constants
from ..utils.projector import get_projector_data from ..utils.projector import get_projector_data
from ..utils.stats import WebsocketLatencyLogger from ..utils.stats import WebsocketLatencyLogger
from ..utils.websocket import ( from ..utils.websocket import (
WEBSOCKET_CHANGE_ID_TOO_HIGH,
WEBSOCKET_NOT_AUTHORIZED, WEBSOCKET_NOT_AUTHORIZED,
BaseWebsocketClientMessage, BaseWebsocketClientMessage,
ProtocollAsyncJsonWebsocketConsumer, ProtocollAsyncJsonWebsocketConsumer,
get_element_data,
) )
@ -116,18 +114,7 @@ class GetElementsWebsocketClientMessage(BaseWebsocketClientMessage):
self, consumer: "ProtocollAsyncJsonWebsocketConsumer", content: Any, id: str self, consumer: "ProtocollAsyncJsonWebsocketConsumer", content: Any, id: str
) -> None: ) -> None:
requested_change_id = content.get("change_id", 0) requested_change_id = content.get("change_id", 0)
try: await consumer.send_autoupdate(requested_change_id, in_response=id)
element_data = await get_element_data(
consumer.scope["user"]["id"], requested_change_id
)
except ValueError as error:
await consumer.send_error(
code=WEBSOCKET_CHANGE_ID_TOO_HIGH, message=str(error), in_response=id
)
else:
await consumer.send_json(
type="autoupdate", content=element_data, in_response=id
)
class AutoupdateWebsocketClientMessage(BaseWebsocketClientMessage): class AutoupdateWebsocketClientMessage(BaseWebsocketClientMessage):

View File

@ -130,8 +130,6 @@ class UserViewSet(ModelViewSet):
if key not in ("username", "about_me"): if key not in ("username", "about_me"):
del request.data[key] del request.data[key]
response = super().update(request, *args, **kwargs) response = super().update(request, *args, **kwargs)
# Maybe some group assignments have changed. Better delete the restricted user cache
async_to_sync(element_cache.del_user)(user.pk)
return response return response
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
@ -275,8 +273,6 @@ class UserViewSet(ModelViewSet):
user.groups.add(*groups) user.groups.add(*groups)
else: else:
user.groups.remove(*groups) user.groups.remove(*groups)
# Maybe some group assignments have changed. Better delete the restricted user cache
async_to_sync(element_cache.del_user)(user.pk)
inform_changed_data(users) inform_changed_data(users)
return Response() return Response()
@ -570,13 +566,9 @@ class GroupViewSet(ModelViewSet):
if not changed_permissions: if not changed_permissions:
return # either None or empty list. return # either None or empty list.
# Delete the user chaches of all affected users
for user in group.user_set.all():
async_to_sync(element_cache.del_user)(user.pk)
elements: List[Element] = [] elements: List[Element] = []
signal_results = permission_change.send(None, permissions=changed_permissions) signal_results = permission_change.send(None, permissions=changed_permissions)
all_full_data = async_to_sync(element_cache.get_all_full_data)() all_full_data = async_to_sync(element_cache.get_all_data_list)()
for _, signal_collections in signal_results: for _, signal_collections in signal_results:
for cachable in signal_collections: for cachable in signal_collections:
for full_data in all_full_data.get( for full_data in all_full_data.get(
@ -672,8 +664,8 @@ class WhoAmIDataView(APIView):
guest_enabled = anonymous_is_enabled() guest_enabled = anonymous_is_enabled()
if user_id: if user_id:
user_data = async_to_sync(element_cache.get_element_restricted_data)( user_data = async_to_sync(element_cache.get_element_data)(
user_id, self.request.user.get_collection_string(), user_id self.request.user.get_collection_string(), user_id, user_id
) )
group_ids = user_data["groups_id"] or [GROUP_DEFAULT_PK] group_ids = user_data["groups_id"] or [GROUP_DEFAULT_PK]
else: else:
@ -682,9 +674,7 @@ class WhoAmIDataView(APIView):
# collect all permissions # collect all permissions
permissions: Set[str] = set() permissions: Set[str] = set()
group_all_data = async_to_sync(element_cache.get_collection_full_data)( group_all_data = async_to_sync(element_cache.get_collection_data)("users/group")
"users/group"
)
for group_id in group_ids: for group_id in group_ids:
permissions.update(group_all_data[group_id]["permissions"]) permissions.update(group_all_data[group_id]["permissions"])

View File

@ -86,16 +86,14 @@ class RequiredUsers:
user_ids: Set[int] = set() user_ids: Set[int] = set()
for collection_string in collection_strings: for collection_string in collection_strings:
collection_full_data = await element_cache.get_collection_full_data( collection_data = await element_cache.get_collection_data(collection_string)
collection_string
)
# Get the callable for the collection_string # Get the callable for the collection_string
get_user_ids = self.callables.get(collection_string) get_user_ids = self.callables.get(collection_string)
if not (get_user_ids and collection_full_data): if not (get_user_ids and collection_data):
# if the collection_string is unknown or it has no data, do nothing # if the collection_string is unknown or it has no data, do nothing
continue continue
for element in collection_full_data.values(): for element in collection_data.values():
user_ids.update(get_user_ids(element)) user_ids.update(get_user_ids(element))
return user_ids return user_ids

View File

@ -67,14 +67,14 @@ async def async_has_perm(user_id: int, perm: str) -> bool:
has_perm = False has_perm = False
elif not user_id: elif not user_id:
# Use the permissions from the default group. # Use the permissions from the default group.
default_group = await element_cache.get_element_full_data( default_group = await element_cache.get_element_data(
group_collection_string, GROUP_DEFAULT_PK group_collection_string, GROUP_DEFAULT_PK
) )
if default_group is None: if default_group is None:
raise RuntimeError("Default Group does not exist.") raise RuntimeError("Default Group does not exist.")
has_perm = perm in default_group["permissions"] has_perm = perm in default_group["permissions"]
else: else:
user_data = await element_cache.get_element_full_data( user_data = await element_cache.get_element_data(
user_collection_string, user_id user_collection_string, user_id
) )
if user_data is None: if user_data is None:
@ -87,7 +87,7 @@ async def async_has_perm(user_id: int, perm: str) -> bool:
# permission. If the user has no groups, then use the default group. # permission. If the user has no groups, then use the default group.
group_ids = user_data["groups_id"] or [GROUP_DEFAULT_PK] group_ids = user_data["groups_id"] or [GROUP_DEFAULT_PK]
for group_id in group_ids: for group_id in group_ids:
group = await element_cache.get_element_full_data( group = await element_cache.get_element_data(
group_collection_string, group_id group_collection_string, group_id
) )
if group is None: if group is None:
@ -131,7 +131,7 @@ async def async_in_some_groups(user_id: int, groups: List[int]) -> bool:
# Use the permissions from the default group. # Use the permissions from the default group.
in_some_groups = GROUP_DEFAULT_PK in groups in_some_groups = GROUP_DEFAULT_PK in groups
else: else:
user_data = await element_cache.get_element_full_data( user_data = await element_cache.get_element_data(
user_collection_string, user_id user_collection_string, user_id
) )
if user_data is None: if user_data is None:
@ -167,7 +167,7 @@ async def async_anonymous_is_enabled() -> bool:
""" """
from ..core.config import config from ..core.config import config
element = await element_cache.get_element_full_data( element = await element_cache.get_element_data(
config.get_collection_string(), config.get_collection_string(),
(await config.async_get_key_to_id())["general_system_enable_anonymous"], (await config.async_get_key_to_id())["general_system_enable_anonymous"],
) )

View File

@ -1,4 +1,3 @@
import asyncio
import json import json
import logging import logging
from collections import defaultdict from collections import defaultdict
@ -7,42 +6,54 @@ from time import sleep
from typing import Any, Callable, Dict, List, Optional, Tuple, Type from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from django.conf import settings from django.apps import apps
from .cache_providers import ( from .cache_providers import (
Cachable, Cachable,
ElementCacheProvider, ElementCacheProvider,
MemmoryCacheProvider, MemmoryCacheProvider,
RedisCacheProvider, RedisCacheProvider,
get_all_cachables,
) )
from .redis import use_redis from .redis import use_redis
from .schema_version import SchemaVersion, schema_version_handler
from .utils import get_element_id, split_element_id from .utils import get_element_id, split_element_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_all_cachables() -> List[Cachable]:
"""
Returns all element of OpenSlides.
"""
out: List[Cachable] = []
for app in apps.get_app_configs():
try:
# Get the method get_startup_elements() from an app.
# This method has to return an iterable of Cachable objects.
get_startup_elements = app.get_startup_elements
except AttributeError:
# Skip apps that do not implement get_startup_elements.
continue
out.extend(get_startup_elements())
return out
class ElementCache: class ElementCache:
""" """
Cache for the elements. Cache for the elements.
Saves the full_data and if enabled the restricted data. Saves the full_data
There is one redis Hash (simular to python dict) for the full_data and one There is one redis Hash (simular to python dict) for the full_data
Hash for every user.
The key of the Hashes is COLLECTIONSTRING:ID where COLLECTIONSTRING is the The key of the Hashes is COLLECTIONSTRING:ID where COLLECTIONSTRING is the
collection_string of a collection and id the id of an element. collection_string of a collection and id the id of an element.
All elements have to be in the cache. If one element is missing, the cache
is invalid, but this can not be detected. When a plugin with a new
collection is added to OpenSlides, then the cache has to be rebuild manualy.
There is an sorted set in redis with the change id as score. The values are There is an sorted set in redis with the change id as score. The values are
COLLETIONSTRING:ID for the elements that have been changed with that change COLLETIONSTRING:ID for the elements that have been changed with that change
id. With this key it is possible, to get all elements as full_data or as id. With this key it is possible, to get all elements as full_data
restricted_data that are newer then a specific change id. that are newer then a specific change id.
All method of this class are async. You either have to call them with All method of this class are async. You either have to call them with
await in an async environment or use asgiref.sync.async_to_sync(). await in an async environment or use asgiref.sync.async_to_sync().
@ -50,36 +61,34 @@ class ElementCache:
def __init__( def __init__(
self, self,
use_restricted_data_cache: bool = False,
cache_provider_class: Type[ElementCacheProvider] = RedisCacheProvider, cache_provider_class: Type[ElementCacheProvider] = RedisCacheProvider,
cachable_provider: Callable[[], List[Cachable]] = get_all_cachables, cachable_provider: Callable[[], List[Cachable]] = get_all_cachables,
start_time: int = None, default_change_id: Optional[int] = None,
) -> None: ) -> None:
""" """
Initializes the cache. Initializes the cache.
When restricted_data_cache is false, no restricted data is saved.
""" """
self.use_restricted_data_cache = use_restricted_data_cache self.cache_provider = cache_provider_class(self.async_ensure_cache)
self.cache_provider = cache_provider_class()
self.cachable_provider = cachable_provider self.cachable_provider = cachable_provider
self._cachables: Optional[Dict[str, Cachable]] = None self._cachables: Optional[Dict[str, Cachable]] = None
self.set_default_change_id(default_change_id)
# Start time is used as first change_id if there is non in redis def set_default_change_id(self, default_change_id: Optional[int] = None) -> None:
if start_time is None: """
Sets the default change id for the cache. Needs to update, if the cache gets generated.
"""
# The current time is used as the first change_id if there is non in redis
if default_change_id is None:
# Use the miliseconds (rounted) since the 2016-02-29. # Use the miliseconds (rounted) since the 2016-02-29.
start_time = ( default_change_id = (
int((datetime.utcnow() - datetime(2016, 2, 29)).total_seconds()) * 1000 int((datetime.utcnow() - datetime(2016, 2, 29)).total_seconds()) * 1000
) )
self.start_time = start_time self.default_change_id = default_change_id
# Tells if self.ensure_cache was called.
self.ensured = False
@property @property
def cachables(self) -> Dict[str, Cachable]: def cachables(self) -> Dict[str, Cachable]:
""" """
Returns all Cachables as a dict where the key is the collection_string of the cachable. Returns all cachables as a dict where the key is the collection_string of the cachable.
""" """
# This method is neccessary to lazy load the cachables # This method is neccessary to lazy load the cachables
if self._cachables is None: if self._cachables is None:
@ -89,45 +98,71 @@ class ElementCache:
} }
return self._cachables return self._cachables
def ensure_cache(self, reset: bool = False) -> None: def ensure_cache(
self, reset: bool = False, default_change_id: Optional[int] = None
) -> None:
""" """
Makes sure that the cache exist. Ensures the existance of the cache; see async_ensure_cache for more info.
Builds the cache if not. If reset is True, it will be reset in any case.
This method is sync, so it can be run when OpenSlides starts.
""" """
cache_exists = async_to_sync(self.cache_provider.data_exists)() async_to_sync(self.async_ensure_cache)(reset, default_change_id)
async def async_ensure_cache(
self, reset: bool = False, default_change_id: Optional[int] = None
) -> None:
"""
Makes sure that the cache exist. Builds the cache if not or reset is given as True.
"""
cache_exists = await self.cache_provider.data_exists()
if reset or not cache_exists: if reset or not cache_exists:
lock_name = "ensure_cache" await self.build_cache(default_change_id)
# Set a lock so only one process builds the cache
if async_to_sync(self.cache_provider.set_lock)(lock_name):
logger.info("Building up the cache data...")
try:
mapping = {}
for collection_string, cachable in self.cachables.items():
for element in cachable.get_elements():
mapping.update(
{
get_element_id(
collection_string, element["id"]
): json.dumps(element)
}
)
logger.info("Done building the cache data.")
logger.info("Saving cache data into the cache...")
async_to_sync(self.cache_provider.reset_full_cache)(mapping)
logger.info("Done saving the cache data.")
finally:
async_to_sync(self.cache_provider.del_lock)(lock_name)
else:
logger.info("Wait for another process to build up the cache...")
while async_to_sync(self.cache_provider.get_lock)(lock_name):
sleep(0.01)
logger.info("Cache is ready (built by another process).")
self.ensured = True def ensure_schema_version(self) -> None:
async_to_sync(self.async_ensure_schema_version)()
async def async_ensure_schema_version(self) -> None:
cache_schema_version = await self.cache_provider.get_schema_version()
schema_changed = not schema_version_handler.compare(cache_schema_version)
schema_version_handler.log_current()
cache_exists = await self.cache_provider.data_exists()
if schema_changed or not cache_exists:
await self.build_cache(schema_version=schema_version_handler.get())
async def build_cache(
self,
default_change_id: Optional[int] = None,
schema_version: Optional[SchemaVersion] = None,
) -> None:
lock_name = "build_cache"
# Set a lock so only one process builds the cache
if await self.cache_provider.set_lock(lock_name):
logger.info("Building up the cache data...")
try:
mapping = {}
for collection_string, cachable in self.cachables.items():
for element in cachable.get_elements():
mapping.update(
{
get_element_id(
collection_string, element["id"]
): json.dumps(element)
}
)
logger.info("Done building the cache data.")
logger.info("Saving cache data into the cache...")
self.set_default_change_id(default_change_id=default_change_id)
await self.cache_provider.reset_full_cache(mapping)
if schema_version:
await self.cache_provider.set_schema_version(schema_version)
logger.info("Done saving the cache data.")
finally:
await self.cache_provider.del_lock(lock_name)
else:
logger.info("Wait for another process to build up the cache...")
while await self.cache_provider.get_lock(lock_name):
sleep(0.01)
logger.info("Cache is ready (built by another process).")
async def change_elements( async def change_elements(
self, elements: Dict[str, Optional[Dict[str, Any]]] self, elements: Dict[str, Optional[Dict[str, Any]]]
@ -135,16 +170,12 @@ class ElementCache:
""" """
Changes elements in the cache. Changes elements in the cache.
elements is a list of the changed elements as dict. When the value is None, elements is a dict with element_id <-> changed element. When the value is None,
it is interpreded as deleted. The key has to be an element_id. it is interpreded as deleted.
Returns the new generated change_id. Returns the new generated change_id.
""" """
if not self.ensured: # Split elements into changed and deleted.
raise RuntimeError(
"Call element_cache.ensure_cache before changing elements."
)
deleted_elements = [] deleted_elements = []
changed_elements = [] changed_elements = []
for element_id, data in elements.items(): for element_id, data in elements.items():
@ -155,47 +186,90 @@ class ElementCache:
else: else:
deleted_elements.append(element_id) deleted_elements.append(element_id)
if changed_elements:
await self.cache_provider.add_elements(changed_elements)
if deleted_elements:
await self.cache_provider.del_elements(deleted_elements)
return await self.cache_provider.add_changed_elements( return await self.cache_provider.add_changed_elements(
self.start_time + 1, elements.keys() changed_elements, deleted_elements, self.default_change_id + 1
) )
async def get_all_full_data(self) -> Dict[str, List[Dict[str, Any]]]: async def get_all_data_list(
self, user_id: Optional[int] = None
) -> Dict[str, List[Dict[str, Any]]]:
""" """
Returns all full_data. Returns all data with a list per collection:
{
<collection>: [<element>, <element>, ...]
}
If the user id is given the data will be restricted for this user.
"""
all_data: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
for element_id, data in (await self.cache_provider.get_all_data()).items():
collection_string, _ = split_element_id(element_id)
all_data[collection_string].append(json.loads(data.decode()))
The returned value is a dict where the key is the collection_string and if user_id is not None:
the value is a list of data. for collection_string in all_data.keys():
""" restricter = self.cachables[collection_string].restrict_elements
all_data = await self.get_all_full_data_ordered() all_data[collection_string] = await restricter(
out: Dict[str, List[Dict[str, Any]]] = defaultdict(list) user_id, all_data[collection_string]
for collection_string, collection_data in all_data.items(): )
for data in collection_data.values(): return dict(all_data)
out[collection_string].append(data)
return dict(out)
async def get_all_full_data_ordered(self) -> Dict[str, Dict[int, Dict[str, Any]]]: async def get_all_data_dict(self) -> Dict[str, Dict[int, Dict[str, Any]]]:
""" """
Like get_all_full_data but orders the element of one collection by there Returns all data with a dict (id <-> element) per collection:
id. {
<collection>: {
<id>: <element>
}
}
""" """
out: Dict[str, Dict[int, Dict[str, Any]]] = defaultdict(dict) all_data: Dict[str, Dict[int, Dict[str, Any]]] = defaultdict(dict)
full_data = await self.cache_provider.get_all_data() for element_id, data in (await self.cache_provider.get_all_data()).items():
for element_id, data in full_data.items():
collection_string, id = split_element_id(element_id) collection_string, id = split_element_id(element_id)
out[collection_string][id] = json.loads(data.decode()) all_data[collection_string][id] = json.loads(data.decode())
return dict(out) return dict(all_data)
async def get_full_data( async def get_collection_data(
self, change_id: int = 0, max_change_id: int = -1 self, collection_string: str
) -> Dict[int, Dict[str, Any]]:
"""
Returns the data for one collection as dict: {id: <element>}
"""
encoded_collection_data = await self.cache_provider.get_collection_data(
collection_string
)
collection_data = {}
for id in encoded_collection_data.keys():
collection_data[id] = json.loads(encoded_collection_data[id].decode())
return collection_data
async def get_element_data(
self, collection_string: str, id: int, user_id: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""
Returns one element or None, if the element does not exist.
If the user id is given the data will be restricted for this user.
"""
encoded_element = await self.cache_provider.get_element_data(
get_element_id(collection_string, id)
)
if encoded_element is None:
return None
element = json.loads(encoded_element.decode()) # type: ignore
if user_id is not None:
restricter = self.cachables[collection_string].restrict_elements
restricted_elements = await restricter(user_id, [element])
element = restricted_elements[0] if restricted_elements else None
return element
async def get_data_since(
self, user_id: Optional[int] = None, change_id: int = 0, max_change_id: int = -1
) -> Tuple[Dict[str, List[Dict[str, Any]]], List[str]]: ) -> Tuple[Dict[str, List[Dict[str, Any]]], List[str]]:
""" """
Returns all full_data since change_id until max_change_id (including). Returns all data since change_id until max_change_id (included).
max_change_id -1 means the highest change_id. max_change_id -1 means the highest change_id. If the user id is given the
data will be restricted for this user.
Returns two values inside a tuple. The first value is a dict where the Returns two values inside a tuple. The first value is a dict where the
key is the collection_string and the value is a list of data. The second key is the collection_string and the value is a list of data. The second
@ -210,7 +284,7 @@ class ElementCache:
that the cache does not know about. that the cache does not know about.
""" """
if change_id == 0: if change_id == 0:
return (await self.get_all_full_data(), []) return (await self.get_all_data_list(user_id), [])
# This raises a Runtime Exception, if there is no change_id # This raises a Runtime Exception, if there is no change_id
lowest_change_id = await self.get_lowest_change_id() lowest_change_id = await self.get_lowest_change_id()
@ -226,245 +300,39 @@ class ElementCache:
raw_changed_elements, deleted_elements = await self.cache_provider.get_data_since( raw_changed_elements, deleted_elements = await self.cache_provider.get_data_since(
change_id, max_change_id=max_change_id change_id, max_change_id=max_change_id
) )
return ( changed_elements = {
{ collection_string: [json.loads(value.decode()) for value in value_list]
collection_string: [json.loads(value.decode()) for value in value_list] for collection_string, value_list in raw_changed_elements.items()
for collection_string, value_list in raw_changed_elements.items() }
},
deleted_elements,
)
async def get_collection_full_data( if user_id is not None:
self, collection_string: str for collection_string, elements in changed_elements.items():
) -> Dict[int, Dict[str, Any]]:
full_data = await self.cache_provider.get_collection_data(collection_string)
out = {}
for element_id, data in full_data.items():
returned_collection_string, id = split_element_id(element_id)
if returned_collection_string == collection_string:
out[id] = json.loads(data.decode())
return out
async def get_element_full_data(
self, collection_string: str, id: int
) -> Optional[Dict[str, Any]]:
"""
Returns one element as full data.
Returns None if the element does not exist.
"""
element = await self.cache_provider.get_element(
get_element_id(collection_string, id)
)
if element is None:
return None
return json.loads(element.decode())
async def exists_restricted_data(self, user_id: int) -> bool:
"""
Returns True, if the restricted_data exists for the user.
"""
if not self.use_restricted_data_cache:
return False
return await self.cache_provider.data_exists(user_id)
async def del_user(self, user_id: int) -> None:
"""
Removes one user from the resticted_data_cache.
"""
await self.cache_provider.del_restricted_data(user_id)
async def update_restricted_data(self, user_id: int) -> None:
"""
Updates the restricted data for an user from the full_data_cache.
"""
# TODO: When elements are changed at the same time then this method run
# this could make the cache invalid.
# This could be fixed when get_full_data would be used with a
# max change_id.
if not self.use_restricted_data_cache:
# If the restricted_data_cache is not used, there is nothing to do
return
if not self.ensured:
raise RuntimeError(
"Call element_cache.ensure_cache before updating restricted data."
)
# Try to write a special key.
# If this succeeds, there is noone else currently updating the cache.
# TODO: Make a timeout. Else this could block forever
lock_name = f"restricted_data_{user_id}"
if await self.cache_provider.set_lock(lock_name):
# Get change_id for this user
value = await self.cache_provider.get_change_id_user(user_id)
# If the change id is not in the cache yet, use -1 to get all data since 0
user_change_id = int(value) if value else -1
change_id = await self.get_current_change_id()
if change_id > user_change_id:
try:
full_data_elements, deleted_elements = await self.get_full_data(
user_change_id + 1
)
except RuntimeError:
# The user_change_id is lower then the lowest change_id in the cache.
# The whole restricted_data for that user has to be recreated.
full_data_elements = await self.get_all_full_data()
deleted_elements = []
await self.cache_provider.del_restricted_data(user_id)
mapping = {}
for collection_string, full_data in full_data_elements.items():
restricter = self.cachables[collection_string].restrict_elements
restricted_elements = await restricter(user_id, full_data)
# find all elements the user can not see at all
full_data_ids = set(element["id"] for element in full_data)
restricted_data_ids = set(
element["id"] for element in restricted_elements
)
for item_id in full_data_ids - restricted_data_ids:
deleted_elements.append(
get_element_id(collection_string, item_id)
)
for element in restricted_elements:
# The user can see the element
mapping.update(
{
get_element_id(
collection_string, element["id"]
): json.dumps(element)
}
)
mapping["_config:change_id"] = str(change_id)
await self.cache_provider.update_restricted_data(user_id, mapping)
# Remove deleted elements
if deleted_elements:
await self.cache_provider.del_elements(deleted_elements, user_id)
# Unset the lock
await self.cache_provider.del_lock(lock_name)
else:
# Wait until the update if finshed
while await self.cache_provider.get_lock(lock_name):
await asyncio.sleep(0.01)
async def get_all_restricted_data(
self, user_id: int
) -> Dict[str, List[Dict[str, Any]]]:
"""
Like get_all_full_data but with restricted_data for an user.
"""
if not self.use_restricted_data_cache:
all_restricted_data = {}
for collection_string, full_data in (
await self.get_all_full_data()
).items():
restricter = self.cachables[collection_string].restrict_elements restricter = self.cachables[collection_string].restrict_elements
elements = await restricter(user_id, full_data) restricted_elements = await restricter(user_id, elements)
all_restricted_data[collection_string] = elements
return all_restricted_data
await self.update_restricted_data(user_id)
out: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
restricted_data = await self.cache_provider.get_all_data(user_id)
for element_id, data in restricted_data.items():
if element_id.decode().startswith("_config"):
continue
collection_string, __ = split_element_id(element_id)
out[collection_string].append(json.loads(data.decode()))
return dict(out)
async def get_restricted_data(
self, user_id: int, change_id: int = 0, max_change_id: int = -1
) -> Tuple[Dict[str, List[Dict[str, Any]]], List[str]]:
"""
Like get_full_data but with restricted_data for an user.
"""
if change_id == 0:
# Return all data
return (await self.get_all_restricted_data(user_id), [])
if not self.use_restricted_data_cache:
changed_elements, deleted_elements = await self.get_full_data(
change_id, max_change_id
)
restricted_data = {}
for collection_string, full_data in changed_elements.items():
restricter = self.cachables[collection_string].restrict_elements
elements = await restricter(user_id, full_data)
# Add removed objects (through restricter) to deleted elements. # Add removed objects (through restricter) to deleted elements.
full_data_ids = set([data["id"] for data in full_data]) element_ids = set([element["id"] for element in elements])
restricted_data_ids = set([data["id"] for data in elements]) restricted_element_ids = set(
for id in full_data_ids - restricted_data_ids: [element["id"] for element in restricted_elements]
)
for id in element_ids - restricted_element_ids:
deleted_elements.append(get_element_id(collection_string, id)) deleted_elements.append(get_element_id(collection_string, id))
if elements: if not restricted_elements:
restricted_data[collection_string] = elements del changed_elements[collection_string]
return restricted_data, deleted_elements else:
changed_elements[collection_string] = restricted_elements
lowest_change_id = await self.get_lowest_change_id() return (changed_elements, deleted_elements)
if change_id < lowest_change_id:
# When change_id is lower then the lowest change_id in redis, we can
# not inform the user about deleted elements.
raise RuntimeError(
f"change_id {change_id} is lower then the lowest change_id in redis {lowest_change_id}. "
"Catch this exception and rerun the method with change_id=0."
)
# If another coroutine or another daphne server also updates the restricted
# data, this waits until it is done.
await self.update_restricted_data(user_id)
raw_changed_elements, deleted_elements = await self.cache_provider.get_data_since(
change_id, user_id, max_change_id
)
return (
{
collection_string: [json.loads(value.decode()) for value in value_list]
for collection_string, value_list in raw_changed_elements.items()
},
deleted_elements,
)
async def get_element_restricted_data(
self, user_id: int, collection_string: str, id: int
) -> Optional[Dict[str, Any]]:
"""
Returns the restricted_data of one element.
Returns None, if the element does not exists or the user has no permission to see it.
"""
if not self.use_restricted_data_cache:
full_data = await self.get_element_full_data(collection_string, id)
if full_data is None:
return None
restricter = self.cachables[collection_string].restrict_elements
restricted_data = await restricter(user_id, [full_data])
return restricted_data[0] if restricted_data else None
await self.update_restricted_data(user_id)
out = await self.cache_provider.get_element(
get_element_id(collection_string, id), user_id
)
return json.loads(out.decode()) if out else None
async def get_current_change_id(self) -> int: async def get_current_change_id(self) -> int:
""" """
Returns the current change id. Returns the current change id.
Returns start_time if there is no change id yet. Returns default_change_id if there is no change id yet.
""" """
value = await self.cache_provider.get_current_change_id() value = await self.cache_provider.get_current_change_id()
if not value: return value if value is not None else self.default_change_id
return self.start_time
# Return the score (second element) of the first (and only) element
return value[0][1]
async def get_lowest_change_id(self) -> int: async def get_lowest_change_id(self) -> int:
""" """
@ -479,7 +347,7 @@ class ElementCache:
return value return value
def load_element_cache(restricted_data: bool = True) -> ElementCache: def load_element_cache() -> ElementCache:
""" """
Generates an element cache instance. Generates an element cache instance.
""" """
@ -488,12 +356,8 @@ def load_element_cache(restricted_data: bool = True) -> ElementCache:
else: else:
cache_provider_class = MemmoryCacheProvider cache_provider_class = MemmoryCacheProvider
return ElementCache( return ElementCache(cache_provider_class=cache_provider_class)
cache_provider_class=cache_provider_class,
use_restricted_data_cache=restricted_data,
)
# Set the element_cache # Set the element_cache
use_restricted_data = getattr(settings, "RESTRICTED_DATA_CACHE", True) element_cache = load_element_cache()
element_cache = load_element_cache(restricted_data=use_restricted_data)

View File

@ -1,17 +1,28 @@
import functools
import hashlib
import logging
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from textwrap import dedent
from typing import Any, Callable, Coroutine, Dict, List, Optional, Set, Tuple
from django.apps import apps from django.core.exceptions import ImproperlyConfigured
from typing_extensions import Protocol from typing_extensions import Protocol
from .redis import use_redis from .redis import use_redis
from .schema_version import SchemaVersion
from .utils import split_element_id, str_dict_to_bytes from .utils import split_element_id, str_dict_to_bytes
logger = logging.getLogger(__name__)
if use_redis: if use_redis:
from .redis import get_connection, aioredis from .redis import get_connection, aioredis
class CacheReset(Exception):
pass
class ElementCacheProvider(Protocol): class ElementCacheProvider(Protocol):
""" """
Base class for cache provider. Base class for cache provider.
@ -19,49 +30,43 @@ class ElementCacheProvider(Protocol):
See RedisCacheProvider as reverence implementation. See RedisCacheProvider as reverence implementation.
""" """
def __init__(self, ensure_cache: Callable[[], Coroutine[Any, Any, None]]) -> None:
...
async def ensure_cache(self) -> None:
...
async def clear_cache(self) -> None: async def clear_cache(self) -> None:
... ...
async def reset_full_cache(self, data: Dict[str, str]) -> None: async def reset_full_cache(self, data: Dict[str, str]) -> None:
... ...
async def data_exists(self, user_id: Optional[int] = None) -> bool: async def data_exists(self) -> bool:
... ...
async def add_elements(self, elements: List[str]) -> None: async def get_all_data(self) -> Dict[bytes, bytes]:
... ...
async def del_elements( async def get_collection_data(self, collection: str) -> Dict[int, bytes]:
self, elements: List[str], user_id: Optional[int] = None ...
) -> None:
async def get_element_data(self, element_id: str) -> Optional[bytes]:
... ...
async def add_changed_elements( async def add_changed_elements(
self, default_change_id: int, element_ids: Iterable[str] self,
changed_elements: List[str],
deleted_element_ids: List[str],
default_change_id: int,
) -> int: ) -> int:
... ...
async def get_all_data(self, user_id: Optional[int] = None) -> Dict[bytes, bytes]:
...
async def get_collection_data(
self, collection: str, user_id: Optional[int] = None
) -> Dict[bytes, bytes]:
...
async def get_data_since( async def get_data_since(
self, change_id: int, user_id: Optional[int] = None, max_change_id: int = -1 self, change_id: int, max_change_id: int = -1
) -> Tuple[Dict[str, List[bytes]], List[str]]: ) -> Tuple[Dict[str, List[bytes]], List[str]]:
... ...
async def get_element(
self, element_id: str, user_id: Optional[int] = None
) -> Optional[bytes]:
...
async def del_restricted_data(self, user_id: int) -> None:
...
async def set_lock(self, lock_name: str) -> bool: async def set_lock(self, lock_name: str) -> bool:
... ...
@ -71,18 +76,48 @@ class ElementCacheProvider(Protocol):
async def del_lock(self, lock_name: str) -> None: async def del_lock(self, lock_name: str) -> None:
... ...
async def get_change_id_user(self, user_id: int) -> Optional[int]: async def get_current_change_id(self) -> Optional[int]:
...
async def update_restricted_data(self, user_id: int, data: Dict[str, str]) -> None:
...
async def get_current_change_id(self) -> List[Tuple[str, int]]:
... ...
async def get_lowest_change_id(self) -> Optional[int]: async def get_lowest_change_id(self) -> Optional[int]:
... ...
async def get_schema_version(self) -> Optional[SchemaVersion]:
...
async def set_schema_version(self, schema_version: SchemaVersion) -> None:
...
def ensure_cache_wrapper() -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Wraps a cache function to ensure, that the cache is filled.
When the function raises a CacheReset-Error the cache will be ensured (call
to `ensure_cache`) and the method will be recalled. This is done, until the
operation was successful.
"""
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(func)
async def wrapped(
cache_provider: ElementCacheProvider, *args: Any, **kwargs: Any
) -> Any:
success = False
while not success:
try:
result = await func(cache_provider, *args, **kwargs)
success = True
except CacheReset:
logger.warn(
f"Redis was flushed before method '{func.__name__}'. Ensures cache now."
)
await cache_provider.ensure_cache()
return result
return wrapped
return wrapper
class RedisCacheProvider: class RedisCacheProvider:
""" """
@ -90,204 +125,251 @@ class RedisCacheProvider:
""" """
full_data_cache_key: str = "full_data" full_data_cache_key: str = "full_data"
restricted_user_cache_key: str = "restricted_data:{user_id}"
change_id_cache_key: str = "change_id" change_id_cache_key: str = "change_id"
schema_cache_key: str = "schema"
prefix: str = "element_cache_" prefix: str = "element_cache_"
# All lua-scripts used by this provider. Every entry is a Tuple (str, bool) with the
# script and an ensure_cache-indicator. If the indicator is True, a short ensure_cache-script
# will be prepended to the script which raises a CacheReset, if the full data cache is empty.
# This requires the full_data_cache_key to be the first key given in `keys`!
# All scripts are dedented and hashed for faster execution. Convention: The keys of this
# member are the methods that needs these scripts.
scripts = {
"clear_cache": (
"return redis.call('del', 'fake_key', unpack(redis.call('keys', ARGV[1])))",
False,
),
"get_all_data": ("return redis.call('hgetall', KEYS[1])", True),
"get_collection_data": (
"""
local cursor = 0
local collection = {}
repeat
local result = redis.call('HSCAN', KEYS[1], cursor, 'MATCH', ARGV[1])
cursor = tonumber(result[1])
for _, v in pairs(result[2]) do
table.insert(collection, v)
end
until cursor == 0
return collection
""",
True,
),
"get_element_data": ("return redis.call('hget', KEYS[1], ARGV[1])", True),
"add_changed_elements": (
"""
-- Generate a new change_id
local tmp = redis.call('zrevrangebyscore', KEYS[2], '+inf', '-inf', 'WITHSCORES', 'LIMIT', 0, 1)
local change_id
if next(tmp) == nil then
-- The key does not exist
change_id = ARGV[1]
else
change_id = tmp[2] + 1
end
local nc = tonumber(ARGV[2])
local nd = tonumber(ARGV[3])
local i, max
-- Add changed_elements to the cache and sorted set (the first of the pairs)
if (nc > 0) then
max = 2 + nc
redis.call('hmset', KEYS[1], unpack(ARGV, 4, max + 1))
for i = 4, max, 2 do
redis.call('zadd', KEYS[2], change_id, ARGV[i])
end
end
-- Delete deleted_element_ids and add them to sorted set
if (nd > 0) then
max = 3 + nc + nd
redis.call('hdel', KEYS[1], unpack(ARGV, 4 + nc, max))
for i = 4 + nc, max, 2 do
redis.call('zadd', KEYS[2], change_id, ARGV[i])
end
end
-- Set lowest_change_id if it does not exist
redis.call('zadd', KEYS[2], 'NX', change_id, '_config:lowest_change_id')
return change_id
""",
True,
),
"get_data_since": (
"""
-- Get change ids of changed elements
local element_ids = redis.call('zrangebyscore', KEYS[2], ARGV[1], ARGV[2])
-- Save elements in array. Rotate element_id and element_json
local elements = {}
for _, element_id in pairs(element_ids) do
table.insert(elements, element_id)
table.insert(elements, redis.call('hget', KEYS[1], element_id))
end
return elements
""",
True,
),
}
def __init__(self, ensure_cache: Callable[[], Coroutine[Any, Any, None]]) -> None:
self._ensure_cache = ensure_cache
# hash all scripts and remove indentation.
for key in self.scripts.keys():
script, add_ensure_cache = self.scripts[key]
script = dedent(script)
if add_ensure_cache:
script = (
dedent(
"""
local exist = redis.call('exists', KEYS[1])
if (exist == 0) then
redis.log(redis.LOG_WARNING, "empty: "..KEYS[1])
return redis.error_reply("cache_reset")
end
"""
)
+ script
)
self.scripts[key] = (script, add_ensure_cache)
self._script_hashes = {
key: hashlib.sha1(script.encode()).hexdigest()
for key, (script, _) in self.scripts.items()
}
async def ensure_cache(self) -> None:
await self._ensure_cache()
def get_full_data_cache_key(self) -> str: def get_full_data_cache_key(self) -> str:
return "".join((self.prefix, self.full_data_cache_key)) return "".join((self.prefix, self.full_data_cache_key))
def get_restricted_data_cache_key(self, user_id: int) -> str:
return "".join(
(self.prefix, self.restricted_user_cache_key.format(user_id=user_id))
)
def get_change_id_cache_key(self) -> str: def get_change_id_cache_key(self) -> str:
return "".join((self.prefix, self.change_id_cache_key)) return "".join((self.prefix, self.change_id_cache_key))
def get_schema_cache_key(self) -> str:
return "".join((self.prefix, self.schema_cache_key))
async def clear_cache(self) -> None: async def clear_cache(self) -> None:
""" """
Deleted all cache entries created with this element cache. Deleted all cache entries created with this element cache.
""" """
async with get_connection() as redis: await self.eval("clear_cache", keys=[], args=[f"{self.prefix}*"])
await redis.eval(
"return redis.call('del', 'fake_key', unpack(redis.call('keys', ARGV[1])))",
keys=[],
args=[f"{self.prefix}*"],
)
async def reset_full_cache(self, data: Dict[str, str]) -> None: async def reset_full_cache(self, data: Dict[str, str]) -> None:
""" """
Deletes the full_data_cache and write new data in it. Deletes the full_data_cache and write new data in it. Clears the change id key.
Does not clear locks.
Also deletes the restricted_data_cache and the change_id_cache.
""" """
async with get_connection() as redis: async with get_connection() as redis:
tr = redis.multi_exec() tr = redis.multi_exec()
# like clear_cache but does not delete a lock
tr.eval(
"return redis.call('del', 'fake_key', unpack(redis.call('keys', ARGV[1])))",
keys=[],
args=[f"{self.prefix}{self.restricted_user_cache_key}*"],
)
tr.delete(self.get_change_id_cache_key()) tr.delete(self.get_change_id_cache_key())
tr.delete(self.get_full_data_cache_key()) tr.delete(self.get_full_data_cache_key())
tr.hmset_dict(self.get_full_data_cache_key(), data) tr.hmset_dict(self.get_full_data_cache_key(), data)
await tr.execute() await tr.execute()
async def data_exists(self, user_id: Optional[int] = None) -> bool: async def data_exists(self) -> bool:
""" """
Returns True, when there is data in the cache. Returns True, when there is data in the cache.
If user_id is None, the method tests for full_data. If user_id is an int, it tests
for the restricted_data_cache for the user with the user_id. 0 is for anonymous.
""" """
async with get_connection() as redis: async with get_connection() as redis:
if user_id is None: return await redis.exists(self.get_full_data_cache_key())
cache_key = self.get_full_data_cache_key()
else:
cache_key = self.get_restricted_data_cache_key(user_id)
return await redis.exists(cache_key)
async def add_elements(self, elements: List[str]) -> None: @ensure_cache_wrapper()
async def get_all_data(self) -> Dict[bytes, bytes]:
""" """
Add or change elements to the cache. Returns all data from the full_data_cache in a mapping from element_id to the element.
elements is a list with an even len. the odd values are the element_ids and the even
values are the elements. The elements have to be encoded, for example with json.
""" """
async with get_connection() as redis: return await aioredis.util.wait_make_dict(
await redis.hmset(self.get_full_data_cache_key(), *elements) self.eval("get_all_data", [self.get_full_data_cache_key()])
)
async def del_elements( @ensure_cache_wrapper()
self, elements: List[str], user_id: Optional[int] = None async def get_collection_data(self, collection: str) -> Dict[int, bytes]:
) -> None:
""" """
Deletes elements from the cache. Returns all elements for a collection from the cache. The data is mapped
from element_id to the element.
elements has to be a list of element_ids.
If user_id is None, the elements are deleted from the full_data cache. If user_id is an
int, the elements are deleted one restricted_data_cache. 0 is for anonymous.
""" """
async with get_connection() as redis: response = await self.eval(
if user_id is None: "get_collection_data", [self.get_full_data_cache_key()], [f"{collection}:*"]
cache_key = self.get_full_data_cache_key() )
else:
cache_key = self.get_restricted_data_cache_key(user_id)
await redis.hdel(cache_key, *elements)
collection_data = {}
for i in range(0, len(response), 2):
_, id = split_element_id(response[i])
collection_data[id] = response[i + 1]
return collection_data
@ensure_cache_wrapper()
async def get_element_data(self, element_id: str) -> Optional[bytes]:
"""
Returns one element from the cache. Returns None, when the element does not exist.
"""
try:
return await self.eval(
"get_element_data", [self.get_full_data_cache_key()], [element_id]
)
except aioredis.errors.ReplyError:
raise CacheReset()
@ensure_cache_wrapper()
async def add_changed_elements( async def add_changed_elements(
self, default_change_id: int, element_ids: Iterable[str] self,
changed_elements: List[str],
deleted_element_ids: List[str],
default_change_id: int,
) -> int: ) -> int:
""" """
Saves which elements are change with a change_id. Modified the full_data_cache to insert the changed_elements and removes the
deleted_element_ids (in this order). Generates a new change_id and inserts all
Generates and returns the change_id. element_ids (changed and deleted) with the change_id into the change_id_cache.
The newly generated change_id is returned.
""" """
async with get_connection() as redis: return int(
return int( await self.eval(
await redis.eval( "add_changed_elements",
lua_script_change_data, keys=[self.get_full_data_cache_key(), self.get_change_id_cache_key()],
keys=[self.get_change_id_cache_key()], args=[
args=[default_change_id, *element_ids], default_change_id,
) len(changed_elements),
len(deleted_element_ids),
*(changed_elements + deleted_element_ids),
],
) )
)
async def get_all_data(self, user_id: Optional[int] = None) -> Dict[bytes, bytes]: @ensure_cache_wrapper()
"""
Returns all data from a cache.
if user_id is None, then the data is returned from the full_data_cache. If it is and
int, it is returned from a restricted_data_cache. 0 is for anonymous.
"""
if user_id is None:
cache_key = self.get_full_data_cache_key()
else:
cache_key = self.get_restricted_data_cache_key(user_id)
async with get_connection() as redis:
return await redis.hgetall(cache_key)
async def get_collection_data(
self, collection: str, user_id: Optional[int] = None
) -> Dict[bytes, bytes]:
"""
Returns all elements for a collection from the cache.
"""
if user_id is None:
cache_key = self.get_full_data_cache_key()
else:
cache_key = self.get_restricted_data_cache_key(user_id)
async with get_connection() as redis:
out = {}
async for k, v in redis.ihscan(cache_key, match=f"{collection}:*"):
out[k] = v
return out
async def get_element(
self, element_id: str, user_id: Optional[int] = None
) -> Optional[bytes]:
"""
Returns one element from the cache.
Returns None, when the element does not exist.
"""
if user_id is None:
cache_key = self.get_full_data_cache_key()
else:
cache_key = self.get_restricted_data_cache_key(user_id)
async with get_connection() as redis:
return await redis.hget(cache_key, element_id)
async def get_data_since( async def get_data_since(
self, change_id: int, user_id: Optional[int] = None, max_change_id: int = -1 self, change_id: int, max_change_id: int = -1
) -> Tuple[Dict[str, List[bytes]], List[str]]: ) -> Tuple[Dict[str, List[bytes]], List[str]]:
""" """
Returns all elements since a change_id. Returns all elements since a change_id (included) and until the max_change_id (included).
The returend value is a two element tuple. The first value is a dict the elements where The returend value is a two element tuple. The first value is a dict the elements where
the key is the collection_string and the value a list of (json-) encoded elements. The the key is the collection_string and the value a list of (json-) encoded elements. The
second element is a list of element_ids, that have been deleted since the change_id. second element is a list of element_ids, that have been deleted since the change_id.
if user_id is None, the full_data is returned. If user_id is an int, the restricted_data
for an user is used. 0 is for the anonymous user.
""" """
changed_elements: Dict[str, List[bytes]] = defaultdict(list) changed_elements: Dict[str, List[bytes]] = defaultdict(list)
deleted_elements: List[str] = [] deleted_elements: List[str] = []
if user_id is None:
cache_key = self.get_full_data_cache_key()
else:
cache_key = self.get_restricted_data_cache_key(user_id)
# Convert max_change_id to a string. If its negative, use the string '+inf' # Convert max_change_id to a string. If its negative, use the string '+inf'
redis_max_change_id = "+inf" if max_change_id < 0 else str(max_change_id) redis_max_change_id = "+inf" if max_change_id < 0 else str(max_change_id)
async with get_connection() as redis: # lua script that returns gets all element_ids from change_id_cache_key
# lua script that returns gets all element_ids from change_id_cache_key # and then uses each element_id on full_data or restricted_data.
# and then uses each element_id on full_data or restricted_data. # It returns a list where the odd values are the change_id and the
# It returns a list where the odd values are the change_id and the # even values the element as json. The function wait_make_dict creates
# even values the element as json. The function wait_make_dict creates # a python dict from the returned list.
# a python dict from the returned list. elements: Dict[bytes, Optional[bytes]] = await aioredis.util.wait_make_dict(
elements: Dict[bytes, Optional[bytes]] = await aioredis.util.wait_make_dict( self.eval(
redis.eval( "get_data_since",
""" keys=[self.get_full_data_cache_key(), self.get_change_id_cache_key()],
-- Get change ids of changed elements args=[change_id, redis_max_change_id],
local element_ids = redis.call('zrangebyscore', KEYS[1], ARGV[1], ARGV[2])
-- Save elements in array. Rotate element_id and element_json
local elements = {}
for _, element_id in pairs(element_ids) do
table.insert(elements, element_id)
table.insert(elements, redis.call('hget', KEYS[2], element_id))
end
return elements
""",
keys=[self.get_change_id_cache_key(), cache_key],
args=[change_id, redis_max_change_id],
)
) )
)
for element_id, element_json in elements.items(): for element_id, element_json in elements.items():
if element_id.startswith(b"_config"): if element_id.startswith(b"_config"):
@ -301,20 +383,11 @@ class RedisCacheProvider:
changed_elements[collection_string].append(element_json) changed_elements[collection_string].append(element_json)
return changed_elements, deleted_elements return changed_elements, deleted_elements
async def del_restricted_data(self, user_id: int) -> None:
"""
Deletes all restricted_data for an user. 0 is for the anonymous user.
"""
async with get_connection() as redis:
await redis.delete(self.get_restricted_data_cache_key(user_id))
async def set_lock(self, lock_name: str) -> bool: async def set_lock(self, lock_name: str) -> bool:
""" """
Tries to sets a lock. Tries to sets a lock.
Returns True when the lock could be set. Returns True when the lock could be set and False, if it was already set.
Returns False when the lock was already set.
""" """
# TODO: Improve lock. See: https://redis.io/topics/distlock # TODO: Improve lock. See: https://redis.io/topics/distlock
async with get_connection() as redis: async with get_connection() as redis:
@ -322,48 +395,28 @@ class RedisCacheProvider:
async def get_lock(self, lock_name: str) -> bool: async def get_lock(self, lock_name: str) -> bool:
""" """
Returns True, when the lock for the restricted_data of an user is set. Else False. Returns True, when the lock is set. Else False.
""" """
async with get_connection() as redis: async with get_connection() as redis:
return await redis.get(f"{self.prefix}lock_{lock_name}") return await redis.get(f"{self.prefix}lock_{lock_name}")
async def del_lock(self, lock_name: str) -> None: async def del_lock(self, lock_name: str) -> None:
""" """
Deletes the lock for the restricted_data of an user. Does nothing when the Deletes the lock. Does nothing when the lock is not set.
lock is not set.
""" """
async with get_connection() as redis: async with get_connection() as redis:
await redis.delete(f"{self.prefix}lock_{lock_name}") await redis.delete(f"{self.prefix}lock_{lock_name}")
async def get_change_id_user(self, user_id: int) -> Optional[int]: async def get_current_change_id(self) -> Optional[int]:
"""
Get the change_id for the restricted_data of an user.
This is the change_id where the restricted_data was last calculated.
"""
async with get_connection() as redis:
return await redis.hget(
self.get_restricted_data_cache_key(user_id), "_config:change_id"
)
async def update_restricted_data(self, user_id: int, data: Dict[str, str]) -> None:
"""
Updates the restricted_data for an user.
data has to be a dict where the key is an element_id and the value the (json-) encoded
element.
"""
async with get_connection() as redis:
await redis.hmset_dict(self.get_restricted_data_cache_key(user_id), data)
async def get_current_change_id(self) -> List[Tuple[str, int]]:
""" """
Get the highest change_id from redis. Get the highest change_id from redis.
""" """
async with get_connection() as redis: async with get_connection() as redis:
return await redis.zrevrangebyscore( value = await redis.zrevrangebyscore(
self.get_change_id_cache_key(), withscores=True, count=1, offset=0 self.get_change_id_cache_key(), withscores=True, count=1, offset=0
) )
# Return the score (second element) of the first (and only) element, if exists.
return value[0][1] if value else None
async def get_lowest_change_id(self) -> Optional[int]: async def get_lowest_change_id(self) -> Optional[int]:
""" """
@ -376,6 +429,53 @@ class RedisCacheProvider:
self.get_change_id_cache_key(), "_config:lowest_change_id" self.get_change_id_cache_key(), "_config:lowest_change_id"
) )
async def get_schema_version(self) -> Optional[SchemaVersion]:
""" Retrieves the schema version of the cache or None, if not existent """
async with get_connection() as redis:
schema_version = await redis.hgetall(self.get_schema_cache_key())
if not schema_version:
return None
return {
"migration": int(schema_version[b"migration"].decode()),
"config": int(schema_version[b"config"].decode()),
"db": schema_version[b"db"].decode(),
}
async def set_schema_version(self, schema_version: SchemaVersion) -> None:
""" Sets the schema version for this cache. """
async with get_connection() as redis:
await redis.hmset_dict(self.get_schema_cache_key(), schema_version)
async def eval(
self, script_name: str, keys: List[str] = [], args: List[Any] = []
) -> Any:
"""
Runs a lua script in redis. This wrapper around redis.eval tries to make
usage of redis script cache. First the hash is send to the server and if
the script is not present there (NOSCRIPT error) the actual script will be
send.
If the script uses the ensure_cache-prefix, the first key must be the full_data
cache key. This is checked here.
"""
hash = self._script_hashes[script_name]
if (
self.scripts[script_name][1]
and not keys[0] == self.get_full_data_cache_key()
):
raise ImproperlyConfigured(
"A script with a ensure_cache prefix must have the full_data cache key as its first key"
)
async with get_connection() as redis:
try:
return await redis.evalsha(hash, keys, args)
except aioredis.errors.ReplyError as e:
if str(e).startswith("NOSCRIPT"):
return await redis.eval(self.scripts[script_name][0], keys, args)
else:
raise e
class MemmoryCacheProvider: class MemmoryCacheProvider:
""" """
@ -385,112 +485,86 @@ class MemmoryCacheProvider:
This provider supports only one process. It saves the data into the memory. This provider supports only one process. It saves the data into the memory.
When you use different processes they will use diffrent data. When you use different processes they will use diffrent data.
For this reason, the ensure_cache is not used and the schema version always
returns an invalid schema to always buold the cache.
""" """
def __init__(self) -> None: def __init__(self, ensure_cache: Callable[[], Coroutine[Any, Any, None]]) -> None:
self.set_data_dicts() self.set_data_dicts()
def set_data_dicts(self) -> None: def set_data_dicts(self) -> None:
self.full_data: Dict[str, str] = {} self.full_data: Dict[str, str] = {}
self.restricted_data: Dict[int, Dict[str, str]] = {}
self.change_id_data: Dict[int, Set[str]] = {} self.change_id_data: Dict[int, Set[str]] = {}
self.locks: Dict[str, str] = {} self.locks: Dict[str, str] = {}
async def ensure_cache(self) -> None:
pass
async def clear_cache(self) -> None: async def clear_cache(self) -> None:
self.set_data_dicts() self.set_data_dicts()
async def reset_full_cache(self, data: Dict[str, str]) -> None: async def reset_full_cache(self, data: Dict[str, str]) -> None:
self.change_id_data = {}
self.full_data = data self.full_data = data
async def data_exists(self, user_id: Optional[int] = None) -> bool: async def data_exists(self) -> bool:
if user_id is None: return bool(self.full_data)
cache_dict = self.full_data
else:
cache_dict = self.restricted_data.get(user_id, {})
return bool(cache_dict) async def get_all_data(self) -> Dict[bytes, bytes]:
return str_dict_to_bytes(self.full_data)
async def add_elements(self, elements: List[str]) -> None: async def get_collection_data(self, collection: str) -> Dict[int, bytes]:
if len(elements) % 2: out = {}
raise ValueError( query = f"{collection}:"
"The argument elements of add_elements has to be a list with an even number of elements." for element_id, value in self.full_data.items():
) if element_id.startswith(query):
_, id = split_element_id(element_id)
out[id] = value.encode()
return out
for i in range(0, len(elements), 2): async def get_element_data(self, element_id: str) -> Optional[bytes]:
self.full_data[elements[i]] = elements[i + 1] value = self.full_data.get(element_id, None)
return value.encode() if value is not None else None
async def del_elements(
self, elements: List[str], user_id: Optional[int] = None
) -> None:
if user_id is None:
cache_dict = self.full_data
else:
cache_dict = self.restricted_data.get(user_id, {})
for element in elements:
try:
del cache_dict[element]
except KeyError:
pass
async def add_changed_elements( async def add_changed_elements(
self, default_change_id: int, element_ids: Iterable[str] self,
changed_elements: List[str],
deleted_element_ids: List[str],
default_change_id: int,
) -> int: ) -> int:
element_ids = list(element_ids) current_change_id = await self.get_current_change_id()
try: if current_change_id is None:
change_id = (await self.get_current_change_id())[0][1] + 1
except IndexError:
change_id = default_change_id change_id = default_change_id
else:
change_id = current_change_id + 1
for i in range(0, len(changed_elements), 2):
element_id = changed_elements[i]
self.full_data[element_id] = changed_elements[i + 1]
for element_id in element_ids:
if change_id in self.change_id_data: if change_id in self.change_id_data:
self.change_id_data[change_id].add(element_id) self.change_id_data[change_id].add(element_id)
else: else:
self.change_id_data[change_id] = {element_id} self.change_id_data[change_id] = {element_id}
for element_id in deleted_element_ids:
try:
del self.full_data[element_id]
except KeyError:
pass
if change_id in self.change_id_data:
self.change_id_data[change_id].add(element_id)
else:
self.change_id_data[change_id] = {element_id}
return change_id return change_id
async def get_all_data(self, user_id: Optional[int] = None) -> Dict[bytes, bytes]:
if user_id is None:
cache_dict = self.full_data
else:
cache_dict = self.restricted_data.get(user_id, {})
return str_dict_to_bytes(cache_dict)
async def get_collection_data(
self, collection: str, user_id: Optional[int] = None
) -> Dict[bytes, bytes]:
if user_id is None:
cache_dict = self.full_data
else:
cache_dict = self.restricted_data.get(user_id, {})
out = {}
for key, value in cache_dict.items():
if key.startswith(f"{collection}:"):
out[key] = value
return str_dict_to_bytes(out)
async def get_element(
self, element_id: str, user_id: Optional[int] = None
) -> Optional[bytes]:
if user_id is None:
cache_dict = self.full_data
else:
cache_dict = self.restricted_data.get(user_id, {})
value = cache_dict.get(element_id, None)
return value.encode() if value is not None else None
async def get_data_since( async def get_data_since(
self, change_id: int, user_id: Optional[int] = None, max_change_id: int = -1 self, change_id: int, max_change_id: int = -1
) -> Tuple[Dict[str, List[bytes]], List[str]]: ) -> Tuple[Dict[str, List[bytes]], List[str]]:
changed_elements: Dict[str, List[bytes]] = defaultdict(list) changed_elements: Dict[str, List[bytes]] = defaultdict(list)
deleted_elements: List[str] = [] deleted_elements: List[str] = []
if user_id is None:
cache_dict = self.full_data
else:
cache_dict = self.restricted_data.get(user_id, {})
all_element_ids: Set[str] = set() all_element_ids: Set[str] = set()
for data_change_id, element_ids in self.change_id_data.items(): for data_change_id, element_ids in self.change_id_data.items():
@ -500,7 +574,7 @@ class MemmoryCacheProvider:
all_element_ids.update(element_ids) all_element_ids.update(element_ids)
for element_id in all_element_ids: for element_id in all_element_ids:
element_json = cache_dict.get(element_id, None) element_json = self.full_data.get(element_id, None)
if element_json is None: if element_json is None:
deleted_elements.append(element_id) deleted_elements.append(element_id)
else: else:
@ -508,12 +582,6 @@ class MemmoryCacheProvider:
changed_elements[collection_string].append(element_json.encode()) changed_elements[collection_string].append(element_json.encode())
return changed_elements, deleted_elements return changed_elements, deleted_elements
async def del_restricted_data(self, user_id: int) -> None:
try:
del self.restricted_data[user_id]
except KeyError:
pass
async def set_lock(self, lock_name: str) -> bool: async def set_lock(self, lock_name: str) -> bool:
if lock_name in self.locks: if lock_name in self.locks:
return False return False
@ -529,20 +597,11 @@ class MemmoryCacheProvider:
except KeyError: except KeyError:
pass pass
async def get_change_id_user(self, user_id: int) -> Optional[int]: async def get_current_change_id(self) -> Optional[int]:
data = self.restricted_data.get(user_id, {})
change_id = data.get("_config:change_id", None)
return int(change_id) if change_id is not None else None
async def update_restricted_data(self, user_id: int, data: Dict[str, str]) -> None:
redis_data = self.restricted_data.setdefault(user_id, {})
redis_data.update(data)
async def get_current_change_id(self) -> List[Tuple[str, int]]:
change_data = self.change_id_data change_data = self.change_id_data
if change_data: if change_data:
return [("no_usefull_value", max(change_data.keys()))] return max(change_data.keys())
return [] return None
async def get_lowest_change_id(self) -> Optional[int]: async def get_lowest_change_id(self) -> Optional[int]:
change_data = self.change_id_data change_data = self.change_id_data
@ -550,6 +609,12 @@ class MemmoryCacheProvider:
return min(change_data.keys()) return min(change_data.keys())
return None return None
async def get_schema_version(self) -> Optional[SchemaVersion]:
return None
async def set_schema_version(self, schema_version: SchemaVersion) -> None:
pass
class Cachable(Protocol): class Cachable(Protocol):
""" """
@ -577,45 +642,3 @@ class Cachable(Protocol):
elements can be an empty list, a list with some elements of the cachable or with all elements can be an empty list, a list with some elements of the cachable or with all
elements of the cachable. elements of the cachable.
""" """
def get_all_cachables() -> List[Cachable]:
"""
Returns all element of OpenSlides.
"""
out: List[Cachable] = []
for app in apps.get_app_configs():
try:
# Get the method get_startup_elements() from an app.
# This method has to return an iterable of Cachable objects.
get_startup_elements = app.get_startup_elements
except AttributeError:
# Skip apps that do not implement get_startup_elements.
continue
out.extend(get_startup_elements())
return out
lua_script_change_data = """
-- Generate a new change_id
local tmp = redis.call('zrevrangebyscore', KEYS[1], '+inf', '-inf', 'WITHSCORES', 'LIMIT', 0, 1)
local change_id
if next(tmp) == nil then
-- The key does not exist
change_id = ARGV[1]
else
change_id = tmp[2] + 1
end
-- Add elements to sorted set
local count = 2
while ARGV[count] do
redis.call('zadd', KEYS[1], change_id, ARGV[count])
count = count + 1
end
-- Set lowest_change_id if it does not exist
redis.call('zadd', KEYS[1], 'NX', change_id, '_config:lowest_change_id')
return change_id
"""

View File

@ -1,14 +1,15 @@
import logging import logging
import time import time
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs from urllib.parse import parse_qs
from ..utils.websocket import WEBSOCKET_CHANGE_ID_TOO_HIGH
from .auth import async_anonymous_is_enabled from .auth import async_anonymous_is_enabled
from .autoupdate import AutoupdateFormat from .autoupdate import AutoupdateFormat
from .cache import element_cache, split_element_id from .cache import element_cache, split_element_id
from .utils import get_worker_id from .utils import get_worker_id
from .websocket import ProtocollAsyncJsonWebsocketConsumer, get_element_data from .websocket import ProtocollAsyncJsonWebsocketConsumer
logger = logging.getLogger("openslides.websocket") logger = logging.getLogger("openslides.websocket")
@ -70,13 +71,7 @@ class SiteConsumer(ProtocollAsyncJsonWebsocketConsumer):
if change_id is not None: if change_id is not None:
logger.debug(f"connect: change id {change_id} ({self._id})") logger.debug(f"connect: change id {change_id} ({self._id})")
try: await self.send_autoupdate(change_id)
data = await get_element_data(self.scope["user"]["id"], change_id)
except ValueError:
# When the change_id is to big, do nothing
pass
else:
await self.send_json(type="autoupdate", content=data)
else: else:
logger.debug(f"connect: no change id ({self._id})") logger.debug(f"connect: no change id ({self._id})")
@ -111,30 +106,69 @@ class SiteConsumer(ProtocollAsyncJsonWebsocketConsumer):
item["senderUserId"] = event["senderUserId"] item["senderUserId"] = event["senderUserId"]
await self.send_json(type="notify", content=item) await self.send_json(type="notify", content=item)
async def send_data(self, event: Dict[str, Any]) -> None: async def send_autoupdate(
self,
change_id: int,
max_change_id: Optional[int] = None,
in_response: Optional[str] = None,
) -> None:
""" """
Send changed or deleted elements to the user. Sends an autoupdate to the client from change_id to max_change_id.
If max_change_id is None, the current change id will be used.
""" """
change_id = event["change_id"] user_id = self.scope["user"]["id"]
changed_elements, deleted_elements_ids = await element_cache.get_restricted_data(
self.scope["user"]["id"], change_id, max_change_id=change_id if max_change_id is None:
) max_change_id = await element_cache.get_current_change_id()
if change_id == max_change_id + 1:
# The client is up-to-date, so nothing will be done
return
if change_id > max_change_id:
message = f"Requested change_id {change_id} is higher this highest change_id {max_change_id}."
await self.send_error(
code=WEBSOCKET_CHANGE_ID_TOO_HIGH,
message=message,
in_response=in_response,
)
return
try:
changed_elements, deleted_element_ids = await element_cache.get_data_since(
user_id, change_id, max_change_id
)
except RuntimeError:
# The change_id is lower the the lowerst change_id in redis. Return all data
changed_elements = await element_cache.get_all_data_list(user_id)
all_data = True
deleted_elements: Dict[str, List[int]] = {}
else:
all_data = False
deleted_elements = defaultdict(list)
for element_id in deleted_element_ids:
collection_string, id = split_element_id(element_id)
deleted_elements[collection_string].append(id)
deleted_elements: Dict[str, List[int]] = defaultdict(list)
for element_id in deleted_elements_ids:
collection_string, id = split_element_id(element_id)
deleted_elements[collection_string].append(id)
await self.send_json( await self.send_json(
type="autoupdate", type="autoupdate",
content=AutoupdateFormat( content=AutoupdateFormat(
changed=changed_elements, changed=changed_elements,
deleted=deleted_elements, deleted=deleted_elements,
from_change_id=change_id, from_change_id=change_id,
to_change_id=change_id, to_change_id=max_change_id,
all_data=False, all_data=all_data,
), ),
in_response=in_response,
) )
async def send_data(self, event: Dict[str, Any]) -> None:
"""
Send changed or deleted elements to the user.
"""
change_id = event["change_id"]
await self.send_autoupdate(change_id, max_change_id=change_id)
async def projector_changed(self, event: Dict[str, Any]) -> None: async def projector_changed(self, event: Dict[str, Any]) -> None:
""" """
The projector has changed. The projector has changed.

View File

@ -54,7 +54,7 @@ async def get_user(scope: Dict[str, Any]) -> Dict[str, Any]:
pass pass
else: else:
if backend_path in settings.AUTHENTICATION_BACKENDS: if backend_path in settings.AUTHENTICATION_BACKENDS:
user = await element_cache.get_element_full_data("users/user", user_id) user = await element_cache.get_element_data("users/user", user_id)
if user: if user:
# Verify the session # Verify the session
session_hash = session.get(HASH_SESSION_KEY) session_hash = session.get(HASH_SESSION_KEY)

View File

@ -67,7 +67,7 @@ async def get_projector_data(
if projector_ids is None: if projector_ids is None:
projector_ids = [] projector_ids = []
all_data = await element_cache.get_all_full_data_ordered() all_data = await element_cache.get_all_data_dict()
projector_data: Dict[int, List[Dict[str, Any]]] = {} projector_data: Dict[int, List[Dict[str, Any]]] = {}
for projector_id, projector in all_data.get("core/projector", {}).items(): for projector_id, projector in all_data.get("core/projector", {}).items():

View File

@ -253,10 +253,11 @@ class ListModelMixin(_ListModelMixin):
# The corresponding queryset does not support caching. # The corresponding queryset does not support caching.
response = super().list(request, *args, **kwargs) response = super().list(request, *args, **kwargs)
else: else:
# TODO
# This loads all data from the cache, not only the requested data. # This loads all data from the cache, not only the requested data.
# If we would use the rest api, we should add a method # If we would use the rest api, we should add a method
# element_cache.get_collection_restricted_data # element_cache.get_collection_restricted_data
all_restricted_data = async_to_sync(element_cache.get_all_restricted_data)( all_restricted_data = async_to_sync(element_cache.get_all_data_list)(
request.user.pk or 0 request.user.pk or 0
) )
response = Response(all_restricted_data.get(collection_string, [])) response = Response(all_restricted_data.get(collection_string, []))
@ -278,8 +279,8 @@ class RetrieveModelMixin(_RetrieveModelMixin):
else: else:
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
user_id = request.user.pk or 0 user_id = request.user.pk or 0
content = async_to_sync(element_cache.get_element_restricted_data)( content = async_to_sync(element_cache.get_element_data)(
user_id, collection_string, self.kwargs[lookup_url_kwarg] collection_string, self.kwargs[lookup_url_kwarg], user_id
) )
if content is None: if content is None:
raise Http404 raise Http404

View File

@ -0,0 +1,84 @@
import logging
from typing import Optional
from django.db.models import Max
from mypy_extensions import TypedDict
logger = logging.getLogger(__name__)
SchemaVersion = TypedDict("SchemaVersion", {"migration": int, "config": int, "db": str})
class SchemaVersionHandler:
"""
Handler for the schema version of this running OpenSlides instance.
What is a schema version? This is an indictor of the current schema of the data
in the database, config variables, and the database itself. E.b. with a migration,
new/changed config variables or with a database change, the schema of the data changes.
To detect this is needed to reset the cache, so it does not hold any old data. This
affects the server cache, but also the client uses this technique to flush the cache.
Get the current schema with `get`. The schema version is built just once. After a change
in the schema, all workers needs a restart!
"""
def __init__(self) -> None:
self._schema_version: Optional[SchemaVersion] = None
def get(self) -> SchemaVersion:
if self._schema_version is not None:
return self._schema_version
from django.db.migrations.recorder import MigrationRecorder
migration = MigrationRecorder.Migration.objects.aggregate(Max("id"))["id__max"]
from openslides.core.config import ConfigStore
try:
config = ConfigStore.objects.get(key="config_version").value
except ConfigStore.DoesNotExist:
config = 0
try:
db = ConfigStore.objects.get(key="db_id").value
except ConfigStore.DoesNotExist:
db = ""
self._schema_version = {"migration": migration, "config": config, "db": db}
return self._schema_version
def compare(self, other: Optional[SchemaVersion]) -> bool:
current = self.get()
if not other:
logger.info("No old schema version")
return False
equal = True
if current["db"] != other["db"]:
other_db = other["db"] or "<empty>"
logger.info(f"DB changed from {other_db} to {current['db']}")
equal = False
if current["config"] != other["config"]:
other_config = other["config"] or "<empty>"
logger.info(f"Config changed from {other_config} to {current['config']}")
equal = False
if current["migration"] != other["migration"]:
logger.info(
f"Migration changed from {other['migration']} to {current['migration']}"
)
equal = False
return equal
def log_current(self) -> None:
current = self.get()
logger.info(
f"""Schema version:
DB: {current["db"]}
migration: {current["migration"]}
config: {current["config"]}"""
)
schema_version_handler = SchemaVersionHandler()

View File

@ -100,11 +100,6 @@ if use_redis:
# or a unix domain socket path string — "/path/to/redis.sock". # or a unix domain socket path string — "/path/to/redis.sock".
REDIS_ADDRESS = "redis://127.0.0.1" REDIS_ADDRESS = "redis://127.0.0.1"
# When use_redis is True, the restricted data cache caches the data individuel
# for each user. This requires a lot of memory if there are a lot of active
# users.
RESTRICTED_DATA_CACHE = True
# Session backend # Session backend
# Redis configuration for django-redis-sessions. # Redis configuration for django-redis-sessions.

View File

@ -1,6 +1,5 @@
import json import json
from collections import defaultdict from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import jsonschema import jsonschema
import lz4.frame import lz4.frame
@ -8,10 +7,7 @@ from channels.generic.websocket import AsyncWebsocketConsumer
from django.conf import settings from django.conf import settings
from websockets.exceptions import ConnectionClosed from websockets.exceptions import ConnectionClosed
from .autoupdate import AutoupdateFormat
from .cache import element_cache
from .stats import WebsocketThroughputLogger from .stats import WebsocketThroughputLogger
from .utils import split_element_id
# Custom Websocket error codes (not to be confused with the websocket *connection* # Custom Websocket error codes (not to be confused with the websocket *connection*
@ -25,7 +21,7 @@ WEBSOCKET_CHANGE_ID_TOO_HIGH = 101
# If data is requested and the given change id is higher than the highest change id # If data is requested and the given change id is higher than the highest change id
# from the element_cache. # from the element_cache.
WEBSOCKET_WRONG_FORMAT = 10 WEBSOCKET_WRONG_FORMAT = 102
# If the recieved data has not the expected format. # If the recieved data has not the expected format.
@ -232,37 +228,3 @@ def register_client_message(
message_schema["required"] = ["content"] message_schema["required"] = ["content"]
schema["anyOf"].append(message_schema) schema["anyOf"].append(message_schema)
async def get_element_data(user_id: int, change_id: int = 0) -> AutoupdateFormat:
"""
Returns all element data since a change_id.
"""
current_change_id = await element_cache.get_current_change_id()
if change_id > current_change_id:
raise ValueError(
f"Requested change_id {change_id} is higher this highest change_id {current_change_id}."
)
try:
changed_elements, deleted_element_ids = await element_cache.get_restricted_data(
user_id, change_id, current_change_id
)
except RuntimeError:
# The change_id is lower the the lowerst change_id in redis. Return all data
changed_elements = await element_cache.get_all_restricted_data(user_id)
all_data = True
deleted_elements: Dict[str, List[int]] = {}
else:
all_data = False
deleted_elements = defaultdict(list)
for element_id in deleted_element_ids:
collection_string, id = split_element_id(element_id)
deleted_elements[collection_string].append(id)
return AutoupdateFormat(
changed=changed_elements,
deleted=deleted_elements,
from_change_id=change_id,
to_change_id=current_change_id,
all_data=all_data,
)

View File

@ -82,5 +82,5 @@ def reset_cache(request):
async_to_sync(element_cache.cache_provider.clear_cache)() async_to_sync(element_cache.cache_provider.clear_cache)()
element_cache.ensure_cache(reset=True) element_cache.ensure_cache(reset=True)
# Set constant start_time # Set constant default change_id
element_cache.start_time = 1 element_cache.set_default_change_id(1)

View File

@ -16,7 +16,10 @@ from openslides.utils.autoupdate import (
inform_deleted_data, inform_deleted_data,
) )
from openslides.utils.cache import element_cache from openslides.utils.cache import element_cache
from openslides.utils.websocket import WEBSOCKET_CHANGE_ID_TOO_HIGH from openslides.utils.websocket import (
WEBSOCKET_CHANGE_ID_TOO_HIGH,
WEBSOCKET_WRONG_FORMAT,
)
from ...unit.utils.cache_provider import Collection1, Collection2, get_cachable_provider from ...unit.utils.cache_provider import Collection1, Collection2, get_cachable_provider
from ..helpers import TConfig, TProjector, TUser from ..helpers import TConfig, TProjector, TUser
@ -36,7 +39,7 @@ async def prepare_element_cache(settings):
[Collection1(), Collection2(), TConfig(), TUser(), TProjector()] [Collection1(), Collection2(), TConfig(), TUser(), TProjector()]
) )
element_cache._cachables = None element_cache._cachables = None
await sync_to_async(element_cache.ensure_cache)() await element_cache.async_ensure_cache(default_change_id=1)
yield yield
# Reset the cachable_provider # Reset the cachable_provider
element_cache.cachable_provider = orig_cachable_provider element_cache.cachable_provider = orig_cachable_provider
@ -118,31 +121,6 @@ async def test_connection_with_change_id(get_communicator, set_config):
assert TUser().get_collection_string() in content["changed"] assert TUser().get_collection_string() in content["changed"]
@pytest.mark.asyncio
async def test_connection_with_change_id_get_restricted_data_with_restricted_data_cache(
get_communicator, set_config
):
"""
Test, that the returned data is the restricted_data when restricted_data_cache is activated
"""
try:
# Save the value of use_restricted_data_cache
original_use_restricted_data = element_cache.use_restricted_data_cache
element_cache.use_restricted_data_cache = True
await set_config("general_system_enable_anonymous", True)
communicator = get_communicator("change_id=0")
await communicator.connect()
response = await communicator.receive_json_from()
content = response.get("content")
assert content["changed"]["app/collection1"][0]["value"] == "restricted_value1"
finally:
# reset the value of use_restricted_data_cache
element_cache.use_restricted_data_cache = original_use_restricted_data
@pytest.mark.xfail # This will fail until a proper solution in #4009 @pytest.mark.xfail # This will fail until a proper solution in #4009
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connection_with_invalid_change_id(get_communicator, set_config): async def test_connection_with_invalid_change_id(get_communicator, set_config):
@ -154,14 +132,14 @@ async def test_connection_with_invalid_change_id(get_communicator, set_config):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connection_with_to_big_change_id(get_communicator, set_config): async def test_connection_with_too_big_change_id(get_communicator, set_config):
await set_config("general_system_enable_anonymous", True) await set_config("general_system_enable_anonymous", True)
communicator = get_communicator("change_id=100") communicator = get_communicator("change_id=100")
connected, __ = await communicator.connect() connected, __ = await communicator.connect()
assert connected is True assert connected is True
assert await communicator.receive_nothing() await communicator.assert_receive_error(code=WEBSOCKET_CHANGE_ID_TOO_HIGH)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -271,8 +249,7 @@ async def test_invalid_websocket_message_type(communicator, set_config):
await communicator.send_json_to([]) await communicator.send_json_to([])
response = await communicator.receive_json_from() await communicator.assert_receive_error(code=WEBSOCKET_WRONG_FORMAT)
assert response["type"] == "error"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -282,8 +259,7 @@ async def test_invalid_websocket_message_no_id(communicator, set_config):
await communicator.send_json_to({"type": "test", "content": "foobar"}) await communicator.send_json_to({"type": "test", "content": "foobar"})
response = await communicator.receive_json_from() await communicator.assert_receive_error(code=WEBSOCKET_WRONG_FORMAT)
assert response["type"] == "error"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -299,9 +275,9 @@ async def test_send_unknown_type(communicator, set_config):
} }
) )
response = await communicator.receive_json_from() await communicator.assert_receive_error(
assert response["type"] == "error" code=WEBSOCKET_WRONG_FORMAT, in_response="test_id"
assert response["in_response"] == "test_id" )
@pytest.mark.asyncio @pytest.mark.asyncio
@ -343,18 +319,16 @@ async def test_send_get_elements(communicator, set_config):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_get_elements_to_big_change_id(communicator, set_config): async def test_send_get_elements_too_big_change_id(communicator, set_config):
await set_config("general_system_enable_anonymous", True) await set_config("general_system_enable_anonymous", True)
await communicator.connect() await communicator.connect()
await communicator.send_json_to( await communicator.send_json_to(
{"type": "getElements", "content": {"change_id": 100}, "id": "test_id"} {"type": "getElements", "content": {"change_id": 100}, "id": "test_id"}
) )
response = await communicator.receive_json_from() await communicator.assert_receive_error(
code=WEBSOCKET_CHANGE_ID_TOO_HIGH, in_response="test_id"
type = response.get("type") )
assert type == "error"
assert response.get("in_response") == "test_id"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -374,10 +348,10 @@ async def test_send_get_elements_to_small_change_id(communicator, set_config):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_connect_twice_with_clear_change_id_cache(communicator, set_config): async def test_send_connect_up_to_date(communicator, set_config):
""" """
Test, that a second request with change_id+1 from the first request, returns Test, that a second request with change_id+1 from the first request does not
an error. send anything, becuase the client is up to date.
""" """
await set_config("general_system_enable_anonymous", True) await set_config("general_system_enable_anonymous", True)
element_cache.cache_provider.change_id_data = {} # type: ignore element_cache.cache_provider.change_id_data = {} # type: ignore
@ -395,13 +369,7 @@ async def test_send_connect_twice_with_clear_change_id_cache(communicator, set_c
"id": "test_id", "id": "test_id",
} }
) )
response2 = await communicator.receive_json_from() assert await communicator.receive_nothing()
assert response2["type"] == "error"
assert response2.get("content") == {
"code": WEBSOCKET_CHANGE_ID_TOO_HIGH,
"message": "Requested change_id 2 is higher this highest change_id 1.",
}
@pytest.mark.xfail # This test is broken @pytest.mark.xfail # This test is broken

View File

@ -23,3 +23,16 @@ class WebsocketCommunicator(ChannelsWebsocketCommunicator):
assert isinstance(text_data, str), "JSON data is not a text frame" assert isinstance(text_data, str), "JSON data is not a text frame"
return json.loads(text_data) return json.loads(text_data)
async def assert_receive_error(self, timeout=1, in_response=None, **kwargs):
response = await self.receive_json_from(timeout)
assert response["type"] == "error"
content = response.get("content")
if kwargs:
assert content
for key, value in kwargs.items():
assert content.get(key) == value
if in_response:
assert response["in_response"] == in_response

View File

@ -1,6 +1,5 @@
import json import json
from typing import Any, Dict, List from typing import Any, Dict, List
from unittest.mock import patch
import pytest import pytest
@ -32,9 +31,10 @@ def element_cache():
element_cache = ElementCache( element_cache = ElementCache(
cache_provider_class=TTestCacheProvider, cache_provider_class=TTestCacheProvider,
cachable_provider=get_cachable_provider(), cachable_provider=get_cachable_provider(),
start_time=0, default_change_id=0,
) )
element_cache.ensure_cache() element_cache.ensure_cache()
element_cache.set_default_change_id(0)
return element_cache return element_cache
@ -44,7 +44,7 @@ async def test_change_elements(element_cache):
"app/collection1:1": {"id": 1, "value": "updated"}, "app/collection1:1": {"id": 1, "value": "updated"},
"app/collection1:2": {"id": 2, "value": "new"}, "app/collection1:2": {"id": 2, "value": "new"},
"app/collection2:1": {"id": 1, "key": "updated"}, "app/collection2:1": {"id": 1, "key": "updated"},
"app/collection2:2": None, "app/collection2:2": None, # Deleted
} }
element_cache.cache_provider.full_data = { element_cache.cache_provider.full_data = {
@ -103,8 +103,8 @@ async def test_change_elements_with_no_data_in_redis(element_cache):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_all_full_data_from_db(element_cache): async def test_get_all_data_from_db(element_cache):
result = await element_cache.get_all_full_data() result = await element_cache.get_all_data_list()
assert result == example_data() assert result == example_data()
# Test that elements are written to redis # Test that elements are written to redis
@ -119,7 +119,7 @@ async def test_get_all_full_data_from_db(element_cache):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_all_full_data_from_redis(element_cache): async def test_get_all_data_from_redis(element_cache):
element_cache.cache_provider.full_data = { element_cache.cache_provider.full_data = {
"app/collection1:1": '{"id": 1, "value": "value1"}', "app/collection1:1": '{"id": 1, "value": "value1"}',
"app/collection1:2": '{"id": 2, "value": "value2"}', "app/collection1:2": '{"id": 2, "value": "value2"}',
@ -127,14 +127,14 @@ async def test_get_all_full_data_from_redis(element_cache):
"app/collection2:2": '{"id": 2, "key": "value2"}', "app/collection2:2": '{"id": 2, "key": "value2"}',
} }
result = await element_cache.get_all_full_data() result = await element_cache.get_all_data_list()
# The output from redis has to be the same then the db_data # The output from redis has to be the same then the db_data
assert sort_dict(result) == sort_dict(example_data()) assert sort_dict(result) == sort_dict(example_data())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_full_data_change_id_0(element_cache): async def test_get_data_since_change_id_0(element_cache):
element_cache.cache_provider.full_data = { element_cache.cache_provider.full_data = {
"app/collection1:1": '{"id": 1, "value": "value1"}', "app/collection1:1": '{"id": 1, "value": "value1"}',
"app/collection1:2": '{"id": 2, "value": "value2"}', "app/collection1:2": '{"id": 2, "value": "value2"}',
@ -142,13 +142,13 @@ async def test_get_full_data_change_id_0(element_cache):
"app/collection2:2": '{"id": 2, "key": "value2"}', "app/collection2:2": '{"id": 2, "key": "value2"}',
} }
result = await element_cache.get_full_data(0) result = await element_cache.get_data_since(None, 0)
assert sort_dict(result[0]) == sort_dict(example_data()) assert sort_dict(result[0]) == sort_dict(example_data())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_full_data_change_id_lower_then_in_redis(element_cache): async def test_get_data_since_change_id_lower_then_in_redis(element_cache):
element_cache.cache_provider.full_data = { element_cache.cache_provider.full_data = {
"app/collection1:1": '{"id": 1, "value": "value1"}', "app/collection1:1": '{"id": 1, "value": "value1"}',
"app/collection1:2": '{"id": 2, "value": "value2"}', "app/collection1:2": '{"id": 2, "value": "value2"}',
@ -157,11 +157,11 @@ async def test_get_full_data_change_id_lower_then_in_redis(element_cache):
} }
element_cache.cache_provider.change_id_data = {2: {"app/collection1:1"}} element_cache.cache_provider.change_id_data = {2: {"app/collection1:1"}}
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await element_cache.get_full_data(1) await element_cache.get_data_since(None, 1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_full_data_change_id_data_in_redis(element_cache): async def test_get_data_since_change_id_data_in_redis(element_cache):
element_cache.cache_provider.full_data = { element_cache.cache_provider.full_data = {
"app/collection1:1": '{"id": 1, "value": "value1"}', "app/collection1:1": '{"id": 1, "value": "value1"}',
"app/collection1:2": '{"id": 2, "value": "value2"}', "app/collection1:2": '{"id": 2, "value": "value2"}',
@ -172,7 +172,7 @@ async def test_get_full_data_change_id_data_in_redis(element_cache):
1: {"app/collection1:1", "app/collection1:3"} 1: {"app/collection1:1", "app/collection1:3"}
} }
result = await element_cache.get_full_data(1) result = await element_cache.get_data_since(None, 1)
assert result == ( assert result == (
{"app/collection1": [{"id": 1, "value": "value1"}]}, {"app/collection1": [{"id": 1, "value": "value1"}]},
@ -181,12 +181,12 @@ async def test_get_full_data_change_id_data_in_redis(element_cache):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_full_data_change_id_data_in_db(element_cache): async def test_get_data_since_change_id_data_in_db(element_cache):
element_cache.cache_provider.change_id_data = { element_cache.cache_provider.change_id_data = {
1: {"app/collection1:1", "app/collection1:3"} 1: {"app/collection1:1", "app/collection1:3"}
} }
result = await element_cache.get_full_data(1) result = await element_cache.get_data_since(None, 1)
assert result == ( assert result == (
{"app/collection1": [{"id": 1, "value": "value1"}]}, {"app/collection1": [{"id": 1, "value": "value1"}]},
@ -195,27 +195,27 @@ async def test_get_full_data_change_id_data_in_db(element_cache):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_full_data_change_id_data_in_db_empty_change_id(element_cache): async def test_get_gata_since_change_id_data_in_db_empty_change_id(element_cache):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await element_cache.get_full_data(1) await element_cache.get_data_since(None, 1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_element_full_data_empty_redis(element_cache): async def test_get_element_data_empty_redis(element_cache):
result = await element_cache.get_element_full_data("app/collection1", 1) result = await element_cache.get_element_data("app/collection1", 1)
assert result == {"id": 1, "value": "value1"} assert result == {"id": 1, "value": "value1"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_element_full_data_empty_redis_does_not_exist(element_cache): async def test_get_element_data_empty_redis_does_not_exist(element_cache):
result = await element_cache.get_element_full_data("app/collection1", 3) result = await element_cache.get_element_data("app/collection1", 3)
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_element_full_data_full_redis(element_cache): async def test_get_element_data_full_redis(element_cache):
element_cache.cache_provider.full_data = { element_cache.cache_provider.full_data = {
"app/collection1:1": '{"id": 1, "value": "value1"}', "app/collection1:1": '{"id": 1, "value": "value1"}',
"app/collection1:2": '{"id": 2, "value": "value2"}', "app/collection1:2": '{"id": 2, "value": "value2"}',
@ -223,208 +223,14 @@ async def test_get_element_full_data_full_redis(element_cache):
"app/collection2:2": '{"id": 2, "key": "value2"}', "app/collection2:2": '{"id": 2, "key": "value2"}',
} }
result = await element_cache.get_element_full_data("app/collection1", 1) result = await element_cache.get_element_data("app/collection1", 1)
assert result == {"id": 1, "value": "value1"} assert result == {"id": 1, "value": "value1"}
@pytest.mark.asyncio
async def test_exists_restricted_data(element_cache):
element_cache.use_restricted_data_cache = True
element_cache.cache_provider.restricted_data = {
0: {
"app/collection1:1": '{"id": 1, "value": "value1"}',
"app/collection1:2": '{"id": 2, "value": "value2"}',
"app/collection2:1": '{"id": 1, "key": "value1"}',
"app/collection2:2": '{"id": 2, "key": "value2"}',
}
}
result = await element_cache.exists_restricted_data(0)
assert result
@pytest.mark.asyncio
async def test_exists_restricted_data_do_not_use_restricted_data(element_cache):
element_cache.use_restricted_data_cache = False
element_cache.cache_provider.restricted_data = {
0: {
"app/collection1:1": '{"id": 1, "value": "value1"}',
"app/collection1:2": '{"id": 2, "value": "value2"}',
"app/collection2:1": '{"id": 1, "key": "value1"}',
"app/collection2:2": '{"id": 2, "key": "value2"}',
}
}
result = await element_cache.exists_restricted_data(0)
assert not result
@pytest.mark.asyncio
async def test_del_user(element_cache):
element_cache.use_restricted_data_cache = True
element_cache.cache_provider.restricted_data = {
0: {
"app/collection1:1": '{"id": 1, "value": "value1"}',
"app/collection1:2": '{"id": 2, "value": "value2"}',
"app/collection2:1": '{"id": 1, "key": "value1"}',
"app/collection2:2": '{"id": 2, "key": "value2"}',
}
}
await element_cache.del_user(0)
assert not element_cache.cache_provider.restricted_data
@pytest.mark.asyncio
async def test_del_user_for_empty_user(element_cache):
element_cache.use_restricted_data_cache = True
await element_cache.del_user(0)
assert not element_cache.cache_provider.restricted_data
@pytest.mark.asyncio
async def test_update_restricted_data(element_cache):
element_cache.use_restricted_data_cache = True
await element_cache.update_restricted_data(0)
assert decode_dict(element_cache.cache_provider.restricted_data[0]) == decode_dict(
{
"app/collection1:1": '{"id": 1, "value": "restricted_value1"}',
"app/collection1:2": '{"id": 2, "value": "restricted_value2"}',
"app/collection2:1": '{"id": 1, "key": "restricted_value1"}',
"app/collection2:2": '{"id": 2, "key": "restricted_value2"}',
"_config:change_id": "0",
}
)
# Make sure the lock is deleted
assert not await element_cache.cache_provider.get_lock("restricted_data_0")
@pytest.mark.asyncio
async def test_update_restricted_data_full_restricted_elements(element_cache):
"""
Tests that elements in the restricted_data cache, that are later hidden from
a user, gets deleted for this user.
"""
element_cache.use_restricted_data_cache = True
await element_cache.update_restricted_data(0)
element_cache.cache_provider.change_id_data = {
1: {"app/collection1:1", "app/collection1:3"}
}
with patch("tests.unit.utils.cache_provider.restrict_elements", lambda x: []):
await element_cache.update_restricted_data(0)
assert decode_dict(element_cache.cache_provider.restricted_data[0]) == decode_dict(
{"_config:change_id": "1"}
)
# Make sure the lock is deleted
assert not await element_cache.cache_provider.get_lock("restricted_data_0")
@pytest.mark.asyncio
async def test_update_restricted_data_disabled_restricted_data(element_cache):
element_cache.use_restricted_data_cache = False
await element_cache.update_restricted_data(0)
assert not element_cache.cache_provider.restricted_data
@pytest.mark.asyncio
async def test_update_restricted_data_to_low_change_id(element_cache):
element_cache.use_restricted_data_cache = True
element_cache.cache_provider.restricted_data[0] = {"_config:change_id": "1"}
element_cache.cache_provider.change_id_data = {3: {"app/collection1:1"}}
await element_cache.update_restricted_data(0)
assert decode_dict(element_cache.cache_provider.restricted_data[0]) == decode_dict(
{
"app/collection1:1": '{"id": 1, "value": "restricted_value1"}',
"app/collection1:2": '{"id": 2, "value": "restricted_value2"}',
"app/collection2:1": '{"id": 1, "key": "restricted_value1"}',
"app/collection2:2": '{"id": 2, "key": "restricted_value2"}',
"_config:change_id": "3",
}
)
@pytest.mark.asyncio
async def test_update_restricted_data_with_same_id(element_cache):
element_cache.use_restricted_data_cache = True
element_cache.cache_provider.restricted_data[0] = {"_config:change_id": "1"}
element_cache.cache_provider.change_id_data = {1: {"app/collection1:1"}}
await element_cache.update_restricted_data(0)
# Same id means, there is nothing to do
assert element_cache.cache_provider.restricted_data[0] == {"_config:change_id": "1"}
@pytest.mark.asyncio
async def test_update_restricted_data_with_deleted_elements(element_cache):
element_cache.use_restricted_data_cache = True
element_cache.cache_provider.restricted_data[0] = {
"app/collection1:3": '{"id": 1, "value": "restricted_value1"}',
"_config:change_id": "1",
}
element_cache.cache_provider.change_id_data = {2: {"app/collection1:3"}}
await element_cache.update_restricted_data(0)
assert element_cache.cache_provider.restricted_data[0] == {"_config:change_id": "2"}
@pytest.mark.asyncio
async def test_update_restricted_data_second_worker(element_cache):
"""
Test, that if another worker is updating the data, noting is done.
This tests makes use of the redis key as it would on different daphne servers.
"""
element_cache.use_restricted_data_cache = True
element_cache.cache_provider.restricted_data = {0: {}}
await element_cache.cache_provider.set_lock("restricted_data_0")
await element_cache.cache_provider.del_lock_after_wait("restricted_data_0")
await element_cache.update_restricted_data(0)
# Restricted_data_should not be set on second worker
assert element_cache.cache_provider.restricted_data == {0: {}}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_all_restricted_data(element_cache): async def test_get_all_restricted_data(element_cache):
element_cache.use_restricted_data_cache = True result = await element_cache.get_all_data_list(0)
result = await element_cache.get_all_restricted_data(0)
assert sort_dict(result) == sort_dict(
{
"app/collection1": [
{"id": 1, "value": "restricted_value1"},
{"id": 2, "value": "restricted_value2"},
],
"app/collection2": [
{"id": 1, "key": "restricted_value1"},
{"id": 2, "key": "restricted_value2"},
],
}
)
@pytest.mark.asyncio
async def test_get_all_restricted_data_disabled_restricted_data_cache(element_cache):
element_cache.use_restricted_data_cache = False
result = await element_cache.get_all_restricted_data(0)
assert sort_dict(result) == sort_dict( assert sort_dict(result) == sort_dict(
{ {
@ -442,9 +248,7 @@ async def test_get_all_restricted_data_disabled_restricted_data_cache(element_ca
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_restricted_data_change_id_0(element_cache): async def test_get_restricted_data_change_id_0(element_cache):
element_cache.use_restricted_data_cache = True result = await element_cache.get_data_since(0, 0)
result = await element_cache.get_restricted_data(0, 0)
assert sort_dict(result[0]) == sort_dict( assert sort_dict(result[0]) == sort_dict(
{ {
@ -461,13 +265,12 @@ async def test_get_restricted_data_change_id_0(element_cache):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_restricted_data_disabled_restricted_data_cache(element_cache): async def test_get_restricted_data_2(element_cache):
element_cache.use_restricted_data_cache = False
element_cache.cache_provider.change_id_data = { element_cache.cache_provider.change_id_data = {
1: {"app/collection1:1", "app/collection1:3"} 1: {"app/collection1:1", "app/collection1:3"}
} }
result = await element_cache.get_restricted_data(0, 1) result = await element_cache.get_data_since(0, 1)
assert result == ( assert result == (
{"app/collection1": [{"id": 1, "value": "restricted_value1"}]}, {"app/collection1": [{"id": 1, "value": "restricted_value1"}]},
@ -477,19 +280,17 @@ async def test_get_restricted_data_disabled_restricted_data_cache(element_cache)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_restricted_data_change_id_lower_then_in_redis(element_cache): async def test_get_restricted_data_change_id_lower_then_in_redis(element_cache):
element_cache.use_restricted_data_cache = True
element_cache.cache_provider.change_id_data = {2: {"app/collection1:1"}} element_cache.cache_provider.change_id_data = {2: {"app/collection1:1"}}
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await element_cache.get_restricted_data(0, 1) await element_cache.get_data_since(0, 1)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_restricted_data_change_with_id(element_cache): async def test_get_restricted_data_change_with_id(element_cache):
element_cache.use_restricted_data_cache = True
element_cache.cache_provider.change_id_data = {2: {"app/collection1:1"}} element_cache.cache_provider.change_id_data = {2: {"app/collection1:1"}}
result = await element_cache.get_restricted_data(0, 2) result = await element_cache.get_data_since(0, 2)
assert result == ( assert result == (
{"app/collection1": [{"id": 1, "value": "restricted_value1"}]}, {"app/collection1": [{"id": 1, "value": "restricted_value1"}]},