Source code for crossauth_backend.storageimpl.inmemorystorage

# Copyright (c) 2024 Matthew Baker.  All rights reserved.  Licenced under the Apache Licence 2.0.  See LICENSE file
from crossauth_backend.storage import KeyStorage, KeyDataEntry, UserStorage, \
    UserStorageOptions, UserStorageGetOptions, UserAndSecrets, \
    OAuthClientStorage, OAuthClientStorageOptions, \
        OAuthAuthorizationStorage, OAuthAuthorizationStorageOptions
from crossauth_backend.common.interfaces import Key, PartialKey, \
    User, UserInputFields, UserSecrets, UserSecretsInputFields, UserState, \
    PartialUser, PartialUserSecrets, \
    OAuthClient, PartialOAuthClient, \
    optional_not_equals
from crossauth_backend.common.error import CrossauthError, ErrorCode
from crossauth_backend.common.logger import CrossauthLogger, j

import json
from typing import Dict, List, Optional, Union, Mapping, Any, cast
from datetime import datetime
from nulltype import Null, NullType

###########################
# KeyStorage

[docs] class InMemoryKeyStorage(KeyStorage): """ Implementation of :class:`KeyStorage` where keys stored in memory. Intended for testing. """ def __init__(self): super().__init__() self.__keys: Dict[str, Key] = {} self.__keys_by_user_id: Dict[str|int, List[Key]] = {} self.__non_user_keys: List[Key] = []
[docs] def print(self): print("Key storage", self.__keys, self.__keys_by_user_id)
[docs] async def get_key(self, key: str) -> Key: if key in self.__keys: return self.__keys[key] CrossauthLogger.logger().debug(j({"msg": "Key does not exist in key storage"})) err = CrossauthError(ErrorCode.InvalidKey) CrossauthLogger.logger().debug(j({"err": str(err)})) raise err
[docs] async def save_key(self, userid: Optional[Union[str, int]], value: str, date_created: datetime, expires: Optional[datetime] = None, data: Optional[str] = None, extra_fields: Optional[Mapping[str, Any]] = None) -> None: key : Key = { "value" : value, "created": date_created, "expires": expires or Null, } if (userid is not None): key["userid"] = userid if (data is not None): key["data"] = data if (extra_fields is not None): for name in extra_fields: key[name] = extra_fields[name] self.__keys[value] = key if userid is not None: if userid not in self.__keys_by_user_id: self.__keys_by_user_id[userid] = [key] else: self.__keys_by_user_id[userid].append(key) else: self.__non_user_keys.append(key)
[docs] async def delete_key(self, value: str) -> None: if value in self.__keys: key = self.__keys[value] if "userid" in key: userid = key["userid"] if (userid != Null): del self.__keys_by_user_id[userid] # type: ignore else: self.__non_user_keys = [v for v in self.__non_user_keys if v["value"] != value] del self.__keys[value]
[docs] async def delete_all_for_user(self, userid: str|int|None, prefix: str, except_key: Optional[str] = None) -> None: self.__keys = {k: v for k, v in self.__keys.items() if ("userid" in v and v["userid"] != userid) or (except_key and k == except_key) or not k.startswith(prefix)} if userid: if userid in self.__keys_by_user_id: for key in self.__keys_by_user_id[userid]: new_keys : List[Key] = [] if (not key["value"].startswith(prefix)): new_keys.append(key) self.__keys_by_user_id[userid] = new_keys else: self.__non_user_keys = []
[docs] async def get_all_for_user(self, userid: str|int|None = None) -> List[Key]: if not userid: return self.__non_user_keys return self.__keys_by_user_id.get(userid, [])
[docs] async def delete_matching(self, key: PartialKey) -> None: delete_from : List[Key] = self.__non_user_keys if ("userid" in key and key["userid"] != None): if (key["userid"] not in self.__keys_by_user_id): return keyvalue = key["userid"] if (type(keyvalue) == NullType): return else: delete_from = self.__keys_by_user_id[keyvalue] # type: ignore matches : list[Key] = [] for entry in delete_from: is_a_match = True for field, value in key.items(): if (field not in entry or entry[field] != value): is_a_match = False if (is_a_match): matches.append(entry) for match in matches: delete_from.remove(match) del self.__keys[match["value"]]
[docs] async def update_key(self, key: PartialKey) -> None: if 'value' in key and key['value'] in self.__keys: for field, value in key.items(): setattr(self.__keys[key['value']], field, value)
[docs] async def update_data(self, key_name: str, data_name: str, value: Any|None) -> None: await self.update_many_data(key_name, [{"data_name": data_name, "value": value}])
[docs] async def update_many_data(self, key_name: str, data_array: List[KeyDataEntry]) -> None: key = await self.get_key(key_name) data : Dict[str, Any] = {} if ("data" in key and key["data"] != ""): data = json.loads(key["data"]) for item in data_array: if ("value" in item): if self._update_data_internal(data, item["data_name"], item["value"]): key["data"] = json.dumps(data) else: raise CrossauthError(ErrorCode.BadRequest, f"parents of {item['data_name']} not found in key data") else: self._delete_data_internal(data, item["data_name"])
[docs] async def delete_data(self, key_name: str, data_name: str) -> None: key = await self.get_key(key_name) if "data" not in key or key["data"] == "": return data = json.loads(key["data"]) if self._delete_data_internal(data, data_name): key["data"] = json.dumps(data)
########################### # UserStorage
[docs] class InMemoryUserStorageOptions(UserStorageOptions): pass
[docs] class InMemoryUserStorage(UserStorage): """ Implementation of :class:`KeyStorage` where keys stored in memory. Intended for testing. """ def __init__(self, options : InMemoryUserStorageOptions = {}): super().__init__(options) self.__users_by_username: Dict[str, User] = {} self.__users_by_email: Dict[str, User] = {} self.__secrets_by_username: Dict[str, UserSecrets] = {} self.__secrets_by_email: Dict[str, UserSecrets] = {}
[docs] async def get_user_by(self, field: str, value: str, options: UserStorageGetOptions = {}) -> UserAndSecrets: field_value_normalized = UserStorage.normalize(value) if self._normalize_username else value users_by_field = self.__users_by_username secrets_by_field = self.__secrets_by_username if (field == "email"): users_by_field = self.__users_by_email secrets_by_field = self.__secrets_by_email if field_value_normalized in users_by_field: user = users_by_field[field_value_normalized] secrets = secrets_by_field[field_value_normalized] if (optional_not_equals(options, "skip_active_check", True) and user["state"]==UserState.password_change_needed): CrossauthLogger.logger().debug(j({"msg": "Password change required"})) raise CrossauthError(ErrorCode.PasswordChangeNeeded) if (optional_not_equals(options, "skip_active_check", True) and (user["state"]==UserState.password_reset_needed or user["state"]==UserState.password_and_factor2_reset_needed)): CrossauthLogger.logger().debug(j({"msg": "Password reset required"})) raise CrossauthError(ErrorCode.PasswordResetNeeded) if (optional_not_equals(options, "skip_active_check", True) and user["state"]==UserState.factor2_reset_needed): CrossauthLogger.logger().debug(j({"msg": "2FA reset required"})) raise CrossauthError(ErrorCode.Factor2ResetNeeded) if (optional_not_equals(options, "skip_active_check", True) and user["state"]==UserState.awaiting_two_factor_setup): CrossauthLogger.logger().debug(j({"msg": "2FA setup is not complete"})) raise CrossauthError(ErrorCode.TwoFactorIncomplete); if (optional_not_equals(options, "skip_email_verified_check", True) and user['state'] == UserState.awaiting_email_verification): CrossauthLogger.logger().debug(j({"msg": "User email not verified"})) raise CrossauthError(ErrorCode.EmailNotVerified) if (optional_not_equals(options, "skip_active_check", True) and user['state'] == UserState.disabled): CrossauthLogger.logger().debug(j({"msg": "User is deactivated"})) raise CrossauthError(ErrorCode.UserNotActive) else: raise CrossauthError(ErrorCode.UserNotExist) return { "user": user, "secrets": secrets }
[docs] async def get_user_by_username(self, username: str, options: UserStorageGetOptions = {}) -> UserAndSecrets: return await self.get_user_by("username", username, options)
[docs] async def get_user_by_id(self, id: str|int, options: UserStorageGetOptions = {}) -> UserAndSecrets: if (type(id) == str): return await self.get_user_by("username", id, options) else: return await self.get_user_by("username", str(id), options)
[docs] async def get_user_by_email(self, email: str, options: UserStorageGetOptions = {}) -> UserAndSecrets: return await self.get_user_by("email", email, options)
[docs] async def create_user(self, user: UserInputFields, secrets: Optional[UserSecretsInputFields] = None) -> User: new_user : User = { **user, "id": user["username"], } if (self._normalize_username): new_user["username_normalized"] = UserStorage.normalize(user["username"]) new_secrets : UserSecrets = { "userid": new_user["id"] } if (secrets is not None): new_secrets : UserSecrets = { **secrets, "userid": new_user["id"] } if (self._normalize_username and "username_normalized" in new_user and new_user["username_normalized"] in self.__users_by_username): raise CrossauthError(ErrorCode.UserExists) if (not self._normalize_username and new_user["username"] in self.__users_by_username): raise CrossauthError(ErrorCode.UserExists) if ("email" in new_user and self._normalize_email): new_user["email_normalized"] = UserStorage.normalize(new_user["email"]) if (new_user["email_normalized"] in self.__users_by_email): raise CrossauthError(ErrorCode.UserExists) username_field = "username_normalized" if self._normalize_username else "username" email_field = "email_normalized" if self._normalize_email else "email" if (username_field not in new_user): raise CrossauthError(ErrorCode.Configuration, username_field + " not in user") self.__users_by_username[new_user[username_field]] = new_user # type: ignore self.__secrets_by_username[new_user[username_field]] = new_secrets # type: ignore if ("email" in new_user and email_field not in new_user): raise CrossauthError(ErrorCode.Configuration, email_field + " not in user") if (email_field in new_user): self.__users_by_email[new_user[email_field]] = new_user # type: ignore self.__secrets_by_email[new_user[email_field]] = new_secrets # type: ignore return new_user
[docs] async def delete_user_by_username(self, username: str) -> None: username_normalized = UserStorage.normalize(username) if self._normalize_username else username username_field = "username_normalized" if self._normalize_username else "username" email_field = "email_normalized" if self._normalize_username else "email" if username_normalized in self.__users_by_username: user = self.__users_by_username[username_normalized] if (username_field in user): email = user[email_field] # type: ignore del self.__users_by_email[email] del self.__secrets_by_email[email] del self.__users_by_username[username_normalized] del self.__secrets_by_username[username_normalized]
[docs] async def delete_user_by_id(self, id: str|int) -> None: if (type(id) == str): await self.delete_user_by_username(id) else: await self.delete_user_by_username(str(id))
[docs] async def get_users(self, skip: Optional[int] = None, take: Optional[int] = None) -> List[User]: keys : List[str] = list(self.__users_by_username.keys()) keys.sort() users : List[User] = [] if (skip is None): skip = 0 last = take if (take is not None) else len(keys) if (last >= len(keys)-skip): last = len(keys)-skip for i in range(skip, last): users.append(self.__users_by_username[keys[i]]) return users
[docs] async def update_user(self, user: PartialUser, secrets: Optional[PartialUserSecrets] = None) -> None: new_user : PartialUser = {**user} username_field = "username_normalized" if self._normalize_username else "username" if ("username" in new_user and self._normalize_username): new_user["username_normalized"] = UserStorage.normalize(new_user["username"]) elif ("id" in new_user): id = new_user["id"] idstr = id if (type(id) == str) else str(id) if (self._normalize_username): new_user["username_normalized"] = UserStorage.normalize(idstr) if ("email" in new_user and self._normalize_email): new_user["email_normalized"] = UserStorage.normalize(new_user["email"]) if (username_field in new_user and new_user[username_field] in self.__users_by_username): # type: ignore stored_user = self.__users_by_username[new_user[username_field]] # type: ignore for field in user: stored_user[field] = user[field] if (secrets is not None and username_field in new_user and new_user[username_field] in self.__secrets_by_username): # type: ignore stored_secrets = self.__secrets_by_username[new_user[username_field]] # type: ignore for field in secrets: stored_secrets[field] = secrets[field]
########################### # OAuthClientStorage
[docs] class InMemoryOAuthClientStorage(OAuthClientStorage): """ Implementation of :class:`KeyStorage` where keys stored in memory. Intended for testing. """ def __init__(self, options : OAuthClientStorageOptions = {}): super().__init__(options) self._clients_by_id : Dict[str, OAuthClient] = {} self._clients_by_name : Dict[str, List[OAuthClient]] = {}
[docs] async def get_client_by_id(self, client_id: str) -> OAuthClient: return self._clients_by_id[client_id]
[docs] async def get_client_by_name(self, name: str, userid: str|int|None|NullType = None) -> List[OAuthClient]: if (name not in self._clients_by_name): return [] all_by_name = self._clients_by_name[name] ret : List[OAuthClient] = [] for client in all_by_name: if (userid is None): ret.append(client) elif (userid == Null): if ("userid" in client and client["userid"] is None): ret.append(client) else: if ("userid" in client and client["userid"] == userid): ret.append(client) return ret
[docs] async def get_clients(self, skip: Optional[int] = None, take: Optional[int] = None, userid: str|int|None|NullType = None) -> List[OAuthClient]: client_ids = list(self._clients_by_id) client_ids.sort() if skip is None: skip = 0 if (take is None): take = 20 ret : List[OAuthClient] = [] skipped = 0 matches = False for client_id in client_ids: client = self._clients_by_id[client_id] matches = False if (userid is None): matches = True elif (userid == Null and "userid" in client and client["userid"] == None): matches = True elif ("userid" in client and client["userid"] == userid): matches = True if (matches): if (skipped >= skip): ret.append(client) if (len(ret) >= take): break else: skipped += 1 return ret
[docs] async def create_client(self, client: OAuthClient) -> OAuthClient: new_client : OAuthClient = cast(OAuthClient, { "client_secret": None, "userid": None, **client, }) self._clients_by_id[client["client_id"]] = new_client if (client["client_name"] not in self._clients_by_name): self._clients_by_name[client["client_name"]] = [new_client] else: l = self._clients_by_name[client["client_name"]] l.append(new_client) return new_client
[docs] async def update_client(self, client: PartialOAuthClient) -> None: if ("client_id" not in client): raise CrossauthError(ErrorCode.InvalidClientId, "Must give client_id when updating client") existing_client = self._clients_by_id[client["client_id"]] for key in client.keys(): existing_client[key] = client[key]
[docs] async def delete_client(self, client_id: str) -> None: if (client_id not in self._clients_by_id): return client = self._clients_by_id[client_id] if (client["client_name"] in self._clients_by_name): client_list = self._clients_by_name[client["client_name"]] for i in range(len(client_list)): if client_list[i]["client_id"] == client_id: del client_list[i] break del self._clients_by_id[client_id]
########################### # OAuthAuthorizationStorage
[docs] class InMemoryOAuthAuthorizationStorage(OAuthAuthorizationStorage): """ Implementation of :class:`KeyStorage` where keys stored in memory. Intended for testing. """ def __init__(self, options : OAuthAuthorizationStorageOptions = {}): super().__init__(options) self._scopes_by_client_id : Dict[str, Dict[str, List[str|None]]] = {}
[docs] async def get_authorizations(self, client_id: str, userid: str|int|None = None) -> List[Optional[str]]: if (client_id not in self._scopes_by_client_id): return [] all_scopes = self._scopes_by_client_id[client_id] ret : List[str|None] = [] if (userid is None): userid = "" else: userid = str(userid) if (userid in all_scopes): scopes = all_scopes[userid] ret = scopes return ret
[docs] async def update_authorizations(self, client_id: str, userid: str|int|None, authorizations: List[str|None]) -> None: if (userid is None): userid = "" by_client_id : Dict[str, List[str|None]] = {} if (client_id not in self._scopes_by_client_id): self._scopes_by_client_id[client_id] = by_client_id else: by_client_id = self._scopes_by_client_id[client_id] by_client_id[str(userid)] = authorizations