diff --git a/CHANGELOG.md b/CHANGELOG.md index f36014b..b081a9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,26 +1,32 @@ # Heksher Changelog -## Next (REQ alembic upgrade) +## 0.5.0 (REQ alembic upgrade) ### Removed -* old api endpoint POST /api/v1/rules/query has been removed and replaced with GET /api/v1/rules/query +* old api endpoint POST /api/v1/rules/query has been removed and replaced with GET /api/v1/query ### Changed * the rename api endpoint has been changed to PUT /api/v1//name. * the method of the endpoint /api/v1/rules/search has been changed to GET. * All setting now must have a default value. * Setting declarations are now versioned. +* `HEKSHER_STARTUP_CONTEXT_FEATURES` is now optional. +* The inputs value for add_rule, value for patch_rule, value for put rule metadata key, value for + put setting metadata key, are now required. ### Deprecated * The api endpoint PATCH /api/v1/rules/ to change a rule's value is now deprecated, new users should use PUT /api/v1/rules//value ### Added * declarations are now tolerant of subtypes (to account for previous type upgrade) * documentation +* Added endpoint PUT /api/v1/settings//configurable_features * The api endpoint PUT /api/v1/rules//value to change a rule's value -* The api endpoint GET /api/v1/rules/query to query rules (replaces the old query endpoint) +* The api endpoint GET /api/v1/query to query rules (replaces the old query endpoint) * POST /api/v1/rules now returns the rule location in the header ### Fixed * A bug where patching a context feature's index using "to_before" would use the incorrect target. ### Internal * a new script to test and correctly report coverage * tools/mk_revision.py to easily create alembic revisions +* all db logic refactored to avoid multiple connections +* Many more column are now strictly non-nullable ## 0.4.1 ### Removed * removed the alembic extra, it's now a requirement diff --git a/alembic/versions/83f0f272313f_added_setting_versioning_default_is_now_.py b/alembic/versions/83f0f272313f_added_setting_versioning_default_is_now_.py new file mode 100644 index 0000000..5586bee --- /dev/null +++ b/alembic/versions/83f0f272313f_added_setting_versioning_default_is_now_.py @@ -0,0 +1,86 @@ +""""added setting versioning + default is now required" + +Revision ID: 83f0f272313f +Revises: 7713520cb02b +Create Date: 2022-01-04 14:04:46.416415 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = '83f0f272313f' +down_revision = '7713520cb02b' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('conditions', 'rule', + existing_type=sa.INTEGER(), + nullable=False) + op.alter_column('conditions', 'context_feature', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('configurable', 'setting', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('configurable', 'context_feature', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('rule_metadata', 'rule', + existing_type=sa.INTEGER(), + nullable=False) + op.alter_column('rules', 'setting', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('setting_aliases', 'setting', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('setting_metadata', 'setting', + existing_type=sa.VARCHAR(), + nullable=False) + op.add_column('settings', sa.Column('version', sa.String(), nullable=True)) + op.execute("UPDATE settings SET version = '1.0'") + op.alter_column('settings', 'default_value', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('settings', 'default_value', + existing_type=sa.VARCHAR(), + nullable=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('settings', 'default_value', + existing_type=sa.VARCHAR(), + nullable=True) + op.drop_column('settings', 'version') + op.alter_column('setting_metadata', 'setting', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('setting_aliases', 'setting', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('rules', 'setting', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('rule_metadata', 'rule', + existing_type=sa.INTEGER(), + nullable=True) + op.alter_column('configurable', 'context_feature', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('configurable', 'setting', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('conditions', 'context_feature', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('conditions', 'rule', + existing_type=sa.INTEGER(), + nullable=True) + # ### end Alembic commands ### diff --git a/docs/api.rst b/docs/api.rst index 45f9df4..73d1316 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -10,7 +10,7 @@ Unless otherwise noted, all responses have the status code 200. Since Heksher is a FastAPI service, the API can also be accessed via the redoc endpoint ``/redoc``. The most common endpoints for users are :ref:`setting declaration `, -and :ref:`rule querying ` +and :ref:`querying ` General ------- @@ -36,6 +36,69 @@ Heksher server. Rather than checking the database connection on every call, heksher performs an automatic health check every 5 seconds. Therefore, all health responses may be up to 5 seconds out of date. +GET /api/v1/query +************************** + +.. note:: + + This should be the primary endpoint that users call to get rule and setting default values. + +.. note:: + + This endpoint supports the + `If-None-Match `_ header. It also returns + an `ETag `_ header on successful responses. + +Query the rules in the service, filtering in only rules pertaining to specific settings and contexts. + +Query parameters: + +* **settings**: A comma-seperated list of the names of the settings to query. If specified only rules that apply to one + of the settings in this list will be returned. Example: ``../api/v1/query?settings=foo,bar`` +* **context_filters**: A comma-seperated list of filters to filter rules by their context. If any filters are specified + only rules all of whose exact-match context conditions match the relevant filters will be returned. Each filter is + is a colon-separated pair. The first element of the pair is the context feature name, the second element is either + the special character ``*`` to accept all values of the context feature, or a comma-seperated list of the values + in parentheses. Example: ``../api/v1/query?context_filters=foo:*,bar:(a,b)``. Alternatively, the context_filters + can be the special character ``*`` to accept all context features (this is the default behaviour). + + .. note:: Context Filter Example + + Assuming a setting has the context features ``X``, ``Y``, and ``Z``, and the following rules: + + .. csv-table:: + :header: "X", "Y", "Z", "**rule_id**" + + "x_0", "\*", "\*", "1" + "x_1", "\*", "\*", "2" + "x_0", "y_0", "\*", "3" + "x_0", "y_1", "\*", "4" + "x_2", "y_0", "\*", "5" + "\*", "\*", "z_0", "6" + "x_0", "\*", "z_0", "7" + + The the context filter: ``X:(x_0,x_1),Y:*`` will only allow the rules ``1``, ``2``, ``3``, and ``4``. Rule ``5`` will + be rejected because it's X condition is not in the X filter's list of values. Rules ``6`` and ``7`` will be rejected + because they have a Z condition and there is no Z filter. + +* **include_metadata** (default false): If true, then the metadata associated with each rule will be included in + the results. + +Response: + +* **settings**: A dictionary that maps setting names to query results of that setting and pass the filters in the + request. Each value is a dictionary with the following keys: + + * **rules**: A list of rule dictionaries, that contains all teh rules that met the query criteria. Each rule + dictionary has the following keys: + + * **value**: The value a setting should take if the rule is matched. + * **feature_values**: An array of 2-str-arrays of the context feature names and values that the rule applies to, in order + of the context features. + * **metadata**: A dictionary of metadata associated with the rule. Only present if include_metadata is true. + + * **default_value**: The default value of the setting. + Context Features ----------------- @@ -157,67 +220,6 @@ PATCH /api/v1/rules/ A deprecated route that is equivalent to `PUT /api/v1/rules//value`_. -GET /api/v1/rules/query -************************** - -.. note:: - - This should be the primary endpoint that users call to get rules. - -.. note:: - - This endpoint supports the - `If-None-Match `_ header. It also returns - an `ETag `_ header on successful responses. - -Query the rules in the service, filtering in only rules pertaining to specific settings and contexts. - -Query parameters: - -* **settings**: A comma-seperated list of the names of the settings to query. If specified only rules that apply to one - of the settings in this list will be returned. Example: ``../api/v1/rules/query?settings=foo,bar`` -* **context_filters**: A comma-seperated list of filters to filter rules by their context. If any filters are specified - only rules all of whose exact-match context conditions match the relevant filters will be returned. Each filter is - is a colon-separated pair. The first element of the pair is the context feature name, the second element is either - the special character ``*`` to accept all values of the context feature, or a comma-seperated list of the values - in parentheses. Example: ``../api/v1/rules/query?context_filters=foo:*,bar:(a,b)``. Alternatively, the context_filters - can be the special character ``*`` to accept all context features (this is the default behaviour). - - .. note:: Context Filter Example - - Assuming a setting has the context features ``X``, ``Y``, and ``Z``, and the following rules: - - .. csv-table:: - :header: "X", "Y", "Z", "**rule_id**" - - "x_0", "\*", "\*", "1" - "x_1", "\*", "\*", "2" - "x_0", "y_0", "\*", "3" - "x_0", "y_1", "\*", "4" - "x_2", "y_0", "\*", "5" - "\*", "\*", "z_0", "6" - "x_0", "\*", "z_0", "7" - - The the context filter: ``X:(x_0,x_1),Y:*`` will only allow the rules ``1``, ``2``, ``3``, and ``4``. Rule ``5`` will - be rejected because it's X condition is not in the X filter's list of values. Rules ``6`` and ``7`` will be rejected - because they have a Z condition and there is no Z filter. - -* **include_metadata** (default false): If true, then the metadata associated with each rule will be included in - the results. - -Response: - -* **settings**: A dictionary that maps setting names to query results of that setting and pass the filters in the - request. Each rule is a dictionary with the following keys: - - * **rules**: A list of rule dictionaries, that contains all teh rules that met the query criteria. Each rule - dictionary has the following keys: - - * **value**: The value a setting should take if the rule is matched. - * **feature_values**: An array of 2-str-arrays of the context feature names and values that the rule applies to, in order - of the context features. - * **metadata**: A dictionary of metadata associated with the rule. Only present if include_metadata is true. - GET /api/v1/rules/ *************************** @@ -300,7 +302,7 @@ POST /api/v1/settings/declare This is the primary endpoint that users call to create and assert the state of settings. Declare that a setting will be used by a service. This endpoint can be used to create new settings or change attributes -of existing settings (while retaining compatibility). +of existing settings (while retaining compatibility, see :ref:`setting_versions:Setting Versions`). Request: @@ -309,36 +311,49 @@ Request: * **type**: The type of the setting. (see :ref:`setting_types:Setting Types`) * **default_value** (optional): The default value of the setting. * **metadata** (optional): A dictionary of metadata associated with the setting. -* **alias** (optional): An alias of the setting. +* **alias** (optional): An alias of the setting. Must either be an existing alias of the setting, or a canonical name of + an existing setting. +* **version** (optional): The version of the setting declaration, defaults to "1.0". + Response: -* **created**: True if the setting was created, false if it already existed. -* **changed**: An array of strings that describe the attributes of the setting that changed due to the declaration. -* **incomplete**: An dictionary describes the attributes of the setting were declared in an incomplete manner. The - dictionary maps attribute names to their complete values. +* **outcome**: one of the following values: + * ``"created"``: The setting was newly created. + * ``"uptodate"``: The setting declaration matches the latest declaration. + * ``"upgraded"``: The setting's attributes were changed to reflect this new declaration. + * ``"outdated"``: This declaration is superseded by a newer declaration. It is up to the user whether to proceed. + * ``"rejected"``: The setting's attributes were not changed due to an incompatible difference with the newer + version. In this case, the response code will be 409. + * ``"mismatch"``: the setting's declaration is not compatible with the current version of the service. In this + case, the response code will be 409. +* **latest_version**: The latest version of the setting declaration. Only present for ``"outdated"`` outcomes. +* **previous_version**: The previous version of the setting declaration. Only present for ``"upgraded"`` and + ``"rejected"`` outcomes. +* **differences**: A list of differences between the request declaration and the latest declaration. Only present for + ``"outdated"``, ``"upgraded"``, ``"rejected"``, and ``"mismatch"`` outcomes. Each difference is a dictionary with the + following possible keys: + + * **level**: one of the following values: + + * ``"minor"``: The difference is fully backwards compatible with previous declarations (of the same major version). + * ``"major"``: The difference is incompatible with previous declarations. + * ``"mismatch"``: The difference cannot be implemented because it would break internal logic. + + * **attribute**: The name of the attribute that is different. Either this key or the "message" key exists. + * **latest_value**: The value of the attribute in the latest declaration. Either this key or the "message" key + exists. + * **message**: A human-readable description of the difference. + + .. note:: + If the outcome is "outdated", then all the differences will be in the sense of the differences that occurred + since that declaration. Meaning that if the declaration request has one more configurable feature than the + latest declaration, then the change will have a level of "minor". + If there is a difference between the setting's declared and actual values that cannot be consolidated, a 409 response will be returned. -Heksher will attempt to consolidate the following differences, if they exist: - -* If the declaration contains configurable_features that do not exist in the setting, they will be added to the setting. - - * If the declaration does not contains configurable_features that do exist in the setting, they will **not** be removed - from the setting, the complete value will be indicated in the response. - -* If the type declared is a supertype of the actual type, the actual type will be updated to the declared type. - - * If the type declared is a subtype of the actual type, the complete value will be indicated in the response. - -* If the default value declared is different from the actual default value, the actual default value will be updated to - the declared default value. -* If the metadata declared is different from the actual metadata, the actual metadata will be changed to the declared - metadata. -* If the alias refers to an existing setting, and the name is not an existing setting. Then the old setting (under - alias) will be renamed to the new name, and the old name will be added as an alias to it. - DELETE /api/v1/settings/ ****************************** @@ -359,6 +374,7 @@ Response: * **default_value**: The default value of the setting. * **metadata**: A dictionary of metadata associated with the setting. * **aliases**: A list aliases of the setting. +* **version**: The version of the latest setting declaration. GET /api/v1/settings ********************** @@ -375,10 +391,11 @@ Response: * **settings**: A list of dictionaries describing each setting. Each element of the list is of the schema: * **name**: The name of the setting. + * **type**: The type of the setting. + * **default_value**: The default value of the setting. + * **version**: The version of the latest setting declaration. * **configurable_features**: A list of context feature names that the setting will be configurable with. Only included if include_additional_data is true. - * **type**: The type of the setting. Only included if include_additional_data is true. - * **default_value**: The default value of the setting. Only included if include_additional_data is true. * **metadata**: A dictionary of metadata associated with the setting. Only included if include_additional_data is true. * **aliases**: A list aliases of the setting. Only included if include_additional_data is true. @@ -390,11 +407,12 @@ Change a setting's type in a way that is not necessarily backwards compatible. Request: * **type**: The new type of the setting. +* **version**: The version of the setting declaration. The type will only be changed if the default value of the setting and the values of a all the rules of the setting are compatible with the new type. If this the case, an empty 204 response will be returned. -Other wise, the 409 response will have the schema: +If there are type conflicts, the 409 response will have the schema: * **conflicts**: A list of strings describing the conflicts. @@ -406,11 +424,25 @@ Rename a setting. Request: * **name**: The new name of the setting. +* **version**: The version of the setting declaration. The name will only be changed if the name is not already in use. If this the case, the old name will be added as an alias to the setting and an empty 204 response will be returned. -If the new name is already in use, the 409 response will be returned. +If the new name is already in use, or if the version is incompatible with the latest declaration, a 409 response will +be returned. + +PUT /api/v1/settings/setting_name>/configurable_features +*********************************************************** + +Change the configurable features of a setting. + +Request: + +* **configurable_features**: A list of context feature names that the setting will be configurable with. +* **version**: The version of the setting declaration. + +Response is an empty 204 response. POST /api/v1/settings//metadata ************************************************ @@ -420,6 +452,7 @@ Update a setting's metadata. This will not delete existing keys, but might overw Request: * **metadata**: A dictionary of metadata to associate with the setting. +* **version**: The version of the setting declaration. Response is an empty 204 response. @@ -431,6 +464,7 @@ Set a setting's metadata. This will overwrite any existing metadata. Request: * **metadata**: A dictionary of metadata to associate with the setting. +* **version**: The version of the setting declaration. Response is an empty 204 response. @@ -440,6 +474,10 @@ DELETE /api/v1/settings//metadata Remove all metadata associated with a setting. This is equivalent to calling `PUT /api/v1/settings//metadata`_ with an empty dictionary. +Request: + +* **version**: The version of the setting declaration. + Response is an empty 204 response. @@ -460,6 +498,7 @@ Set the value of a key in a setting's metadata. Request: * **value**: The value to associate with the key. +* **version**: The version of the setting declaration. Response is an empty 204 response. @@ -468,4 +507,8 @@ DELETE /api/v1/settings//metadata/ Remove a key from a setting's metadata. +Request: + +* **version**: The version of the setting declaration. + Response is an empty 204 response. \ No newline at end of file diff --git a/docs/concepts.rst b/docs/concepts.rst index af01bd2..794c8cf 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -13,7 +13,8 @@ attributes: * type: the :ref:`types ` of the setting. * metadata: additional metadata about the setting. * aliases: a list of :ref:`aliases ` of the setting. -* default: the default value of the setting. Optional. +* default: the default value of the setting. +* version: the latest version of the setting declaration, see :ref:`setting_versioning:Setting Versioning`. Context Features ----------------------- diff --git a/docs/index.rst b/docs/index.rst index fa0d679..4e7de24 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,8 +9,8 @@ Welcome to the Heksher documentation! concepts running setting_types - api setting_versions + api setting_aliases cookbook libraries diff --git a/docs/running.rst b/docs/running.rst index b9f65f7..101b9a2 100644 --- a/docs/running.rst +++ b/docs/running.rst @@ -34,9 +34,9 @@ Additional Environment Variables ------------------------------------------- * **HEKSHER_DB_CONNECTION_STRING**: The database connection string (see `above `_). -* **HEKSHER_STARTUP_CONTEXT_FEATURES**: A semicolon-delimited list of :ref:`concepts:context features`, in order. - Heksher will adapt the database's existing context features to this list (or raise an error if it cannot). For - example: ``user;trust;theme`` +* **HEKSHER_STARTUP_CONTEXT_FEATURES** (optional): A semicolon-delimited list of :ref:`concepts:context features`, in + order. If present, Heksher will adapt the database's existing context features to this list (or raise an error if it + cannot). For example: ``user;trust;theme`` The following environment variables are optional for logging: diff --git a/docs/setting_versions.rst b/docs/setting_versions.rst index b90962f..0e99d76 100644 --- a/docs/setting_versions.rst +++ b/docs/setting_versions.rst @@ -50,4 +50,15 @@ the following changes: There are some changes that are never acceptable, as they would break the logic of the application. These are: * Changing a setting type to a value that does not accept the value of at least one rule of the setting. -* Removing configurable features that are matched by at least one rule of the setting. \ No newline at end of file +* Removing configurable features that are matched by at least one rule of the setting. + +Explicit Versioning +-------------------- + +For most use cases, upgrading a setting via the declaration API is sufficient. However, there are some changes that +might fail depending on the state of the ruleset of the service. If these conflicts are encountered with the declaration +API, our app might fail. To avoid this, these potentially conflicting changes can be made explicit with explicit API +calls. These API endpoints are: + +* :ref:`PUT /api/v1/settings/setting_name>/configurable_features` +* :ref:`PUT /api/v1/settings/setting_name>/type` \ No newline at end of file diff --git a/heksher/api/v1/__init__.py b/heksher/api/v1/__init__.py index 473cdaa..fa86712 100644 --- a/heksher/api/v1/__init__.py +++ b/heksher/api/v1/__init__.py @@ -1,4 +1,5 @@ import heksher.api.v1.context_features # noqa: F401 +import heksher.api.v1.query # noqa: F401 import heksher.api.v1.rules # noqa: F401 import heksher.api.v1.settings # noqa: F401 from heksher.api.v1.util import router diff --git a/heksher/api/v1/context_features.py b/heksher/api/v1/context_features.py index 4052f67..b5cbaf3 100644 --- a/heksher/api/v1/context_features.py +++ b/heksher/api/v1/context_features.py @@ -8,6 +8,10 @@ from heksher.api.v1.util import ORJSONModel, application, router as v1_router from heksher.api.v1.validation import ContextFeatureName from heksher.app import HeksherApp +from heksher.db_logic.context_feature import ( + db_add_context_feature_to_end, db_delete_context_feature, db_get_context_feature_index, db_get_context_features, + db_is_configurable_setting_from_context_features, db_move_after_context_feature +) router = APIRouter(prefix='/context_features') @@ -21,7 +25,9 @@ async def check_context_features(app: HeksherApp = application): """ Get a listing of all the context features, in their hierarchical order. """ - return GetContextFeaturesResponse(context_features=await app.db_logic.get_context_features()) + async with app.engine.connect() as conn: + cfs = await db_get_context_features(conn) + return GetContextFeaturesResponse(context_features=(name for (name, _) in cfs)) class GetContextFeatureResponse(ORJSONModel): @@ -33,7 +39,8 @@ async def get_context_feature(name: str, app: HeksherApp = application): """ Returns the index of the context feature; If it doesn't exists, returns status code 404. """ - index = await app.db_logic.get_context_feature_index(name) + async with app.engine.connect() as conn: + index = await db_get_context_feature_index(conn, name) if index is None: return Response(status_code=status.HTTP_404_NOT_FOUND) return GetContextFeatureResponse(index=index) @@ -50,13 +57,14 @@ async def delete_context_feature(name: str, app: HeksherApp = application): """ Deletes context feature. """ - if await app.db_logic.get_context_feature_index(name) is None: - return Response(status_code=status.HTTP_404_NOT_FOUND) - if await app.db_logic.is_configurable_setting_from_context_features(name): - # if there is setting configured to use the context feature, it can't be deleted - return PlainTextResponse("context feature can't be deleted, there is at least one setting configured by it", - status_code=status.HTTP_409_CONFLICT) - await app.db_logic.delete_context_feature(name) + async with app.engine.begin() as conn: + if await db_get_context_feature_index(conn, name) is None: + return Response(status_code=status.HTTP_404_NOT_FOUND) + if await db_is_configurable_setting_from_context_features(conn, name): + # if there is setting configured to use the context feature, it can't be deleted + return PlainTextResponse("context feature can't be deleted, there is at least one setting configured by it", + status_code=status.HTTP_409_CONFLICT) + await db_delete_context_feature(conn, name) class PatchAfterContextFeatureInput(ORJSONModel): @@ -83,15 +91,16 @@ async def patch_context_feature(name: str, input: Union[PatchAfterContextFeature """ Modify existing context feature's index """ - index_to_move = await app.db_logic.get_context_feature_index(name) - target_index = await app.db_logic.get_context_feature_index(input.target) - if index_to_move is None or target_index is None: - return Response(status_code=status.HTTP_404_NOT_FOUND) - if isinstance(input, PatchBeforeContextFeatureInput): - target_index -= 1 - if index_to_move == target_index: - return None - await app.db_logic.move_after_context_feature(index_to_move, target_index) + async with app.engine.begin() as conn: + index_to_move = await db_get_context_feature_index(conn, name) + target_index = await db_get_context_feature_index(conn, input.target) + if index_to_move is None or target_index is None: + return Response(status_code=status.HTTP_404_NOT_FOUND) + if isinstance(input, PatchBeforeContextFeatureInput): + target_index -= 1 + if index_to_move == target_index: + return None + await db_move_after_context_feature(conn, index_to_move, target_index) class AddContextFeatureInput(ORJSONModel): @@ -103,10 +112,11 @@ async def add_context_feature(input: AddContextFeatureInput, app: HeksherApp = a """ Add a context feature to the end of the context features. """ - existing_context_feature = await app.db_logic.get_context_feature_index(input.context_feature) - if existing_context_feature is not None: - return PlainTextResponse('context feature already exists', status_code=status.HTTP_409_CONFLICT) - await app.db_logic.add_context_feature(input.context_feature) + async with app.engine.begin() as conn: + existing_context_feature = await db_get_context_feature_index(conn, input.context_feature) + if existing_context_feature is not None: + return PlainTextResponse('context feature already exists', status_code=status.HTTP_409_CONFLICT) + await db_add_context_feature_to_end(conn, input.context_feature) v1_router.include_router(router) diff --git a/heksher/api/v1/query.py b/heksher/api/v1/query.py new file mode 100644 index 0000000..ee2ddbc --- /dev/null +++ b/heksher/api/v1/query.py @@ -0,0 +1,149 @@ +import re +from typing import Any, Dict, List, Optional, Tuple, Union + +from fastapi import Query +from pydantic import Field +from starlette import status +from starlette.requests import Request +from starlette.responses import PlainTextResponse + +from heksher.api.v1.util import ORJSONModel, PydanticResponse, application, handle_etag, router as v1_router +from heksher.api.v1.validation import MetadataKey +from heksher.app import HeksherApp +from heksher.db_logic.context_feature import db_get_not_found_context_features +from heksher.db_logic.rule import db_query_rules +from heksher.db_logic.setting import db_get_canonical_names, db_get_settings + + +# https://github.com/tiangolo/fastapi/issues/2724 +class QueryRulesOutput_Rule(ORJSONModel): + value: Any = Field(description="the value of the setting in contexts where the rule matches") + context_features: List[Tuple[str, str]] = Field( + description="a list of exact-match conditions for the rule, in hierarchical order" + ) + rule_id: int = Field(description="unique identifier of the rule.") + + +class QueryRulesOutput_Setting(ORJSONModel): + default_value: Any = Field(description="the default value of the setting") + rules: List[QueryRulesOutput_Rule] = Field(description="a list of rules for the setting") + + +class QueryRulesOutput(ORJSONModel): + settings: Dict[str, QueryRulesOutput_Setting] = Field(description="query results for each setting") + + +class QueryRulesOutputWithMetadata_Rule(QueryRulesOutput_Rule): + metadata: Dict[MetadataKey, Any] = Field(description="the metadata of the rule, if requested") + + +class QueryRulesOutputWithMetadata_Setting(ORJSONModel): + default_value: Any = Field(description="the default value of the setting") + rules: List[QueryRulesOutputWithMetadata_Rule] = Field(description="a list of rules for the setting") + + +class QueryRulesOutputWithMetadata(ORJSONModel): + settings: Dict[str, QueryRulesOutputWithMetadata_Setting] = Field(description="query results for each setting") + + +raw_context_feature_filters_pattern = r'''(?x) +( + \* # we allow an explicit global wildcard + | + [a-zA-Z_0-9]+: # context feature name + ( + \* # accept any value + | + \( + [a-zA-Z_0-9]+ # value + (,[a-zA-Z_0-9]+)* # additional values + \) + ) + ( + ,[a-zA-Z_0-9]+: (\* | \([a-zA-Z_0-9]+ (,[a-zA-Z_0-9]+)*\)) # additional filters + )* +)? # we also allow empty string to signify no rules could match +$ +''' + + +@v1_router.get('/query', response_model=Union[QueryRulesOutputWithMetadata, QueryRulesOutput]) # type: ignore +async def query_rules(request: Request, app: HeksherApp = application, + raw_settings: str = Query(None, alias='settings', + description="a comma-separated list of setting names", + regex='([a-zA-Z_0-9.]+(,[a-zA-Z_0-9.]+)*)?$'), + raw_context_filters: str = Query( + '*', alias='context_filters', + description="a comma-separated list of context feature filters", + regex=raw_context_feature_filters_pattern, example=["a:(X,Y),b:(Z),c:*", '*', 'a:*']), + include_metadata: bool = Query(False, description="whether to include rule metadata in the" + " response"), + ): + async with app.engine.connect() as conn: + if raw_settings is None: + all_settings = await db_get_settings(conn, include_configurable_features=False, + include_aliases=False, + include_metadata=False) + settings = list(all_settings.keys()) + defaults = {k: v.default_value for (k, v) in all_settings.items()} + elif not raw_settings: + settings = [] + defaults = {} + else: + names = raw_settings.split(',') + aliases = await db_get_canonical_names(conn, names) + not_settings = [k for k, v in aliases.items() if not v] + if not_settings: + return PlainTextResponse(f'the following are not setting names: {not_settings}', + status_code=status.HTTP_404_NOT_FOUND) + settings = list(aliases.values()) + settings_data = await db_get_settings(conn, include_configurable_features=False, + include_aliases=False, + include_metadata=False, setting_names=settings) + defaults = {k: v.default_value for (k, v) in settings_data.items()} + + if raw_context_filters == '*': + context_features_options: Optional[Dict[str, Optional[List[str]]]] = None + else: + context_filter_items = ((match['key'], (None if match['values'] is None else match['values'].split(','))) + for match in + re.finditer(r'(?P[a-z]+):(\((?P[^)]+)\)|\*)', raw_context_filters)) + context_features_options = {} + for k, v in context_filter_items: + if k in context_features_options: + return PlainTextResponse(f'context name repeated in context filter: {k}', + status_code=status.HTTP_400_BAD_REQUEST) + context_features_options[k] = v + + not_context_features = await db_get_not_found_context_features(conn, context_features_options) + if not_context_features: + return PlainTextResponse(f'the following are not valid context features: {not_context_features}', + status_code=status.HTTP_404_NOT_FOUND) + results = await db_query_rules(conn, settings, context_features_options, include_metadata) + if include_metadata: + ret: Union[QueryRulesOutputWithMetadata, QueryRulesOutput] = QueryRulesOutputWithMetadata( + settings={ + setting: + QueryRulesOutputWithMetadata_Setting(rules=[ + QueryRulesOutputWithMetadata_Rule( + value=rule.value, context_features=rule.feature_values, metadata=rule.metadata, + rule_id=rule.rule_id + ) + for rule in rules + ], default_value=defaults[setting]) for setting, rules in results.items() + }) + else: + ret = QueryRulesOutput( + settings={ + setting: + QueryRulesOutput_Setting(rules=[ + QueryRulesOutput_Rule( + value=rule.value, context_features=rule.feature_values, metadata=rule.metadata, + rule_id=rule.rule_id + ) + for rule in rules + ], default_value=defaults[setting]) for setting, rules in results.items() + }) + response = PydanticResponse(ret) + handle_etag(response, request) + return response diff --git a/heksher/api/v1/rules.py b/heksher/api/v1/rules.py index cf3327c..2ad16f7 100644 --- a/heksher/api/v1/rules.py +++ b/heksher/api/v1/rules.py @@ -1,17 +1,17 @@ -import re from logging import getLogger -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple from fastapi import APIRouter, Query, Response from pydantic import Field, validator from starlette import status -from starlette.requests import Request from starlette.responses import PlainTextResponse from heksher.api.v1.rules_metadata import router as metadata_router -from heksher.api.v1.util import ORJSONModel, PydanticResponse, application, handle_etag, router as v1_router +from heksher.api.v1.util import ORJSONModel, PydanticResponse, application, router as v1_router from heksher.api.v1.validation import ContextFeatureName, ContextFeatureValue, MetadataKey, SettingName from heksher.app import HeksherApp +from heksher.db_logic.rule import db_add_rule, db_delete_rule, db_get_rule, db_get_rule_id, db_patch_rule +from heksher.db_logic.setting import db_get_setting router = APIRouter(prefix='/rules') logger = getLogger(__name__) @@ -22,12 +22,13 @@ async def delete_rule(rule_id: int, app: HeksherApp = application): """ Remove a rule. """ - rule_spec = await app.db_logic.get_rule(rule_id, include_metadata=False) + async with app.engine.begin() as conn: + rule_spec = await db_get_rule(conn, rule_id, include_metadata=False) - if not rule_spec: - return Response(status_code=status.HTTP_404_NOT_FOUND) + if not rule_spec: + return Response(status_code=status.HTTP_404_NOT_FOUND) - await app.db_logic.delete_rule(rule_id) + await db_delete_rule(conn, rule_id) class SearchRuleOutput(ORJSONModel): @@ -47,15 +48,17 @@ async def search_rule(app: HeksherApp = application, """ Get the ID of a rule with specific conditions. """ - canon_setting = await app.db_logic.get_setting(setting, include_metadata=False, include_configurable_features=False, - include_aliases=False) # for aliasing - if not canon_setting: - return Response(status_code=status.HTTP_404_NOT_FOUND) - feature_values_dict: Dict[str, str] = dict(part.split(':') for part in feature_values.split(',')) # type: ignore - rule_id = await app.db_logic.get_rule_id(canon_setting.name, feature_values_dict) - if not rule_id: - return Response(status_code=status.HTTP_404_NOT_FOUND) - return SearchRuleOutput(rule_id=rule_id) + async with app.engine.connect() as conn: + canon_setting = await db_get_setting(conn, setting, include_metadata=False, include_configurable_features=False, + include_aliases=False) # for aliasing + if not canon_setting: + return Response(status_code=status.HTTP_404_NOT_FOUND) + feature_values_dict: Dict[str, str] = dict( + part.split(':') for part in feature_values.split(',')) # type: ignore + rule_id = await db_get_rule_id(conn, canon_setting.name, feature_values_dict) + if not rule_id: + return Response(status_code=status.HTTP_404_NOT_FOUND) + return SearchRuleOutput(rule_id=rule_id) class AddRuleInput(ORJSONModel): @@ -82,25 +85,26 @@ async def add_rule(input: AddRuleInput, app: HeksherApp = application): """ Add a rule, and get its ID. """ - setting = await app.db_logic.get_setting(input.setting, include_metadata=False, include_configurable_features=True, - include_aliases=False) - if not setting: - return PlainTextResponse(f'setting not found with name {input.setting}', - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) - not_configurable = input.feature_values.keys() - setting.configurable_features - if not_configurable: - return PlainTextResponse(f'setting is not configurable at context features {not_configurable}', - status_code=status.HTTP_400_BAD_REQUEST) - - if not setting.type.validate(input.value): - return PlainTextResponse(f'rule value is incompatible with setting type {setting.type}', - status_code=status.HTTP_400_BAD_REQUEST) - - existing_rule = await app.db_logic.get_rule_id(setting.name, input.feature_values) - if existing_rule: - return PlainTextResponse('rule already exists', status_code=status.HTTP_409_CONFLICT) - - new_id = await app.db_logic.add_rule(setting.name, input.value, input.metadata, input.feature_values) + async with app.engine.begin() as conn: + setting = await db_get_setting(conn, input.setting, include_metadata=False, include_configurable_features=True, + include_aliases=False) + if not setting: + return PlainTextResponse(f'setting not found with name {input.setting}', + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) + not_configurable = input.feature_values.keys() - setting.configurable_features + if not_configurable: + return PlainTextResponse(f'setting is not configurable at context features {not_configurable}', + status_code=status.HTTP_400_BAD_REQUEST) + + if not setting.type.validate(input.value): + return PlainTextResponse(f'rule value is incompatible with setting type {setting.type}', + status_code=status.HTTP_400_BAD_REQUEST) + + existing_rule = await db_get_rule_id(conn, setting.name, input.feature_values) + if existing_rule: + return PlainTextResponse('rule already exists', status_code=status.HTTP_409_CONFLICT) + + new_id = await db_add_rule(conn, setting.name, input.value, input.metadata, input.feature_values) return PydanticResponse(AddRuleOutput(rule_id=new_id), headers={'Location': f'/{new_id}'}, status_code=status.HTTP_201_CREATED) @@ -116,143 +120,20 @@ async def patch_rule(rule_id: int, input: PatchRuleInput, app: HeksherApp = appl """ Modify existing rule's value """ - rule = await app.db_logic.get_rule(rule_id, include_metadata=False) - if not rule: - return PlainTextResponse(status_code=status.HTTP_404_NOT_FOUND) + async with app.engine.begin() as conn: + rule = await db_get_rule(conn, rule_id, include_metadata=False) + if not rule: + return PlainTextResponse(status_code=status.HTTP_404_NOT_FOUND) - setting = await app.db_logic.get_setting(rule.setting, include_metadata=False, include_aliases=False, - include_configurable_features=False) - assert setting + setting = await db_get_setting(conn, rule.setting, include_metadata=False, include_aliases=False, + include_configurable_features=False) + assert setting - if not setting.type.validate(input.value): - return PlainTextResponse(f'rule value is incompatible with setting type {setting.type}', - status_code=status.HTTP_400_BAD_REQUEST) + if not setting.type.validate(input.value): + return PlainTextResponse(f'rule value is incompatible with setting type {setting.type}', + status_code=status.HTTP_400_BAD_REQUEST) - await app.db_logic.patch_rule(rule_id, input.value) - - -# https://github.com/tiangolo/fastapi/issues/2724 -class QueryRulesOutput_Rule(ORJSONModel): - value: Any = Field(description="the value of the setting in contexts where the rule matches") - context_features: List[Tuple[str, str]] = Field( - description="a list of exact-match conditions for the rule, in hierarchical order" - ) - rule_id: int = Field(description="unique identifier of the rule.") - - -class QueryRulesOutput_Setting(ORJSONModel): - rules: List[QueryRulesOutput_Rule] = Field(description="a list of rules for the setting") - - -class QueryRulesOutput(ORJSONModel): - settings: Dict[str, QueryRulesOutput_Setting] = Field(description="query results for each setting") - - -class QueryRulesOutputWithMetadata_Rule(QueryRulesOutput_Rule): - metadata: Dict[MetadataKey, Any] = Field(description="the metadata of the rule, if requested") - - -class QueryRulesOutputWithMetadata_Setting(ORJSONModel): - rules: List[QueryRulesOutputWithMetadata_Rule] = Field(description="a list of rules for the setting") - - -class QueryRulesOutputWithMetadata(ORJSONModel): - settings: Dict[str, QueryRulesOutputWithMetadata_Setting] = Field(description="query results for each setting") - - -raw_context_feature_filters_pattern = r'''(?x) -( - \* # we allow an explicit global wildcard - | - [a-zA-Z_0-9]+: # context feature name - ( - \* # accept any value - | - \( - [a-zA-Z_0-9]+ # value - (,[a-zA-Z_0-9]+)* # additional values - \) - ) - ( - ,[a-zA-Z_0-9]+: (\* | \([a-zA-Z_0-9]+ (,[a-zA-Z_0-9]+)*\)) # additional filters - )* -)? # we also allow empty string to signify no rules could match -$ -''' - - -@router.get('/query', response_model=Union[QueryRulesOutputWithMetadata, QueryRulesOutput]) # type: ignore -async def query_rules(request: Request, app: HeksherApp = application, - raw_settings: str = Query(None, alias='settings', - description="a comma-separated list of setting names", - regex='([a-zA-Z_0-9.]+(,[a-zA-Z_0-9.]+)*)?$'), - raw_context_filters: str = Query( - '*', alias='context_filters', - description="a comma-separated list of context feature filters", - regex=raw_context_feature_filters_pattern, example=["a:(X,Y),b:(Z),c:*", '*', 'a:*']), - include_metadata: bool = Query(False, description="whether to include rule metadata in the" - " response"), - ): - if raw_settings is None: - settings = [spec.name for spec in await app.db_logic.get_all_settings(include_configurable_features=False, - include_aliases=False, - include_metadata=False)] - elif not raw_settings: - settings = [] - else: - names = raw_settings.split(',') - aliases = await app.db_logic.get_canonical_names(names) - not_settings = [k for k, v in aliases.items() if not v] - if not_settings: - return PlainTextResponse(f'the following are not setting names: {not_settings}', - status_code=status.HTTP_404_NOT_FOUND) - settings = list(aliases.values()) - - if raw_context_filters == '*': - context_features_options: Optional[Dict[str, Optional[List[str]]]] = None - else: - context_filter_items = ((match['key'], (None if match['values'] is None else match['values'].split(','))) - for match in - re.finditer(r'(?P[a-z]+):(\((?P[^)]+)\)|\*)', raw_context_filters)) - context_features_options = {} - for k, v in context_filter_items: - if k in context_features_options: - return PlainTextResponse(f'context name repeated in context filter: {k}', - status_code=status.HTTP_400_BAD_REQUEST) - context_features_options[k] = v - - not_context_features = await app.db_logic.get_not_found_context_features(context_features_options) - if not_context_features: - return PlainTextResponse(f'the following are not valid context features: {not_context_features}', - status_code=status.HTTP_404_NOT_FOUND) - results = await app.db_logic.query_rules(settings, context_features_options, include_metadata) - if include_metadata: - ret: Union[QueryRulesOutputWithMetadata, QueryRulesOutput] = QueryRulesOutputWithMetadata( - settings={ - setting: - QueryRulesOutputWithMetadata_Setting(rules=[ - QueryRulesOutputWithMetadata_Rule( - value=rule.value, context_features=rule.feature_values, metadata=rule.metadata, - rule_id=rule.rule_id - ) - for rule in rules - ]) for setting, rules in results.items() - }) - else: - ret = QueryRulesOutput( - settings={ - setting: - QueryRulesOutput_Setting(rules=[ - QueryRulesOutput_Rule( - value=rule.value, context_features=rule.feature_values, metadata=rule.metadata, - rule_id=rule.rule_id - ) - for rule in rules - ]) for setting, rules in results.items() - }) - response = PydanticResponse(ret) - handle_etag(response, request) - return response + await db_patch_rule(conn, rule_id, input.value) class GetRuleOutput(ORJSONModel): @@ -264,13 +145,14 @@ class GetRuleOutput(ORJSONModel): @router.get('/{rule_id}', response_model=GetRuleOutput) async def get_rule(rule_id: int, app: HeksherApp = application): - rule_spec = await app.db_logic.get_rule(rule_id, include_metadata=True) + async with app.engine.connect() as conn: + rule_spec = await db_get_rule(conn, rule_id, include_metadata=True) - if not rule_spec: - return Response(status_code=status.HTTP_404_NOT_FOUND) + if not rule_spec: + return Response(status_code=status.HTTP_404_NOT_FOUND) - return GetRuleOutput(setting=rule_spec.setting, value=rule_spec.value, - feature_values=rule_spec.feature_values, metadata=rule_spec.metadata) + return GetRuleOutput(setting=rule_spec.setting, value=rule_spec.value, + feature_values=rule_spec.feature_values, metadata=rule_spec.metadata) router.include_router(metadata_router) diff --git a/heksher/api/v1/rules_metadata.py b/heksher/api/v1/rules_metadata.py index 00f85b2..80458f9 100644 --- a/heksher/api/v1/rules_metadata.py +++ b/heksher/api/v1/rules_metadata.py @@ -8,6 +8,11 @@ from heksher.api.v1.util import ORJSONModel, application from heksher.api.v1.validation import MetadataKey from heksher.app import HeksherApp +from heksher.db_logic.rule import db_get_rule +from heksher.db_logic.rule_metadata import ( + db_delete_rule_metadata, db_delete_rule_metadata_key, db_replace_rule_metadata, db_update_rule_metadata, + db_update_rule_metadata_key +) router = APIRouter() logger = getLogger(__name__) @@ -24,9 +29,10 @@ async def update_rule_metadata(rule_id: int, input: InputRuleMetadata, app: Heks """ if not input.metadata: return None - if not await app.db_logic.get_rule(rule_id, include_metadata=False): - return Response(status_code=status.HTTP_404_NOT_FOUND) - await app.db_logic.update_rule_metadata(rule_id, input.metadata) + async with app.engine.begin() as conn: + if not await db_get_rule(conn, rule_id, include_metadata=False): + return Response(status_code=status.HTTP_404_NOT_FOUND) + await db_update_rule_metadata(conn, rule_id, input.metadata) @router.put('/{rule_id}/metadata', status_code=status.HTTP_204_NO_CONTENT, response_class=Response) @@ -34,13 +40,14 @@ async def replace_rule_metadata(rule_id: int, input: InputRuleMetadata, app: Hek """ Change the current metadata of the rule. """ - if not await app.db_logic.get_rule(rule_id, include_metadata=False): - return Response(status_code=status.HTTP_404_NOT_FOUND) - if not input.metadata: - # empty dictionary equals to deleting the metadata - await app.db_logic.delete_rule_metadata(rule_id) - else: - await app.db_logic.replace_rule_metadata(rule_id, input.metadata) + async with app.engine.begin() as conn: + if not await db_get_rule(conn, rule_id, include_metadata=False): + return Response(status_code=status.HTTP_404_NOT_FOUND) + if not input.metadata: + # empty dictionary equals to deleting the metadata + await db_delete_rule_metadata(conn, rule_id) + else: + await db_replace_rule_metadata(conn, rule_id, input.metadata) class PutRuleMetadataKey(ORJSONModel): @@ -53,9 +60,10 @@ async def update_rule_metadata_key(rule_id: int, key: MetadataKey, input: PutRul """ Updates the current metadata of the rule. Existing keys won't be deleted. """ - if not await app.db_logic.get_rule(rule_id, include_metadata=False): - return Response(status_code=status.HTTP_404_NOT_FOUND) - await app.db_logic.update_rule_metadata_key(rule_id, key, input.value) + async with app.engine.begin() as conn: + if not await db_get_rule(conn, rule_id, include_metadata=False): + return Response(status_code=status.HTTP_404_NOT_FOUND) + await db_update_rule_metadata_key(conn, rule_id, key, input.value) @router.delete('/{rule_id}/metadata', status_code=status.HTTP_204_NO_CONTENT, response_class=Response) @@ -63,9 +71,10 @@ async def delete_rule_metadata(rule_id: int, app: HeksherApp = application): """ Delete a rule's metadata. """ - if not await app.db_logic.get_rule(rule_id, include_metadata=False): - return Response(status_code=status.HTTP_404_NOT_FOUND) - await app.db_logic.delete_rule_metadata(rule_id) + async with app.engine.begin() as conn: + if not await db_get_rule(conn, rule_id, include_metadata=False): + return Response(status_code=status.HTTP_404_NOT_FOUND) + await db_delete_rule_metadata(conn, rule_id) @router.delete('/{rule_id}/metadata/{key}', status_code=status.HTTP_204_NO_CONTENT, response_class=Response) @@ -73,9 +82,10 @@ async def delete_rule_key_from_metadata(rule_id: int, key: MetadataKey, app: Hek """ Delete a specific key from the rule's metadata. """ - if not await app.db_logic.get_rule(rule_id, include_metadata=False): - return Response(status_code=status.HTTP_404_NOT_FOUND) - await app.db_logic.delete_rule_metadata_key(rule_id, key) + async with app.engine.begin() as conn: + if not await db_get_rule(conn, rule_id, include_metadata=False): + return Response(status_code=status.HTTP_404_NOT_FOUND) + await db_delete_rule_metadata_key(conn, rule_id, key) class GetRuleMetadataOutput(ORJSONModel): @@ -93,6 +103,7 @@ async def get_rule_metadata(rule_id: int, app: HeksherApp = application): """ Get metadata of a rule. """ - if not (rule := await app.db_logic.get_rule(rule_id, include_metadata=True)): - return Response(status_code=status.HTTP_404_NOT_FOUND) - return GetRuleMetadataOutput(metadata=rule.metadata) + async with app.engine.connect() as conn: + if not (rule := await db_get_rule(conn, rule_id, include_metadata=True)): + return Response(status_code=status.HTTP_404_NOT_FOUND) + return GetRuleMetadataOutput(metadata=rule.metadata) diff --git a/heksher/api/v1/setting_declaration.py b/heksher/api/v1/setting_declaration.py index 79da690..75d3ae0 100644 --- a/heksher/api/v1/setting_declaration.py +++ b/heksher/api/v1/setting_declaration.py @@ -1,4 +1,3 @@ -from asyncio import gather from dataclasses import dataclass from itertools import chain from logging import getLogger @@ -13,7 +12,11 @@ from heksher.api.v1.util import ORJSONModel, PydanticResponse, application from heksher.api.v1.validation import ContextFeatureName, MetadataKey, SettingName, SettingVersion from heksher.app import HeksherApp -from heksher.db_logic.setting import SettingSpec +from heksher.db_logic.context_feature import db_get_not_found_context_features +from heksher.db_logic.rule import db_get_actual_configurable_features, db_get_rules_for_setting +from heksher.db_logic.setting import ( + SettingSpec, db_add_setting, db_get_canonical_names, db_get_setting, db_update_setting +) from heksher.db_logic.util import parse_setting_version from heksher.setting_types import SettingType @@ -115,181 +118,183 @@ class MessageDifference: Optional[Dict[str, Any]]] -async def declare_setting_endpoint(input: DeclareSettingInput, app: HeksherApp = application): +async def declare_setting(input: DeclareSettingInput, app: HeksherApp = application): """ Ensure that a setting exists, creating it if necessary. """ - existing = await app.db_logic.get_setting(input.name, include_aliases=True, include_metadata=True, - include_configurable_features=True) - if input.alias: - alias_canonical_name = (await app.db_logic.get_canonical_names([input.alias]))[input.alias] - if alias_canonical_name is None: - raise HTTPException(status_code=404, detail=f'alias {input.alias} does not exist') - # we only accept two options: either the setting does not exist and the alias is a canonical name, or it is an - # existing alias of the setting - if not ( - (not existing and input.alias == alias_canonical_name) - or (existing and existing.name == alias_canonical_name) - ): - raise HTTPException(status_code=409, detail=f'alias {input.alias} is an alias of unrelated setting ' - f'{alias_canonical_name}') - else: - alias_canonical_name = None - - if existing is None: - if alias_canonical_name: - existing = await app.db_logic.get_setting(alias_canonical_name, include_aliases=True, - include_metadata=True, - include_configurable_features=True) + async with app.engine.connect() as conn: # note that we might upgrade the connection to a transaction + existing = await db_get_setting(conn, input.name, include_aliases=True, include_metadata=True, + include_configurable_features=True) + if input.alias: + alias_canonical_name = (await db_get_canonical_names(conn, [input.alias]))[input.alias] + if alias_canonical_name is None: + raise HTTPException(status_code=404, detail=f'alias {input.alias} does not exist') + # we only accept two options: either the setting does not exist and the alias is a canonical name, or it is + # an existing alias of the setting + if not ( + (not existing and input.alias == alias_canonical_name) + or (existing and existing.name == alias_canonical_name) + ): + raise HTTPException(status_code=409, detail=f'alias {input.alias} is an alias of unrelated setting ' + f'{alias_canonical_name}') else: - if input.version != '1.0': - return PlainTextResponse('newly created settings must have version 1.0', status_code=400) - not_cf = await app.db_logic.get_not_found_context_features(input.configurable_features) - if not_cf: - return PlainTextResponse(f'{not_cf} are not acceptable context features', - status_code=status.HTTP_404_NOT_FOUND) - logger.info('creating new setting', extra={'setting_name': input.name}) - aliases = [input.alias] if input.alias else None - raw_default_value = str(orjson.dumps(input.default_value), 'utf-8') - spec = SettingSpec(input.name, str(input.type), raw_default_value, input.metadata, - input.configurable_features, aliases, input.version) - await app.db_logic.add_setting(spec) - return PydanticResponse(UpToDateDeclareSettingOutput(outcome='created')) - - async def get_diffs(is_outdated: bool) -> Tuple[NewSettingAttributes, DifferencesDict]: - differences: DifferencesDict = {k: [] for k in ['minor', 'major', 'mismatch']} # type: ignore[misc] - - # all the functions below handle different attributes of the setting. They append the differences to the dict, - # and return True if there are any differences. - - async def handle_cf_diff(is_outdated: bool) -> bool: - existing_setting_cfs = frozenset(existing.configurable_features) - new_setting_cfs = frozenset(input.configurable_features) - if existing_setting_cfs == new_setting_cfs: - return False - if is_outdated: - # all changes are flipped in direction - existing_setting_cfs, new_setting_cfs = new_setting_cfs, existing_setting_cfs - removed_cfs = existing_setting_cfs - new_setting_cfs - if not is_outdated and removed_cfs: - actual_cfs_in_use = await app.db_logic.get_actual_configurable_features(existing.name) - removed_cfs_in_use = removed_cfs & actual_cfs_in_use.keys() - if removed_cfs_in_use: - rule_ids = list(chain.from_iterable(actual_cfs_in_use[cf] for cf in removed_cfs_in_use)) - differences['mismatch'].append(MessageDifference( - f'configurable features {sorted(removed_cfs)} are still in use by rules {rule_ids}')) - return True - if existing_setting_cfs > new_setting_cfs: - differences['minor'].append( - MessageDifference(f'removal of configurable features {sorted(removed_cfs)}')) - else: - if not is_outdated: - not_cf = await app.db_logic.get_not_found_context_features(new_setting_cfs) - if not_cf: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, - detail=f'{not_cf} are not acceptable context features') - differences['major'].append(AttrDifference('configurable_features', existing.configurable_features)) - return True - - async def handle_type_diff(is_outdated: bool) -> bool: - if existing.type == input.type: - return False - existing_type = existing.type - new_type = input.type - if is_outdated: - # all changes are flipped in direction - existing_type, new_type = new_type, existing_type - if new_type < existing_type: - differences['minor'].append(AttrDifference('type', str(existing.type))) + alias_canonical_name = None + + if existing is None: + if alias_canonical_name: + existing = await db_get_setting(conn, alias_canonical_name, include_aliases=True, + include_metadata=True, + include_configurable_features=True) else: - # only check the rules if we're not outdated - rules = await app.db_logic.get_rules_for_setting(input.name) if not is_outdated else [] - mismatched_rule_ids = [rule.rule_id for rule in rules if not new_type.validate(rule.value)] - if mismatched_rule_ids: - differences['mismatch'].append( - MessageDifference(f'setting type incompatible with values for rules: ' - f'{sorted(mismatched_rule_ids)}')) + if input.version != '1.0': + return PlainTextResponse('newly created settings must have version 1.0', status_code=400) + not_cf = await db_get_not_found_context_features(conn, input.configurable_features) + if not_cf: + return PlainTextResponse(f'{not_cf} are not acceptable context features', + status_code=status.HTTP_404_NOT_FOUND) + logger.info('creating new setting', extra={'setting_name': input.name}) + aliases = [input.alias] if input.alias else None + raw_default_value = str(orjson.dumps(input.default_value), 'utf-8') + spec = SettingSpec(input.name, str(input.type), raw_default_value, input.metadata, + input.configurable_features, aliases, input.version) + async with app.engine.begin() as conn: + await db_add_setting(conn, spec) + return PydanticResponse(UpToDateDeclareSettingOutput(outcome='created')) + + async def get_diffs(is_outdated: bool) -> Tuple[NewSettingAttributes, DifferencesDict]: + differences: DifferencesDict = {k: [] for k in ['minor', 'major', 'mismatch']} # type: ignore[misc] + + # all the functions below handle different attributes of the setting. They append the differences to the + # dict and return True if there are any differences. + + async def handle_cf_diff(is_outdated: bool) -> bool: + existing_setting_cfs = frozenset(existing.configurable_features) + new_setting_cfs = frozenset(input.configurable_features) + if existing_setting_cfs == new_setting_cfs: + return False + if is_outdated: + # all changes are flipped in direction + existing_setting_cfs, new_setting_cfs = new_setting_cfs, existing_setting_cfs + removed_cfs = existing_setting_cfs - new_setting_cfs + if not is_outdated and removed_cfs: + actual_cfs_in_use = await db_get_actual_configurable_features(conn, existing.name) + removed_cfs_in_use = removed_cfs & actual_cfs_in_use.keys() + if removed_cfs_in_use: + rule_ids = list(chain.from_iterable(actual_cfs_in_use[cf] for cf in removed_cfs_in_use)) + differences['mismatch'].append(MessageDifference( + f'configurable features {sorted(removed_cfs)} are still in use by rules {rule_ids}')) + return True + if existing_setting_cfs > new_setting_cfs: + differences['minor'].append( + MessageDifference(f'removal of configurable features {sorted(removed_cfs)}')) else: - differences['major'].append(AttrDifference('type', str(existing.type))) - return True - - def handle_rename(is_outdated: bool) -> bool: - if input.name == existing.name: - # the names are equal, and we have a guarantee that the alias is already an alias of the setting - return False - if is_outdated: - differences['minor'].append(AttrDifference('name', existing.name)) + if not is_outdated: + not_cf = await db_get_not_found_context_features(conn, new_setting_cfs) + if not_cf: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, + detail=f'{not_cf} are not acceptable context features') + differences['major'].append(AttrDifference('configurable_features', existing.configurable_features)) return True - assert input.alias == existing.name - differences['minor'].append(AttrDifference('name', existing.name)) - return True - - def handle_default_value_diff(is_outdated: bool) -> bool: - if existing.default_value == input.default_value: - return False - differences['minor'].append(AttrDifference('default_value', existing.default_value)) - return True - - def handle_metadata_diff(is_outdated: bool) -> bool: - if existing.metadata == input.metadata: - return False - differences['minor'].append(AttrDifference('metadata', existing.metadata)) - return True - - is_cf_diff, is_type_diff = await gather(handle_cf_diff(is_outdated), handle_type_diff(is_outdated)) - return (input.configurable_features if is_cf_diff else None, - input.type if is_type_diff else None, - input.name if handle_rename(is_outdated) else None, - input.default_value if handle_default_value_diff(is_outdated) else None, - input.metadata if handle_metadata_diff(is_outdated) else None), differences - - def diff_list(differences: DifferencesDict) -> List[Union[MessageDifferenceOutput, AttributeDifferenceOutput]]: - diff_list: List[Union[MessageDifferenceOutput, AttributeDifferenceOutput]] = [] - for level, diffs in differences.items(): - for diff in diffs: - if isinstance(diff, AttrDifference): - diff_list.append(AttributeDifferenceOutput(level=level, attribute=diff.attribute, - latest_value=diff.latest_value)) + + async def handle_type_diff(is_outdated: bool) -> bool: + if existing.type == input.type: + return False + existing_type = existing.type + new_type = input.type + if is_outdated: + # all changes are flipped in direction + existing_type, new_type = new_type, existing_type + if new_type < existing_type: + differences['minor'].append(AttrDifference('type', str(existing.type))) else: - diff_list.append(MessageDifferenceOutput(level=level, message=diff.message)) - return diff_list - - if input.version == existing.version: - _, diffs = await get_diffs(False) - if any(v for v in diffs.values()): - return PydanticResponse(MismatchDeclareSettingOutput(differences=diff_list(diffs)), - status_code=status.HTTP_409_CONFLICT) - return PydanticResponse(UpToDateDeclareSettingOutput(outcome='uptodate')) - - input_version = parse_setting_version(input.version) - existing_version = parse_setting_version(existing.version) - - if input_version < existing_version: - _, diffs = await get_diffs(True) - return PydanticResponse(OutdatedDeclareSettingOutput(latest_version=existing.version, - differences=diff_list(diffs))) - - # now the user is definitely attempting an upgrade - (new_cfs, new_type, new_name, new_default_value, new_metadata), differences = await get_diffs(False) - if differences['mismatch']: - accepted = False - elif input_version[0] > existing_version[0]: - # we perform a major upgrade without mismatches - accepted = True - else: - assert input_version[1] > existing_version[1] - # we perform a minor upgrade, so long as there are no major differences - accepted = not differences['major'] - - if not accepted: - return PydanticResponse( - UpgradeDeclareSettingOutput(outcome='rejected', previous_version=existing.version, - differences=diff_list(differences)), - status_code=status.HTTP_409_CONFLICT) - await app.db_logic.update_setting(existing.name, new_name, new_cfs, new_type, new_default_value, - new_metadata, input.version) - return PydanticResponse(UpgradeDeclareSettingOutput(outcome='upgraded', previous_version=existing.version, - differences=diff_list(differences))) + # only check the rules if we're not outdated + rules = await db_get_rules_for_setting(conn, input.name) if not is_outdated else [] + mismatched_rule_ids = [rule.rule_id for rule in rules if not new_type.validate(rule.value)] + if mismatched_rule_ids: + differences['mismatch'].append( + MessageDifference(f'setting type incompatible with values for rules: ' + f'{sorted(mismatched_rule_ids)}')) + else: + differences['major'].append(AttrDifference('type', str(existing.type))) + return True + + def handle_rename(is_outdated: bool) -> bool: + if input.name == existing.name: + # the names are equal, and we have a guarantee that the alias is already an alias of the setting + return False + if is_outdated: + differences['minor'].append(AttrDifference('name', existing.name)) + return True + assert input.alias == existing.name + differences['minor'].append(AttrDifference('name', existing.name)) + return True + + def handle_default_value_diff(is_outdated: bool) -> bool: + if existing.default_value == input.default_value: + return False + differences['minor'].append(AttrDifference('default_value', existing.default_value)) + return True + + def handle_metadata_diff(is_outdated: bool) -> bool: + if existing.metadata == input.metadata: + return False + differences['minor'].append(AttrDifference('metadata', existing.metadata)) + return True + + return (input.configurable_features if await handle_cf_diff(is_outdated) else None, + input.type if await handle_type_diff(is_outdated) else None, + input.name if handle_rename(is_outdated) else None, + input.default_value if handle_default_value_diff(is_outdated) else None, + input.metadata if handle_metadata_diff(is_outdated) else None), differences + + def diff_list(differences: DifferencesDict) -> List[Union[MessageDifferenceOutput, AttributeDifferenceOutput]]: + diff_list: List[Union[MessageDifferenceOutput, AttributeDifferenceOutput]] = [] + for level, diffs in differences.items(): + for diff in diffs: + if isinstance(diff, AttrDifference): + diff_list.append(AttributeDifferenceOutput(level=level, attribute=diff.attribute, + latest_value=diff.latest_value)) + else: + diff_list.append(MessageDifferenceOutput(level=level, message=diff.message)) + return diff_list + + if input.version == existing.version: + _, diffs = await get_diffs(False) + if any(v for v in diffs.values()): + return PydanticResponse(MismatchDeclareSettingOutput(differences=diff_list(diffs)), + status_code=status.HTTP_409_CONFLICT) + return PydanticResponse(UpToDateDeclareSettingOutput(outcome='uptodate')) + + input_version = parse_setting_version(input.version) + existing_version = parse_setting_version(existing.version) + + if input_version < existing_version: + _, diffs = await get_diffs(True) + return PydanticResponse(OutdatedDeclareSettingOutput(latest_version=existing.version, + differences=diff_list(diffs))) + + # now the user is definitely attempting an upgrade + (new_cfs, new_type, new_name, new_default_value, new_metadata), differences = await get_diffs(False) + if differences['mismatch']: + accepted = False + elif input_version[0] > existing_version[0]: + # we perform a major upgrade without mismatches + accepted = True + else: + assert input_version[1] > existing_version[1] + # we perform a minor upgrade, so long as there are no major differences + accepted = not differences['major'] + + if not accepted: + return PydanticResponse( + UpgradeDeclareSettingOutput(outcome='rejected', previous_version=existing.version, + differences=diff_list(differences)), + status_code=status.HTTP_409_CONFLICT) + async with app.engine.begin() as conn: + await db_update_setting(conn, existing.name, new_name, new_cfs, new_type, new_default_value, + new_metadata, input.version) + return PydanticResponse(UpgradeDeclareSettingOutput(outcome='upgraded', previous_version=existing.version, + differences=diff_list(differences))) declare_setting_enpoint_args: Dict[str, Any] = dict( diff --git a/heksher/api/v1/settings.py b/heksher/api/v1/settings.py index b92fe8c..7ac45e0 100644 --- a/heksher/api/v1/settings.py +++ b/heksher/api/v1/settings.py @@ -7,12 +7,19 @@ from starlette import status from starlette.responses import JSONResponse, PlainTextResponse -from heksher.api.v1.setting_declaration import declare_setting_endpoint, declare_setting_enpoint_args +from heksher.api.v1.setting_declaration import declare_setting, declare_setting_enpoint_args from heksher.api.v1.settings_metadata import router as metadata_router from heksher.api.v1.util import ORJSONModel, application, router as v1_router from heksher.api.v1.validation import MetadataKey, SettingName, SettingVersion from heksher.app import HeksherApp -from heksher.db_logic.setting_configurable_features import set_settings_configurable_features +from heksher.db_logic.rule import ( + db_get_actual_configurable_features, db_get_rules_feature_values, db_get_rules_for_setting +) +from heksher.db_logic.setting import ( + db_bump_setting_version, db_delete_setting, db_get_canonical_names, db_get_setting, db_get_settings, + db_rename_setting, db_set_setting_type +) +from heksher.db_logic.setting_configurable_features import db_set_settings_configurable_features from heksher.db_logic.util import parse_setting_version from heksher.setting_types import SettingType @@ -20,7 +27,7 @@ logger = getLogger(__name__) -router.add_api_route('/declare', declare_setting_endpoint, **declare_setting_enpoint_args) +router.add_api_route('/declare', declare_setting, **declare_setting_enpoint_args) @router.delete('/{name}', status_code=status.HTTP_204_NO_CONTENT, response_class=Response) @@ -28,13 +35,14 @@ async def delete_setting(name: str, app: HeksherApp = application): """ Delete a setting. """ - setting = await app.db_logic.get_setting(name, include_metadata=False, include_aliases=False, - include_configurable_features=False) # for aliasing - if not setting: - return PlainTextResponse('setting name not found', status_code=status.HTTP_404_NOT_FOUND) - deleted = await app.db_logic.delete_setting(setting.name) - if not deleted: - return PlainTextResponse('setting name not found', status_code=status.HTTP_404_NOT_FOUND) + async with app.engine.begin() as conn: + setting = await db_get_setting(conn, name, include_metadata=False, include_aliases=False, + include_configurable_features=False) # for aliasing + if not setting: + return PlainTextResponse('setting name not found', status_code=status.HTTP_404_NOT_FOUND) + deleted = await db_delete_setting(conn, setting.name) + if not deleted: + return PlainTextResponse('setting name not found', status_code=status.HTTP_404_NOT_FOUND) class GetSettingOutput(ORJSONModel): @@ -58,13 +66,14 @@ async def get_setting(name: str, app: HeksherApp = application): """ Get details on a setting. """ - setting = await app.db_logic.get_setting(name, include_metadata=True, include_aliases=True, - include_configurable_features=True) - if not setting: - return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) - return GetSettingOutput(name=setting.name, configurable_features=setting.configurable_features, - type=setting.raw_type, default_value=setting.default_value, metadata=setting.metadata, - aliases=setting.aliases, version=setting.version) + async with app.engine.connect() as conn: + setting = await db_get_setting(conn, name, include_metadata=True, include_aliases=True, + include_configurable_features=True) + if not setting: + return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) + return GetSettingOutput(name=setting.name, configurable_features=setting.configurable_features, + type=setting.raw_type, default_value=setting.default_value, metadata=setting.metadata, + aliases=setting.aliases, version=setting.version) # https://github.com/tiangolo/fastapi/issues/2724 @@ -96,9 +105,11 @@ async def get_settings(include_additional_data: bool = False, app: HeksherApp = """ List all the settings in the service """ + async with app.engine.connect() as conn: + results = (await db_get_settings(conn, include_metadata=include_additional_data, + include_aliases=include_additional_data, + include_configurable_features=include_additional_data)).values() if include_additional_data: - full_results = await app.db_logic.get_all_settings(include_metadata=True, include_aliases=True, - include_configurable_features=True) return GetSettingsOutputWithData(settings=[ GetSettingsOutputWithData_Setting( name=spec.name, @@ -108,11 +119,9 @@ async def get_settings(include_additional_data: bool = False, app: HeksherApp = metadata=spec.metadata, aliases=spec.aliases, version=spec.version, - ) for spec in full_results + ) for spec in results ]) else: - results = await app.db_logic.get_all_settings(include_metadata=False, include_aliases=False, - include_configurable_features=False) return GetSettingsOutput(settings=[ GetSettingsOutput_Setting( name=spec.name, @@ -145,38 +154,40 @@ async def set_setting_type(name: str, input: PutSettingTypeInput, app: HeksherAp """ Change The type of a setting """ - setting = await app.db_logic.get_setting(name, include_metadata=False, include_aliases=False, - include_configurable_features=False) - if not setting: - return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) - existing_version = parse_setting_version(setting.version) - new_version = parse_setting_version(input.version) - if existing_version == new_version and input.type == setting.type: - return Response(status_code=status.HTTP_204_NO_CONTENT) - if existing_version >= new_version: - return PlainTextResponse(f'the setting {name} is at a higher version than the request ({existing_version})', - status_code=status.HTTP_409_CONFLICT) - if input.type == setting.type: - # we only need to do a version bump - async with app.db_logic.db_engine.begin() as conn: - await app.db_logic.bump_setting_version(conn, name, input.version) - return Response(status_code=status.HTTP_204_NO_CONTENT) - conflicts = [] - if not input.type.validate(setting.default_value): - conflicts.append(f'the default value {setting.default_value!r} does not match the new type') - rules = await app.db_logic.get_rules_for_setting(setting.name) - bad_rules = {rule_id: rule_value for (rule_id, rule_value) in rules if not input.type.validate(rule_value)} - if bad_rules: - conditions = await app.db_logic.get_rules_feature_values(list(bad_rules.keys())) - for rule_id, value in bad_rules.items(): - conflicts.append(f'rule {rule_id} ({conditions[rule_id]}) has incompatible value {value}') - if not (input.type < setting.type) and existing_version[0] == new_version[0]: - conflicts.append(f'cannot change type to non-subtype {input.type} in a minor version bump') - if conflicts: - return JSONResponse(PutSettingTypeConflictOutput(conflicts=conflicts).dict(), - status_code=status.HTTP_409_CONFLICT) - await app.db_logic.set_setting_type(setting.name, input.type, input.version) - return None + async with app.engine.connect() as conn: # note that the connection may be upgraded to transaction + setting = await db_get_setting(conn, name, include_metadata=False, include_aliases=False, + include_configurable_features=False) + if not setting: + return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) + existing_version = parse_setting_version(setting.version) + new_version = parse_setting_version(input.version) + if existing_version == new_version and input.type == setting.type: + return Response(status_code=status.HTTP_204_NO_CONTENT) + if existing_version >= new_version: + return PlainTextResponse(f'the setting {name} is at a higher version than the request ({existing_version})', + status_code=status.HTTP_409_CONFLICT) + if input.type == setting.type: + # we only need to do a version bump + async with app.engine.begin() as conn: + await db_bump_setting_version(conn, name, input.version) + return Response(status_code=status.HTTP_204_NO_CONTENT) + conflicts = [] + if not input.type.validate(setting.default_value): + conflicts.append(f'the default value {setting.default_value!r} does not match the new type') + rules = await db_get_rules_for_setting(conn, setting.name) + bad_rules = {rule_id: rule_value for (rule_id, rule_value) in rules if not input.type.validate(rule_value)} + if bad_rules: + conditions = await db_get_rules_feature_values(conn, list(bad_rules.keys())) + for rule_id, value in bad_rules.items(): + conflicts.append(f'rule {rule_id} ({conditions[rule_id]}) has incompatible value {value}') + if not (input.type < setting.type) and existing_version[0] == new_version[0]: + conflicts.append(f'cannot change type to non-subtype {input.type} in a minor version bump') + if conflicts: + return JSONResponse(PutSettingTypeConflictOutput(conflicts=conflicts).dict(), + status_code=status.HTTP_409_CONFLICT) + async with app.engine.begin() as conn: + await db_set_setting_type(conn, setting.name, input.type, input.version) + return None class RenameSettingInput(ORJSONModel): @@ -197,39 +208,42 @@ async def rename_setting(name: str, input: RenameSettingInput, app: HeksherApp = """ Rename a setting, adding the previous name as an alias """ - # we try and validate the names we were given, and check they do not conflict with other settings - names_map = await app.db_logic.get_canonical_names((name, input.name)) - # the names map should contain 2 entries: - # the first entry: given original name/alias -> canonical name - # (could be the same if the given original name is the canonical one) - canonical_name = names_map[name] - # if the canonical name is None - this setting does not exist - if not canonical_name: - return PlainTextResponse('setting does not exist', status_code=status.HTTP_404_NOT_FOUND) - setting = await app.db_logic.get_setting(canonical_name, include_metadata=False, include_aliases=False, - include_configurable_features=False) - existing_version = parse_setting_version(setting.version) - new_version = parse_setting_version(input.version) - if existing_version == new_version and canonical_name == input.name: - return Response(status_code=status.HTTP_204_NO_CONTENT) - elif existing_version > new_version: - return PlainTextResponse(f'The setting {name} already has a newer version ({setting.version})', - status_code=status.HTTP_409_CONFLICT) - - if input.name == canonical_name: - # we just need to version bump - async with app.db_logic.db_engine.begin() as conn: - await app.db_logic.bump_setting_version(conn, canonical_name, input.version) - return Response(status_code=status.HTTP_204_NO_CONTENT) - # the second entry: given new name -> None - # if the value is not None - this name/alias already exists - if names_map[input.name] is not None: - # if this new name is an alias for the same setting, - # we can allow this operation to make the alias the canonical name - # otherwise - the operation cannot be done since the new name already exists as another setting's name or alias - if names_map[input.name] != canonical_name: - return PlainTextResponse('name already exists', status_code=status.HTTP_409_CONFLICT) - await app.db_logic.rename_setting(canonical_name, input.name, input.version) + async with app.engine.connect() as conn: # note that the connection may be upgraded to transaction + # we try and validate the names we were given, and check they do not conflict with other settings + names_map = await db_get_canonical_names(conn, (name, input.name)) + # the names map should contain 2 entries: + # the first entry: given original name/alias -> canonical name + # (could be the same if the given original name is the canonical one) + canonical_name = names_map[name] + # if the canonical name is None - this setting does not exist + if not canonical_name: + return PlainTextResponse('setting does not exist', status_code=status.HTTP_404_NOT_FOUND) + setting = await db_get_setting(conn, canonical_name, include_metadata=False, include_aliases=False, + include_configurable_features=False) + existing_version = parse_setting_version(setting.version) + new_version = parse_setting_version(input.version) + if existing_version == new_version and canonical_name == input.name: + return Response(status_code=status.HTTP_204_NO_CONTENT) + elif existing_version > new_version: + return PlainTextResponse(f'The setting {name} already has a newer version ({setting.version})', + status_code=status.HTTP_409_CONFLICT) + + if input.name == canonical_name: + # we just need to version bump + async with app.engine.begin() as conn: + await db_bump_setting_version(conn, canonical_name, input.version) + return Response(status_code=status.HTTP_204_NO_CONTENT) + # the second entry: given new name -> None + # if the value is not None - this name/alias already exists + if names_map[input.name] is not None: + # if this new name is an alias for the same setting, + # we can allow this operation to make the alias the canonical name + # otherwise - the operation cannot be done since the new name already exists as another setting's name or + # alias + if names_map[input.name] != canonical_name: + return PlainTextResponse('name already exists', status_code=status.HTTP_409_CONFLICT) + async with app.engine.begin() as conn: + await db_rename_setting(conn, canonical_name, input.name, input.version) return None @@ -250,39 +264,40 @@ class ConfigurableFeaturesInput(ORJSONModel): } }) async def set_configurable_features(name: str, input: ConfigurableFeaturesInput, app: HeksherApp = application): - setting = await app.db_logic.get_setting(name, include_metadata=False, include_aliases=False, - include_configurable_features=True) - if not setting: - return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) - existing_version = parse_setting_version(setting.version) - new_version = parse_setting_version(input.version) - existing_cfs = frozenset(setting.configurable_features) - new_cfs = frozenset(input.configurable_features) - if existing_version == new_version and new_cfs == existing_cfs: - return Response(status_code=status.HTTP_204_NO_CONTENT) - if existing_version >= new_version: - return PlainTextResponse(f'the setting {name} is at a higher version than the request ({existing_version})', - status_code=status.HTTP_409_CONFLICT) - if new_cfs == existing_cfs: - # we only need to do a version bump - async with app.db_logic.db_engine.begin() as conn: - await app.db_logic.bump_setting_version(conn, name, input.version) - return Response(status_code=status.HTTP_204_NO_CONTENT) - removed_cfs = existing_cfs - new_cfs - if removed_cfs: - # check if any rules are using the removed configurable features - actual_cfs_in_use = await app.db_logic.get_actual_configurable_features(setting.name) - removed_cfs_in_use = removed_cfs & actual_cfs_in_use.keys() - if removed_cfs_in_use: - rule_ids = list(chain.from_iterable(actual_cfs_in_use[cf] for cf in removed_cfs_in_use)) - return PlainTextResponse(f'Configurable features {removed_cfs_in_use} are in use by rules {rule_ids}', + async with app.engine.connect() as conn: + setting = await db_get_setting(conn, name, include_metadata=False, include_aliases=False, + include_configurable_features=True) + if not setting: + return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) + existing_version = parse_setting_version(setting.version) + new_version = parse_setting_version(input.version) + existing_cfs = frozenset(setting.configurable_features) + new_cfs = frozenset(input.configurable_features) + if existing_version == new_version and new_cfs == existing_cfs: + return Response(status_code=status.HTTP_204_NO_CONTENT) + if existing_version >= new_version: + return PlainTextResponse(f'the setting {name} is at a higher version than the request ({existing_version})', + status_code=status.HTTP_409_CONFLICT) + if new_cfs == existing_cfs: + # we only need to do a version bump + async with app.engine.begin() as conn: + await db_bump_setting_version(conn, name, input.version) + return Response(status_code=status.HTTP_204_NO_CONTENT) + removed_cfs = existing_cfs - new_cfs + if removed_cfs: + # check if any rules are using the removed configurable features + actual_cfs_in_use = await db_get_actual_configurable_features(conn, setting.name) + removed_cfs_in_use = removed_cfs & actual_cfs_in_use.keys() + if removed_cfs_in_use: + rule_ids = list(chain.from_iterable(actual_cfs_in_use[cf] for cf in removed_cfs_in_use)) + return PlainTextResponse(f'Configurable features {removed_cfs_in_use} are in use by rules {rule_ids}', + status_code=status.HTTP_409_CONFLICT) + if not (existing_cfs > new_cfs) and existing_version[0] == new_version[0]: + # can't add new cfs on the same major + return PlainTextResponse('Cannot add new configurable features on a minor version bump', status_code=status.HTTP_409_CONFLICT) - if not (existing_cfs > new_cfs) and existing_version[0] == new_version[0]: - # can't add new cfs on the same major - return PlainTextResponse('Cannot add new configurable features on a minor version bump', - status_code=status.HTTP_409_CONFLICT) - async with app.db_logic.db_engine.begin() as conn: - await set_settings_configurable_features(conn, setting.name, input.configurable_features, input.version) + async with app.engine.begin() as conn: + await db_set_settings_configurable_features(conn, setting.name, input.configurable_features, input.version) return Response(status_code=status.HTTP_204_NO_CONTENT) diff --git a/heksher/api/v1/settings_metadata.py b/heksher/api/v1/settings_metadata.py index 608f76d..05f2b20 100644 --- a/heksher/api/v1/settings_metadata.py +++ b/heksher/api/v1/settings_metadata.py @@ -9,6 +9,11 @@ from heksher.api.v1.util import ORJSONModel, application from heksher.api.v1.validation import MetadataKey, SettingVersion from heksher.app import HeksherApp +from heksher.db_logic.setting import db_get_setting +from heksher.db_logic.setting_metadata import ( + db_delete_setting_metadata, db_delete_setting_metadata_key, db_replace_setting_metadata, db_update_setting_metadata, + db_update_setting_metadata_key +) from heksher.db_logic.util import parse_setting_version router = APIRouter() @@ -25,16 +30,18 @@ async def update_setting_metadata(name: str, input: InputSettingMetadata, app: H """ Update the setting's metadata """ - setting = await app.db_logic.get_setting(name, include_metadata=False, include_aliases=False, - include_configurable_features=False) - if not setting: - return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) - existing_version = parse_setting_version(setting.version) - new_version = parse_setting_version(input.version) - if existing_version >= new_version: - return PlainTextResponse(f'The setting {name} already has a newer version ({setting.version})', - status_code=status.HTTP_409_CONFLICT) - await app.db_logic.update_setting_metadata(setting.name, input.metadata, input.version) + async with app.engine.connect() as conn: + setting = await db_get_setting(conn, name, include_metadata=False, include_aliases=False, + include_configurable_features=False) + if not setting: + return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) + existing_version = parse_setting_version(setting.version) + new_version = parse_setting_version(input.version) + if existing_version >= new_version: + return PlainTextResponse(f'The setting {name} already has a newer version ({setting.version})', + status_code=status.HTTP_409_CONFLICT) + async with app.engine.begin() as conn: + await db_update_setting_metadata(conn, setting.name, input.metadata, input.version) @router.put('/{name}/metadata', status_code=status.HTTP_204_NO_CONTENT, response_class=Response) @@ -42,16 +49,18 @@ async def replace_setting_metadata(name: str, input: InputSettingMetadata, app: """ Change the current metadata of the setting. """ - setting = await app.db_logic.get_setting(name, include_metadata=False, include_aliases=False, - include_configurable_features=False) - if not setting: - return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) - existing_version = parse_setting_version(setting.version) - new_version = parse_setting_version(input.version) - if existing_version >= new_version: - return PlainTextResponse(f'The setting {name} already has a newer version ({setting.version})', - status_code=status.HTTP_409_CONFLICT) - await app.db_logic.replace_setting_metadata(setting.name, input.metadata, input.version) + async with app.engine.begin() as conn: + setting = await db_get_setting(conn, name, include_metadata=False, include_aliases=False, + include_configurable_features=False) + if not setting: + return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) + existing_version = parse_setting_version(setting.version) + new_version = parse_setting_version(input.version) + if existing_version >= new_version: + return PlainTextResponse(f'The setting {name} already has a newer version ({setting.version})', + status_code=status.HTTP_409_CONFLICT) + async with app.engine.begin() as conn: + await db_replace_setting_metadata(conn, setting.name, input.metadata, input.version) class PutSettingMetadataKey(ORJSONModel): @@ -65,16 +74,18 @@ async def update_setting_metadata_key(name: str, key: MetadataKey, input: PutSet """ Updates the current metadata of the setting. Existing keys won't be deleted. """ - setting = await app.db_logic.get_setting(name, include_metadata=False, include_aliases=False, - include_configurable_features=False) - if not setting: - return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) - existing_version = parse_setting_version(setting.version) - new_version = parse_setting_version(input.version) - if existing_version >= new_version: - return PlainTextResponse(f'The setting {name} already has a newer version ({setting.version})', - status_code=status.HTTP_409_CONFLICT) - await app.db_logic.update_setting_metadata_key(setting.name, key, input.value, input.version) + async with app.engine.connect() as conn: + setting = await db_get_setting(conn, name, include_metadata=False, include_aliases=False, + include_configurable_features=False) + if not setting: + return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) + existing_version = parse_setting_version(setting.version) + new_version = parse_setting_version(input.version) + if existing_version >= new_version: + return PlainTextResponse(f'The setting {name} already has a newer version ({setting.version})', + status_code=status.HTTP_409_CONFLICT) + async with app.engine.begin() as conn: + await db_update_setting_metadata_key(conn, setting.name, key, input.value, input.version) class DeleteSettingMetadataInput(ORJSONModel): @@ -86,16 +97,18 @@ async def delete_setting_metadata(name: str, input: DeleteSettingMetadataInput, """ Delete a setting's metadata. """ - setting = await app.db_logic.get_setting(name, include_metadata=False, include_aliases=False, - include_configurable_features=False) - if not setting: - return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) - existing_version = parse_setting_version(setting.version) - new_version = parse_setting_version(input.version) - if existing_version >= new_version: - return PlainTextResponse(f'The setting {name} already has a newer version ({setting.version})', - status_code=status.HTTP_409_CONFLICT) - await app.db_logic.delete_setting_metadata(setting.name, input.version) + async with app.engine.connect() as conn: + setting = await db_get_setting(conn, name, include_metadata=False, include_aliases=False, + include_configurable_features=False) + if not setting: + return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) + existing_version = parse_setting_version(setting.version) + new_version = parse_setting_version(input.version) + if existing_version >= new_version: + return PlainTextResponse(f'The setting {name} already has a newer version ({setting.version})', + status_code=status.HTTP_409_CONFLICT) + async with app.engine.begin() as conn: + await db_delete_setting_metadata(conn, setting.name, input.version) @router.delete('/{name}/metadata/{key}', status_code=status.HTTP_204_NO_CONTENT, response_class=Response) @@ -104,16 +117,18 @@ async def delete_rule_key_from_metadata(name: str, key: MetadataKey, input: Dele """ Delete a specific key from the setting's metadata. """ - setting = await app.db_logic.get_setting(name, include_metadata=False, include_aliases=False, - include_configurable_features=False) - if not setting: - return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) - existing_version = parse_setting_version(setting.version) - new_version = parse_setting_version(input.version) - if existing_version >= new_version: - return PlainTextResponse(f'The setting {name} already has a newer version ({setting.version})', - status_code=status.HTTP_409_CONFLICT) - await app.db_logic.delete_setting_metadata_key(setting.name, key, input.version) + async with app.engine.connect() as conn: + setting = await db_get_setting(conn, name, include_metadata=False, include_aliases=False, + include_configurable_features=False) + if not setting: + return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) + existing_version = parse_setting_version(setting.version) + new_version = parse_setting_version(input.version) + if existing_version >= new_version: + return PlainTextResponse(f'The setting {name} already has a newer version ({setting.version})', + status_code=status.HTTP_409_CONFLICT) + async with app.engine.begin() as conn: + await db_delete_setting_metadata_key(conn, setting.name, key, input.version) class GetSettingMetadataOutput(ORJSONModel): @@ -131,7 +146,8 @@ async def get_setting_metadata(name: str, app: HeksherApp = application): """ Get metadata of a setting. """ - if not (setting := await app.db_logic.get_setting(name, include_metadata=True, include_aliases=False, - include_configurable_features=False)): - return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) + async with app.engine.connect() as conn: + if not (setting := await db_get_setting(conn, name, include_metadata=True, include_aliases=False, + include_configurable_features=False)): + return PlainTextResponse(f'the setting {name} does not exist', status_code=status.HTTP_404_NOT_FOUND) return GetSettingMetadataOutput(metadata=setting.metadata) diff --git a/heksher/app.py b/heksher/app.py index 8b1f6fe..c0409bf 100644 --- a/heksher/app.py +++ b/heksher/app.py @@ -1,6 +1,7 @@ import re from asyncio import wait_for from logging import INFO, getLogger +from typing import Sequence import orjson import sentry_sdk @@ -11,14 +12,15 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from heksher._version import __version__ -from heksher.db_logic import DBLogic +from heksher.db_logic.context_feature import db_add_context_features, db_get_context_features, db_move_context_features +from heksher.db_logic.util import supersequence_new_elements from heksher.health_monitor import HealthMonitor from heksher.util import db_url_with_async_driver logger = getLogger(__name__) connection_string = EnvVar('HEKSHER_DB_CONNECTION_STRING', type=db_url_with_async_driver) -startup_context_features = EnvVar('HEKSHER_STARTUP_CONTEXT_FEATURES', type=CollectionParser(';', str)) +startup_context_features = EnvVar('HEKSHER_STARTUP_CONTEXT_FEATURES', type=CollectionParser(';', str), default=None) class LogstashSettingSchema(Schema): @@ -38,9 +40,29 @@ class HeksherApp(FastAPI): The application class """ engine: AsyncEngine - db_logic: DBLogic health_monitor: HealthMonitor + async def ensure_context_features(self, expected_context_features: Sequence[str]): + async with self.engine.connect() as conn: + existing_features = await db_get_context_features(conn) + actual = dict(existing_features) + super_sequence = supersequence_new_elements(expected_context_features, actual) + if super_sequence is None: + raise RuntimeError(f'expected context features to be a subsequence of {expected_context_features}, ' + f'actual: {actual}') + expected = {cf: i for (i, cf) in enumerate(expected_context_features)} + misplaced_keys = [k for k, v in actual.items() if expected[k] != v] + if misplaced_keys: + logger.warning('fixing indexing for context features', extra={'misplaced_keys': misplaced_keys}) + async with self.engine.begin() as conn: + await db_move_context_features(conn, expected) + if super_sequence: + logger.info('adding new context features', extra={ + 'new_context_features': [element for (element, _) in super_sequence] + }) + async with self.engine.begin() as conn: + await db_add_context_features(conn, dict(super_sequence)) + async def startup(self): logstash_settings = logstash_settings_ev.get() if logstash_settings is not None: @@ -55,13 +77,13 @@ async def startup(self): self.engine = create_async_engine(db_connection_string, json_serializer=lambda obj: orjson.dumps(obj).decode(), - json_deserializer=orjson.loads + json_deserializer=orjson.loads, ) - self.db_logic = DBLogic(self.engine) # assert that the db logic holds up expected_context_features = startup_context_features.get() - await self.db_logic.ensure_context_features(expected_context_features) + if expected_context_features is not None: + await self.ensure_context_features(expected_context_features) self.health_monitor = HealthMonitor(self.engine) await self.health_monitor.start() diff --git a/heksher/db_logic/__init__.py b/heksher/db_logic/__init__.py index a847906..e69de29 100644 --- a/heksher/db_logic/__init__.py +++ b/heksher/db_logic/__init__.py @@ -1,16 +0,0 @@ -from sqlalchemy.ext.asyncio import AsyncEngine - -from heksher.db_logic.context_feature import ContextFeatureMixin -from heksher.db_logic.rule import RuleMixin -from heksher.db_logic.rule_metadata import RuleMetadataMixin -from heksher.db_logic.setting import SettingMixin -from heksher.db_logic.setting_metadata import SettingMetadataMixin - - -class DBLogic(ContextFeatureMixin, SettingMixin, RuleMixin, SettingMetadataMixin, RuleMetadataMixin): - """ - Class to handle all logic for interacting with the DB - """ - # note that all methods are implemented inside mixin classes - def __init__(self, engine: AsyncEngine): - self.db_engine = engine diff --git a/heksher/db_logic/context_feature.py b/heksher/db_logic/context_feature.py index dd1c3b3..33ead1e 100644 --- a/heksher/db_logic/context_feature.py +++ b/heksher/db_logic/context_feature.py @@ -1,156 +1,103 @@ from logging import getLogger -from typing import AbstractSet, Iterable, Optional, Sequence +from typing import AbstractSet, Iterable, Mapping, Optional, Sequence, Tuple -from sqlalchemy import and_, desc, select +from sqlalchemy import and_, case, func, select +from sqlalchemy.ext.asyncio import AsyncConnection -from heksher.db_logic.logic_base import DBLogicBase from heksher.db_logic.metadata import configurable, context_features -from heksher.db_logic.util import supersequence_new_elements logger = getLogger(__name__) -class ContextFeatureMixin(DBLogicBase): - async def ensure_context_features(self, expected_context_features: Sequence[str]): - """ - Ensure that the context features in the DB match those expected, or raise an error if that is not possible. - Args: - expected_context_features: The context features that should be present in the DB. - Raises: - raises a RuntimeError if the DB state cannot match the expected without deleting or reordering features. - """ - query = select([context_features.c.name, context_features.c.index]).order_by(context_features.c.index) - async with self.db_engine.connect() as conn: - records = (await conn.execute(query)).mappings().all() - expected = {cf: i for (i, cf) in enumerate(expected_context_features)} - actual = {row['name']: row['index'] for row in records} - super_sequence = supersequence_new_elements(expected_context_features, actual) - if super_sequence is None: - raise RuntimeError(f'expected context features to be a subsequence of {list(expected)}, ' - f'actual: {list(actual)}') - # get all context features that are out place with what we expect - misplaced_keys = [k for k, v in actual.items() if expected[k] != v] - if misplaced_keys: - logger.warning('fixing indexing for context features', extra={'misplaced_keys': misplaced_keys}) - async with self.db_engine.begin() as conn: - for k in misplaced_keys: - stmt = context_features.update().where(context_features.c.name == k).values(index=expected[k]) - await conn.execute(stmt) - if super_sequence: - logger.info('adding new context features', extra={ - 'new_context_features': [element for (element, _) in super_sequence] - }) - async with self.db_engine.begin() as conn: - await conn.execute( - context_features.insert().values( - [{'name': name, 'index': index} for (name, index) in super_sequence] - )) - - async def get_context_features(self) -> Sequence[str]: - """ - Returns: - A sequence of all the context features currently in the DB - """ - async with self.db_engine.connect() as conn: - rows = (await conn.execute( - select([context_features.c.name]).order_by(context_features.c.index), - )).scalars().all() - return rows - - async def get_not_found_context_features(self, candidates: Iterable[str]) -> AbstractSet[str]: - """ - Filter an iterable to only include strings that are not context features in the DB. - Args: - candidates: An iterable of potential context feature names - - Returns: - A set including only the candidates that are not context features - - """ - # todo improve? we expect both sets to be very small (<20 elements) - return set(candidates) - set(await self.get_context_features()) - - async def get_context_feature_index(self, context_feature: str) -> Optional[int]: - """ - Args: - context_feature: a potential context feature name. - - Returns: - The index of the context feature name (if it exists in the DB, else None) - """ - async with self.db_engine.connect() as conn: - index = (await conn.execute(select([context_features.c.index]) - .where(context_features.c.name == context_feature))).scalar_one_or_none() - return index - - async def is_configurable_setting_from_context_features(self, context_feature: str): - async with self.db_engine.connect() as conn: - setting_of_cf = ( - (await conn.execute(select([configurable.c.setting]) - .where(configurable.c.context_feature == context_feature))) - .scalar_one_or_none() - ) - return setting_of_cf is not None - - async def delete_context_feature(self, context_feature: str): - """ - Deletes the given context feature, and re-ordering the context features indexes - Args: - context_feature: a potential context feature name to be deleted. - """ - async with self.db_engine.begin() as conn: - index_deleted = (await conn.execute(context_features.delete() - .where(context_features.c.name == context_feature) - .returning(context_features.c.index))).scalar_one_or_none() - await conn.execute(context_features.update() - .where(context_features.c.index > index_deleted) - .values(index=context_features.c.index - 1)) - - async def move_after_context_feature(self, index_to_move: int, target_index: int): - """ - Changing context feature index to be after a different context feature. - Example: - {"a": 0, "b": 1, "c": 2} - when called, move_after_context_feature(0, 1) will result {"b": 0, "a": 1, "c": 2} - Args: - index_to_move: the index of the context feature to be moved after the target context feature. - target_index: the index of the target to be second to the given context feature. - """ - - async with self.db_engine.begin() as conn: - # first, change the index of the context feature to be moved to -1 so it won't be overridden - await conn.execute(context_features.update() - .where(context_features.c.index == index_to_move) - .values(index=-1)) - if index_to_move < target_index: - # move in between context features one step back - await conn.execute(context_features.update() - .where(and_(context_features.c.index <= target_index, - context_features.c.index > index_to_move)) - .values(index=context_features.c.index - 1)) - # update the index of the context feature to be moved to its correct position - await conn.execute(context_features.update() - .where(context_features.c.index == -1) - .values(index=target_index)) - else: - # move in between context features one step forward - await conn.execute(context_features.update() - .where(and_(context_features.c.index > target_index, - context_features.c.index < index_to_move)) - .values(index=context_features.c.index + 1)) - # update the index of the context feature to be moved to its correct position - await conn.execute(context_features.update() - .where(context_features.c.index == -1) - .values(index=target_index+1)) - - async def add_context_feature(self, context_feature: str): - """ - Adds context feature to end of the context_features table. - Args: - context_feature: context_feature to be inserted. - """ - async with self.db_engine.begin() as conn: - last_index = (await conn.execute( - select([context_features.c.index]).order_by(desc(context_features.c.index)).limit(1), - )).scalar_one() - await conn.execute(context_features.insert().values([{"name": context_feature, "index": last_index + 1}])) +async def db_move_context_features(conn: AsyncConnection, new_indices: Mapping[str, int]): + """ + Update the indices of all the context features in the database. The new indices are given by the mapping + new_indices. new_indices must include at least all the context features in the database. + """ + new_index = case([(context_features.c.name == name, index) for (name, index) in new_indices.items()]) + await conn.execute(context_features.update().values(index=new_index)) + + +async def db_get_context_features(conn: AsyncConnection) -> Sequence[Tuple[str, int]]: + return (await conn.execute( + select([context_features.c.name, context_features.c.index]).order_by(context_features.c.index), + )).all() + + +async def db_get_not_found_context_features(conn: AsyncConnection, candidates: Iterable[str]) -> AbstractSet[str]: + # todo improve? we expect both sets to be very small (<20 elements) + return set(candidates) - set(name for (name, _) in await db_get_context_features(conn)) + + +async def db_get_context_feature_index(conn: AsyncConnection, context_feature: str) -> Optional[int]: + return (await conn.execute(select([context_features.c.index]) + .where(context_features.c.name == context_feature))).scalar_one_or_none() + + +async def db_is_configurable_setting_from_context_features(conn: AsyncConnection, context_feature: str): + setting_of_cf = ( + (await conn.execute(select([configurable.c.setting]) + .where(configurable.c.context_feature == context_feature))) + .scalar_one_or_none() + ) + return setting_of_cf is not None + + +async def db_delete_context_feature(conn: AsyncConnection, context_feature: str): + index_deleted = (await conn.execute(context_features.delete() + .where(context_features.c.name == context_feature) + .returning(context_features.c.index))).scalar_one_or_none() + await conn.execute(context_features.update() + .where(context_features.c.index > index_deleted) + .values(index=context_features.c.index - 1)) + + +async def db_move_after_context_feature(conn: AsyncConnection, index_to_move: int, target_index: int): + """ + Changing context feature index to be after a different context feature. + Example: + {"a": 0, "b": 1, "c": 2} + when called, move_after_context_feature(0, 1) will result {"b": 0, "a": 1, "c": 2} + Args: + conn: the transaction connection to use + index_to_move: the index of the context feature to be moved after the target context feature. + target_index: the index of the target to be second to the given context feature. + """ + + # first, change the index of the context feature to be moved to -1 so it won't be overridden + await conn.execute(context_features.update() + .where(context_features.c.index == index_to_move) + .values(index=-1)) + if index_to_move < target_index: + # move in between context features one step back + await conn.execute(context_features.update() + .where(and_(context_features.c.index <= target_index, + context_features.c.index > index_to_move)) + .values(index=context_features.c.index - 1)) + # update the index of the context feature to be moved to its correct position + await conn.execute(context_features.update() + .where(context_features.c.index == -1) + .values(index=target_index)) + else: + # move in between context features one step forward + await conn.execute(context_features.update() + .where(and_(context_features.c.index > target_index, + context_features.c.index < index_to_move)) + .values(index=context_features.c.index + 1)) + # update the index of the context feature to be moved to its correct position + await conn.execute(context_features.update() + .where(context_features.c.index == -1) + .values(index=target_index + 1)) + + +async def db_add_context_feature_to_end(conn: AsyncConnection, context_feature: str): + last_index = (await conn.execute( + select([func.max(context_features.c.index)]), + )).scalar_one() + await db_add_context_features(conn, {context_feature: last_index + 1}) + + +async def db_add_context_features(conn: AsyncConnection, features: Mapping[str, int]): + await conn.execute(context_features.insert().values( + [{"name": context_feature, "index": index} for context_feature, index in features.items()]) + ) diff --git a/heksher/db_logic/logic_base.py b/heksher/db_logic/logic_base.py deleted file mode 100644 index 6520919..0000000 --- a/heksher/db_logic/logic_base.py +++ /dev/null @@ -1,11 +0,0 @@ -from sqlalchemy.ext.asyncio import AsyncEngine - - -class DBLogicBase: - """ - A base class for DBLogic mixins - """ - db_engine: AsyncEngine - - async def bump_setting_version(self, conn, setting_name: str, new_version: str) -> None: - pass diff --git a/heksher/db_logic/rule.py b/heksher/db_logic/rule.py index 51431a2..c0556c4 100644 --- a/heksher/db_logic/rule.py +++ b/heksher/db_logic/rule.py @@ -6,8 +6,8 @@ import orjson from sqlalchemy import and_, func, join, not_, select, tuple_ +from sqlalchemy.ext.asyncio import AsyncConnection -from heksher.db_logic.logic_base import DBLogicBase from heksher.db_logic.metadata import conditions, context_features, rule_metadata, rules @@ -33,307 +33,251 @@ class BareRuleSpec(NamedTuple): value: Any -class RuleMixin(DBLogicBase): - async def get_rule(self, id_: int, include_metadata: bool) -> Optional[RuleSpec]: - """ - Args: - id_: the id of a specific rule - include_metadata: whether to include rule metadata - - Returns: - A RuleSpec describing the rule with the id, or None if no such rule exists. - """ - async with self.db_engine.connect() as conn: - basic_results = (await conn.execute( - select([rules.c.setting, rules.c.value]).where(rules.c.id == id_).limit(1) - )).mappings().first() - if not basic_results: - # rule does not exist - return None - feature_values = (await conn.execute( - select([conditions.c.context_feature, conditions.c.feature_value]) - .select_from(join(conditions, context_features, - conditions.c.context_feature == context_features.c.name)) - .where(conditions.c.rule == id_) - .order_by(context_features.c.index) - )).mappings().all() - if include_metadata: - metadata_ = dict((await conn.execute( - select([rule_metadata.c.key, rule_metadata.c.value]) - .where(rule_metadata.c.rule == id_) - )).all()) - else: - metadata_ = None - - value_ = orjson.loads(basic_results['value']) - return RuleSpec( - basic_results['setting'], - value_, - [(f['context_feature'], f['feature_value']) for f in feature_values], - metadata_ +async def db_get_rule(conn: AsyncConnection, id_: int, include_metadata: bool) -> Optional[RuleSpec]: + basic_results = (await conn.execute( + select([rules.c.setting, rules.c.value]).where(rules.c.id == id_).limit(1) + )).mappings().first() + if not basic_results: + # rule does not exist + return None + feature_values = (await conn.execute( + select([conditions.c.context_feature, conditions.c.feature_value]) + .select_from(join(conditions, context_features, + conditions.c.context_feature == context_features.c.name)) + .where(conditions.c.rule == id_) + .order_by(context_features.c.index) + )).mappings().all() + if include_metadata: + metadata_ = dict((await conn.execute( + select([rule_metadata.c.key, rule_metadata.c.value]) + .where(rule_metadata.c.rule == id_) + )).all()) + else: + metadata_ = None + + value_ = orjson.loads(basic_results['value']) + return RuleSpec( + basic_results['setting'], + value_, + [(f['context_feature'], f['feature_value']) for f in feature_values], + metadata_ + ) + + +async def db_get_rules_feature_values(conn: AsyncConnection, ids: List[int]) -> Mapping[int, Sequence[Tuple[str, str]]]: + feature_values = (await conn.execute( + select([conditions.c.rule, conditions.c.context_feature, conditions.c.feature_value]) + .select_from(join(conditions, context_features, + conditions.c.context_feature == context_features.c.name)) + .where(conditions.c.rule.in_(ids)) + .order_by(context_features.c.index))).all() + + ret: Dict[int, List[Tuple[str, str]]] = {k: [] for k in ids} + for rule_id, feature, value in feature_values: + ret[rule_id].append((feature, value)) + return ret + + +async def db_get_rule_id(conn: AsyncConnection, setting: str, match_conditions: Dict[str, str]) -> Optional[int]: + condition_count = len(match_conditions) + condition_tuples = list(match_conditions.items()) + + setting_rules = rules.select().where(rules.c.setting == setting).subquery() + + stmt = select(setting_rules.c.id.distinct()) \ + .where( + # to make sure there is an exact-match to the given conditions, + # check the amount of conditions for the rule alongside testing there are no other conditions not + # specified by user + and_( + select(func.count()).select_from(conditions) + .where(conditions.c.rule == setting_rules.c.id).scalar_subquery() + == condition_count, # amount of conditions + not_(conditions.select() + .where( # for better performance (speed wise), do the negative check + and_(conditions.c.rule == setting_rules.c.id, + tuple_(conditions.c.context_feature, conditions.c.feature_value) + .not_in(condition_tuples))) + .exists()) ) + ) + resp = await conn.execute(stmt) + return resp.scalar_one_or_none() - async def get_rules_feature_values(self, ids: List[int]) -> Mapping[int, Sequence[Tuple[str, str]]]: - """ - Get all the context feature conditions of a list of rules - Args: - ids: the ids of the rules to get the conditions of - - Returns: - A mapping of rule ids to a list of context features and their values, ordered by their indices - - """ - async with self.db_engine.connect() as conn: - feature_values = (await conn.execute( - select([conditions.c.rule, conditions.c.context_feature, conditions.c.feature_value]) - .select_from(join(conditions, context_features, - conditions.c.context_feature == context_features.c.name)) - .where(conditions.c.rule.in_(ids)) - .order_by(context_features.c.index))).all() - - ret: Dict[int, List[Tuple[str, str]]] = {k: [] for k in ids} - for rule_id, feature, value in feature_values: - ret[rule_id].append((feature, value)) - return ret - - async def get_rule_id(self, setting: str, match_conditions: Dict[str, str]) -> Optional[int]: - """ - Lookup a rule by its settings and conditions, and retrieve its id - - Args: - setting: The name of the setting the rule pertains to - match_conditions: The exact-match conditions of the rule - - Returns: - The id of the rule, or None if it does not exist - - """ - condition_count = len(match_conditions) - condition_tuples = list(match_conditions.items()) - - setting_rules = rules.select().where(rules.c.setting == setting).subquery() - - async with self.db_engine.connect() as conn: - # - stmt = select(setting_rules.c.id.distinct()) \ - .where( - # to make sure there is an exact-match to the given conditions, - # check the amount of conditions for the rule alongside testing there are no other conditions not - # specified by user - and_( - select(func.count()).select_from(conditions) - .where(conditions.c.rule == setting_rules.c.id).scalar_subquery() - == condition_count, # amount of conditions - not_(conditions.select() - .where( # for better performance (speed wise), do the negative check - and_(conditions.c.rule == setting_rules.c.id, - tuple_(conditions.c.context_feature, conditions.c.feature_value) - .not_in(condition_tuples))) - .exists()) - ) - ) - resp = await conn.execute(stmt) - return resp.scalar_one_or_none() - - async def delete_rule(self, rule_id: int): - """ - Delete a rule from the DB - Args: - rule_id: the id of the rule to delete - """ - async with self.db_engine.begin() as conn: - await conn.execute(rules.delete().where(rules.c.id == rule_id)) - - async def add_rule(self, setting: str, value: Any, metadata: Dict[str, Any], - match_conditions: Dict[str, str]) -> int: - """ - Add a rule to the DB - Args: - setting: The setting the rule pertains to - value: The value of the setting where the rule matches - match_conditions: The exact-match conditions of the rule - metadata: additional metadata - - Returns: - The id of the newly-created rule - - Notes: - The caller must ensure that the rule does not exist prior - """ - value_ = str(orjson.dumps(value), 'utf-8') - async with self.db_engine.begin() as conn: - rule_id = (await conn.execute( - rules.insert() - .values(setting=setting, value=value_) - .returning(rules.c.id) - )).scalar_one() - await conn.execute( - conditions.insert().values( - [{'rule': rule_id, 'context_feature': k, 'feature_value': v} - for (k, v) in match_conditions.items()]) + +async def db_delete_rule(conn: AsyncConnection, rule_id: int): + """ + Delete a rule from the DB + Args: + rule_id: the id of the rule to delete + """ + await conn.execute(rules.delete().where(rules.c.id == rule_id)) + + +async def db_add_rule(conn: AsyncConnection, setting: str, value: Any, metadata: Dict[str, Any], + match_conditions: Dict[str, str]) -> int: + """ + Add a rule to the DB + + Returns: + The id of the newly-created rule + + Notes: + The caller must ensure that the rule does not exist prior + """ + value_ = str(orjson.dumps(value), 'utf-8') + rule_id = (await conn.execute( + rules.insert() + .values(setting=setting, value=value_) + .returning(rules.c.id) + )).scalar_one() + await conn.execute( + conditions.insert().values( + [{'rule': rule_id, 'context_feature': k, 'feature_value': v} + for (k, v) in match_conditions.items()]) + ) + if metadata: + await conn.execute( + rule_metadata.insert().values( + [{'rule': rule_id, 'key': k, 'value': v} for (k, v) in metadata.items()] ) - if metadata: - await conn.execute( - rule_metadata.insert().values( - [{'rule': rule_id, 'key': k, 'value': v} for (k, v) in metadata.items()] - ) - ) - return rule_id - - async def patch_rule(self, rule_id: int, value: Any) -> None: - """ - Patches existing rule in the database. Supports only changing value. - Args: - rule_id: Rule ID to patch - value: Value to change to - """ - encoded_value = str(orjson.dumps(value), 'utf-8') - async with self.db_engine.begin() as conn: - await conn.execute(rules.update().where(rules.c.id == rule_id).values(value=encoded_value)) - - async def query_rules(self, setting_names: List[str], - feature_value_options: Optional[Dict[str, Optional[List[str]]]], - include_metadata: bool) -> Dict[str, List[InnerRuleSpec]]: - """ - Search the rules of multiple settings - - Args: - setting_names: The names of the settings to query. - feature_value_options: The options for each context feature. Rules that cannot match with these options are - discounted. If None, all rules are counted - setting_touch_time_cutoff: If provided, will discount all rules pertaining to settings that have not been - updated since this time. - include_metadata: Whether to retrieve and include the metadata of each rule in the result. - - Returns: - A mapping of non-filtered settings to rules - - """ - applicable_rules: Dict[int, Tuple[Tuple[str, str], ...]] = {} - conditions_ = conditions.alias() - - if not setting_names: - # shortcut in case no settings are selected - rule_results = [] + ) + return rule_id + + +async def db_patch_rule(conn: AsyncConnection, rule_id: int, value: Any) -> None: + encoded_value = str(orjson.dumps(value), 'utf-8') + await conn.execute(rules.update().where(rules.c.id == rule_id).values(value=encoded_value)) + + +async def db_query_rules(conn: AsyncConnection, setting_names: List[str], + feature_value_options: Optional[Dict[str, Optional[List[str]]]], + include_metadata: bool) -> Dict[str, List[InnerRuleSpec]]: + """ + Search the rules of multiple settings + + Args: + conn: the DB connection + setting_names: The names of the settings to query. + feature_value_options: The options for each context feature. Rules that cannot match with these options are + discounted. If None, all rules are counted + include_metadata: Whether to retrieve and include the metadata of each rule in the result. + + Returns: + A mapping of non-filtered settings to rules + + """ + applicable_rules: Dict[int, Tuple[Tuple[str, str], ...]] = {} + conditions_ = conditions.alias() + + if not setting_names: + # shortcut in case no settings are selected + return {} + else: + # inv_match is a mixin condition, if an exact-match condition returns True for it, the rule associated with + # it will not be returned + if feature_value_options is None: + # match all + inv_match = False + elif not feature_value_options: + # match none + inv_match = True else: - # inv_match is a mixin condition, if an exact-match condition returns True for it, the rule associated with - # it will not be returned - if feature_value_options is None: - # match all - inv_match = False - elif not feature_value_options: - # match none - inv_match = True - else: - exact_tuple_conditions = [] - only_cf_conditions = [] - for k, v in feature_value_options.items(): - if v is None: - only_cf_conditions.append(k) - else: - for cf_value in v: - exact_tuple_conditions.append((k, cf_value)) - tuple_conditions = tuple_(conditions_.c.context_feature, conditions_.c.feature_value).not_in( - exact_tuple_conditions) - cf_conditions = conditions_.c.context_feature.not_in(only_cf_conditions) - if exact_tuple_conditions and only_cf_conditions: - inv_match = and_(tuple_conditions, cf_conditions) + exact_tuple_conditions = [] + only_cf_conditions = [] + for k, v in feature_value_options.items(): + if v is None: + only_cf_conditions.append(k) else: - inv_match = tuple_conditions if exact_tuple_conditions else cf_conditions - - clauses = [ - not_(conditions_.select() - .where( - and_( - conditions.c.rule == conditions_.c.rule, inv_match - )) - .exists()), - rules.c.setting.in_(setting_names) - ] - - query = (select() - .add_columns(rules.c.id, conditions.c.context_feature, conditions.c.feature_value) - .select_from(rules.outerjoin(conditions, rules.c.id == conditions.c.rule)) - .outerjoin(context_features, context_features.c.name == conditions.c.context_feature) - .where(*clauses) - .order_by(rules.c.id, context_features.c.index)) - - async with self.db_engine.connect() as conn: - conditions_results = (await conn.execute(query) - ).mappings().all() - # group all the conditions by rules - for rule_id, rows in groupby(conditions_results, key=itemgetter('id')): - rule_conditions = tuple((row['context_feature'], row['feature_value']) for row in rows) - if rule_conditions == ((None, None),): - # this will occur if a rule has no exact-match conditions (i.e. it is a wildcard on all features) - # though we don't support users entering rules without conditions, we nevertheless prepare against - # them existing in the DB - rule_conditions = () - applicable_rules[rule_id] = rule_conditions - - # finally, get all the actual data for each rule - rule_query = ( - select([rules.c.id, rules.c.setting, rules.c.value]) - .where(rules.c.id.in_(applicable_rules)) - .order_by(rules.c.setting, rules.c.id) - ) - async with self.db_engine.connect() as conn: - rule_results = (await conn.execute(rule_query)).mappings().all() - - if include_metadata: - metadata_results = (await conn.execute( - select([rule_metadata.c.rule, rule_metadata.c.key, rule_metadata.c.value]) - .where(rule_metadata.c.rule.in_(applicable_rules)) - .order_by(rule_metadata.c.rule) - )).all() - metadata = { - rule_id: {k: v for (_, k, v) in rows} - for (rule_id, rows) in groupby(metadata_results, key=itemgetter(0)) - } - else: - metadata = None - - ret: Dict[str, List[InnerRuleSpec]] = {setting: [] for setting in setting_names} - for setting, rows in groupby(rule_results, key=itemgetter('setting')): - rule_list = [ - InnerRuleSpec( - orjson.loads(row['value']), - applicable_rules[row['id']], - metadata.get(row['id'], {}) if metadata is not None else None, - row['id'] - ) - for row in rows - ] - ret[setting] = rule_list - return ret - - async def get_rules_for_setting(self, setting_name: str) -> Sequence[BareRuleSpec]: - """ - Get all the rules for a particular setting - - Args: - setting_name: the name of the setting to ge the rules from - - Returns: - A sequence of rules for the setting - - """ - async with self.db_engine.connect() as conn: - result = ((await conn.execute(select([rules.c.id, rules.c.value]).where(rules.c.setting == setting_name))) - .all()) - return [BareRuleSpec(id_, orjson.loads(v)) for id_, v in result] - - async def get_actual_configurable_features(self, setting_name: str) -> Dict[str, List[int]]: - """ - Get all the rules that are configured for each context feature for a setting - """ - async with self.db_engine.connect() as conn: - results = (await conn.execute( - select([conditions.c.context_feature, rules.c.id]) - .select_from(rules.join(conditions, rules.c.id == conditions.c.rule)) - .where(rules.c.setting == setting_name) - .order_by(conditions.c.context_feature) + for cf_value in v: + exact_tuple_conditions.append((k, cf_value)) + tuple_conditions = tuple_(conditions_.c.context_feature, conditions_.c.feature_value).not_in( + exact_tuple_conditions) + cf_conditions = conditions_.c.context_feature.not_in(only_cf_conditions) + if exact_tuple_conditions and only_cf_conditions: + inv_match = and_(tuple_conditions, cf_conditions) + else: + inv_match = tuple_conditions if exact_tuple_conditions else cf_conditions + + clauses = [ + not_(conditions_.select() + .where( + and_( + conditions.c.rule == conditions_.c.rule, inv_match + )) + .exists()), + rules.c.setting.in_(setting_names) + ] + + query = (select() + .add_columns(rules.c.id, conditions.c.context_feature, conditions.c.feature_value) + .select_from(rules.outerjoin(conditions, rules.c.id == conditions.c.rule)) + .outerjoin(context_features, context_features.c.name == conditions.c.context_feature) + .where(*clauses) + .order_by(rules.c.id, context_features.c.index)) + + conditions_results = (await conn.execute(query)).mappings().all() + # group all the conditions by rules + for rule_id, rows in groupby(conditions_results, key=itemgetter('id')): + rule_conditions = tuple((row['context_feature'], row['feature_value']) for row in rows) + if rule_conditions == ((None, None),): + # this will occur if a rule has no exact-match conditions (i.e. it is a wildcard on all features) + # though we don't support users entering rules without conditions, we nevertheless prepare against + # them existing in the DB + rule_conditions = () + applicable_rules[rule_id] = rule_conditions + + # finally, get all the actual data for each rule + rule_query = ( + select([rules.c.id, rules.c.setting, rules.c.value]) + .where(rules.c.id.in_(applicable_rules)) + .order_by(rules.c.setting, rules.c.id) + ) + rule_results = (await conn.execute(rule_query)).mappings().all() + + if include_metadata: + metadata_results = (await conn.execute( + select([rule_metadata.c.rule, rule_metadata.c.key, rule_metadata.c.value]) + .where(rule_metadata.c.rule.in_(applicable_rules)) + .order_by(rule_metadata.c.rule) )).all() + metadata = { + rule_id: {k: v for (_, k, v) in rows} + for (rule_id, rows) in groupby(metadata_results, key=itemgetter(0)) + } + else: + metadata = None + + ret: Dict[str, List[InnerRuleSpec]] = {setting: [] for setting in setting_names} + for setting, rows in groupby(rule_results, key=itemgetter('setting')): + rule_list = [ + InnerRuleSpec( + orjson.loads(row['value']), + applicable_rules[row['id']], + metadata.get(row['id'], {}) if metadata is not None else None, + row['id'] + ) + for row in rows + ] + ret[setting] = rule_list + return ret + - return {context_feature: [rule_id for (_, rule_id) in rows] - for (context_feature, rows) in groupby(results, key=itemgetter(0))} +async def db_get_rules_for_setting(conn: AsyncConnection, setting_name: str) -> Sequence[BareRuleSpec]: + result = (await conn.execute(select([rules.c.id, rules.c.value]).where(rules.c.setting == setting_name))).all() + return [BareRuleSpec(id_, orjson.loads(v)) for id_, v in result] + + +async def db_get_actual_configurable_features(conn: AsyncConnection, setting_name: str) -> Dict[str, List[int]]: + """ + Get all the rules that are configured for each context feature for a setting + """ + results = (await conn.execute( + select([conditions.c.context_feature, rules.c.id]) + .select_from(rules.join(conditions, rules.c.id == conditions.c.rule)) + .where(rules.c.setting == setting_name) + .order_by(conditions.c.context_feature) + )).all() + + return {context_feature: [rule_id for (_, rule_id) in rows] + for (context_feature, rows) in groupby(results, key=itemgetter(0))} diff --git a/heksher/db_logic/rule_metadata.py b/heksher/db_logic/rule_metadata.py index f59c824..247fb52 100644 --- a/heksher/db_logic/rule_metadata.py +++ b/heksher/db_logic/rule_metadata.py @@ -2,76 +2,53 @@ from sqlalchemy import and_ from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.ext.asyncio import AsyncConnection -from heksher.db_logic.logic_base import DBLogicBase from heksher.db_logic.metadata import rule_metadata -class RuleMetadataMixin(DBLogicBase): - async def update_rule_metadata(self, rule_id: int, metadata: Dict[str, Any]): - """ - Update the metadata of the given rule. Similar to the dict.update() method, meaning that for existing keys, - the value will be updated, and new keys will be added as well. - Args: - rule_id: the id of the rule to update it's metadata. - metadata: the metadata to update. - """ - async with self.db_engine.begin() as conn: - stmt = insert(rule_metadata).values( - [{'rule': rule_id, 'key': k, 'value': v} for (k, v) in metadata.items()] - ) - await conn.execute(stmt.on_conflict_do_update(index_elements=[rule_metadata.c.rule, rule_metadata.c.key], - set_={"value": stmt.excluded.value})) - - async def replace_rule_metadata(self, rule_id: int, new_metadata: Dict[str, Any]): - """ - Replace the metadata of the given rule with new metadata. - Args: - rule_id: the id of the rule to change its metadata. - new_metadata: the new metadata for the rule. - """ - async with self.db_engine.begin() as conn: - await conn.execute(rule_metadata.delete() - .where(rule_metadata.c.rule == rule_id) - ) - await conn.execute( - rule_metadata.insert().values( - [{'rule': rule_id, 'key': k, 'value': v} for (k, v) in new_metadata.items()] - ) - ) - - async def update_rule_metadata_key(self, rule_id: int, key: str, new_value: Any): - """ - Updates a specific key of the rule's metadata. - Args: - rule_id: the id of the rule to change its metadata. - key: the key to update. - new_value: the value to update for the given key. - """ - async with self.db_engine.begin() as conn: - await conn.execute(insert(rule_metadata) - .values([{'rule': rule_id, 'key': key, 'value': new_value}]) - .on_conflict_do_update(index_elements=[rule_metadata.c.rule, - rule_metadata.c.key], - set_={"value": new_value})) - - async def delete_rule_metadata(self, rule_id: int): - """ - Remove a rule's metadata from the DB - Args: - rule_id: the id of the rule to remove its metadata - """ - async with self.db_engine.begin() as conn: - await conn.execute(rule_metadata.delete().where(rule_metadata.c.rule == rule_id)) - - async def delete_rule_metadata_key(self, rule_id: int, key: str): - """ - Remove a specific key from the rule's metadata - Args: - rule_id: the name of the rule - key: the name of the key to be deleted from the rule metadata - """ - async with self.db_engine.begin() as conn: - await conn.execute(rule_metadata.delete() - .where(and_(rule_metadata.c.rule == rule_id, - rule_metadata.c.key == key))) +async def db_update_rule_metadata(conn: AsyncConnection, rule_id: int, metadata: Dict[str, Any]): + """ + Update the metadata of the given rule. Similar to the dict.update() method, meaning that for existing keys, + the value will be updated, and new keys will be added as well. + """ + stmt = insert(rule_metadata).values( + [{'rule': rule_id, 'key': k, 'value': v} for (k, v) in metadata.items()] + ) + await conn.execute(stmt.on_conflict_do_update(index_elements=[rule_metadata.c.rule, rule_metadata.c.key], + set_={"value": stmt.excluded.value})) + + +async def db_replace_rule_metadata(conn: AsyncConnection, rule_id: int, new_metadata: Dict[str, Any]): + """ + Replace the metadata of the given rule with new metadata.. + """ + await conn.execute(rule_metadata.delete() + .where(rule_metadata.c.rule == rule_id) + ) + await conn.execute( + rule_metadata.insert().values( + [{'rule': rule_id, 'key': k, 'value': v} for (k, v) in new_metadata.items()] + ) + ) + + +async def db_update_rule_metadata_key(conn: AsyncConnection, rule_id: int, key: str, new_value: Any): + """ + Updates a specific key of the rule's metadata. + """ + await conn.execute(insert(rule_metadata) + .values([{'rule': rule_id, 'key': key, 'value': new_value}]) + .on_conflict_do_update(index_elements=[rule_metadata.c.rule, + rule_metadata.c.key], + set_={"value": new_value})) + + +async def db_delete_rule_metadata(conn: AsyncConnection, rule_id: int): + await conn.execute(rule_metadata.delete().where(rule_metadata.c.rule == rule_id)) + + +async def db_delete_rule_metadata_key(conn: AsyncConnection, rule_id: int, key: str): + await conn.execute(rule_metadata.delete() + .where(and_(rule_metadata.c.rule == rule_id, + rule_metadata.c.key == key))) diff --git a/heksher/db_logic/setting.py b/heksher/db_logic/setting.py index 83d6e8b..94a4fde 100644 --- a/heksher/db_logic/setting.py +++ b/heksher/db_logic/setting.py @@ -9,7 +9,6 @@ from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.asyncio import AsyncConnection -from heksher.db_logic.logic_base import DBLogicBase from heksher.db_logic.metadata import configurable, context_features, setting_aliases, setting_metadata, settings from heksher.setting_types import SettingType, setting_type @@ -33,286 +32,270 @@ def default_value(self): return orjson.loads(self.raw_default_value) -class SettingMixin(DBLogicBase): - async def get_canonical_names(self, names_or_aliases: Iterable[str]) -> Dict[str, str]: - """ - Args: - names_or_aliases: an iterable of potential setting names/aliases. - - Returns: - A dictionary of the given names/aliases to their canonical names. - Note: For settings that do no exist, the canonical name will be None. - """ - names_table = values(column('n', String), name='names').data([(name,) for name in names_or_aliases]) - - async with self.db_engine.connect() as conn: - stmt = ( - select([names_table.c.n, settings.c.name]) - .select_from( - join(names_table, - join(settings, - setting_aliases, - settings.c.name == setting_aliases.c.setting, - isouter=True), - or_(settings.c.name == names_table.c.n, setting_aliases.c.alias == names_table.c.n), - isouter=True) - ) - .distinct() +async def db_get_canonical_names(conn: AsyncConnection, names_or_aliases: Iterable[str]) -> Dict[str, str]: + """ + Args: + conn: the database connection + names_or_aliases: an iterable of potential setting names/aliases. + + Returns: + A dictionary of the given names/aliases to their canonical names. + Note: For settings that do no exist, the canonical name will be None. + """ + names_table = values(column('n', String), name='names').data([(name,) for name in names_or_aliases]) + + stmt = ( + select([names_table.c.n, settings.c.name]) + .select_from( + join(names_table, + join(settings, + setting_aliases, + settings.c.name == setting_aliases.c.setting, + isouter=True), + or_(settings.c.name == names_table.c.n, setting_aliases.c.alias == names_table.c.n), + isouter=True) ) - results = (await conn.execute(stmt)).all() - return dict(results) - - async def get_setting(self, name_or_alias: str, *, include_metadata: bool, include_aliases: bool, - include_configurable_features: bool) -> Optional[SettingSpec]: - """ - Args: - name_or_alias: The name/alias of a setting - include_metadata: whether to include setting metadata - Returns: - The setting object for the setting in the DB with the same name, or None if it does not exist - - """ - async with self.db_engine.connect() as conn: - stmt = ( - select([settings.c.name, settings.c.type, settings.c.default_value, settings.c.version]) - .select_from( - join(settings, setting_aliases, settings.c.name == setting_aliases.c.setting, isouter=True) - ) - .where(or_(settings.c.name == name_or_alias, setting_aliases.c.alias == name_or_alias)) - .limit(1) + .distinct() + ) + results = (await conn.execute(stmt)).all() + return dict(results) + + +async def db_get_setting(conn: AsyncConnection, name_or_alias: str, *, include_metadata: bool, include_aliases: bool, + include_configurable_features: bool) -> Optional[SettingSpec]: + stmt = ( + select([settings.c.name, settings.c.type, settings.c.default_value, settings.c.version]) + .select_from( + join(settings, setting_aliases, settings.c.name == setting_aliases.c.setting, isouter=True) + ) + .where(or_(settings.c.name == name_or_alias, setting_aliases.c.alias == name_or_alias)) + .limit(1) + ) + data_row = (await conn.execute(stmt)).mappings().first() + if data_row is None: + return None + setting_name = data_row['name'] + if include_aliases: + stmt = ( + select([setting_aliases.c.alias]) + .select_from( + join(settings, setting_aliases, settings.c.name == setting_aliases.c.setting) ) - data_row = (await conn.execute(stmt)).mappings().first() - - if data_row is None: - return None - - setting_name = data_row['name'] - if include_aliases: - stmt = ( - select([setting_aliases.c.alias]) - .select_from( - join(settings, setting_aliases, settings.c.name == setting_aliases.c.setting) - ) - .where(settings.c.name == setting_name) - .order_by(setting_aliases.c.alias) - ) - aliases = (await conn.execute(stmt)).scalars().all() - else: - aliases = None - - if include_configurable_features: - stmt = ( - select([configurable.c.context_feature]) - .select_from( - join(configurable, context_features, configurable.c.context_feature == context_features.c.name) - ) - .where(configurable.c.setting == setting_name) - .order_by(context_features.c.index) - ) - configurable_features = (await conn.execute(stmt)).scalars().all() - else: - configurable_features = None - - if include_metadata: - stmt = ( - select([setting_metadata.c.key, setting_metadata.c.value]) - .where(setting_metadata.c.setting == setting_name) - ) - metadata_ = dict((await conn.execute(stmt)).all()) - else: - metadata_ = None - - return SettingSpec( - setting_name, - data_row['type'], - data_row['default_value'], - metadata_, - configurable_features, - aliases, - data_row['version'] + .where(settings.c.name == setting_name) + .order_by(setting_aliases.c.alias) ) + aliases = (await conn.execute(stmt)).scalars().all() + else: + aliases = None - async def add_setting(self, setting: SettingSpec): - """ - Add a setting to the DB - Args: - setting: data of the new setting to add. - """ - async with self.db_engine.begin() as conn: - await conn.execute( - settings.insert().values( - name=setting.name, - type=str(setting.type), - default_value=str(orjson.dumps(setting.default_value), 'utf-8'), - version=setting.version - ) + if include_configurable_features: + stmt = ( + select([configurable.c.context_feature]) + .select_from( + join(configurable, context_features, configurable.c.context_feature == context_features.c.name) ) - await conn.execute( - configurable.insert().values( - [{'setting': setting.name, 'context_feature': cf} for cf in setting.configurable_features] - ) - ) - if setting.metadata: - await conn.execute( - setting_metadata.insert().values( - [{'setting': setting.name, 'key': k, 'value': v} for (k, v) in setting.metadata.items()] - ) - ) - assert not setting.aliases # newly added settings can't have aliases - - async def update_setting(self, old_name: str, new_name: Optional[str], configurable_features: Optional[List[str]], - type: Optional[SettingType], default_value: Optional[Any], - metadata: Optional[Dict[str, Any]], version: str): - async with self.db_engine.begin() as conn: - if configurable_features is not None: - await conn.execute( - configurable.delete().where(configurable.c.setting == old_name) - ) - await conn.execute( - configurable.insert().values( - [{'setting': old_name, 'context_feature': cf} for cf in configurable_features] - ) - ) - if metadata is not None: - await conn.execute( - setting_metadata.delete().where(setting_metadata.c.setting == old_name) - ) - await conn.execute( - setting_metadata.insert().values( - [{'setting': old_name, 'key': k, 'value': v} for (k, v) in metadata.items()] - ) - ) - - # we change the row last, so that the other tables can still refer to the setting by it's old name - row_changes = {'version': version} - if new_name: - row_changes['name'] = new_name - if type: - row_changes['type'] = str(type) - if default_value: - row_changes['default_value'] = str(orjson.dumps(default_value), 'utf-8') - await conn.execute( - settings.update().where(settings.c.name == old_name).values(**row_changes) + .where(configurable.c.setting == setting_name) + .order_by(context_features.c.index) + ) + configurable_features = (await conn.execute(stmt)).scalars().all() + else: + configurable_features = None + + if include_metadata: + stmt = ( + select([setting_metadata.c.key, setting_metadata.c.value]) + .where(setting_metadata.c.setting == setting_name) + ) + metadata_ = dict((await conn.execute(stmt)).all()) + else: + metadata_ = None + + return SettingSpec( + setting_name, + data_row['type'], + data_row['default_value'], + metadata_, + configurable_features, + aliases, + data_row['version'] + ) + + +async def db_add_setting(conn: AsyncConnection, setting: SettingSpec): + await conn.execute( + settings.insert().values( + name=setting.name, + type=str(setting.type), + default_value=str(orjson.dumps(setting.default_value), 'utf-8'), + version=setting.version + ) + ) + await conn.execute( + configurable.insert().values( + [{'setting': setting.name, 'context_feature': cf} for cf in setting.configurable_features] + ) + ) + if setting.metadata: + await conn.execute( + setting_metadata.insert().values( + [{'setting': setting.name, 'key': k, 'value': v} for (k, v) in setting.metadata.items()] ) + ) + assert not setting.aliases # newly added settings can't have aliases - if new_name: - # remove the new name from the aliases table, if it exists - await conn.execute( - setting_aliases.delete().where(setting_aliases.c.alias == new_name) - ) - await conn.execute( - insert(setting_aliases).values( - [{'setting': new_name, 'alias': old_name}] - ) - ) - - async def delete_setting(self, name: str) -> bool: - """ - Remove a setting from the DB - Args: - name: the name of the setting to remove - Returns: - Whether a setting with the name was found - """ - async with self.db_engine.begin() as conn: - resp = (await conn.execute(settings.delete().where(settings.c.name == name))).rowcount - return resp == 1 - - async def get_all_settings(self, include_configurable_features: bool, include_metadata: bool, - include_aliases: bool) -> List[SettingSpec]: - """ - Returns: - A list of all setting specs in the DB - """ - select_query = select([settings.c.name, settings.c.type, settings.c.default_value, settings.c.version]) \ - .order_by(settings.c.name) - - async with self.db_engine.connect() as conn: - records = (await conn.execute(select_query)).mappings().all() - if include_configurable_features: - configurable_rows = (await conn.execute( - select([configurable.c.setting, configurable.c.context_feature]) - .select_from(join(configurable, context_features, - configurable.c.context_feature == context_features.c.name)) - .order_by(configurable.c.setting, context_features.c.index) - )).mappings().all() - else: - configurable_rows = None - if include_metadata: - metadata_rows = await conn.execute( - setting_metadata.select().order_by(setting_metadata.c.setting) - ) - else: - metadata_rows = None - if include_aliases: - alias_rows = await conn.execute( - setting_aliases.select().order_by(setting_aliases.c.setting) - ) - else: - alias_rows = None - - if include_configurable_features: - configurable_features = { - setting: [row['context_feature'] for row in rows] - for (setting, rows) in groupby(configurable_rows, key=itemgetter('setting')) - } - else: - configurable_features = None - if include_metadata: - metadata = { - setting: {k: v for (_, k, v) in rows} for (setting, rows) in groupby(metadata_rows, key=itemgetter(0)) - } - else: - metadata = None - if include_aliases: - aliases = { - setting: [v for (_, v) in rows] for (setting, rows) in groupby(alias_rows, key=itemgetter(0)) - } - else: - aliases = None - - return [ - SettingSpec( - row['name'], - row['type'], - row['default_value'], - metadata.get(row['name'], {}) if include_metadata else None, - configurable_features[row['name']] if include_configurable_features else None, - aliases.get(row['name'], []) if include_aliases else None, - row['version'] - ) for row in records - ] - - async def set_setting_type(self, setting_name: str, new_type: SettingType, new_version: str): - """ - Change the type of a setting. Does not check validity. - Args: - setting_name: the name of the setting - new_type: the new type of the setting - """ - async with self.db_engine.begin() as conn: - await conn.execute(settings.update().where(settings.c.name == setting_name) - .values(type=str(new_type), version=new_version)) - - async def rename_setting(self, old_name: str, new_name: str, new_version: str): - async with self.db_engine.begin() as conn: - # this should cascade through all other tables - await conn.execute( - settings.update() - .where(settings.c.name == old_name) - .values({"name": new_name, 'version': new_version}) + +async def db_update_setting(conn: AsyncConnection, old_name: str, new_name: Optional[str], + configurable_features: Optional[List[str]], type: Optional[SettingType], + default_value: Optional[Any], metadata: Optional[Dict[str, Any]], version: str): + if configurable_features is not None: + await conn.execute( + configurable.delete().where(configurable.c.setting == old_name) + ) + await conn.execute( + configurable.insert().values( + [{'setting': old_name, 'context_feature': cf} for cf in configurable_features] ) - # add the old name as an alias of the new one - await conn.execute( - insert(setting_aliases).values( - [{'setting': new_name, 'alias': old_name}] - ) + ) + if metadata is not None: + await conn.execute( + setting_metadata.delete().where(setting_metadata.c.setting == old_name) + ) + await conn.execute( + setting_metadata.insert().values( + [{'setting': old_name, 'key': k, 'value': v} for (k, v) in metadata.items()] ) - # in case that the new name is an old alias, we remove the old alias from the aliases table - await conn.execute( - delete(setting_aliases) - .where(setting_aliases.c.setting == new_name, setting_aliases.c.alias == new_name) + ) + + # we change the row last, so that the other tables can still refer to the setting by it's old name + row_changes = {'version': version} + if new_name: + row_changes['name'] = new_name + if type: + row_changes['type'] = str(type) + if default_value: + row_changes['default_value'] = str(orjson.dumps(default_value), 'utf-8') + await conn.execute( + settings.update().where(settings.c.name == old_name).values(**row_changes) + ) + + if new_name: + # remove the new name from the aliases table, if it exists + await conn.execute( + setting_aliases.delete().where(setting_aliases.c.alias == new_name) + ) + await conn.execute( + insert(setting_aliases).values( + [{'setting': new_name, 'alias': old_name}] ) + ) + + +async def db_delete_setting(conn: AsyncConnection, name: str) -> bool: + """ + Remove a setting from the DB + Returns: + Whether a setting with the name was found + """ + resp = (await conn.execute(settings.delete().where(settings.c.name == name))).rowcount + return resp == 1 + + +async def db_get_settings(conn: AsyncConnection, include_configurable_features: bool, include_metadata: bool, + include_aliases: bool, setting_names: Optional[Iterable[str]] = None)\ + -> Dict[str, SettingSpec]: + """ + Returns: + A list of all setting specs in the DB + """ + select_query = select([settings.c.name, settings.c.type, settings.c.default_value, settings.c.version]) \ + .order_by(settings.c.name) + + if setting_names: + select_query = select_query.where(settings.c.name.in_(setting_names)) + + records = (await conn.execute(select_query)).all() + if include_configurable_features: + configurable_rows = (await conn.execute( + select([configurable.c.setting, configurable.c.context_feature]) + .select_from(join(configurable, context_features, + configurable.c.context_feature == context_features.c.name)) + .order_by(configurable.c.setting, context_features.c.index) + )).mappings().all() + else: + configurable_rows = None + if include_metadata: + metadata_rows = await conn.execute( + setting_metadata.select().order_by(setting_metadata.c.setting) + ) + else: + metadata_rows = None + if include_aliases: + alias_rows = await conn.execute( + setting_aliases.select().order_by(setting_aliases.c.setting) + ) + else: + alias_rows = None + + if include_configurable_features: + configurable_features = { + setting: [row['context_feature'] for row in rows] + for (setting, rows) in groupby(configurable_rows, key=itemgetter('setting')) + } + else: + configurable_features = None + if include_metadata: + metadata = { + setting: {k: v for (_, k, v) in rows} for (setting, rows) in groupby(metadata_rows, key=itemgetter(0)) + } + else: + metadata = None + if include_aliases: + aliases = { + setting: [v for (_, v) in rows] for (setting, rows) in groupby(alias_rows, key=itemgetter(0)) + } + else: + aliases = None + + return { + name: SettingSpec( + name, + raw_type, + default_value, + metadata.get(name, {}) if include_metadata else None, + configurable_features[name] if include_configurable_features else None, + aliases.get(name, []) if include_aliases else None, + version + ) for (name, raw_type, default_value, version) in records + } + + +async def db_set_setting_type(conn: AsyncConnection, setting_name: str, new_type: SettingType, new_version: str): + """ + Change the type of a setting. Does not check validity. + """ + await conn.execute(settings.update().where(settings.c.name == setting_name) + .values(type=str(new_type), version=new_version)) + + +async def db_rename_setting(conn: AsyncConnection, old_name: str, new_name: str, new_version: str): + # this should cascade through all other tables + await conn.execute( + settings.update() + .where(settings.c.name == old_name) + .values({"name": new_name, 'version': new_version}) + ) + # add the old name as an alias of the new one + await conn.execute( + insert(setting_aliases).values( + [{'setting': new_name, 'alias': old_name}] + ) + ) + # in case that the new name is an old alias, we remove the old alias from the aliases table + await conn.execute( + delete(setting_aliases) + .where(setting_aliases.c.setting == new_name, setting_aliases.c.alias == new_name) + ) + - async def bump_setting_version(self, conn: AsyncConnection, setting_name: str, new_version: str): - await conn.execute(settings.update().where(settings.c.name == setting_name).values(version=new_version)) +async def db_bump_setting_version(conn: AsyncConnection, setting_name: str, new_version: str): + await conn.execute(settings.update().where(settings.c.name == setting_name).values(version=new_version)) diff --git a/heksher/db_logic/setting_configurable_features.py b/heksher/db_logic/setting_configurable_features.py index 0f3be7f..c5e403e 100644 --- a/heksher/db_logic/setting_configurable_features.py +++ b/heksher/db_logic/setting_configurable_features.py @@ -5,8 +5,8 @@ from heksher.db_logic.metadata import configurable, settings -async def set_settings_configurable_features(conn: AsyncConnection, setting_name: str, configurable_features: List[str], - version: str): +async def db_set_settings_configurable_features(conn: AsyncConnection, setting_name: str, + configurable_features: List[str], version: str): await conn.execute(configurable.delete().where(configurable.c.setting == setting_name)) await conn.execute(configurable.insert().values([ {'setting': setting_name, 'context_feature': cf} for cf in configurable_features])) diff --git a/heksher/db_logic/setting_metadata.py b/heksher/db_logic/setting_metadata.py index 174e6d0..a64d86f 100644 --- a/heksher/db_logic/setting_metadata.py +++ b/heksher/db_logic/setting_metadata.py @@ -2,53 +2,52 @@ from sqlalchemy import and_ from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.ext.asyncio import AsyncConnection -from heksher.db_logic.logic_base import DBLogicBase from heksher.db_logic.metadata import setting_metadata +from heksher.db_logic.setting import db_bump_setting_version -class SettingMetadataMixin(DBLogicBase): - async def update_setting_metadata(self, name: str, metadata: Dict[str, Any], new_version: str): - async with self.db_engine.begin() as conn: - if metadata: - stmt = insert(setting_metadata).values( - [{'setting': name, 'key': k, 'value': v} for (k, v) in metadata.items()] - ) - await conn.execute(stmt.on_conflict_do_update(index_elements=[setting_metadata.c.setting, - setting_metadata.c.key], - set_={"value": stmt.excluded.value})) - await self.bump_setting_version(conn, name, new_version) - - async def replace_setting_metadata(self, name: str, new_metadata: Dict[str, Any], new_version: str): - async with self.db_engine.begin() as conn: - await conn.execute(setting_metadata.delete() - .where(setting_metadata.c.setting == name) - ) - if new_metadata: - await conn.execute( - setting_metadata.insert().values( - [{'setting': name, 'key': k, 'value': v} for (k, v) in new_metadata.items()] - ) - ) - await self.bump_setting_version(conn, name, new_version) - - async def update_setting_metadata_key(self, name: str, key: str, new_value: Any, new_version: str): - async with self.db_engine.begin() as conn: - await conn.execute(insert(setting_metadata) - .values([{'setting': name, 'key': key, 'value': new_value}]) - .on_conflict_do_update(index_elements=[setting_metadata.c.setting, +async def db_update_setting_metadata(conn: AsyncConnection, name: str, metadata: Dict[str, Any], new_version: str): + if metadata: + stmt = insert(setting_metadata).values( + [{'setting': name, 'key': k, 'value': v} for (k, v) in metadata.items()] + ) + await conn.execute(stmt.on_conflict_do_update(index_elements=[setting_metadata.c.setting, setting_metadata.c.key], - set_={"value": new_value})) - await self.bump_setting_version(conn, name, new_version) - - async def delete_setting_metadata(self, name: str, new_version: str): - async with self.db_engine.begin() as conn: - await conn.execute(setting_metadata.delete().where(setting_metadata.c.setting == name)) - await self.bump_setting_version(conn, name, new_version) - - async def delete_setting_metadata_key(self, name: str, key: str, new_version: str): - async with self.db_engine.begin() as conn: - await conn.execute(setting_metadata.delete() - .where(and_(setting_metadata.c.setting == name, - setting_metadata.c.key == key))) - await self.bump_setting_version(conn, name, new_version) + set_={"value": stmt.excluded.value})) + await db_bump_setting_version(conn, name, new_version) + + +async def db_replace_setting_metadata(conn: AsyncConnection, name: str, new_metadata: Dict[str, Any], new_version: str): + await conn.execute(setting_metadata.delete() + .where(setting_metadata.c.setting == name) + ) + if new_metadata: + await conn.execute( + setting_metadata.insert().values( + [{'setting': name, 'key': k, 'value': v} for (k, v) in new_metadata.items()] + ) + ) + await db_bump_setting_version(conn, name, new_version) + + +async def db_update_setting_metadata_key(conn: AsyncConnection, name: str, key: str, new_value: Any, new_version: str): + await conn.execute(insert(setting_metadata) + .values([{'setting': name, 'key': key, 'value': new_value}]) + .on_conflict_do_update(index_elements=[setting_metadata.c.setting, + setting_metadata.c.key], + set_={"value": new_value})) + await db_bump_setting_version(conn, name, new_version) + + +async def db_delete_setting_metadata(conn: AsyncConnection, name: str, new_version: str): + await conn.execute(setting_metadata.delete().where(setting_metadata.c.setting == name)) + await db_bump_setting_version(conn, name, new_version) + + +async def db_delete_setting_metadata_key(conn: AsyncConnection, name: str, key: str, new_version: str): + await conn.execute(setting_metadata.delete() + .where(and_(setting_metadata.c.setting == name, + setting_metadata.c.key == key))) + await db_bump_setting_version(conn, name, new_version) diff --git a/pyproject.toml b/pyproject.toml index e31b7f6..83bc0a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "Heksher" -version = "0.4.1" +version = "0.5.0" description = "Heksker" authors = ["Biocatch LTD "] diff --git a/tests/blackbox/app/test_v1api_aliases.py b/tests/blackbox/app/test_v1api_aliases.py index 30beb6f..8a1561b 100644 --- a/tests/blackbox/app/test_v1api_aliases.py +++ b/tests/blackbox/app/test_v1api_aliases.py @@ -71,7 +71,7 @@ async def search(setting: str, theme: str): })) async def query(*settings: str): - return _get_ok_data(await app_client.get('/api/v1/rules/query', query_string={ + return _get_ok_data(await app_client.get('/api/v1/query', query_string={ 'settings': ','.join(settings), 'context_filters': "*", })) @@ -100,8 +100,10 @@ async def query(*settings: str): await query("cat", "kelev", "yanshuf") assert (await query("cat", "kelev"))['settings'] == { 'hatul': {'rules': [{'value': 10, 'context_features': [['theme', 'bright']], 'rule_id': cat_rule}, - {'value': 10, 'context_features': [['theme', 'dark']], 'rule_id': hatul_rule}]}, - 'kelev': {'rules': [{'value': 10, 'context_features': [['theme', 'dracula']], 'rule_id': kelev_rule}]} + {'value': 10, 'context_features': [['theme', 'dark']], 'rule_id': hatul_rule}], + 'default_value': 5}, + 'kelev': {'rules': [{'value': 10, 'context_features': [['theme', 'dracula']], 'rule_id': kelev_rule}], + 'default_value': 5} } diff --git a/tests/blackbox/app/test_v1api_query.py b/tests/blackbox/app/test_v1api_query.py new file mode 100644 index 0000000..4ac8325 --- /dev/null +++ b/tests/blackbox/app/test_v1api_query.py @@ -0,0 +1,487 @@ +import json +from itertools import chain + +from pytest import fixture, mark + + +@fixture +def mk_setting(app_client): + async def mk_setting(name: str): + res = await app_client.post('/api/v1/settings/declare', data=json.dumps({ + 'name': name, + 'configurable_features': ['theme', 'trust', 'user'], + 'type': 'int', + 'default_value': 0, + })) + res.raise_for_status() + assert res.json() == {'outcome': 'created'} + + return mk_setting + + +@fixture +def mk_rule(app_client): + async def mk_rule(setting_name, features, val): + res = await app_client.post('/api/v1/rules', data=json.dumps({ + 'setting': setting_name, + 'feature_values': features, + 'value': val, + 'metadata': {'test': 'yes'} + })) + res.raise_for_status() + assert res.json().keys() == {'rule_id'} + + return mk_rule + + +@fixture +async def setup_rules(mk_setting, mk_rule): + await mk_setting('a') + await mk_setting('long_setting_name') + await mk_setting('b') + + await mk_rule('a', {'trust': 'full'}, 1) + await mk_rule('a', {'theme': 'black'}, 2) + await mk_rule('a', {'theme': 'black', 'trust': 'full'}, 3) + await mk_rule('long_setting_name', {'trust': 'none'}, 4) + await mk_rule('long_setting_name', {'trust': 'part'}, 5) + await mk_rule('b', {'trust': 'full'}, 6) + await mk_rule('a', {'theme': 'black', 'user': 'admin'}, 7) + + +def patch_rule_expectation_with_metadata(expected_rules): + for rule in chain.from_iterable(s['rules'] for s in expected_rules['settings'].values()): + rule['metadata'] = {'test': 'yes'} + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules(metadata: bool, app_client, setup_rules): + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a,long_setting_name', + 'context_filters': 'trust:(full,part),theme:(black)', + 'include_metadata': str(metadata) + }) + + expected = { + 'settings': { + 'a': {'rules': [ + {'context_features': [['trust', 'full']], 'value': 1, 'rule_id': 1}, + {'context_features': [['theme', 'black']], 'value': 2, 'rule_id': 2}, + {'context_features': [['trust', 'full'], ['theme', 'black']], 'value': 3, 'rule_id': 3} + ], 'default_value': 0}, + 'long_setting_name': {'rules': [ + {'context_features': [['trust', 'part']], 'value': 5, 'rule_id': 5} + ], 'default_value': 0} + } + } + if metadata: + patch_rule_expectation_with_metadata(expected) + + assert res.json() == expected + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_with_empty(metadata: bool, app_client, setup_rules, sql_service): + with sql_service.connection() as connection: + connection.execute(""" + INSERT INTO rules (setting, value) VALUES ('long_setting_name', '10') + """) + connection.execute(""" + INSERT INTO rule_metadata (rule, key, value) VALUES (8, 'test', '"yes"') + """) + + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a,long_setting_name', + 'context_filters': 'trust:(full,part),theme:(black)', + 'include_metadata': str(metadata) + }) + + expected = { + 'settings': { + 'a': {'rules': [ + {'context_features': [['trust', 'full']], 'rule_id': 1, 'value': 1}, + {'context_features': [['theme', 'black']], 'rule_id': 2, 'value': 2}, + {'context_features': [['trust', 'full'], ['theme', 'black']], 'rule_id': 3, 'value': 3} + ], 'default_value': 0}, + 'long_setting_name': {'rules': [ + {'context_features': [['trust', 'part']], 'rule_id': 5, 'value': 5}, + {'context_features': [], 'rule_id': 8, 'value': 10} + ], 'default_value': 0} + } + } + if metadata: + patch_rule_expectation_with_metadata(expected) + + assert res.json() == expected + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_nooptions(metadata: bool, app_client, setup_rules): + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a', + 'context_filters': '', + 'include_metadata': metadata + }) + res.raise_for_status() + + expected = { + 'settings': { + 'a': {'rules': [], 'default_value': 0} + } + } + + assert res.json() == expected + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_nooptions_with_matchall(metadata: bool, app_client, setup_rules, sql_service): + with sql_service.connection() as connection: + connection.execute(""" + INSERT INTO rules (setting, value) VALUES ('long_setting_name', '10') + """) + connection.execute(""" + INSERT INTO rule_metadata (rule, key, value) VALUES (8, 'test', '"yes"') + """) + + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a,long_setting_name', + 'context_filters': '', + 'include_metadata': str(metadata) + }) + + expected = { + 'settings': { + 'a': {'rules': [], 'default_value': 0}, + 'long_setting_name': {'rules': [ + {'context_features': [], 'value': 10, 'rule_id': 8}, + ], 'default_value': 0} + } + } + + if metadata: + patch_rule_expectation_with_metadata(expected) + + assert res.json() == expected + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_matchall(metadata: bool, app_client, setup_rules): + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a,long_setting_name', + 'context_filters': '*', + 'include_metadata': str(metadata) + }) + + expected = { + 'settings': { + 'a': {'rules': [ + {'context_features': [['trust', 'full']], 'value': 1, 'rule_id': 1}, + {'context_features': [['theme', 'black']], 'value': 2, 'rule_id': 2}, + {'context_features': [['trust', 'full'], ['theme', 'black']], 'rule_id': 3, 'value': 3}, + {'context_features': [['user', 'admin'], ['theme', 'black']], 'rule_id': 7, 'value': 7} + ], 'default_value': 0}, + 'long_setting_name': {'rules': [ + {'context_features': [['trust', 'none']], 'rule_id': 4, 'value': 4}, + {'context_features': [['trust', 'part']], 'rule_id': 5, 'value': 5} + ], 'default_value': 0} + } + } + if metadata: + patch_rule_expectation_with_metadata(expected) + + assert res.json() == expected + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_wildcard_some(metadata: bool, app_client, setup_rules): + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a,long_setting_name', + 'context_filters': 'theme:*,trust:(full,none)', + 'include_metadata': str(metadata) + }) + + expected = { + 'settings': { + 'a': {'rules': [ + {'context_features': [['trust', 'full']], 'value': 1, 'rule_id': 1}, + {'context_features': [['theme', 'black']], 'value': 2, 'rule_id': 2}, + {'context_features': [['trust', 'full'], ['theme', 'black']], 'value': 3, 'rule_id': 3}, + ], 'default_value': 0}, + 'long_setting_name': {'rules': [ + {'context_features': [['trust', 'none']], 'value': 4, 'rule_id': 4}, + ], 'default_value': 0} + } + } + if metadata: + patch_rule_expectation_with_metadata(expected) + + assert res.json() == expected + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_wildcard_only(metadata: bool, app_client, setup_rules): + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a,long_setting_name', + 'context_filters': 'theme:*', + 'include_metadata': str(metadata) + }) + + expected = { + 'settings': { + 'a': {'rules': [ + {'context_features': [['theme', 'black']], 'value': 2, 'rule_id': 2}, + ], 'default_value': 0}, + 'long_setting_name': {'rules': [], 'default_value': 0} + } + } + if metadata: + patch_rule_expectation_with_metadata(expected) + + assert res.json() == expected + + +@mark.asyncio +async def test_query_rules_bad_contexts(app_client, setup_rules): + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a,long_setting_name', + 'context_filters': 'theme:(black),trust:(full,part),love:(overflowing)' + }) + assert res.status_code == 404 + + +@mark.asyncio +async def test_query_rules_empty_contexts(app_client, setup_rules): + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a,long_setting_name', + 'context_filters': 'trust:()' + }) + assert res.status_code == 422 + + +@mark.asyncio +async def test_query_rules_bad_settings(app_client, setup_rules): + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a,d', + 'context_filters': 'trust:(full,part),theme:(black)' + }) + assert res.status_code == 404 + + +@mark.asyncio +@mark.parametrize('options', ['null', '**', 'wildcard']) +async def test_query_rules_bad_options(app_client, setup_rules, options): + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a', + 'context_filters': options, + }) + assert res.status_code == 422, res.content + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_nosettings(metadata: bool, app_client, setup_rules): + res = await app_client.get('/api/v1/query', query_string={ + 'settings': '', + 'include_metadata': str(metadata) + }) + res.raise_for_status() + + expected = {'settings': {}} + + assert res.json() == expected + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_allsettings_no_filter(metadata: bool, app_client, setup_rules): + res = await app_client.get('/api/v1/query', query_string={ + 'include_metadata': str(metadata) + }) + res.raise_for_status() + + expected = {'settings': { + 'a': {'rules': [ + {'context_features': [['trust', 'full']], 'value': 1, 'rule_id': 1}, + {'context_features': [['theme', 'black']], 'value': 2, 'rule_id': 2}, + {'context_features': [['trust', 'full'], ['theme', 'black']], 'value': 3, 'rule_id': 3}, + {'context_features': [['user', 'admin'], ['theme', 'black']], 'value': 7, 'rule_id': 7}, + ], 'default_value': 0}, + 'b': {'rules': [ + {'context_features': [['trust', 'full']], 'value': 6, 'rule_id': 6}, + ], 'default_value': 0}, + 'long_setting_name': {'rules': [ + {'context_features': [['trust', 'none']], 'value': 4, 'rule_id': 4}, + {'context_features': [['trust', 'part']], 'value': 5, 'rule_id': 5}, + ], 'default_value': 0} + }} + + if metadata: + patch_rule_expectation_with_metadata(expected) + + assert res.json() == expected + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_allsettings_with_filter(metadata: bool, app_client, setup_rules): + res = await app_client.get('/api/v1/query', query_string={ + 'context_filters': 'user:*,trust:(full),theme:(blue)', + 'include_metadata': str(metadata) + }) + res.raise_for_status() + + expected = {'settings': { + 'a': {'rules': [ + {'context_features': [['trust', 'full']], 'value': 1, 'rule_id': 1}, + ], 'default_value': 0}, + 'b': {'rules': [ + {'context_features': [['trust', 'full']], 'value': 6, 'rule_id': 6}, + ], 'default_value': 0}, + 'long_setting_name': {'rules': [], 'default_value': 0} + }} + + if metadata: + patch_rule_expectation_with_metadata(expected) + + assert res.json() == expected + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_change_default(metadata: bool, app_client, setup_rules): + res = await app_client.post('/api/v1/settings/declare', json={ + 'name': 'a', + 'default_value': 1, + 'configurable_features': ['theme', 'trust', 'user'], + 'type': 'int', + 'version': '1.1' + }) + res.raise_for_status() + + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a,long_setting_name', + 'context_filters': 'trust:(full,part),theme:(black)', + 'include_metadata': str(metadata) + }) + + expected = { + 'settings': { + 'a': {'rules': [ + {'context_features': [['trust', 'full']], 'value': 1, 'rule_id': 1}, + {'context_features': [['theme', 'black']], 'value': 2, 'rule_id': 2}, + {'context_features': [['trust', 'full'], ['theme', 'black']], 'value': 3, 'rule_id': 3} + ], 'default_value': 1}, + 'long_setting_name': {'rules': [ + {'context_features': [['trust', 'part']], 'value': 5, 'rule_id': 5} + ], 'default_value': 0} + } + } + if metadata: + patch_rule_expectation_with_metadata(expected) + + assert res.json() == expected + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_change_default_all_settings(metadata: bool, app_client, setup_rules): + res = await app_client.post('/api/v1/settings/declare', json={ + 'name': 'a', + 'default_value': 1, + 'configurable_features': ['theme', 'trust', 'user'], + 'type': 'int', + 'version': '1.1' + }) + res.raise_for_status() + + res = await app_client.get('/api/v1/query', query_string={ + 'context_filters': 'trust:(full,part),theme:(black)', + 'include_metadata': str(metadata) + }) + + expected = { + 'settings': { + 'a': {'rules': [ + {'context_features': [['trust', 'full']], 'value': 1, 'rule_id': 1}, + {'context_features': [['theme', 'black']], 'value': 2, 'rule_id': 2}, + {'context_features': [['trust', 'full'], ['theme', 'black']], 'value': 3, 'rule_id': 3} + ], 'default_value': 1}, + 'b': {'rules': [ + {'context_features': [['trust', 'full']], 'value': 6, 'rule_id': 6}, + ], 'default_value': 0}, + 'long_setting_name': {'rules': [ + {'context_features': [['trust', 'part']], 'value': 5, 'rule_id': 5} + ], 'default_value': 0} + } + } + if metadata: + patch_rule_expectation_with_metadata(expected) + + assert res.json() == expected + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_change_default_all_settings_no_rules(metadata: bool, app_client, setup_rules): + res = await app_client.post('/api/v1/settings/declare', json={ + 'name': 'a', + 'default_value': 1, + 'configurable_features': ['theme', 'trust', 'user'], + 'type': 'int', + 'version': '1.1' + }) + res.raise_for_status() + + res = await app_client.get('/api/v1/query', query_string={ + 'context_filters': 'trust:(fleeting)', + 'include_metadata': str(metadata) + }) + + expected = { + 'settings': { + 'a': {'rules': [], 'default_value': 1}, + 'b': {'rules': [], 'default_value': 0}, + 'long_setting_name': {'rules': [], 'default_value': 0} + } + } + if metadata: + patch_rule_expectation_with_metadata(expected) + + assert res.json() == expected + + +@mark.asyncio +@mark.parametrize('metadata', [False, True]) +async def test_query_rules_change_default_no_rules(metadata: bool, app_client, setup_rules): + res = await app_client.post('/api/v1/settings/declare', json={ + 'name': 'a', + 'default_value': 1, + 'configurable_features': ['theme', 'trust', 'user'], + 'type': 'int', + 'version': '1.1' + }) + res.raise_for_status() + + res = await app_client.get('/api/v1/query', query_string={ + 'settings': 'a,long_setting_name', + 'context_filters': 'trust:(fleeting)', + 'include_metadata': str(metadata) + }) + + expected = { + 'settings': { + 'a': {'rules': [], 'default_value': 1}, + 'long_setting_name': {'rules': [], 'default_value': 0} + } + } + if metadata: + patch_rule_expectation_with_metadata(expected) + + assert res.json() == expected diff --git a/tests/blackbox/app/test_v1api_rules.py b/tests/blackbox/app/test_v1api_rules.py index 6cea505..286f038 100644 --- a/tests/blackbox/app/test_v1api_rules.py +++ b/tests/blackbox/app/test_v1api_rules.py @@ -1,5 +1,4 @@ import json -from itertools import chain from pytest import fixture, mark @@ -190,310 +189,6 @@ async def setup_rules(mk_setting, mk_rule): await mk_rule('a', {'theme': 'black', 'user': 'admin'}, 7) -def patch_rule_expectation_with_metadata(expected_rules): - for rule in chain.from_iterable(s['rules'] for s in expected_rules['settings'].values()): - rule['metadata'] = {'test': 'yes'} - - -@mark.asyncio -@mark.parametrize('metadata', [False, True]) -async def test_query_rules(metadata: bool, app_client, setup_rules): - res = await app_client.get('/api/v1/rules/query', query_string={ - 'settings': 'a,long_setting_name', - 'context_filters': 'trust:(full,part),theme:(black)', - 'include_metadata': str(metadata) - }) - - expected = { - 'settings': { - 'a': {'rules': [ - {'context_features': [['trust', 'full']], 'value': 1, 'rule_id': 1}, - {'context_features': [['theme', 'black']], 'value': 2, 'rule_id': 2}, - {'context_features': [['trust', 'full'], ['theme', 'black']], 'value': 3, 'rule_id': 3} - ]}, - 'long_setting_name': {'rules': [ - {'context_features': [['trust', 'part']], 'value': 5, 'rule_id': 5} - ]} - } - } - if metadata: - patch_rule_expectation_with_metadata(expected) - - assert res.json() == expected - - -@mark.asyncio -@mark.parametrize('metadata', [False, True]) -async def test_query_rules_with_empty(metadata: bool, app_client, setup_rules, sql_service): - with sql_service.connection() as connection: - connection.execute(""" - INSERT INTO rules (setting, value) VALUES ('long_setting_name', '10') - """) - connection.execute(""" - INSERT INTO rule_metadata (rule, key, value) VALUES (8, 'test', '"yes"') - """) - - res = await app_client.get('/api/v1/rules/query', query_string={ - 'settings': 'a,long_setting_name', - 'context_filters': 'trust:(full,part),theme:(black)', - 'include_metadata': str(metadata) - }) - - expected = { - 'settings': { - 'a': {'rules': [ - {'context_features': [['trust', 'full']], 'rule_id': 1, 'value': 1}, - {'context_features': [['theme', 'black']], 'rule_id': 2, 'value': 2}, - {'context_features': [['trust', 'full'], ['theme', 'black']], 'rule_id': 3, 'value': 3} - ]}, - 'long_setting_name': {'rules': [ - {'context_features': [['trust', 'part']], 'rule_id': 5, 'value': 5}, - {'context_features': [], 'rule_id': 8, 'value': 10} - ]} - } - } - if metadata: - patch_rule_expectation_with_metadata(expected) - - assert res.json() == expected - - -@mark.asyncio -@mark.parametrize('metadata', [False, True]) -async def test_query_rules_nooptions(metadata: bool, app_client, setup_rules): - res = await app_client.get('/api/v1/rules/query', query_string={ - 'settings': 'a', - 'context_filters': '', - 'include_metadata': metadata - }) - res.raise_for_status() - - expected = { - 'settings': { - 'a': {'rules': []} - } - } - - assert res.json() == expected - - -@mark.asyncio -@mark.parametrize('metadata', [False, True]) -async def test_query_rules_nooptions_with_matchall(metadata: bool, app_client, setup_rules, sql_service): - with sql_service.connection() as connection: - connection.execute(""" - INSERT INTO rules (setting, value) VALUES ('long_setting_name', '10') - """) - connection.execute(""" - INSERT INTO rule_metadata (rule, key, value) VALUES (8, 'test', '"yes"') - """) - - res = await app_client.get('/api/v1/rules/query', query_string={ - 'settings': 'a,long_setting_name', - 'context_filters': '', - 'include_metadata': str(metadata) - }) - - expected = { - 'settings': { - 'a': {'rules': []}, - 'long_setting_name': {'rules': [ - {'context_features': [], 'value': 10, 'rule_id': 8}, - ]} - } - } - - if metadata: - patch_rule_expectation_with_metadata(expected) - - assert res.json() == expected - - -@mark.asyncio -@mark.parametrize('metadata', [False, True]) -async def test_query_rules_matchall(metadata: bool, app_client, setup_rules): - res = await app_client.get('/api/v1/rules/query', query_string={ - 'settings': 'a,long_setting_name', - 'context_filters': '*', - 'include_metadata': str(metadata) - }) - - expected = { - 'settings': { - 'a': {'rules': [ - {'context_features': [['trust', 'full']], 'value': 1, 'rule_id': 1}, - {'context_features': [['theme', 'black']], 'value': 2, 'rule_id': 2}, - {'context_features': [['trust', 'full'], ['theme', 'black']], 'rule_id': 3, 'value': 3}, - {'context_features': [['user', 'admin'], ['theme', 'black']], 'rule_id': 7, 'value': 7} - ]}, - 'long_setting_name': {'rules': [ - {'context_features': [['trust', 'none']], 'rule_id': 4, 'value': 4}, - {'context_features': [['trust', 'part']], 'rule_id': 5, 'value': 5} - ]} - } - } - if metadata: - patch_rule_expectation_with_metadata(expected) - - assert res.json() == expected - - -@mark.asyncio -@mark.parametrize('metadata', [False, True]) -async def test_query_rules_wildcard_some(metadata: bool, app_client, setup_rules): - res = await app_client.get('/api/v1/rules/query', query_string={ - 'settings': 'a,long_setting_name', - 'context_filters': 'theme:*,trust:(full,none)', - 'include_metadata': str(metadata) - }) - - expected = { - 'settings': { - 'a': {'rules': [ - {'context_features': [['trust', 'full']], 'value': 1, 'rule_id': 1}, - {'context_features': [['theme', 'black']], 'value': 2, 'rule_id': 2}, - {'context_features': [['trust', 'full'], ['theme', 'black']], 'value': 3, 'rule_id': 3}, - ]}, - 'long_setting_name': {'rules': [ - {'context_features': [['trust', 'none']], 'value': 4, 'rule_id': 4}, - ]} - } - } - if metadata: - patch_rule_expectation_with_metadata(expected) - - assert res.json() == expected - - -@mark.asyncio -@mark.parametrize('metadata', [False, True]) -async def test_query_rules_wildcard_only(metadata: bool, app_client, setup_rules): - res = await app_client.get('/api/v1/rules/query', query_string={ - 'settings': 'a,long_setting_name', - 'context_filters': 'theme:*', - 'include_metadata': str(metadata) - }) - - expected = { - 'settings': { - 'a': {'rules': [ - {'context_features': [['theme', 'black']], 'value': 2, 'rule_id': 2}, - ]}, - 'long_setting_name': {'rules': []} - } - } - if metadata: - patch_rule_expectation_with_metadata(expected) - - assert res.json() == expected - - -@mark.asyncio -async def test_query_rules_bad_contexts(app_client, setup_rules): - res = await app_client.get('/api/v1/rules/query', query_string={ - 'settings': 'a,long_setting_name', - 'context_filters': 'theme:(black),trust:(full,part),love:(overflowing)' - }) - assert res.status_code == 404 - - -@mark.asyncio -async def test_query_rules_empty_contexts(app_client, setup_rules): - res = await app_client.get('/api/v1/rules/query', query_string={ - 'settings': 'a,long_setting_name', - 'context_filters': 'trust:()' - }) - assert res.status_code == 422 - - -@mark.asyncio -async def test_query_rules_bad_settings(app_client, setup_rules): - res = await app_client.get('/api/v1/rules/query', query_string={ - 'settings': 'a,d', - 'context_filters': 'trust:(full,part),theme:(black)' - }) - assert res.status_code == 404 - - -@mark.asyncio -@mark.parametrize('options', ['null', '**', 'wildcard']) -async def test_query_rules_bad_options(app_client, setup_rules, options): - res = await app_client.get('/api/v1/rules/query', query_string={ - 'settings': 'a', - 'context_filters': options, - }) - assert res.status_code == 422, res.content - - -@mark.asyncio -@mark.parametrize('metadata', [False, True]) -async def test_query_rules_nosettings(metadata: bool, app_client, setup_rules): - res = await app_client.get('/api/v1/rules/query', query_string={ - 'settings': '', - 'include_metadata': str(metadata) - }) - res.raise_for_status() - - expected = {'settings': {}} - - assert res.json() == expected - - -@mark.asyncio -@mark.parametrize('metadata', [False, True]) -async def test_query_rules_allsettings_no_filter(metadata: bool, app_client, setup_rules): - res = await app_client.get('/api/v1/rules/query', query_string={ - 'include_metadata': str(metadata) - }) - res.raise_for_status() - - expected = {'settings': { - 'a': {'rules': [ - {'context_features': [['trust', 'full']], 'value': 1, 'rule_id': 1}, - {'context_features': [['theme', 'black']], 'value': 2, 'rule_id': 2}, - {'context_features': [['trust', 'full'], ['theme', 'black']], 'value': 3, 'rule_id': 3}, - {'context_features': [['user', 'admin'], ['theme', 'black']], 'value': 7, 'rule_id': 7}, - ]}, - 'b': {'rules': [ - {'context_features': [['trust', 'full']], 'value': 6, 'rule_id': 6}, - ]}, - 'long_setting_name': {'rules': [ - {'context_features': [['trust', 'none']], 'value': 4, 'rule_id': 4}, - {'context_features': [['trust', 'part']], 'value': 5, 'rule_id': 5}, - ]} - }} - - if metadata: - patch_rule_expectation_with_metadata(expected) - - assert res.json() == expected - - -@mark.asyncio -@mark.parametrize('metadata', [False, True]) -async def test_query_rules_allsettings_with_filter(metadata: bool, app_client, setup_rules): - res = await app_client.get('/api/v1/rules/query', query_string={ - 'context_filters': 'user:*,trust:(full),theme:(blue)', - 'include_metadata': str(metadata) - }) - res.raise_for_status() - - expected = {'settings': { - 'a': {'rules': [ - {'context_features': [['trust', 'full']], 'value': 1, 'rule_id': 1}, - ]}, - 'b': {'rules': [ - {'context_features': [['trust', 'full']], 'value': 6, 'rule_id': 6}, - ]}, - 'long_setting_name': {'rules': []} - }} - - if metadata: - patch_rule_expectation_with_metadata(expected) - - assert res.json() == expected - - @fixture(params=['deprecated', 'new']) def patch_callback(request): def callback(client, rule_name, value): @@ -535,7 +230,7 @@ async def test_patch_rule_bad_data(patch_callback, example_rule, app_client): @mark.asyncio @mark.parametrize('metadata', [False, True]) async def test_query_etag(app_client, setup_rules, metadata): - res = await app_client.get('/api/v1/rules/query', query_string={ + res = await app_client.get('/api/v1/query', query_string={ 'settings': 'a,long_setting_name', 'context_filters': 'theme:*,trust:(full,none)', 'include_metadata': str(metadata) @@ -544,7 +239,7 @@ async def test_query_etag(app_client, setup_rules, metadata): etag = res.headers['ETag'] - repeat_resp = await app_client.get('/api/v1/rules/query', query_string={ + repeat_resp = await app_client.get('/api/v1/query', query_string={ 'settings': 'a,long_setting_name', 'context_filters': 'theme:*,trust:(full,none)', 'include_metadata': str(metadata) @@ -558,7 +253,7 @@ async def test_query_etag(app_client, setup_rules, metadata): @mark.asyncio @mark.parametrize('metadata', [False, True]) async def test_query_wrong_etag(app_client, setup_rules, metadata): - res = await app_client.get('/api/v1/rules/query', query_string={ + res = await app_client.get('/api/v1/query', query_string={ 'settings': 'a,long_setting_name', 'context_filters': 'theme:*,trust:(full,none)', 'include_metadata': str(metadata) @@ -568,7 +263,7 @@ async def test_query_wrong_etag(app_client, setup_rules, metadata): etag = res.headers['ETag'] wrong_etag = etag[:5] + '%' + etag[5:] - repeat_resp = await app_client.get('/api/v1/rules/query', query_string={ + repeat_resp = await app_client.get('/api/v1/query', query_string={ 'settings': 'a,long_setting_name', 'context_filters': 'theme:*,trust:(full,none)', 'include_metadata': str(metadata) @@ -582,7 +277,7 @@ async def test_query_wrong_etag(app_client, setup_rules, metadata): @mark.asyncio @mark.parametrize('metadata', [False, True]) async def test_query_etag_wildcard(app_client, setup_rules, metadata): - repeat_resp = await app_client.get('/api/v1/rules/query', query_string={ + repeat_resp = await app_client.get('/api/v1/query', query_string={ 'settings': 'a,long_setting_name', 'context_filters': 'theme:*,trust:(full,none)', 'include_metadata': str(metadata) @@ -595,7 +290,7 @@ async def test_query_etag_wildcard(app_client, setup_rules, metadata): @mark.asyncio @mark.parametrize('metadata', [False, True]) async def test_query_repeat_filter(app_client, setup_rules, metadata): - res = await app_client.get('/api/v1/rules/query', query_string={ + res = await app_client.get('/api/v1/query', query_string={ 'settings': 'a,long_setting_name', 'context_filters': 'theme:*,trust:(full,none),theme:*', 'include_metadata': str(metadata) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index d4013ec..17cb31b 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -5,7 +5,6 @@ from pytest import fixture import heksher.app as app_mod -from heksher.db_logic import DBLogic from heksher.main import app @@ -46,10 +45,8 @@ def mock_engine(): @fixture async def app_client(monkeypatch, mock_engine): monkeypatch.setenv('HEKSHER_DB_CONNECTION_STRING', 'postgresql://dbuser:swordfish@pghost10/') - monkeypatch.setenv('HEKSHER_STARTUP_CONTEXT_FEATURES', '["A","B","C"]') monkeypatch.setattr(app_mod, 'create_async_engine', mock_engine) - monkeypatch.setattr(app_mod, 'DBLogic', lambda *a: AsyncMock(DBLogic)) async with TestClient(app) as app_client: yield app_client