From 2e074f9227c4fa3540ca7ad0a76974ba0019fc38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C2=BB=20teej?= <157928223+titan-teej@users.noreply.github.com> Date: Mon, 23 Sep 2024 19:57:33 -0700 Subject: [PATCH] drift after apply issues (#115) * drift after apply issues --------- Co-authored-by: TJ Murphy <1796+teej@users.noreply.github.com> --- .../fixtures/json/authentication_policy.json | 2 +- tests/fixtures/json/view.json | 5 +- tests/fixtures/json/warehouse.json | 8 +- tests/fixtures/sql/authentication_policy.sql | 3 +- .../data_provider/test_fetch_resource.py | 657 +----------------- .../test_fetch_resource_simple.py | 372 ++++++++++ .../data_provider/test_list_resource.py | 31 +- tests/integration/test_blueprint.py | 30 +- tests/integration/test_lifecycle.py | 21 +- tests/test_blueprint.py | 70 +- tests/test_blueprint_ownership.py | 8 +- tests/test_identities.py | 27 +- tests/test_plan.py | 3 +- tests/test_resource_containers.py | 2 +- tests/test_resources.py | 8 +- titan/blueprint.py | 378 ++++++---- titan/cli.py | 5 +- titan/data_provider.py | 82 ++- titan/data_types.py | 4 +- titan/diff.py | 106 --- titan/exceptions.py | 12 + titan/privs.py | 1 + titan/props.py | 2 +- titan/resources/database.py | 9 +- .../resources/external_access_integration.py | 24 +- titan/resources/grant.py | 2 +- titan/resources/replication_group.py | 2 +- titan/resources/resource.py | 123 ++-- titan/resources/role.py | 22 +- titan/resources/schema.py | 12 +- titan/resources/secret.py | 2 +- titan/resources/user.py | 10 +- titan/resources/view.py | 57 +- titan/resources/warehouse.py | 48 +- tools/reset_test_account.py | 1 + tools/test_account.yml | 11 +- tools/test_account_enterprise.yml | 10 - 37 files changed, 1068 insertions(+), 1102 deletions(-) create mode 100644 tests/integration/data_provider/test_fetch_resource_simple.py delete mode 100644 titan/diff.py diff --git a/tests/fixtures/json/authentication_policy.json b/tests/fixtures/json/authentication_policy.json index f30a0bb..50bdd4c 100644 --- a/tests/fixtures/json/authentication_policy.json +++ b/tests/fixtures/json/authentication_policy.json @@ -3,7 +3,7 @@ "client_types": [ "SNOWFLAKE_UI" ], - "comment": "Auth policy that only allows access through the web interface", + "comment": null, "owner": "SECURITYADMIN", "mfa_enrollment": "OPTIONAL", "security_integrations": [ diff --git a/tests/fixtures/json/view.json b/tests/fixtures/json/view.json index 3e6598b..f8caa85 100644 --- a/tests/fixtures/json/view.json +++ b/tests/fixtures/json/view.json @@ -6,7 +6,10 @@ "change_tracking": false, "columns": [ { - "name": "id" + "name": "id", + "data_type": null, + "comment": "this is a column comment", + "not_null": false } ], "comment": "This is a view", diff --git a/tests/fixtures/json/warehouse.json b/tests/fixtures/json/warehouse.json index 72779b1..ec91936 100644 --- a/tests/fixtures/json/warehouse.json +++ b/tests/fixtures/json/warehouse.json @@ -3,16 +3,16 @@ "owner": "SYSADMIN", "warehouse_type": "STANDARD", "warehouse_size": "XSMALL", - "max_cluster_count": null, - "min_cluster_count": null, - "scaling_policy": null, + "max_cluster_count": 1, + "min_cluster_count": 1, + "scaling_policy": "STANDARD", "auto_suspend": 60, "auto_resume": false, "initially_suspended": true, "resource_monitor": null, "comment": "My XSMALL warehouse", "enable_query_acceleration": false, - "query_acceleration_max_scale_factor": null, + "query_acceleration_max_scale_factor": 8, "max_concurrency_level": 8, "statement_queued_timeout_in_seconds": 0, "statement_timeout_in_seconds": 172800 diff --git a/tests/fixtures/sql/authentication_policy.sql b/tests/fixtures/sql/authentication_policy.sql index 9351672..8f4ba9a 100644 --- a/tests/fixtures/sql/authentication_policy.sql +++ b/tests/fixtures/sql/authentication_policy.sql @@ -1,3 +1,2 @@ CREATE AUTHENTICATION POLICY restrict_client_types_policy - CLIENT_TYPES = ('SNOWFLAKE_UI') - COMMENT = 'Auth policy that only allows access through the web interface'; \ No newline at end of file + CLIENT_TYPES = ('SNOWFLAKE_UI'); \ No newline at end of file diff --git a/tests/integration/data_provider/test_fetch_resource.py b/tests/integration/data_provider/test_fetch_resource.py index e93a3b1..218cca0 100644 --- a/tests/integration/data_provider/test_fetch_resource.py +++ b/tests/integration/data_provider/test_fetch_resource.py @@ -5,15 +5,16 @@ from tests.helpers import ( assert_resource_dicts_eq_ignore_nulls, assert_resource_dicts_eq_ignore_nulls_and_unfetchable, - safe_fetch, clean_resource_data, + safe_fetch, ) from titan import data_provider from titan import resources as res from titan.client import reset_cache -from titan.enums import ResourceType +from titan.enums import AccountEdition, ResourceType from titan.identifiers import URN, parse_FQN, parse_URN from titan.resource_name import ResourceName +from titan.resources import Resource from titan.resources.resource import ResourcePointer pytestmark = pytest.mark.requires_snowflake @@ -34,8 +35,10 @@ def email_address(cursor): return user["email"] -def create(cursor, resource): - sql = resource.create_sql(if_not_exists=True) +def create(cursor, resource: Resource): + session_ctx = data_provider.fetch_session(cursor.connection) + account_edition = AccountEdition.ENTERPRISE if session_ctx["tag_support"] else AccountEdition.STANDARD + sql = resource.create_sql(account_edition=account_edition, if_not_exists=True) try: cursor.execute(sql) except Exception as err: @@ -123,20 +126,6 @@ def test_fetch_grant_on_account(cursor, suffix): cursor.execute(role.drop_sql(if_exists=True)) -def test_fetch_database(cursor, suffix, marked_for_cleanup): - database = res.Database( - name=f"SOMEDB_{suffix}", - owner=TEST_ROLE, - transient=True, - ) - create(cursor, database) - marked_for_cleanup.append(database) - - result = safe_fetch(cursor, database.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, database.to_dict()) - - def test_fetch_grant_all_on_resource(cursor): cursor.execute("GRANT ALL ON WAREHOUSE STATIC_WAREHOUSE TO ROLE STATIC_ROLE") grant_all_urn = parse_URN("urn:::grant/STATIC_ROLE?priv=ALL&on=warehouse/STATIC_WAREHOUSE") @@ -160,113 +149,6 @@ def test_fetch_grant_all_on_resource(cursor): cursor.execute("REVOKE ALL ON WAREHOUSE STATIC_WAREHOUSE FROM ROLE STATIC_ROLE") -def test_fetch_external_stage(cursor, test_db, marked_for_cleanup): - external_stage = res.ExternalStage( - name="EXTERNAL_STAGE_EXAMPLE", - url="s3://titan-snowflake/", - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - ) - create(cursor, external_stage) - marked_for_cleanup.append(external_stage) - - result = safe_fetch(cursor, external_stage.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, external_stage.to_dict()) - - external_stage = res.ExternalStage( - name="EXTERNAL_STAGE_EXAMPLE_WITH_DIRECTORY", - url="s3://titan-snowflake/", - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - directory={"enable": True}, - ) - create(cursor, external_stage) - marked_for_cleanup.append(external_stage) - - result = safe_fetch(cursor, external_stage.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, external_stage.to_dict()) - - -def test_fetch_internal_stage(cursor, test_db, marked_for_cleanup): - internal_stage = res.InternalStage( - name="INTERNAL_STAGE_EXAMPLE", - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - ) - create(cursor, internal_stage) - marked_for_cleanup.append(internal_stage) - - result = safe_fetch(cursor, internal_stage.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, internal_stage.to_dict()) - - internal_stage = res.InternalStage( - name="INTERNAL_STAGE_EXAMPLE_WITH_DIRECTORY", - directory={"enable": True}, - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - ) - create(cursor, internal_stage) - marked_for_cleanup.append(internal_stage) - - result = safe_fetch(cursor, internal_stage.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, internal_stage.to_dict()) - - -def test_fetch_csv_file_format(cursor, test_db, marked_for_cleanup): - csv_file_format = res.CSVFileFormat( - name="CSV_FILE_FORMAT_EXAMPLE", - owner=TEST_ROLE, - field_delimiter="|", - skip_header=1, - null_if=["NULL", "null"], - empty_field_as_null=True, - compression="GZIP", - database=test_db, - schema="PUBLIC", - ) - create(cursor, csv_file_format) - marked_for_cleanup.append(csv_file_format) - - result = safe_fetch(cursor, csv_file_format.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, csv_file_format.to_dict()) - - csv_file_format = res.CSVFileFormat( - name="CSV_FILE_FORMAT_EXAMPLE_ALL_DEFAULTS", - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - ) - create(cursor, csv_file_format) - marked_for_cleanup.append(csv_file_format) - - result = safe_fetch(cursor, csv_file_format.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, csv_file_format.to_dict()) - - -def test_fetch_resource_monitor(cursor, marked_for_cleanup): - resource_monitor = res.ResourceMonitor( - name="RESOURCE_MONITOR_EXAMPLE", - credit_quota=1000, - start_timestamp="2049-01-01 00:00", - ) - create(cursor, resource_monitor) - marked_for_cleanup.append(resource_monitor) - - result = safe_fetch(cursor, resource_monitor.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, resource_monitor.to_dict()) - - def test_fetch_email_notification_integration(cursor, email_address, marked_for_cleanup): email_notification_integration = res.EmailNotificationIntegration( @@ -284,25 +166,6 @@ def test_fetch_email_notification_integration(cursor, email_address, marked_for_ assert result == data_provider.remove_none_values(email_notification_integration.to_dict()) -def test_fetch_event_table(cursor, test_db, marked_for_cleanup): - event_table = res.EventTable( - name="EVENT_TABLE_EXAMPLE", - change_tracking=True, - cluster_by=["START_TIMESTAMP"], - data_retention_time_in_days=1, - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - ) - create(cursor, event_table) - marked_for_cleanup.append(event_table) - - result = safe_fetch(cursor, event_table.urn) - assert result is not None - result = data_provider.remove_none_values(result) - assert result == data_provider.remove_none_values(event_table.to_dict()) - - def test_fetch_grant_with_fully_qualified_ref(cursor, test_db, suffix, marked_for_cleanup): cursor.execute(f"USE DATABASE {test_db}") cursor.execute(f"CREATE SCHEMA if not exists {test_db}.my_schema") @@ -358,25 +221,6 @@ def test_fetch_pipe(cursor, test_db, marked_for_cleanup): assert result == data_provider.remove_none_values(pipe.to_dict()) -def test_fetch_view(cursor, test_db, marked_for_cleanup): - view = res.View( - name="VIEW_EXAMPLE", - as_="SELECT 1 as id FROM STATIC_DATABASE.PUBLIC.STATIC_TABLE", - columns=[{"name": "ID", "data_type": "NUMBER(1,0)", "not_null": False}], - comment="View for testing", - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - ) - create(cursor, view) - marked_for_cleanup.append(view) - - result = safe_fetch(cursor, view.urn) - assert result is not None - result = data_provider.remove_none_values(result) - assert_resource_dicts_eq_ignore_nulls_and_unfetchable(res.View.spec, result, view.to_dict()) - - @pytest.mark.enterprise def test_fetch_tag(cursor, test_db, marked_for_cleanup): tag = res.Tag( @@ -411,16 +255,6 @@ def test_fetch_tag(cursor, test_db, marked_for_cleanup): assert result == data_provider.remove_none_values(tag.to_dict()) -def test_fetch_role(cursor, suffix, marked_for_cleanup): - role = res.Role(name=f"ANOTHER_ROLE_{suffix}", owner=TEST_ROLE) - create(cursor, role) - marked_for_cleanup.append(role) - - result = safe_fetch(cursor, role.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, role.to_dict()) - - def test_fetch_role_grant(cursor, suffix, marked_for_cleanup): parent = res.Role(name=f"PARENT_ROLE_{suffix}", owner=TEST_ROLE) child = res.Role(name=f"CHILD_ROLE_{suffix}", owner=TEST_ROLE) @@ -449,309 +283,6 @@ def test_fetch_role_grant(cursor, suffix, marked_for_cleanup): assert_resource_dicts_eq_ignore_nulls(result, grant.to_dict()) -def test_fetch_user(cursor, suffix, marked_for_cleanup): - user = res.User( - name=f"SOME_USER_{suffix}@applytitan.com", - owner=TEST_ROLE, - ) - create(cursor, user) - marked_for_cleanup.append(user) - - result = safe_fetch(cursor, user.urn) - assert result is not None - result = clean_resource_data(res.User.spec, result) - data = clean_resource_data(res.User.spec, user.to_dict()) - assert result == data - - user = res.User( - name=f"SOME_USER_TYPE_PERSON_{suffix}@applytitan.com", - owner=TEST_ROLE, - type="PERSON", - ) - create(cursor, user) - marked_for_cleanup.append(user) - - result = safe_fetch(cursor, user.urn) - assert result is not None - result = clean_resource_data(res.User.spec, result) - data = clean_resource_data(res.User.spec, user.to_dict()) - assert result == data - - -def test_fetch_glue_catalog_integration(cursor, marked_for_cleanup): - catalog_integration = res.GlueCatalogIntegration( - name="some_catalog_integration", - table_format="ICEBERG", - glue_aws_role_arn="arn:aws:iam::123456789012:role/SnowflakeAccess", - glue_catalog_id="123456789012", - catalog_namespace="some_namespace", - enabled=True, - glue_region="us-west-2", - comment="Integration for AWS Glue with Snowflake.", - owner=TEST_ROLE, - ) - create(cursor, catalog_integration) - marked_for_cleanup.append(catalog_integration) - - result = safe_fetch(cursor, catalog_integration.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, catalog_integration.to_dict()) - - -def test_fetch_object_store_catalog_integration(cursor, marked_for_cleanup): - catalog_integration = res.ObjectStoreCatalogIntegration( - name="OBJECT_STORE_CATALOG_INTEGRATION_EXAMPLE", - catalog_source="OBJECT_STORE", - table_format="ICEBERG", - enabled=True, - comment="Catalog integration for testing", - owner=TEST_ROLE, - ) - create(cursor, catalog_integration) - marked_for_cleanup.append(catalog_integration) - - result = safe_fetch(cursor, catalog_integration.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, catalog_integration.to_dict()) - - -def test_fetch_share(cursor, suffix, marked_for_cleanup): - share = res.Share( - name=f"SHARE_EXAMPLE_{suffix}", - comment="Share for testing", - owner=TEST_ROLE, - ) - create(cursor, share) - marked_for_cleanup.append(share) - - result = safe_fetch(cursor, share.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, share.to_dict()) - - -def test_fetch_s3_storage_integration(cursor, suffix, marked_for_cleanup): - storage_integration = res.S3StorageIntegration( - name=f"S3_STORAGE_INTEGRATION_EXAMPLE_{suffix}", - storage_provider="S3", - storage_aws_role_arn="arn:aws:iam::001234567890:role/myrole", - enabled=True, - storage_allowed_locations=["s3://mybucket1/path1/", "s3://mybucket2/path2/"], - owner=TEST_ROLE, - ) - create(cursor, storage_integration) - marked_for_cleanup.append(storage_integration) - - result = safe_fetch(cursor, storage_integration.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, storage_integration.to_dict()) - - -def test_fetch_gcs_storage_integration(cursor, suffix, marked_for_cleanup): - storage_integration = res.GCSStorageIntegration( - name=f"GCS_STORAGE_INTEGRATION_EXAMPLE_{suffix}", - enabled=True, - storage_allowed_locations=["gcs://mybucket1/path1/", "gcs://mybucket2/path2/"], - owner=TEST_ROLE, - ) - create(cursor, storage_integration) - marked_for_cleanup.append(storage_integration) - - result = safe_fetch(cursor, storage_integration.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, storage_integration.to_dict()) - - -def test_fetch_azure_storage_integration(cursor, suffix, marked_for_cleanup): - storage_integration = res.AzureStorageIntegration( - name=f"AZURE_STORAGE_INTEGRATION_EXAMPLE_{suffix}", - enabled=True, - azure_tenant_id="a123b4c5-1234-123a-a12b-1a23b45678c9", - storage_allowed_locations=[ - "azure://myaccount.blob.core.windows.net/mycontainer/path1/", - "azure://myaccount.blob.core.windows.net/mycontainer/path2/", - ], - owner=TEST_ROLE, - ) - create(cursor, storage_integration) - marked_for_cleanup.append(storage_integration) - - result = safe_fetch(cursor, storage_integration.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, storage_integration.to_dict()) - - -def test_fetch_alert(cursor, suffix, test_db, marked_for_cleanup): - alert = res.Alert( - name=f"SOMEALERT_{suffix}", - warehouse="STATIC_WAREHOUSE", - schedule="60 MINUTE", - condition="SELECT 1", - then="SELECT 1", - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - ) - create(cursor, alert) - marked_for_cleanup.append(alert) - - result = safe_fetch(cursor, alert.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, alert.to_dict()) - - -def test_fetch_dynamic_table(cursor, test_db, marked_for_cleanup): - dynamic_table = res.DynamicTable( - name="PRODUCT", - columns=[{"name": "ID", "comment": "This is a comment"}], - target_lag="20 minutes", - warehouse="CI", - refresh_mode="AUTO", - initialize="ON_CREATE", - comment="this is a comment", - as_="SELECT id FROM STATIC_DATABASE.PUBLIC.STATIC_TABLE", - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - ) - create(cursor, dynamic_table) - marked_for_cleanup.append(dynamic_table) - - result = safe_fetch(cursor, dynamic_table.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, dynamic_table.to_dict()) - - -def test_fetch_javascript_udf(cursor, test_db, marked_for_cleanup): - function = res.JavascriptUDF( - name="SOME_JAVASCRIPT_UDF", - args=[{"name": "INPUT_ARG", "data_type": "VARIANT"}], - returns="FLOAT", - volatility="VOLATILE", - as_="return 42;", - secure=False, - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - ) - create(cursor, function) - marked_for_cleanup.append(function) - - result = safe_fetch(cursor, function.urn) - assert result is not None - result = clean_resource_data(res.JavascriptUDF.spec, result) - data = clean_resource_data(res.JavascriptUDF.spec, function.to_dict()) - assert result == data - - -def test_fetch_password_policy(cursor, test_db, marked_for_cleanup): - password_policy = res.PasswordPolicy( - name="SOME_PASSWORD_POLICY", - password_min_length=12, - password_max_length=24, - password_min_upper_case_chars=2, - password_min_lower_case_chars=2, - password_min_numeric_chars=2, - password_min_special_chars=2, - password_min_age_days=1, - password_max_age_days=30, - password_max_retries=3, - password_lockout_time_mins=30, - password_history=5, - # comment="production account password policy", # Leaving this out until Snowflake fixes their bugs - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - ) - create(cursor, password_policy) - marked_for_cleanup.append(password_policy) - - result = safe_fetch(cursor, password_policy.urn) - assert result is not None - assert result == password_policy.to_dict() - - -def test_fetch_python_stored_procedure(cursor, suffix, test_db, marked_for_cleanup): - procedure = res.PythonStoredProcedure( - name=f"somesproc_{suffix}", - args=[{"name": "ARG1", "data_type": "VARCHAR"}], - returns="NUMBER", - packages=["snowflake-snowpark-python"], - runtime_version="3.9", - handler="main", - execute_as="OWNER", - comment="user-defined procedure", - imports=[], - null_handling="CALLED ON NULL INPUT", - secure=False, - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - as_="def main(arg1): return 42", - ) - cursor.execute(procedure.create_sql()) - marked_for_cleanup.append(procedure) - - result = safe_fetch(cursor, procedure.urn) - assert result is not None - result = clean_resource_data(res.PythonStoredProcedure.spec, result) - data = clean_resource_data(res.PythonStoredProcedure.spec, procedure.to_dict()) - assert result == data - - -def test_fetch_schema(cursor, test_db, marked_for_cleanup): - schema = res.Schema( - name="SOMESCH", - data_retention_time_in_days=1, - max_data_extension_time_in_days=3, - transient=False, - managed_access=False, - owner=TEST_ROLE, - database=test_db, - ) - create(cursor, schema) - marked_for_cleanup.append(schema) - - result = safe_fetch(cursor, schema.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, schema.to_dict()) - - -def test_fetch_sequence(cursor, suffix, test_db, marked_for_cleanup): - sequence = res.Sequence( - name=f"SOMESEQ_{suffix}", - start=1, - increment=2, - comment="+3", - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - ) - create(cursor, sequence) - marked_for_cleanup.append(sequence) - - result = safe_fetch(cursor, sequence.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, sequence.to_dict()) - - -def test_fetch_task(cursor, suffix, test_db, marked_for_cleanup): - task = res.Task( - name=f"SOMETASK_{suffix}", - schedule="60 MINUTE", - state="SUSPENDED", - as_="SELECT 1", - owner=TEST_ROLE, - database=test_db, - schema="PUBLIC", - ) - create(cursor, task) - marked_for_cleanup.append(task) - - result = safe_fetch(cursor, task.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, task.to_dict()) - - def test_fetch_network_rule(cursor, suffix, test_db, marked_for_cleanup): network_rule = res.NetworkRule( name=f"NETWORK_RULE_EXAMPLE_HOST_PORT_{suffix}", @@ -831,36 +362,6 @@ def test_fetch_api_integration(cursor, suffix, marked_for_cleanup): assert result == data -def test_fetch_database_role(cursor, suffix, test_db, marked_for_cleanup): - database_role = res.DatabaseRole( - name=f"DATABASE_ROLE_EXAMPLE_{suffix}", - database=test_db, - owner=TEST_ROLE, - ) - create(cursor, database_role) - marked_for_cleanup.append(database_role) - - result = safe_fetch(cursor, database_role.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, database_role.to_dict()) - - -def test_fetch_packages_policy(cursor, suffix, marked_for_cleanup): - packages_policy = res.PackagesPolicy( - name=f"PACKAGES_POLICY_EXAMPLE_{suffix}", - allowlist=["numpy", "pandas"], - blocklist=["os", "sys"], - comment="Example packages policy", - owner=TEST_ROLE, - ) - create(cursor, packages_policy) - marked_for_cleanup.append(packages_policy) - - result = safe_fetch(cursor, packages_policy.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, packages_policy.to_dict()) - - @pytest.mark.enterprise def test_fetch_aggregation_policy(cursor, suffix, test_db, marked_for_cleanup): aggregation_policy = res.AggregationPolicy( @@ -898,22 +399,6 @@ def test_fetch_compute_pool(cursor, suffix, marked_for_cleanup): assert result == data -def test_fetch_warehouse(cursor, suffix, marked_for_cleanup): - warehouse = res.Warehouse( - name=f"SOME_WAREHOUSE_{suffix}", - warehouse_size="XSMALL", - auto_suspend=60, - auto_resume=True, - owner=TEST_ROLE, - ) - create(cursor, warehouse) - marked_for_cleanup.append(warehouse) - - result = safe_fetch(cursor, warehouse.urn) - assert result is not None - assert_resource_dicts_eq_ignore_nulls(result, warehouse.to_dict()) - - def test_fetch_password_secret(cursor, suffix, marked_for_cleanup): secret = res.PasswordSecret( name=f"PASSWORD_SECRET_EXAMPLE_{suffix}", @@ -1066,42 +551,6 @@ def test_fetch_stage_stream(cursor, suffix, marked_for_cleanup): assert_resource_dicts_eq_ignore_nulls_and_unfetchable(res.StageStream.spec, result, stream.to_dict()) -def test_fetch_authentication_policies(cursor, suffix, marked_for_cleanup): - policy = res.AuthenticationPolicy( - name=f"SOME_AUTHENTICATION_POLICY_{suffix}", - mfa_authentication_methods=["PASSWORD", "SAML"], - mfa_enrollment="REQUIRED", - client_types=["SNOWFLAKE_UI"], - comment="Authentication policy for testing", - owner=TEST_ROLE, - ) - create(cursor, policy) - marked_for_cleanup.append(policy) - - result = safe_fetch(cursor, policy.urn) - assert result is not None - result = clean_resource_data(res.AuthenticationPolicy.spec, result) - data = clean_resource_data(res.AuthenticationPolicy.spec, policy.to_dict()) - assert result == data - - -def test_fetch_external_access_integration(cursor, suffix, marked_for_cleanup): - integration = res.ExternalAccessIntegration( - name=f"EXTERNAL_ACCESS_INTEGRATION_{suffix}", - allowed_network_rules=["static_database.public.static_network_rule"], - comment="External access integration for testing", - owner=TEST_ROLE, - ) - create(cursor, integration) - marked_for_cleanup.append(integration) - - result = safe_fetch(cursor, integration.urn) - assert result is not None - result = clean_resource_data(res.ExternalAccessIntegration.spec, result) - data = clean_resource_data(res.ExternalAccessIntegration.spec, integration.to_dict()) - assert result == data - - def test_fetch_parquet_file_format(cursor, suffix, marked_for_cleanup): file_format = res.ParquetFileFormat( name=f"SOME_PARQUET_FILE_FORMAT_{suffix}", @@ -1118,38 +567,6 @@ def test_fetch_parquet_file_format(cursor, suffix, marked_for_cleanup): assert result == data -def test_fetch_json_file_format(cursor, suffix, marked_for_cleanup): - file_format = res.JSONFileFormat( - name=f"SOME_JSON_FILE_FORMAT_{suffix}", - owner=TEST_ROLE, - ) - create(cursor, file_format) - marked_for_cleanup.append(file_format) - - result = safe_fetch(cursor, file_format.urn) - assert result is not None - result = clean_resource_data(res.JSONFileFormat.spec, result) - data = clean_resource_data(res.JSONFileFormat.spec, file_format.to_dict()) - assert result == data - - -def test_fetch_notebook(cursor, suffix, marked_for_cleanup): - notebook = res.Notebook( - name=f"SOME_NOTEBOOK_{suffix}", - query_warehouse="static_warehouse", - comment="This is a test notebook", - owner=TEST_ROLE, - ) - create(cursor, notebook) - marked_for_cleanup.append(notebook) - - result = safe_fetch(cursor, notebook.urn) - assert result is not None - result = clean_resource_data(res.Notebook.spec, result) - data = clean_resource_data(res.Notebook.spec, notebook.to_dict()) - assert result == data - - def test_fetch_network_policy(cursor, suffix, marked_for_cleanup): policy = res.NetworkPolicy( name=f"SOME_NETWORK_POLICY_{suffix}", @@ -1170,42 +587,6 @@ def test_fetch_network_policy(cursor, suffix, marked_for_cleanup): assert result == data -def test_fetch_table(cursor, suffix, marked_for_cleanup): - table = res.Table( - name=f"SOME_TABLE_{suffix}", - columns=[ - res.Column(name="ID", data_type="NUMBER(38,0)", not_null=True), - res.Column(name="NAME", data_type="VARCHAR(16777216)", not_null=False), - ], - database="STATIC_DATABASE", - schema="PUBLIC", - owner=TEST_ROLE, - ) - create(cursor, table) - marked_for_cleanup.append(table) - - result = safe_fetch(cursor, table.urn) - assert result is not None - result = clean_resource_data(res.Table.spec, result) - data = clean_resource_data(res.Table.spec, table.to_dict()) - assert result == data - - -def test_fetch_image_repository(cursor, suffix, marked_for_cleanup): - repository = res.ImageRepository( - name=f"SOME_IMAGE_REPOSITORY_{suffix}", - owner=TEST_ROLE, - ) - create(cursor, repository) - marked_for_cleanup.append(repository) - - result = safe_fetch(cursor, repository.urn) - assert result is not None - result = clean_resource_data(res.ImageRepository.spec, result) - data = clean_resource_data(res.ImageRepository.spec, repository.to_dict()) - assert result == data - - def test_fetch_external_volume(cursor, suffix, marked_for_cleanup): from titan.resources.external_volume import ExternalVolumeStorageLocation @@ -1237,27 +618,3 @@ def test_fetch_external_volume(cursor, suffix, marked_for_cleanup): ExternalVolumeStorageLocation.spec, data_storage_locations[0] ) assert result == data - - -def test_fetch_iceberg_table(cursor, suffix, marked_for_cleanup): - table = res.SnowflakeIcebergTable( - name=f"SOME_ICEBERG_TABLE_{suffix}", - columns=[ - res.Column(name="ID", data_type="NUMBER(38,0)", not_null=True), - res.Column(name="NAME", data_type="VARCHAR(16777216)", not_null=False), - ], - database="STATIC_DATABASE", - schema="PUBLIC", - owner=TEST_ROLE, - catalog="SNOWFLAKE", - external_volume="static_external_volume", - base_location="some_prefix", - ) - cursor.execute(table.create_sql(if_not_exists=True)) - marked_for_cleanup.append(table) - - result = safe_fetch(cursor, table.urn) - assert result is not None - result = clean_resource_data(res.SnowflakeIcebergTable.spec, result) - data = clean_resource_data(res.SnowflakeIcebergTable.spec, table.to_dict()) - assert result == data diff --git a/tests/integration/data_provider/test_fetch_resource_simple.py b/tests/integration/data_provider/test_fetch_resource_simple.py new file mode 100644 index 0000000..6f56761 --- /dev/null +++ b/tests/integration/data_provider/test_fetch_resource_simple.py @@ -0,0 +1,372 @@ +import os + +import pytest + +from tests.helpers import safe_fetch +from titan import data_provider +from titan import resources as res +from titan.enums import AccountEdition +from titan.resources import Resource +from titan.scope import AccountScope, DatabaseScope, SchemaScope + +pytestmark = pytest.mark.requires_snowflake + +TEST_ROLE = os.environ.get("TEST_SNOWFLAKE_ROLE", "ACCOUNTADMIN") +TEST_USER = os.environ.get("TEST_SNOWFLAKE_USER") + + +@pytest.fixture(scope="session") +def account_edition(cursor): + session_ctx = data_provider.fetch_session(cursor.connection) + return AccountEdition.ENTERPRISE if session_ctx["tag_support"] else AccountEdition.STANDARD + + +@pytest.fixture(scope="session") +def email_address(cursor): + user = cursor.execute(f"SHOW TERSE USERS LIKE '{TEST_USER}'").fetchone() + return user["email"] + + +def strip_unfetchable_fields(spec, data: dict) -> dict: + keys = set(data.keys()) + for attr in keys: + attr_metadata = spec.get_metadata(attr) + if not attr_metadata.fetchable or attr_metadata.known_after_apply: + data.pop(attr, None) + return data + + +def resource_fixtures() -> list: + return [ + res.Alert( + name="TEST_FETCH_ALERT", + warehouse="STATIC_WAREHOUSE", + schedule="60 MINUTE", + condition="SELECT 1", + then="SELECT 1", + owner=TEST_ROLE, + ), + res.AuthenticationPolicy( + name="TEST_FETCH_AUTHENTICATION_POLICY", + mfa_authentication_methods=["PASSWORD", "SAML"], + mfa_enrollment="REQUIRED", + client_types=["SNOWFLAKE_UI"], + security_integrations=["STATIC_SECURITY_INTEGRATION"], + owner=TEST_ROLE, + ), + res.AzureStorageIntegration( + name="TEST_FETCH_AZURE_STORAGE_INTEGRATION", + enabled=True, + azure_tenant_id="a123b4c5-1234-123a-a12b-1a23b45678c9", + storage_allowed_locations=[ + "azure://myaccount.blob.core.windows.net/mycontainer/path1/", + "azure://myaccount.blob.core.windows.net/mycontainer/path2/", + ], + owner=TEST_ROLE, + ), + res.CSVFileFormat( + name="TEST_FETCH_CSV_FILE_FORMAT", + owner=TEST_ROLE, + field_delimiter="|", + skip_header=1, + null_if=["NULL", "null"], + empty_field_as_null=True, + compression="GZIP", + ), + res.Database( + name="TEST_FETCH_DATABASE", + owner=TEST_ROLE, + transient=True, + data_retention_time_in_days=1, + max_data_extension_time_in_days=3, + comment="This is a test database", + ), + res.DynamicTable( + name="TEST_FETCH_DYNAMIC_TABLE", + columns=[{"name": "ID", "comment": "This is a comment"}], + target_lag="20 minutes", + warehouse="STATIC_WAREHOUSE", + refresh_mode="AUTO", + initialize="ON_CREATE", + comment="this is a comment", + as_="SELECT id FROM STATIC_DATABASE.PUBLIC.STATIC_TABLE", + owner=TEST_ROLE, + ), + res.EventTable( + name="TEST_FETCH_EVENT_TABLE", + change_tracking=True, + cluster_by=["START_TIMESTAMP"], + data_retention_time_in_days=1, + owner=TEST_ROLE, + comment="This is a test event table", + ), + res.ExternalAccessIntegration( + name="TEST_FETCH_EXTERNAL_ACCESS_INTEGRATION", + allowed_network_rules=["static_database.public.static_network_rule"], + allowed_authentication_secrets=["static_database.public.static_secret"], + enabled=True, + comment="External access integration for testing", + owner=TEST_ROLE, + ), + res.ExternalStage( + name="TEST_FETCH_EXTERNAL_STAGE", + url="s3://titan-snowflake/", + owner=TEST_ROLE, + directory={"enable": True}, + comment="This is a test external stage", + ), + res.GCSStorageIntegration( + name="TEST_FETCH_GCS_STORAGE_INTEGRATION", + enabled=True, + storage_allowed_locations=["gcs://mybucket1/path1/", "gcs://mybucket2/path2/"], + owner=TEST_ROLE, + ), + res.GlueCatalogIntegration( + name="TEST_FETCH_GLUE_CATALOG_INTEGRATION", + table_format="ICEBERG", + glue_aws_role_arn="arn:aws:iam::123456789012:role/SnowflakeAccess", + glue_catalog_id="123456789012", + catalog_namespace="some_namespace", + enabled=True, + glue_region="us-west-2", + comment="Integration for AWS Glue with Snowflake.", + owner=TEST_ROLE, + ), + res.ImageRepository( + name="TEST_FETCH_IMAGE_REPOSITORY", + owner=TEST_ROLE, + ), + res.InternalStage( + name="TEST_FETCH_INTERNAL_STAGE", + directory={"enable": True}, + owner=TEST_ROLE, + comment="This is a test internal stage", + ), + res.JavascriptUDF( + name="SOME_JAVASCRIPT_UDF", + args=[{"name": "INPUT_ARG", "data_type": "VARIANT"}], + returns="FLOAT", + volatility="VOLATILE", + as_="return 42;", + secure=False, + owner=TEST_ROLE, + ), + res.JSONFileFormat( + name="TEST_FETCH_JSON_FILE_FORMAT", + owner=TEST_ROLE, + compression="GZIP", + replace_invalid_characters=True, + comment="This is a test JSON file format", + ), + res.Notebook( + name="TEST_FETCH_NOTEBOOK", + query_warehouse="static_warehouse", + comment="This is a test notebook", + owner=TEST_ROLE, + ), + res.ObjectStoreCatalogIntegration( + name="TEST_FETCH_OBJECT_STORE_CATALOG_INTEGRATION", + catalog_source="OBJECT_STORE", + table_format="ICEBERG", + enabled=True, + comment="Catalog integration for testing", + owner=TEST_ROLE, + ), + res.PasswordPolicy( + name="TEST_FETCH_PASSWORD_POLICY", + password_min_length=12, + password_max_length=24, + password_min_upper_case_chars=2, + password_min_lower_case_chars=2, + password_min_numeric_chars=2, + password_min_special_chars=2, + password_min_age_days=1, + password_max_age_days=30, + password_max_retries=3, + password_lockout_time_mins=30, + password_history=5, + # comment="production account password policy", # Leaving this out until Snowflake fixes their bugs + owner=TEST_ROLE, + ), + res.PackagesPolicy( + name="TEST_FETCH_PACKAGES_POLICY", + allowlist=["numpy", "pandas"], + blocklist=["os", "sys"], + additional_creation_blocklist=["numpy.random.randint"], + comment="Example packages policy", + owner=TEST_ROLE, + ), + res.PythonStoredProcedure( + name="TEST_FETCH_PYTHON_STORED_PROCEDURE", + args=[{"name": "ARG1", "data_type": "VARCHAR"}], + returns="NUMBER", + packages=["snowflake-snowpark-python"], + runtime_version="3.9", + handler="main", + execute_as="OWNER", + comment="user-defined procedure", + imports=[], + null_handling="CALLED ON NULL INPUT", + secure=False, + owner=TEST_ROLE, + as_="def main(arg1): return 42", + ), + res.ResourceMonitor( + name="TEST_FETCH_RESOURCE_MONITOR", + credit_quota=1000, + start_timestamp="2049-01-01 00:00", + ), + res.Role( + name="TEST_FETCH_ROLE", + owner=TEST_ROLE, + ), + res.S3StorageIntegration( + name="TEST_FETCH_S3_STORAGE_INTEGRATION", + storage_provider="S3", + storage_aws_role_arn="arn:aws:iam::001234567890:role/myrole", + enabled=True, + storage_allowed_locations=["s3://mybucket1/path1/", "s3://mybucket2/path2/"], + owner=TEST_ROLE, + ), + res.Schema( + name="TEST_FETCH_SCHEMA", + owner=TEST_ROLE, + transient=True, + managed_access=True, + comment="This is a test schema", + ), + res.Sequence( + name="TEST_FETCH_SEQUENCE", + start=1, + increment=2, + comment="+3", + owner=TEST_ROLE, + ), + res.Share( + name="TEST_FETCH_SHARE", + comment="Share for testing", + owner=TEST_ROLE, + ), + res.SnowflakeIcebergTable( + name="TEST_FETCH_SNOWFLAKE_ICEBERG_TABLE", + columns=[ + res.Column(name="ID", data_type="NUMBER(38,0)", not_null=True), + res.Column(name="NAME", data_type="VARCHAR(16777216)", not_null=False), + ], + owner=TEST_ROLE, + catalog="SNOWFLAKE", + external_volume="static_external_volume", + base_location="some_prefix", + ), + res.Table( + name="TEST_FETCH_TABLE", + columns=[ + res.Column(name="ID", data_type="NUMBER(38,0)", not_null=True), + res.Column(name="NAME", data_type="VARCHAR(16777216)", not_null=False), + ], + owner=TEST_ROLE, + ), + res.Task( + name="TEST_FETCH_TASK", + schedule="60 MINUTE", + state="SUSPENDED", + as_="SELECT 1", + owner=TEST_ROLE, + ), + res.User( + name="TEST_FETCH_USER@applytitan.com", + owner=TEST_ROLE, + type="PERSON", + password="hunter2", + must_change_password=True, + display_name="Test User", + first_name="Test", + middle_name="Q.", + last_name="User", + comment="This is a test user", + default_warehouse="a_default_warehouse", + days_to_expiry=30, + ), + res.View( + name="TEST_FETCH_VIEW", + as_="SELECT 1 as id FROM STATIC_DATABASE.PUBLIC.STATIC_TABLE", + columns=[{"name": "ID", "data_type": "NUMBER(1,0)", "not_null": False}], + comment="View for testing", + owner=TEST_ROLE, + ), + res.Warehouse( + name="TEST_FETCH_WAREHOUSE", + warehouse_size="XSMALL", + auto_suspend=60, + auto_resume=True, + owner=TEST_ROLE, + comment="This is a test warehouse", + ), + ] + + +def create(cursor, resource: Resource, account_edition): + sql = resource.create_sql(account_edition=account_edition) + try: + cursor.execute(sql) + except Exception as err: + raise Exception(f"Error creating resource: \nQuery: {err.query}\nMsg: {err.msg}") from err + return resource + + +@pytest.fixture( + params=resource_fixtures(), + ids=[resource.__class__.__name__ for resource in resource_fixtures()], + scope="function", +) +def resource_fixture( + request, + cursor, + test_database, + marked_for_cleanup, +): + resource = request.param + + if isinstance(resource.scope, DatabaseScope): + test_database.add(resource) + elif isinstance(resource.scope, SchemaScope): + test_database.public_schema.add(resource) + elif isinstance(resource.scope, AccountScope): + cursor.execute(resource.drop_sql(if_exists=True)) + + marked_for_cleanup.append(resource) + yield resource + + +@pytest.fixture(scope="session") +def test_database(cursor, suffix, marked_for_cleanup): + db = res.Database(name=f"fetch_resource_test_database_{suffix}") + cursor.execute(db.create_sql(if_not_exists=True)) + marked_for_cleanup.append(db) + yield db + + +def test_fetch( + cursor, + resource_fixture, + account_edition, +): + if account_edition not in resource_fixture.edition: + pytest.skip(f"Skipping test for {resource_fixture.__class__.__name__} on {account_edition} edition") + + create(cursor, resource_fixture, account_edition) + + fetched = safe_fetch(cursor, resource_fixture.urn) + assert fetched is not None + fetched = resource_fixture.spec(**fetched).to_dict(account_edition) + fetched = strip_unfetchable_fields(resource_fixture.spec, fetched) + fixture = strip_unfetchable_fields(resource_fixture.spec, resource_fixture.to_dict(account_edition)) + + if "columns" in fetched: + fetched_columns = fetched["columns"] + fixture_columns = fixture["columns"] + assert len(fetched_columns) == len(fixture_columns) + for fetched_column, fixture_column in zip(fetched_columns, fixture_columns): + assert fetched_column == fixture_column + + assert fetched == fixture diff --git a/tests/integration/data_provider/test_list_resource.py b/tests/integration/data_provider/test_list_resource.py index b7248f4..e7640af 100644 --- a/tests/integration/data_provider/test_list_resource.py +++ b/tests/integration/data_provider/test_list_resource.py @@ -6,9 +6,10 @@ from tests.helpers import get_json_fixtures from titan import data_provider -from titan.client import UNSUPPORTED_FEATURE +from titan.client import UNSUPPORTED_FEATURE, reset_cache +from titan.enums import AccountEdition from titan.identifiers import resource_label_for_type -from titan.resources import Database +from titan.resources import Database, Resource from titan.scope import DatabaseScope, SchemaScope pytestmark = pytest.mark.requires_snowflake @@ -35,6 +36,17 @@ def resource(request, suffix): yield res +def create(cursor, resource: Resource): + session_ctx = data_provider.fetch_session(cursor.connection) + account_edition = AccountEdition.ENTERPRISE if session_ctx["tag_support"] else AccountEdition.STANDARD + sql = resource.create_sql(account_edition=account_edition, if_not_exists=True) + try: + cursor.execute(sql) + except Exception as err: + raise Exception(f"Error creating resource: \nQuery: {err.query}\nMsg: {err.msg}") from err + return resource + + @pytest.fixture(scope="session") def list_resources_database(cursor, suffix, marked_for_cleanup): db = Database(name=f"list_resources_test_database_{suffix}") @@ -44,6 +56,15 @@ def list_resources_database(cursor, suffix, marked_for_cleanup): def test_list_resource(cursor, list_resources_database, resource, marked_for_cleanup): + + data_provider.fetch_session.cache_clear() + reset_cache() + session_ctx = data_provider.fetch_session(cursor.connection) + account_edition = AccountEdition.ENTERPRISE if session_ctx["tag_support"] else AccountEdition.STANDARD + + if account_edition not in resource.edition: + pytest.skip(f"Skipping {resource.__class__.__name__}, not supported by account edition {account_edition}") + if isinstance(resource.scope, DatabaseScope): list_resources_database.add(resource) elif isinstance(resource.scope, SchemaScope): @@ -53,14 +74,14 @@ def test_list_resource(cursor, list_resources_database, resource, marked_for_cle pytest.skip(f"{resource.resource_type} is not supported") try: - create_sql = resource.create_sql(if_not_exists=True) - cursor.execute(create_sql) + create(cursor, resource) + marked_for_cleanup.append(resource) except snowflake.connector.errors.ProgrammingError as err: if err.errno == UNSUPPORTED_FEATURE: pytest.skip(f"{resource.resource_type} is not supported") else: raise - marked_for_cleanup.append(resource) + list_resources = data_provider.list_resource(cursor, resource_label_for_type(resource.resource_type)) assert len(list_resources) > 0 assert resource.fqn in list_resources diff --git a/tests/integration/test_blueprint.py b/tests/integration/test_blueprint.py index 7fcdfdc..cc89554 100644 --- a/tests/integration/test_blueprint.py +++ b/tests/integration/test_blueprint.py @@ -13,6 +13,7 @@ UpdateResource, compile_plan_to_sql, ) +from titan.resources.database import public_schema_urn from titan.client import reset_cache from titan.enums import ResourceType @@ -231,7 +232,7 @@ def test_blueprint_missing_database_inferred_from_session_context(cursor): blueprint.plan(session) -def test_blueprint_all_grant_forces_add(cursor, test_db, role): +def test_blueprint_all_grant_triggers_create(cursor, test_db, role): cursor.execute(f"GRANT USAGE ON DATABASE {test_db} TO ROLE {role.name}") session = cursor.connection all_grant = res.Grant(priv="ALL", on_database=test_db, to=role, owner=TEST_ROLE) @@ -432,3 +433,30 @@ def test_blueprint_create_resource_with_database_role_owner(cursor, suffix, test assert schema_data is not None assert schema_data["name"] == schema.name assert schema_data["owner"] == str(database_role.fqn) + + +def test_blueprint_database_params_passed_to_public_schema(cursor, suffix): + session = cursor.connection + + def _database(): + return res.Database( + name=f"test_db_params_passed_to_public_schema_{suffix}", + data_retention_time_in_days=1, + max_data_extension_time_in_days=2, + default_ddl_collation="en_US", + ) + + database = _database() + blueprint = Blueprint(resources=[database]) + plan = blueprint.plan(session) + assert len(plan) == 1 + blueprint.apply(session, plan) + schema_data = safe_fetch(cursor, public_schema_urn(database.urn)) + assert schema_data is not None + assert schema_data["data_retention_time_in_days"] == 1 + assert schema_data["max_data_extension_time_in_days"] == 2 + assert schema_data["default_ddl_collation"] == "en_US" + database = _database() + blueprint = Blueprint(resources=[database]) + plan = blueprint.plan(session) + assert len(plan) == 0 diff --git a/tests/integration/test_lifecycle.py b/tests/integration/test_lifecycle.py index 85b917f..d53d605 100644 --- a/tests/integration/test_lifecycle.py +++ b/tests/integration/test_lifecycle.py @@ -1,13 +1,14 @@ import os -import pytest +import pytest import snowflake.connector.errors from tests.helpers import get_json_fixtures - from titan import resources as res -from titan.blueprint import Blueprint +from titan.blueprint import Blueprint, CreateResource from titan.client import FEATURE_NOT_ENABLED_ERR, UNSUPPORTED_FEATURE +from titan.data_provider import fetch_session +from titan.enums import AccountEdition from titan.scope import DatabaseScope, SchemaScope JSON_FIXTURES = list(get_json_fixtures()) @@ -29,7 +30,7 @@ def resource(request): def test_create_drop_from_json(resource, cursor, suffix, marked_for_cleanup): - lifecycle_db = f"LIFECYCLE_DB_{suffix}" + lifecycle_db = f"LIFECYCLE_DB_{suffix}_{resource.__class__.__name__}" cursor.execute("USE ROLE SYSADMIN") cursor.execute(f"CREATE DATABASE IF NOT EXISTS {lifecycle_db}") cursor.execute(f"USE DATABASE {lifecycle_db}") @@ -46,10 +47,19 @@ def test_create_drop_from_json(resource, cursor, suffix, marked_for_cleanup): res.Grant, res.RoleGrant, res.PasswordPolicy, + res.Pipe, ): - pytest.skip("Skipping Service") + pytest.skip("Skipping") try: + fetch_session.cache_clear() + session_ctx = fetch_session(cursor.connection) + account_edition = AccountEdition.ENTERPRISE if session_ctx["tag_support"] else AccountEdition.STANDARD + + if account_edition not in resource.edition: + feature_enabled = False + pytest.skip(f"Skipping {resource.__class__.__name__}, not supported by account edition {account_edition}") + if isinstance(resource.scope, DatabaseScope): database.add(resource) elif isinstance(resource.scope, SchemaScope): @@ -59,6 +69,7 @@ def test_create_drop_from_json(resource, cursor, suffix, marked_for_cleanup): blueprint.add(resource) plan = blueprint.plan(cursor.connection) assert len(plan) == 1 + assert isinstance(plan[0], CreateResource) blueprint.apply(cursor.connection, plan) except snowflake.connector.errors.ProgrammingError as err: if err.errno == FEATURE_NOT_ENABLED_ERR or err.errno == UNSUPPORTED_FEATURE: diff --git a/tests/test_blueprint.py b/tests/test_blueprint.py index 8a06a0e..a4d2778 100644 --- a/tests/test_blueprint.py +++ b/tests/test_blueprint.py @@ -1,4 +1,5 @@ import json +from copy import deepcopy import pytest @@ -7,14 +8,19 @@ from titan.blueprint import ( Blueprint, CreateResource, - DuplicateResourceException, _merge_pointers, compile_plan_to_sql, dump_plan, ) from titan.blueprint_config import BlueprintConfig from titan.enums import ResourceType, RunMode -from titan.exceptions import InvalidResourceException, MissingVarException, DuplicateResourceException +from titan.exceptions import ( + DuplicateResourceException, + InvalidResourceException, + MissingVarException, + NonConformingPlanException, + WrongEditionException, +) from titan.identifiers import FQN, URN, parse_URN from titan.privs import AccountPriv, GrantedPrivilege from titan.resource_name import ResourceName @@ -45,6 +51,7 @@ def session_ctx() -> dict: GrantedPrivilege(privilege=AccountPriv.CREATE_WAREHOUSE, on="ABCD123"), ], }, + "tag_support": True, } @@ -62,6 +69,7 @@ def resource_manifest(): "account_locator": "ABCD123", "current_role": "SYSADMIN", "available_roles": ["SYSADMIN", "USERADMIN"], + "tag_support": True, } db = res.Database(name="DB") schema = res.Schema(name="SCHEMA", database=db) @@ -86,7 +94,7 @@ def test_blueprint_with_database(resource_manifest): db_urn = parse_URN("urn::ABCD123:database/DB") assert db_urn in resource_manifest - assert resource_manifest[db_urn].to_dict() == { + assert resource_manifest[db_urn].data == { "name": "DB", "owner": "SYSADMIN", "comment": None, @@ -102,7 +110,7 @@ def test_blueprint_with_database(resource_manifest): def test_blueprint_with_schema(resource_manifest): schema_urn = parse_URN("urn::ABCD123:schema/DB.SCHEMA") assert schema_urn in resource_manifest - assert resource_manifest[schema_urn].to_dict() == { + assert resource_manifest[schema_urn].data == { "comment": None, "data_retention_time_in_days": 1, "default_ddl_collation": None, @@ -117,7 +125,7 @@ def test_blueprint_with_schema(resource_manifest): def test_blueprint_with_view(resource_manifest): view_urn = parse_URN("urn::ABCD123:view/DB.SCHEMA.VIEW") assert view_urn in resource_manifest - assert resource_manifest[view_urn].to_dict() == { + assert resource_manifest[view_urn].data == { "as_": "SELECT 1", "change_tracking": False, "columns": None, @@ -134,7 +142,7 @@ def test_blueprint_with_view(resource_manifest): def test_blueprint_with_table(resource_manifest): table_urn = parse_URN("urn::ABCD123:table/DB.SCHEMA.TABLE") assert table_urn in resource_manifest - assert resource_manifest[table_urn].to_dict() == { + assert resource_manifest[table_urn].data == { "name": "TABLE", "owner": "SYSADMIN", "columns": [ @@ -177,7 +185,7 @@ def test_blueprint_with_udf(resource_manifest): account_locator="ABCD123", ) assert udf_urn in resource_manifest - assert resource_manifest[udf_urn].to_dict() == { + assert resource_manifest[udf_urn].data == { "name": "SOMEUDF", "owner": "SYSADMIN", "returns": "VARCHAR", @@ -565,7 +573,7 @@ def test_blueprint_vars(session_ctx): vars={"role_comment": "var role comment"}, ) manifest = blueprint.generate_manifest(session_ctx) - assert manifest.resources[1]._data.comment == "var role comment" + assert manifest.resources[1].data["comment"] == "var role comment" role = res.Role(name="role", comment="some comment {{ var.suffix }}") assert isinstance(role._data.comment, VarString) @@ -574,7 +582,7 @@ def test_blueprint_vars(session_ctx): vars={"suffix": "1234"}, ) manifest = blueprint.generate_manifest(session_ctx) - assert manifest.resources[1]._data.comment == "some comment 1234" + assert manifest.resources[1].data["comment"] == "some comment 1234" role = res.Role(name=var.role_name) assert isinstance(role.name, VarString) @@ -583,7 +591,7 @@ def test_blueprint_vars(session_ctx): vars={"role_name": "role123"}, ) manifest = blueprint.generate_manifest(session_ctx) - assert manifest.resources[1].name == "role123" + assert manifest.resources[1].data["name"] == "role123" role = res.Role(name="role_{{ var.suffix }}") assert isinstance(role.name, VarString) @@ -592,7 +600,7 @@ def test_blueprint_vars(session_ctx): vars={"suffix": "5678"}, ) manifest = blueprint.generate_manifest(session_ctx) - assert manifest.resources[1].name == "role_5678" + assert manifest.resources[1].data["name"] == "role_5678" def test_blueprint_vars_spec(session_ctx): @@ -608,7 +616,7 @@ def test_blueprint_vars_spec(session_ctx): ) assert blueprint._config.vars == {"role_comment": "var role comment"} manifest = blueprint.generate_manifest(session_ctx) - assert manifest.resources[1]._data.comment == "var role comment" + assert manifest.resources[1].data["comment"] == "var role comment" with pytest.raises(MissingVarException): blueprint = Blueprint( @@ -689,3 +697,41 @@ def test_merge_account_scoped_resources_fail(): ] with pytest.raises(DuplicateResourceException): _merge_pointers(resources) + + +def test_blueprint_edition_checks(session_ctx, remote_state): + session_ctx = deepcopy(session_ctx) + session_ctx["tag_support"] = False + + blueprint = Blueprint(resources=[res.Database(name="DB1"), res.Tag(name="TAG1")]) + manifest = blueprint.generate_manifest(session_ctx) + plan = blueprint._plan(remote_state, manifest) + with pytest.raises(NonConformingPlanException): + blueprint._raise_for_nonconforming_plan(session_ctx, plan) + + blueprint = Blueprint(resources=[res.Warehouse(name="WH", min_cluster_count=2)]) + with pytest.raises(WrongEditionException): + blueprint.generate_manifest(session_ctx) + + blueprint = Blueprint(resources=[res.Warehouse(name="WH", min_cluster_count=1)]) + assert blueprint.generate_manifest(session_ctx) + + blueprint = Blueprint(resources=[res.Warehouse(name="WH")]) + assert blueprint.generate_manifest(session_ctx) + + +def test_blueprint_warehouse_scaling_policy_doesnt_render_in_standard_edition(session_ctx, remote_state): + session_ctx = deepcopy(session_ctx) + session_ctx["tag_support"] = False + wh = res.Warehouse(name="WH", warehouse_size="XSMALL") + blueprint = Blueprint(resources=[wh]) + manifest = blueprint.generate_manifest(session_ctx) + plan = blueprint._plan(remote_state, manifest) + assert len(plan) == 1 + assert isinstance(plan[0], CreateResource) + sql = compile_plan_to_sql(session_ctx, plan) + assert len(sql) == 3 + assert sql[0] == "USE SECONDARY ROLES ALL" + assert sql[1] == "USE ROLE SYSADMIN" + assert sql[2].startswith("CREATE WAREHOUSE WH") + assert "scaling_policy" not in sql[2] diff --git a/tests/test_blueprint_ownership.py b/tests/test_blueprint_ownership.py index 6b83d3a..95cf785 100644 --- a/tests/test_blueprint_ownership.py +++ b/tests/test_blueprint_ownership.py @@ -28,6 +28,7 @@ def session_ctx() -> dict: "PUBLIC", ], "role_privileges": {}, + "tag_support": True, } @@ -89,7 +90,7 @@ def test_custom_role_owner(session_ctx, remote_state): def test_transfer_ownership(session_ctx, remote_state): remote_state = remote_state.copy() remote_state[parse_URN("urn::ABCD123:role/test_role")] = { - "name": "test_role", + "name": str(ResourceName("test_role")), "owner": "ACCOUNTADMIN", "comment": None, } @@ -111,7 +112,7 @@ def test_transfer_ownership(session_ctx, remote_state): def test_transfer_ownership_with_changes(session_ctx, remote_state): remote_state = remote_state.copy() remote_state[parse_URN("urn::ABCD123:role/test_role")] = { - "name": "test_role", + "name": str(ResourceName("test_role")), "owner": "ACCOUNTADMIN", "comment": None, } @@ -185,6 +186,7 @@ def test_resource_cant_be_created(remote_state): "TEST_ROLE", ], "role_privileges": {}, + "tag_support": True, } warehouse = res.Warehouse(name="test_warehouse", owner="test_role") blueprint = Blueprint(resources=[warehouse]) @@ -211,6 +213,7 @@ def test_grant_with_grant_admin_custom_role(remote_state): GrantedPrivilege(privilege=AccountPriv.MANAGE_GRANTS, on="ABCD123"), ] }, + "tag_support": True, } grant = res.RoleGrant(role="GRANT_ADMIN", to_role="SYSADMIN") @@ -236,6 +239,7 @@ def test_tag_reference_with_tag_admin_custom_role(): GrantedPrivilege(privilege=AccountPriv.APPLY_TAG, on="ABCD123"), ] }, + "tag_support": True, "tags": ["tags.tags.cost_center"], } diff --git a/tests/test_identities.py b/tests/test_identities.py index f91a558..1e66222 100644 --- a/tests/test_identities.py +++ b/tests/test_identities.py @@ -7,6 +7,7 @@ from tests.helpers import get_json_fixtures from titan.data_types import convert_to_canonical_data_type +from titan.enums import AccountEdition from titan.resources import Resource from titan.resource_name import ResourceName from titan.role_ref import RoleRef @@ -14,6 +15,18 @@ JSON_FIXTURES = list(get_json_fixtures()) +def remove_none_values(d): + new_dict = {} + for k, v in d.items(): + if isinstance(v, dict): + new_dict[k] = remove_none_values(v) + elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict): + new_dict[k] = [remove_none_values(item) for item in v if item is not None] + elif v is not None: + new_dict[k] = v + return new_dict + + @pytest.fixture( params=JSON_FIXTURES, ids=[resource_cls.__name__ for resource_cls, _ in JSON_FIXTURES], @@ -27,6 +40,8 @@ def resource(request): def _field_type_is_serialized_as_resource_name(field): if field.type is RoleRef: return True + if field.type is ResourceName: + return True elif isinstance(field.type, str) and field.name == "owner" and field.type == "Role": return True elif issubclass(field.type, Resource): @@ -63,11 +78,11 @@ def test_data_identity(resource): for lhs, rhs in zip(lhs_cols, rhs_cols): if "name" in lhs: assert _resource_names_are_eq(lhs.pop("name"), rhs.pop("name")) - if "data_type" in lhs: + if "data_type" in lhs and "data_type" in rhs: assert convert_to_canonical_data_type(lhs.pop("data_type")) == convert_to_canonical_data_type( rhs.pop("data_type") ) - assert lhs == rhs + assert remove_none_values(lhs) == remove_none_values(rhs) if "args" in serialized: lhs_args = serialized.pop("args", []) or [] rhs_args = data.pop("args", []) or [] @@ -90,13 +105,13 @@ def test_data_identity(resource): assert serialized == data -def test_sql_identity(resource): +def test_sql_identity(resource: tuple[type[Resource], dict]): resource_cls, data = resource instance = resource_cls(**data) - sql = instance.create_sql() + sql = instance.create_sql(AccountEdition.ENTERPRISE) new = resource_cls.from_sql(sql) - new_dict = new.to_dict() - instance_dict = instance.to_dict() + new_dict = new.to_dict(AccountEdition.ENTERPRISE) + instance_dict = instance.to_dict(AccountEdition.ENTERPRISE) if "name" in new_dict: assert ResourceName(new_dict.pop("name")) == ResourceName(instance_dict.pop("name")) diff --git a/tests/test_plan.py b/tests/test_plan.py index 4a08d9e..3a017c3 100644 --- a/tests/test_plan.py +++ b/tests/test_plan.py @@ -13,6 +13,7 @@ def session_ctx() -> dict: "account_locator": "ABCD123", "role": "SYSADMIN", "available_roles": ["SYSADMIN", "USERADMIN"], + "tag_support": True, } @@ -90,4 +91,4 @@ def test_plan_no_removes_in_run_mode_create_or_update(session_ctx, remote_state) assert isinstance(change, DropResource) assert change.urn == parse_URN("urn::ABCD123:role/REMOVED_ROLE") with pytest.raises(NonConformingPlanException): - bp._raise_for_nonconforming_plan(plan) + bp._raise_for_nonconforming_plan(session_ctx, plan) diff --git a/tests/test_resource_containers.py b/tests/test_resource_containers.py index 3a04869..ce32206 100644 --- a/tests/test_resource_containers.py +++ b/tests/test_resource_containers.py @@ -2,7 +2,7 @@ from titan import resources as res from titan.enums import ResourceType -from titan.resources.resource import ResourceHasContainerException, WrongContainerException +from titan.exceptions import ResourceHasContainerException, WrongContainerException def test_account_can_add_database(): diff --git a/tests/test_resources.py b/tests/test_resources.py index c638491..79b71db 100644 --- a/tests/test_resources.py +++ b/tests/test_resources.py @@ -10,6 +10,7 @@ from titan.resource_tags import ResourceTags from titan.resources.resource import ResourcePointer from titan.resources.user import UserType +from titan.resources.view import ViewColumn SQL_FIXTURES = list(get_sql_fixtures()) @@ -40,8 +41,11 @@ def test_view_fails_with_empty_columns(): def test_view_with_columns(): - view = res.View.from_sql("CREATE VIEW MY_VIEW (COL1) AS SELECT 1") - assert view._data.columns == [{"name": "COL1"}] + view = res.View.from_sql("CREATE VIEW MY_VIEW (col1) AS SELECT 1") + assert isinstance(view._data.columns[0], ViewColumn) + assert view._data.columns[0].name == "COL1" + assert view._data.columns[0]._data.data_type is None + assert view._data.columns[0]._data.comment is None def test_enum_field_serialization(): diff --git a/titan/blueprint.py b/titan/blueprint.py index eb9663c..5a9a600 100644 --- a/titan/blueprint.py +++ b/titan/blueprint.py @@ -17,12 +17,10 @@ reset_cache, ) from .data_provider import SessionContext -from .diff import Action, diff -from .enums import ResourceType, RunMode, resource_type_is_grant +from .enums import AccountEdition, ResourceType, RunMode, resource_type_is_grant from .exceptions import ( DuplicateResourceException, InvalidResourceException, - MarkedForReplacementException, MissingPrivilegeException, MissingResourceException, NonConformingPlanException, @@ -39,13 +37,14 @@ ) from .resource_name import ResourceName from .resource_tags import ResourceTags -from .resources import Database, DatabaseRole, Role, RoleGrant, Schema +from .resources import Database, RoleGrant, Schema from .resources.database import public_schema_urn from .resources.resource import ( RESOURCE_SCOPES, NamedResource, Resource, ResourceContainer, + ResourceLifecycleConfig, ResourcePointer, infer_role_type_from_name, ) @@ -54,15 +53,6 @@ logger = logging.getLogger("titan") -# SYNC_MODE_BLOCKLIST = [ -# ResourceType.FUTURE_GRANT, -# ResourceType.GRANT, -# ResourceType.GRANT_ON_ALL, -# ResourceType.ROLE, -# ResourceType.USER, -# ResourceType.TABLE, -# ] - @dataclass class ResourceChange(ABC): @@ -75,6 +65,7 @@ def to_dict(self) -> dict: @dataclass class CreateResource(ResourceChange): + resource_cls: type[Resource] after: dict def to_dict(self) -> dict: @@ -99,6 +90,7 @@ def to_dict(self) -> dict: @dataclass class UpdateResource(ResourceChange): + resource_cls: type[Resource] before: dict after: dict delta: dict @@ -115,6 +107,7 @@ def to_dict(self) -> dict: @dataclass class TransferOwnership(ResourceChange): + resource_cls: type[Resource] from_owner: str to_owner: str @@ -131,55 +124,77 @@ def to_dict(self) -> dict: Plan = list[ResourceChange] +@dataclass +class ManifestResource: + urn: URN + resource_cls: type[Resource] + data: dict[str, Any] + implicit: bool + lifecycle: ResourceLifecycleConfig + + class Manifest: def __init__(self, account_locator: str = ""): self._account_locator = account_locator - self._data: dict[URN, Resource] = {} + self._resources: dict[URN, Union[ManifestResource, ResourcePointer]] = {} self._refs: list[tuple[URN, URN]] = [] def __getitem__(self, key: URN): if isinstance(key, URN): - return self._data[key] + return self._resources[key] else: raise Exception("Manifest keys must be URNs") - def __contains__(self, key): + def __contains__(self, key: URN): if isinstance(key, URN): - return key in self._data + return key in self._resources else: raise Exception("Manifest keys must be URNs") - def add(self, resource: Resource): + def add(self, resource: Resource, account_edition: AccountEdition): urn = URN.from_resource( account_locator=self._account_locator, resource=resource, ) - # resource_data = resource.to_dict() - - if urn in self._data: - # if resource_data != self._data[urn]: + if urn in self._resources: if not isinstance(resource, ResourcePointer): logger.warning(f"Duplicate resource {urn} with conflicting data, discarding {resource}") return - self._data[urn] = resource + if isinstance(resource, ResourcePointer): + self._resources[urn] = resource + else: + self._resources[urn] = ManifestResource( + urn, + resource.__class__, + resource.to_dict(account_edition), + resource.implicit, + resource.lifecycle, + ) for ref in resource.refs: ref_urn = URN.from_resource(account_locator=self._account_locator, resource=ref) self._refs.append((urn, ref_urn)) def get(self, key: URN, default=None): if isinstance(key, URN): - return self._data.get(key, default) + return self._resources.get(key, default) else: raise Exception("Manifest keys must be URNs") - def to_dict(self): - return {k: v.to_dict() for k, v in self._data.items()} + def items(self): + return self._resources.items() + + def __repr__(self): + contents = "" + for urn, resource in self._resources.items(): + contents += f"[{urn}] =>\n" + contents += f" {resource}\n" + return f"Manifest({len(self._resources)} resources)\n{contents}" @property def urns(self) -> list[URN]: - return list(self._data.keys()) + return list(self._resources.keys()) @property def refs(self): @@ -187,7 +202,7 @@ def refs(self): @property def resources(self): - return list(self._data.values()) + return list(self._resources.values()) def dump_plan(plan: Plan, format: str = "json"): @@ -455,35 +470,33 @@ def from_config(cls, config: BlueprintConfig): blueprint.add(config.resources or []) return blueprint - def _raise_for_nonconforming_plan(self, plan: Plan): + def _raise_for_nonconforming_plan(self, session_ctx: SessionContext, plan: Plan): exceptions = [] - # Run Mode exceptions - if self._config.run_mode == RunMode.CREATE_OR_UPDATE: - for change in plan: + # If the account doesn't support tags, assume the account is standard edition or trial + account_is_standard_edition = session_ctx["tag_support"] is False + + for change in plan: + # Run Mode exceptions + if self._config.run_mode == RunMode.CREATE_OR_UPDATE: if isinstance(change, DropResource): exceptions.append( f"Create-or-update mode does not allow resources to be removed (ref: {change.urn})" ) - if isinstance(change, UpdateResource): - if "name" in change.delta: - exceptions.append( - f"Create-or-update mode does not allow renaming resources (ref: {change.urn})" - ) - - # if self._config.run_mode == RunMode.SYNC: - # for change in plan: - # if change.urn.resource_type in SYNC_MODE_BLOCKLIST: - # exceptions.append( - # f"Sync mode does not allow changes to {change.urn.resource_type} (ref: {change.urn})" - # ) - - # Valid Resource Types exceptions - if self._config.allowlist: - for change in plan: + if isinstance(change, UpdateResource): + if "name" in change.delta: + exceptions.append(f"Create-or-update mode does not allow renaming resources (ref: {change.urn})") + if change.resource_cls.resource_type == ResourceType.GRANT: + exceptions.append(f"Grants cannot be updated (ref: {change.urn})") + # Valid Resource Types exceptions + if self._config.allowlist: if change.urn.resource_type not in self._config.allowlist: exceptions.append(f"Resource type {change.urn.resource_type} not allowed in blueprint") + if account_is_standard_edition: + if isinstance(change, CreateResource) and AccountEdition.STANDARD not in change.resource_cls.edition: + exceptions.append(f"Resource {change.urn} requires enterprise edition or higher") + if exceptions: if len(exceptions) > 5: exception_block = "\n".join(exceptions[0:5]) + f"\n... and {len(exceptions) - 5} more" @@ -492,75 +505,14 @@ def _raise_for_nonconforming_plan(self, plan: Plan): raise NonConformingPlanException("Non-conforming actions found in plan:\n" + exception_block) def _plan(self, remote_state: State, manifest: Manifest) -> Plan: - manifest_dict = manifest.to_dict() - - # changes: Plan = [] additive_changes: list[ResourceChange] = [] destructive_changes: list[ResourceChange] = [] - marked_for_replacement = set() - for action, urn, delta in diff(remote_state, manifest_dict): - before = remote_state.get(urn, {}) - after = manifest_dict.get(urn, {}) - - if action == Action.UPDATE: - if urn in marked_for_replacement: - continue - - resource = manifest[urn] - - # TODO: if the attr is marked as must_replace, then instead we yield a rename, add, remove - attr = list(delta.keys())[0] - attr_metadata = resource.spec.get_metadata(attr) - - change_requires_replacement = attr_metadata.triggers_replacement - change_forces_add = attr_metadata.forces_add - change_is_fetchable = attr_metadata.fetchable - change_is_known_after_apply = attr_metadata.known_after_apply - change_should_be_ignored = attr in resource.lifecycle.ignore_changes or attr_metadata.ignore_changes - - if change_requires_replacement: - raise MarkedForReplacementException(f"Resource {urn} is marked for replacement due to {attr}") - marked_for_replacement.add(urn) - elif change_forces_add: - additive_changes.append(CreateResource(urn, after)) - continue - elif not change_is_fetchable: - # drift on fields that aren't fetchable should be ignored - # TODO: throw a warning, or have a blueprint runmode that fails on this - continue - elif change_is_known_after_apply: - continue - elif change_should_be_ignored: - continue - else: - additive_changes.append(UpdateResource(urn, before, after, delta)) - elif action == Action.CREATE: - additive_changes.append(CreateResource(urn, after)) - elif action == Action.DROP: - destructive_changes.append(DropResource(urn, before)) - elif action == Action.TRANSFER: - resource = manifest[urn] - - attr = list(delta.keys())[0] - attr_metadata = resource.spec.get_metadata(attr) - change_is_fetchable = attr_metadata.fetchable - change_should_be_ignored = attr in resource.lifecycle.ignore_changes or attr_metadata.ignore_changes - if not change_is_fetchable: - continue - if change_should_be_ignored: - continue - additive_changes.append( - TransferOwnership( - urn, - from_owner=before["owner"], - to_owner=after["owner"], - ) - ) - for urn in marked_for_replacement: - raise MarkedForReplacementException(f"Resource {urn} is marked for replacement") - # changes.append(ResourceChange(action=Action.REMOVE, urn=urn, before=before, after={}, delta={})) - # changes.append(ResourceChange(action=Action.ADD, urn=urn, before={}, after=after, delta=after)) + for resource_change in diff(remote_state, manifest): + if isinstance(resource_change, (CreateResource, UpdateResource, TransferOwnership)): + additive_changes.append(resource_change) + elif isinstance(resource_change, DropResource): + destructive_changes.append(resource_change) # Generate a list of all URNs resource_set = set(manifest.urns + list(remote_state.keys())) @@ -569,34 +521,15 @@ def _plan(self, remote_state: State, manifest: Manifest) -> Plan: resource_set.add(ref[1]) # Calculate a topological sort order for the URNs sort_order = topological_sort(resource_set, set(manifest.refs)) - plan = sorted(additive_changes, key=lambda change: sort_order[change.urn]) + sorted( - destructive_changes, key=lambda change: -1 * sort_order[change.urn] + plan = sorted(additive_changes, key=lambda change: sort_order[change.urn]) + _sort_destructive_changes( + destructive_changes, sort_order ) return plan def fetch_remote_state(self, session, manifest: Manifest) -> State: state: State = {} session_ctx = data_provider.fetch_session(session) - - def _normalize(urn: URN, data: dict) -> dict: - resource_cls = Resource.resolve_resource_cls(urn.resource_type, data) - if urn.resource_type == ResourceType.FUTURE_GRANT: - normalized = data - elif isinstance(data, list): - raise Exception(f"Fetching list of {urn.resource_type} is not supported yet") - else: - # There is an edge case here where the resource spec doesnt have defaults specified. - # Instead of throwing an error, dataclass will provide a dataclass._MISSINGFIELD object - # That is bad. - # The answer is not that defaults should be added. The root cause is that data_provider - # method return raw dicts that aren't type checked against their corresponding - # Resource spec. - # I have considered tightly coupling the data provider to the resource spec, but I don't think - # the complexity is worth it. - # Another solution would be to build in automatic tests to check that the data_provider - # returns data that matches the spec. - normalized = resource_cls.defaults() | data - return normalized + account_edition = AccountEdition.ENTERPRISE if session_ctx["tag_support"] else AccountEdition.STANDARD if self._config.run_mode == RunMode.SYNC: if self._config.allowlist: @@ -605,17 +538,21 @@ def _normalize(urn: URN, data: dict) -> dict: urn = URN(resource_type=resource_type, fqn=fqn, account_locator=session_ctx["account_locator"]) data = data_provider.fetch_resource(session, urn) if data is None: - raise Exception(f"Resource {urn} not found") - normalized_data = _normalize(urn, data) - state[urn] = normalized_data + raise MissingResourceException(f"Resource could not be found: {urn}") + resource_cls = Resource.resolve_resource_cls(urn.resource_type, data) + state[urn] = resource_cls.spec(**data).to_dict(account_edition) else: raise RuntimeError("Sync mode requires an allowlist") - for urn in manifest.urns: + for urn, manifest_item in manifest.items(): data = data_provider.fetch_resource(session, urn) if data is not None: - normalized_data = _normalize(urn, data) - state[urn] = normalized_data + if isinstance(manifest_item, ResourcePointer): + resource_cls = Resource.resolve_resource_cls(urn.resource_type, data) + else: + resource_cls = manifest_item.resource_cls + + state[urn] = resource_cls.spec(**data).to_dict(account_edition) # check for existence of resource refs for parent, reference in manifest.refs: @@ -632,7 +569,7 @@ def _normalize(urn: URN, data: dict) -> dict: data = None if data is None and not is_public_schema: - logger.error(manifest.to_dict()) + # logger.error(manifest.to_dict(session_ctx)) raise MissingResourceException( f"Resource {reference} required by {parent} not found or failed to fetch" ) @@ -807,6 +744,10 @@ def _create_grandparent_refs(self): if isinstance(resource.scope, SchemaScope): resource.requires(resource.container.container) + def _finalize_resources(self): + for resource in _walk(self._root): + resource._finalized = True + def _finalize(self, session_ctx): if self._finalized: raise RuntimeError("Blueprint already finalized") @@ -816,15 +757,15 @@ def _finalize(self, session_ctx): self._create_tag_references() self._create_ownership_refs(session_ctx) self._create_grandparent_refs() - for resource in _walk(self._root): - resource._finalized = True + self._finalize_resources() def generate_manifest(self, session_ctx: SessionContext) -> Manifest: manifest = Manifest(account_locator=session_ctx["account_locator"]) + account_edition = AccountEdition.ENTERPRISE if session_ctx["tag_support"] else AccountEdition.STANDARD self._finalize(session_ctx) for resource in _walk(self._root): if isinstance(resource, Resource): - manifest.add(resource) + manifest.add(resource, account_edition) else: raise RuntimeError(f"Unexpected object found in blueprint: {resource}") @@ -847,7 +788,7 @@ def plan(self, session) -> Plan: logger.error(manifest) raise e - self._raise_for_nonconforming_plan(finished_plan) + self._raise_for_nonconforming_plan(session_ctx, finished_plan) return finished_plan def apply(self, session, plan: Optional[Plan] = None): @@ -1048,7 +989,7 @@ def granted_priv_allows_change(granted_priv: GrantedPrivilege, change: ResourceC if granted_priv.privilege == AccountPriv.APPLY_TAG: return True - # If we own the resource container, we can always perform ADD + # If we own the resource container, we can always perform CreateResource if is_ownership_priv(granted_priv.privilege) and granted_priv.on == container_name: return True @@ -1124,8 +1065,8 @@ def sql_commands_for_change( before_change_cmd.append(f"USE ROLE {execution_role}") if isinstance(change, CreateResource): - props = Resource.props_for_resource_type(change.urn.resource_type, change.after) - change_cmd = lifecycle.create_resource(change.urn, change.after, props) + + change_cmd = lifecycle.create_resource(change.urn, change.after, change.resource_cls.props) if transfer_owner: after_change_cmd.append( lifecycle.transfer_resource( @@ -1151,6 +1092,15 @@ def sql_commands_for_change( props = Resource.props_for_resource_type(change.urn.resource_type, change.after) change_cmd = lifecycle.update_resource(change.urn, change.delta, props) elif isinstance(change, DropResource): + if transfer_owner: + before_change_cmd.append( + lifecycle.transfer_resource( + change.urn, + owner=str(execution_role), + owner_resource_type=infer_role_type_from_name(str(execution_role)), + copy_current_grants=True, + ) + ) change_cmd = lifecycle.drop_resource( change.urn, change.before, @@ -1184,7 +1134,7 @@ def compile_plan_to_sql(session_ctx: SessionContext, plan: Plan): ) sql_commands.extend(commands) - # Update state + # Update role privileges state for the next commands if isinstance(change, CreateResource): if change.urn.resource_type == ResourceType.ROLE_GRANT: if change.after["to_role"] in available_roles: @@ -1272,3 +1222,129 @@ def topological_sort(resource_set: set, references: set): if len(nodes) != len(resource_set): raise Exception("Graph is not a DAG") return {value: index for index, value in enumerate(nodes)} + + +def diff(remote_state: State, manifest: Manifest): + + def _diff_resource_data(lhs: dict, rhs: dict) -> dict: + + if not isinstance(lhs, dict) or not isinstance(rhs, dict): + raise TypeError("diff_resources requires two dictionaries") + + delta = {} + for field_name in lhs.keys(): + lhs_value = lhs[field_name] + rhs_value = rhs[field_name] + if lhs_value != rhs_value: + delta[field_name] = rhs_value + return delta + + state_urns = set(remote_state.keys()) + manifest_urns = set(manifest.urns) + + # Resources in remote state but not in the manifest should be removed + for urn in state_urns - manifest_urns: + yield DropResource(urn, remote_state[urn]) + + # Resources in the manifest but not in remote state should be added + for urn in manifest_urns - state_urns: + manifest_item = manifest[urn] + if isinstance(manifest_item, ResourcePointer): + raise MissingResourceException( + f"Blueprint has pointer to resource that doesn't exist or isn't visible in session: {urn}" + ) + elif isinstance(manifest_item, ManifestResource): + # We don't create implicit resources + if manifest[urn].implicit: + continue + yield CreateResource(urn, manifest_item.resource_cls, manifest_item.data) + else: + raise Exception(f"Unknown type in manifest: {manifest_item}") + + # Resources in both should be compared + for urn in state_urns & manifest_urns: + manifest_item = manifest[urn] + + # We don't diff resource pointers + if isinstance(manifest_item, ResourcePointer): + continue + + delta = _diff_resource_data(remote_state[urn], manifest_item.data) + owner_attr = delta.pop("owner", None) + + # TODO: do we care about implicit resources? + replace_resource = False + create_resource = False + ignore_fields = set() + for attr in delta.keys(): + attr_metadata = manifest_item.resource_cls.spec.get_metadata(attr) + change_requires_replacement = attr_metadata.triggers_replacement + change_triggers_create = attr_metadata.triggers_create + change_is_fetchable = attr_metadata.fetchable + change_is_known_after_apply = attr_metadata.known_after_apply + change_should_be_ignored = attr in manifest_item.lifecycle.ignore_changes or attr_metadata.ignore_changes + if change_requires_replacement: + replace_resource = True + break + elif change_triggers_create: + create_resource = True + break + elif not change_is_fetchable: + ignore_fields.add(attr) + elif change_is_known_after_apply: + ignore_fields.add(attr) + elif change_should_be_ignored: + ignore_fields.add(attr) + + if replace_resource: + raise NotImplementedError("replace_resource") + # yield DropResource(urn, remote_state[urn]) + # yield CreateResource(urn, manifest_item.resource_cls, manifest_item.data) + # continue + + if create_resource: + yield CreateResource(urn, manifest_item.resource_cls, manifest_item.data) + continue + + delta = {k: v for k, v in delta.items() if k not in ignore_fields} + if delta: + yield UpdateResource( + urn, + manifest_item.resource_cls, + remote_state[urn], + manifest_item.data, + delta, + ) + + # Force transfers to occur after all other attribute changes + if owner_attr: + owner_metadata = manifest_item.resource_cls.spec.get_metadata("owner") + owner_is_fetchable = owner_metadata.fetchable + owner_changes_should_be_ignored = ( + "owner" in manifest_item.lifecycle.ignore_changes or owner_metadata.ignore_changes + ) + + if not owner_is_fetchable or owner_changes_should_be_ignored: + continue + + yield TransferOwnership( + urn, + manifest_item.resource_cls, + from_owner=remote_state[urn]["owner"], + to_owner=manifest_item.data["owner"], + ) + + +def _sort_destructive_changes( + destructive_changes: list[ResourceChange], sort_order: dict[URN, int] +) -> list[ResourceChange]: + # Not quite right but close enough for now. + def sort_key(change: ResourceChange) -> tuple: + return ( + change.urn.resource_type != ResourceType.NETWORK_POLICY, + change.urn.database is not None, + change.urn.schema is not None, + -1 * sort_order[change.urn], + ) + + return sorted(destructive_changes, key=sort_key) diff --git a/titan/cli.py b/titan/cli.py index 069d16f..71a0d1e 100644 --- a/titan/cli.py +++ b/titan/cli.py @@ -4,6 +4,7 @@ import yaml from titan.blueprint import dump_plan +from titan.enums import RunMode from titan.operations.blueprint import blueprint_apply, blueprint_plan from titan.operations.export import export_resources from titan.operations.connector import connect, get_env_vars @@ -59,7 +60,7 @@ def plan(config_file, json_output, output_file, vars: dict, run_mode): if vars: cli_config["vars"] = vars if run_mode: - cli_config["run_mode"] = run_mode + cli_config["run_mode"] = RunMode(run_mode) plan_obj = blueprint_plan(yaml_config, cli_config) output = None if json_output: @@ -98,7 +99,7 @@ def apply(config_file, plan_file, vars, run_mode, dry_run): if vars: cli_config["vars"] = vars if run_mode: - cli_config["run_mode"] = run_mode + cli_config["run_mode"] = RunMode(run_mode) if dry_run: cli_config["dry_run"] = dry_run blueprint_apply(yaml_config, cli_config) diff --git a/titan/data_provider.py b/titan/data_provider.py index e347909..2ac5c69 100644 --- a/titan/data_provider.py +++ b/titan/data_provider.py @@ -545,9 +545,28 @@ def fetch_session(session) -> SessionContext: raise available_roles = [ResourceName(role) for role in json.loads(session_obj["AVAILABLE_ROLES"])] + role_privileges = fetch_role_privileges(session, available_roles, cacheable=True) + return { + "account_locator": session_obj["ACCOUNT_LOCATOR"], + "account": session_obj["ACCOUNT"], + "available_roles": available_roles, + "database": session_obj["DATABASE"], + "role": session_obj["ROLE"], + "schemas": json.loads(session_obj["SCHEMAS"]), + "secondary_roles": json.loads(session_obj["SECONDARY_ROLES"]), + "tag_support": tag_support, + "tags": tags, + "user": session_obj["USER"], + "version": session_obj["VERSION"], + "warehouse": session_obj["WAREHOUSE"], + "role_privileges": role_privileges, + } + + +def fetch_role_privileges(session, roles: list, cacheable: bool = True) -> dict[ResourceName, list[GrantedPrivilege]]: role_privileges = {} - for role in available_roles: + for role in roles: # Adds 30+s of latency and we can infer what privs are available if role == "ACCOUNTADMIN" or role.startswith("SNOWFLAKE."): @@ -555,7 +574,7 @@ def fetch_session(session) -> SessionContext: role_privileges[role] = [] - grants = _show_grants_to_role(session, role, cacheable=True) + grants = _show_grants_to_role(session, role, cacheable=cacheable) for grant in grants: try: granted_priv = GrantedPrivilege.from_grant( @@ -567,22 +586,7 @@ def fetch_session(session) -> SessionContext: # If titan isnt aware of the privilege, ignore it except ValueError: continue - - return { - "account_locator": session_obj["ACCOUNT_LOCATOR"], - "account": session_obj["ACCOUNT"], - "available_roles": available_roles, - "database": session_obj["DATABASE"], - "role": session_obj["ROLE"], - "schemas": json.loads(session_obj["SCHEMAS"]), - "secondary_roles": json.loads(session_obj["SECONDARY_ROLES"]), - "tag_support": tag_support, - "tags": tags, - "user": session_obj["USER"], - "version": session_obj["VERSION"], - "warehouse": session_obj["WAREHOUSE"], - "role_privileges": role_privileges, - } + return role_privileges # ------------------------------ @@ -592,7 +596,10 @@ def fetch_session(session) -> SessionContext: def fetch_account(session, fqn: FQN): # raise NotImplementedError() - return {} + return { + "name": None, + "locator": None, + } def fetch_aggregation_policy(session, fqn: FQN): @@ -612,8 +619,7 @@ def fetch_aggregation_policy(session, fqn: FQN): def fetch_alert(session, fqn: FQN): - show_result = execute(session, "SHOW ALERTS", cacheable=True) - alerts = _filter_result(show_result, name=fqn.name) + alerts = _show_resources(session, "ALERTS", fqn) if len(alerts) == 0: return None if len(alerts) > 1: @@ -812,8 +818,7 @@ def fetch_database_role(session, fqn: FQN): def fetch_dynamic_table(session, fqn: FQN): - show_result = execute(session, f"SHOW DYNAMIC TABLES LIKE '{fqn.name}'") - + show_result = _show_resources(session, "DYNAMIC TABLES", fqn) if len(show_result) == 0: return None if len(show_result) > 1: @@ -998,9 +1003,9 @@ def fetch_function(session, fqn: FQN): raise Exception(f"Found multiple functions matching {fqn}") data = udfs[0] - inputs, output = data["arguments"].split(" RETURN ") + _, returns = data["arguments"].split(" RETURN ") try: - desc_result = execute(session, f"DESC FUNCTION {inputs}", cacheable=True) + desc_result = execute(session, f"DESC FUNCTION {fqn}", cacheable=True) except ProgrammingError as err: if err.errno == DOES_NOT_EXIST_ERR: return None @@ -1013,7 +1018,7 @@ def fetch_function(session, fqn: FQN): "name": _quote_snowflake_identifier(data["name"]), "secure": data["is_secure"] == "Y", "args": _parse_signature(properties["signature"]), - "returns": output, + "returns": returns, "language": data["language"], "comment": None if data["description"] == "user-defined function" else data["description"], "volatility": properties["volatility"], @@ -1025,7 +1030,7 @@ def fetch_function(session, fqn: FQN): "name": _quote_snowflake_identifier(data["name"]), "secure": data["is_secure"] == "Y", "args": _parse_signature(properties["signature"]), - "returns": output, + "returns": returns, "language": data["language"], "comment": None if data["description"] == "user-defined function" else data["description"], "volatility": properties["volatility"], @@ -2060,16 +2065,18 @@ def fetch_user(session, fqn: FQN) -> Optional[dict]: must_change_password = data["must_change_password"] == "true" rsa_public_key = properties["rsa_public_key"] if properties["rsa_public_key"] != "null" else None + middle_name = properties["middle_name"] if properties["middle_name"] != "null" else None return { "name": _quote_snowflake_identifier(data["name"]), "login_name": login_name, "display_name": display_name, "first_name": data["first_name"] or None, + "middle_name": middle_name, "last_name": data["last_name"] or None, "email": data["email"] or None, "mins_to_unlock": data["mins_to_unlock"] or None, - "days_to_expiry": data["days_to_expiry"] or None, + # "days_to_expiry": data["days_to_expiry"] or None, "comment": data["comment"] or None, "disabled": data["disabled"] == "true", "must_change_password": must_change_password, @@ -2131,29 +2138,36 @@ def fetch_warehouse(session, fqn: FQN): show_params_result = execute(session, f"SHOW PARAMETERS FOR WAREHOUSE {fqn}") params = params_result_to_dict(show_params_result) + resource_monitor = None if data["resource_monitor"] == "null" else data["resource_monitor"] + + # Enterprise edition features query_accel = data.get("enable_query_acceleration") if query_accel: query_accel = query_accel == "true" else: query_accel = False - return { + warehouse_dict = { "name": _quote_snowflake_identifier(data["name"]), "owner": _get_owner_identifier(data), "warehouse_type": data["type"], "warehouse_size": str(WarehouseSize(data["size"])), - # "max_cluster_count": data["max_cluster_count"], - # "min_cluster_count": data["min_cluster_count"], - # "scaling_policy": data["scaling_policy"], "auto_suspend": data["auto_suspend"], "auto_resume": data["auto_resume"] == "true", "comment": data["comment"] or None, + "resource_monitor": resource_monitor, "enable_query_acceleration": query_accel, + "query_acceleration_max_scale_factor": data.get("query_acceleration_max_scale_factor", None), + "max_cluster_count": data.get("max_cluster_count", None), + "min_cluster_count": data.get("min_cluster_count", None), + "scaling_policy": data.get("scaling_policy", None), "max_concurrency_level": params["max_concurrency_level"], "statement_queued_timeout_in_seconds": params["statement_queued_timeout_in_seconds"], "statement_timeout_in_seconds": params["statement_timeout_in_seconds"], } + return warehouse_dict + ################ List functions @@ -2292,13 +2306,13 @@ def list_functions(session) -> list[FQN]: def list_grants(session) -> list[FQN]: - roles = execute(session, "SHOW ROLES", cacheable=True) + roles = execute(session, "SHOW ROLES") grants = [] for role in roles: role_name = resource_name_from_snowflake_metadata(role["name"]) if role_name in SYSTEM_ROLES: continue - grant_data = _show_grants_to_role(session, role_name, cacheable=True) + grant_data = _show_grants_to_role(session, role_name, cacheable=False) for data in grant_data: if data["granted_on"] == "ROLE": # raise Exception(f"Role grants are not supported yet: {data}") diff --git a/titan/data_types.py b/titan/data_types.py index 85ed5f3..172d9ed 100644 --- a/titan/data_types.py +++ b/titan/data_types.py @@ -3,7 +3,9 @@ from .enums import DataType -def convert_to_canonical_data_type(data_type: Union[str, DataType]) -> str: +def convert_to_canonical_data_type(data_type: Union[str, DataType, None]) -> str: + if data_type is None: + return None if isinstance(data_type, DataType): data_type = str(data_type) data_type = data_type.upper() diff --git a/titan/diff.py b/titan/diff.py deleted file mode 100644 index 233e3cb..0000000 --- a/titan/diff.py +++ /dev/null @@ -1,106 +0,0 @@ -from enum import Enum - -from .resource_name import ResourceName, attribute_is_resource_name - - -class Action(Enum): - CREATE = "create" - DROP = "drop" - UPDATE = "update" - TRANSFER = "transfer" - - -def eq(lhs, rhs, key): - if lhs is None or rhs is None: - return lhs == rhs - - if attribute_is_resource_name(key): - return ResourceName(lhs) == ResourceName(rhs) - elif key == "args": - # Ignore arg defaults - def _scrub_defaults(args): - new_args = [] - for arg in args: - new_arg = arg.copy() - new_arg.pop("default", None) - new_args.append(new_arg) - return new_args - - lhs_copy = _scrub_defaults(lhs) - rhs_copy = _scrub_defaults(rhs) - return lhs_copy == rhs_copy - elif key == "columns": - if len(lhs) != len(rhs): - return False - for i, lhs_col in enumerate(lhs): - rhs_col = rhs[i] - for col_key, lhs_value in lhs_col.items(): - if col_key not in rhs_col: - # raise Exception(f"Column {col_key} not found in rhs {rhs}") - continue - rhs_value = rhs_col[col_key] - if not eq(lhs_value, rhs_value, col_key): - return False - return True - else: - return lhs == rhs - - -def dict_delta(original, new): - original_keys = set(original.keys()) - new_keys = set(new.keys()) - {"_pointer", "_implicit"} - - delta = {} - - for key in original_keys - new_keys: - delta[key] = None - - for key in original_keys & new_keys: - if not eq(original[key], new[key], key): - delta[key] = new[key] - - for key in new_keys - original_keys: - delta[key] = new[key] - - if "_implicit" in delta: - raise Exception(f"Unexpected implicit resource {delta}") - return delta - - -def diff(original, new): - original_keys = set(original.keys()) - new_keys = set(new.keys()) - - # Resources in remote state but not in the manifest should be removed - for key in original_keys - new_keys: - yield Action.DROP, key, original[key] - - # Resources in the manifest but not in remote state should be added - for key in new_keys - original_keys: - if new[key].get("_pointer", False): - raise Exception(f"Blueprint has pointer to resource that doesn't exist or isn't visible in session: {key}") - - # We don't create implicit resources - if new[key].get("_implicit", False): - continue - - yield Action.CREATE, key, new[key] - - # Resources in both should be compared - for key in original_keys & new_keys: - if not isinstance(original[key], dict): - raise RuntimeError(f"Unexpected type for resource {key}: {type(original[key])}") - - # We don't diff resource pointers - if new[key].get("_pointer", False): - continue - - delta = dict_delta(original[key], new[key]) - owner_attr = delta.pop("owner", None) - - for attr, value in delta.items(): - yield Action.UPDATE, key, {attr: value} - - # Force the transfer to happen after all other attribute changes - if owner_attr: - yield Action.TRANSFER, key, {"owner": owner_attr} diff --git a/titan/exceptions.py b/titan/exceptions.py index 77e652b..93c04f9 100644 --- a/titan/exceptions.py +++ b/titan/exceptions.py @@ -36,3 +36,15 @@ class InvalidOwnerException(Exception): class InvalidResourceException(Exception): pass + + +class WrongContainerException(Exception): + pass + + +class WrongEditionException(Exception): + pass + + +class ResourceHasContainerException(Exception): + pass diff --git a/titan/privs.py b/titan/privs.py index 4da4b5a..b3b057a 100644 --- a/titan/privs.py +++ b/titan/privs.py @@ -433,6 +433,7 @@ class WarehousePriv(Priv): ResourceType.MATERIALIZED_VIEW: SchemaPriv.CREATE_MATERIALIZED_VIEW, ResourceType.NETWORK_POLICY: AccountPriv.CREATE_NETWORK_POLICY, ResourceType.NETWORK_RULE: SchemaPriv.CREATE_NETWORK_RULE, + ResourceType.NOTEBOOK: SchemaPriv.CREATE_NOTEBOOK, ResourceType.PACKAGES_POLICY: SchemaPriv.CREATE_PACKAGES_POLICY, ResourceType.PASSWORD_POLICY: SchemaPriv.CREATE_PASSWORD_POLICY, ResourceType.PIPE: SchemaPriv.CREATE_PIPE, diff --git a/titan/props.py b/titan/props.py index 84827d3..18e285c 100644 --- a/titan/props.py +++ b/titan/props.py @@ -621,7 +621,7 @@ def render(self, values): columns = [] for column in values: name = column["name"] - comment = f" COMMENT '{column['comment']}'" if "comment" in column else "" + comment = f" COMMENT '{column['comment']}'" if "comment" in column and column["comment"] else "" column_str = f"{name}{comment}" columns.append(column_str) return f"({', '.join(columns)})" diff --git a/titan/resources/database.py b/titan/resources/database.py index 7ee7d79..37b428f 100644 --- a/titan/resources/database.py +++ b/titan/resources/database.py @@ -162,7 +162,14 @@ def __init__( comment=comment, ) if self._data.name != "SNOWFLAKE": - self._public_schema = Schema(name="PUBLIC", implicit=True, owner=owner) + self._public_schema = Schema( + name="PUBLIC", + implicit=True, + owner=owner, + data_retention_time_in_days=data_retention_time_in_days, + max_data_extension_time_in_days=max_data_extension_time_in_days, + default_ddl_collation=default_ddl_collation, + ) self.add(self._public_schema) self.set_tags(tags) diff --git a/titan/resources/external_access_integration.py b/titan/resources/external_access_integration.py index e890e92..13eb76a 100644 --- a/titan/resources/external_access_integration.py +++ b/titan/resources/external_access_integration.py @@ -1,9 +1,11 @@ from dataclasses import dataclass -from .resource import Resource, ResourceSpec, NamedResource +from .resource import Resource, ResourceSpec, NamedResource, ResourcePointer, convert_to_resource from ..resource_name import ResourceName from .network_rule import NetworkRule from .role import Role + +# from .secret import Secret from ..enums import ResourceType from ..scope import AccountScope @@ -33,11 +35,21 @@ def __post_init__(self): if len(self.allowed_authentication_secrets) < 1: raise ValueError("allowed_authentication_secrets must have at least one element if specified") - if "any" in self.allowed_authentication_secrets and len(self.allowed_authentication_secrets) > 1: - raise ValueError("allowed_authentication_secrets must not contain 'any' if there are other secrets") - - if "none" in self.allowed_authentication_secrets and len(self.allowed_authentication_secrets) > 1: - raise ValueError("allowed_authentication_secrets must not contain 'none' if there are other secrets") + if not ( + len(self.allowed_authentication_secrets) == 1 + and self.allowed_authentication_secrets[0] in ("all", "none") + ): + converted_secrets = [] + for secret in self.allowed_authentication_secrets: + if isinstance(secret, (str, ResourceName)): + converted_secrets.append(ResourcePointer(name=secret, resource_type=ResourceType.SECRET)) + elif isinstance(secret, ResourcePointer) and secret.resource_type == ResourceType.SECRET: + converted_secrets.append(secret) + elif isinstance(secret, Resource) and secret.resource_type == ResourceType.SECRET: + converted_secrets.append(secret) + else: + raise ValueError(f"Invalid secret type: {secret}") + self.allowed_authentication_secrets = converted_secrets class ExternalAccessIntegration(NamedResource, Resource): diff --git a/titan/resources/grant.py b/titan/resources/grant.py index 4ad6c99..22c74b7 100644 --- a/titan/resources/grant.py +++ b/titan/resources/grant.py @@ -28,7 +28,7 @@ class _Grant(ResourceSpec): to: Role grant_option: bool = False owner: Role = field(default=None, metadata={"fetchable": False}) - _privs: list[str] = field(default_factory=list, metadata={"forces_add": True}) + _privs: list[str] = field(default_factory=list, metadata={"triggers_create": True}) def __post_init__(self): super().__post_init__() diff --git a/titan/resources/replication_group.py b/titan/resources/replication_group.py index edd93d4..a565795 100644 --- a/titan/resources/replication_group.py +++ b/titan/resources/replication_group.py @@ -46,7 +46,7 @@ class _ReplicationGroup(ResourceSpec): allowed_integration_types: list[IntegrationType] = None ignore_edition_check: bool = None replication_schedule: str = None - owner: Role = "SYSADMIN" + owner: Role = "ACCOUNTADMIN" class ReplicationGroup(NamedResource, Resource): diff --git a/titan/resources/resource.py b/titan/resources/resource.py index 05ab6b4..e9a75fa 100644 --- a/titan/resources/resource.py +++ b/titan/resources/resource.py @@ -10,6 +10,7 @@ import pyparsing as pp from ..enums import AccountEdition, DataType, ParseableEnum, ResourceType +from ..exceptions import ResourceHasContainerException, WrongContainerException, WrongEditionException from ..identifiers import FQN, URN, parse_identifier, resource_label_for_type from ..lifecycle import create_resource, drop_resource from ..parse import _parse_create_header, _parse_props, resolve_resource_class @@ -28,14 +29,6 @@ from ..var import VarString, string_contains_var -class WrongContainerException(Exception): - pass - - -class ResourceHasContainerException(Exception): - pass - - def _suggest_correct_kwargs(expected_kwargs, passed_kwargs): suggestions = {} for passed_kwarg in passed_kwargs: @@ -64,7 +57,7 @@ class Returns(TypedDict): @dataclass -class LifecycleConfig: +class ResourceLifecycleConfig: ignore_changes: list[str] = field(default_factory=list) prevent_destroy: bool = False @@ -149,6 +142,13 @@ def _coerce_resource_field(field_value, field_type): raise TypeError else: return field_value + elif field_type is float: + if isinstance(field_value, float): + return field_value + elif isinstance(field_value, int): + return float(field_value) + else: + raise TypeError else: # Typecheck all other field types (str, int, etc.) if not isinstance(field_value, field_type): @@ -160,13 +160,59 @@ def _coerce_resource_field(field_value, field_type): class ResourceSpecMetadata: fetchable: bool = True triggers_replacement: bool = False - forces_add: bool = False + triggers_create: bool = False ignore_changes: bool = False known_after_apply: bool = False + edition: set[AccountEdition] = field( + default_factory=lambda: {AccountEdition.STANDARD, AccountEdition.ENTERPRISE, AccountEdition.BUSINESS_CRITICAL} + ) @dataclass class ResourceSpec: + + def to_dict(self, account_edition: AccountEdition): + dict_: dict[str, Any] = {} + + def _serialize_field(field, value): + if field.name == "owner": + return str(value.fqn) + elif isinstance(value, ResourcePointer): + return str(value.fqn) + elif isinstance(value, Resource): + if getattr(value, "serialize_inline", False): + return value.to_dict(account_edition) + elif isinstance(value, NamedResource): + return str(value.fqn) + else: + raise Exception(f"Cannot serialize {value}") + elif isinstance(value, ParseableEnum): + return str(value) + elif isinstance(value, list): + return [_serialize_field(field, v) for v in value] + elif isinstance(value, dict): + return {k: _serialize_field(field, v) for k, v in value.items()} + elif isinstance(value, ResourceName): + return str(value) + elif isinstance(value, ResourceTags): + return value.tags + else: + return value + + for f in fields(self): + value = getattr(self, f.name) + field_metadata = ResourceSpecMetadata(**f.metadata) + if account_edition not in field_metadata.edition: + if value != f.default and value is not None: + raise WrongEditionException( + f"Field {self.__class__.__name__}.{f.name} is not supported in edition {account_edition}. Supported editions: {field_metadata.edition}" + ) + else: + continue + dict_[f.name] = _serialize_field(f, value) + + return dict_ + def __post_init__(self): for f in fields(self): field_value = getattr(self, f.name) @@ -234,7 +280,7 @@ def __init__( self._data: ResourceSpec = None self._container: "ResourceContainer" = None self._finalized = False - self.lifecycle = LifecycleConfig(**lifecycle) if lifecycle else LifecycleConfig() + self.lifecycle = ResourceLifecycleConfig(**lifecycle) if lifecycle else ResourceLifecycleConfig() self.implicit = implicit self.refs: set[Resource] = set() @@ -338,52 +384,27 @@ def __eq__(self, other): def __hash__(self): return hash(URN.from_resource(self, "")) - def to_dict(self): - serialized: dict[str, Any] = {} - if self.implicit: - serialized["_implicit"] = True + def to_dict(self, account_edition: Optional[AccountEdition] = None): + return self._data.to_dict(account_edition or AccountEdition.ENTERPRISE) - def _serialize(field, value): - if field.name == "owner": - return str(value.fqn) - elif isinstance(value, ResourcePointer): - return str(value.fqn) - elif isinstance(value, Resource): - if getattr(value, "serialize_inline", False): - return value.to_dict() - elif isinstance(value, NamedResource): - return str(value.fqn) - else: - raise Exception(f"Cannot serialize {value}") - elif isinstance(value, ParseableEnum): - return str(value) - elif isinstance(value, list): - return [_serialize(field, v) for v in value] - elif isinstance(value, dict): - return {k: _serialize(field, v) for k, v in value.items()} - elif isinstance(value, ResourceName): - return str(value) - elif isinstance(value, ResourceTags): - return value.tags - else: - return value - - for f in fields(self._data): - value = getattr(self._data, f.name) - serialized[f.name] = _serialize(f, value) - - return serialized - - def create_sql(self, **kwargs): + def create_sql( + self, + account_edition: Optional[AccountEdition] = None, + **kwargs, + ): return create_resource( self.urn, - self.to_dict(), + self.to_dict(account_edition), self.props, **kwargs, ) - def drop_sql(self, if_exists: bool = False): - return drop_resource(self.urn, self.to_dict(), if_exists=if_exists) + def drop_sql( + self, + if_exists: bool = False, + account_edition: Optional[AccountEdition] = None, + ): + return drop_resource(self.urn, self.to_dict(account_edition), if_exists=if_exists) def _requires(self, resource: "Resource"): if self._finalized: @@ -616,7 +637,7 @@ def fqn(self): def resource_type(self): return self._resource_type - def to_dict(self): + def to_dict(self, _=None): return { "_pointer": True, "name": self.name, diff --git a/titan/resources/role.py b/titan/resources/role.py index b80abef..df4575e 100644 --- a/titan/resources/role.py +++ b/titan/resources/role.py @@ -78,6 +78,18 @@ def __init__( self.set_tags(tags) +@dataclass(unsafe_hash=True) +class _DatabaseRole(ResourceSpec): + name: ResourceName + database: ResourceName + owner: ResourceName = "USERADMIN" + comment: str = None + + def __post_init__(self): + super().__post_init__() + self.owner = ResourcePointer(self.owner, ResourceType.ROLE) + + class DatabaseRole(NamedResource, TaggableResource, Resource): """ Description: @@ -123,7 +135,7 @@ class DatabaseRole(NamedResource, TaggableResource, Resource): comment=StringProp("comment"), ) scope = DatabaseScope() - spec = _Role + spec = _DatabaseRole def __init__( self, @@ -135,14 +147,10 @@ def __init__( **kwargs, ): super().__init__(name, database=database, **kwargs) - self._data: _Role = _Role( + self._data: _DatabaseRole = _DatabaseRole( name=self._name, + database=self.container.name, owner=owner, comment=comment, ) self.set_tags(tags) - - def to_dict(self): - data = super().to_dict() - data["database"] = self.container.name - return data diff --git a/titan/resources/schema.py b/titan/resources/schema.py index b34bb32..31a9c75 100644 --- a/titan/resources/schema.py +++ b/titan/resources/schema.py @@ -1,12 +1,11 @@ from dataclasses import dataclass -from ..builtins import SYSTEM_SCHEMAS from ..enums import ResourceType from ..props import FlagProp, IntProp, Props, StringProp, TagsProp from ..resource_name import ResourceName +from ..role_ref import RoleRef from ..scope import DatabaseScope from .resource import NamedResource, Resource, ResourceContainer, ResourceSpec -from ..role_ref import RoleRef from .tag import TaggableResource @@ -21,13 +20,6 @@ class _Schema(ResourceSpec): owner: RoleRef = "SYSADMIN" comment: str = None - def __post_init__(self): - super().__post_init__() - if self.transient and self.data_retention_time_in_days is not None: - raise ValueError("Transient schema can't have data retention time") - elif not self.transient and self.data_retention_time_in_days is None: - self.data_retention_time_in_days = 1 - class Schema(NamedResource, TaggableResource, Resource, ResourceContainer): """ @@ -99,7 +91,7 @@ def __init__( name: str, transient: bool = False, managed_access: bool = False, - data_retention_time_in_days: int = None, + data_retention_time_in_days: int = 1, max_data_extension_time_in_days: int = 14, default_ddl_collation: str = None, tags: dict[str, str] = None, diff --git a/titan/resources/secret.py b/titan/resources/secret.py index 99ae960..c07e517 100644 --- a/titan/resources/secret.py +++ b/titan/resources/secret.py @@ -291,7 +291,7 @@ def __init__( owner=owner, ) - def to_dict(self): + def to_dict(self, _=None): data = super().to_dict() if data["oauth_scopes"]: data.pop("oauth_refresh_token") diff --git a/titan/resources/user.py b/titan/resources/user.py index 42d5adb..99356ef 100644 --- a/titan/resources/user.py +++ b/titan/resources/user.py @@ -1,5 +1,5 @@ import logging -from dataclasses import dataclass +from dataclasses import dataclass, field from ..enums import ParseableEnum, ResourceType from ..props import BoolProp, EnumProp, IntProp, Props, StringListProp, StringProp, TagsProp @@ -23,8 +23,8 @@ class UserType(ParseableEnum): class _User(ResourceSpec): name: ResourceName owner: Role = "USERADMIN" - password: str = None - login_name: str = None + password: str = field(default=None, metadata={"fetchable": False}) + login_name: ResourceName = None display_name: str = None first_name: str = None middle_name: str = None @@ -32,7 +32,7 @@ class _User(ResourceSpec): email: str = None must_change_password: bool = None disabled: bool = False - days_to_expiry: int = None + days_to_expiry: int = field(default=None, metadata={"known_after_apply": True}) mins_to_unlock: int = None default_warehouse: str = None default_namespace: str = None @@ -67,7 +67,7 @@ def __post_init__(self): else: if self.login_name is None or self.login_name == "": - self.login_name = self.name._name.upper() + self.login_name = ResourceName(str(self.name).upper()) if self.display_name is None: self.display_name = self.name._name if self.must_change_password is None: diff --git a/titan/resources/view.py b/titan/resources/view.py index f7f4581..4f4e6c1 100644 --- a/titan/resources/view.py +++ b/titan/resources/view.py @@ -12,11 +12,54 @@ ) from ..resource_name import ResourceName from ..role_ref import RoleRef -from ..scope import SchemaScope +from ..scope import SchemaScope, TableScope from .resource import NamedResource, Resource, ResourceSpec from .tag import TaggableResource +@dataclass(unsafe_hash=True) +class _ViewColumn(ResourceSpec): + name: ResourceName + comment: str = None + data_type: str = field(default=None, metadata={"known_after_apply": True}) + not_null: bool = False + default: str = None + constraint: str = None + collate: str = None + + +class ViewColumn(NamedResource, Resource): + resource_type = ResourceType.COLUMN + props = Props( + comment=StringProp("comment", eq=False), + ) + scope = TableScope() + spec = _ViewColumn + serialize_inline = True + + def __init__( + self, + name: str, + comment: str = None, + data_type: str = None, + not_null: bool = False, + default: str = None, + constraint: str = None, + collate: str = None, + **kwargs, + ): + super().__init__(name, **kwargs) + self._data: _ViewColumn = _ViewColumn( + name=self._name, + comment=comment, + data_type=data_type, + not_null=not_null, + default=default, + constraint=constraint, + collate=collate, + ) + + @dataclass(unsafe_hash=True) class _View(ResourceSpec): name: ResourceName @@ -24,12 +67,12 @@ class _View(ResourceSpec): secure: bool = False volatile: bool = None recursive: bool = None - columns: list[dict] = None + columns: list[ViewColumn] = None change_tracking: bool = False - copy_grants: bool = field(default_factory=False, metadata={"fetchable": False}) + copy_grants: bool = field(default=False, metadata={"fetchable": False}) comment: str = None # TODO: remove this if parsing is feasible - as_: str = field(default=None, metadata={"fetchable": False}) + as_: str = None # field(default=None, metadata={"fetchable": False}) def __post_init__(self): super().__post_init__() @@ -111,6 +154,12 @@ def __init__( as_: str = None, **kwargs, ): + if "lifecycle" not in kwargs: + lifecycle = { + "ignore_changes": "columns", + } + kwargs["lifecycle"] = lifecycle + super().__init__(name, **kwargs) self._data: _View = _View( name=self._name, diff --git a/titan/resources/warehouse.py b/titan/resources/warehouse.py index e57b2f4..acb7942 100644 --- a/titan/resources/warehouse.py +++ b/titan/resources/warehouse.py @@ -1,6 +1,7 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Optional, Union -from ..enums import ParseableEnum, ResourceType, WarehouseSize +from ..enums import AccountEdition, ParseableEnum, ResourceType, WarehouseSize from ..props import ( BoolProp, EnumProp, @@ -34,16 +35,31 @@ class _Warehouse(ResourceSpec): owner: Role = "SYSADMIN" warehouse_type: WarehouseType = WarehouseType.STANDARD warehouse_size: WarehouseSize = WarehouseSize.XSMALL - max_cluster_count: int = None - min_cluster_count: int = None - scaling_policy: WarehouseScalingPolicy = None + max_cluster_count: int = field( + default=1, + metadata={"edition": {AccountEdition.ENTERPRISE, AccountEdition.BUSINESS_CRITICAL}}, + ) + min_cluster_count: int = field( + default=1, + metadata={"edition": {AccountEdition.ENTERPRISE, AccountEdition.BUSINESS_CRITICAL}}, + ) + scaling_policy: WarehouseScalingPolicy = field( + default=WarehouseScalingPolicy.STANDARD, + metadata={"edition": {AccountEdition.ENTERPRISE, AccountEdition.BUSINESS_CRITICAL}}, + ) auto_suspend: int = 600 auto_resume: bool = True - initially_suspended: bool = None + initially_suspended: bool = field(default=False, metadata={"fetchable": False}) resource_monitor: ResourceMonitor = None comment: str = None - enable_query_acceleration: bool = False - query_acceleration_max_scale_factor: int = None + enable_query_acceleration: bool = field( + default=False, + metadata={"edition": {AccountEdition.ENTERPRISE, AccountEdition.BUSINESS_CRITICAL}}, + ) + query_acceleration_max_scale_factor: int = field( + default=8, + metadata={"edition": {AccountEdition.ENTERPRISE, AccountEdition.BUSINESS_CRITICAL}}, + ) max_concurrency_level: int = 8 statement_queued_timeout_in_seconds: int = 0 statement_timeout_in_seconds: int = 172800 @@ -155,18 +171,18 @@ def __init__( self, name: str, owner: str = "SYSADMIN", - warehouse_type: WarehouseType = "STANDARD", - warehouse_size: WarehouseSize = WarehouseSize.XSMALL, - max_cluster_count: int = None, - min_cluster_count: int = None, - scaling_policy: WarehouseScalingPolicy = None, + warehouse_type: str = "STANDARD", + warehouse_size: str = "XSMALL", + max_cluster_count: int = 1, + min_cluster_count: int = 1, + scaling_policy: str = "STANDARD", auto_suspend: int = 600, auto_resume: bool = True, - initially_suspended: bool = None, - resource_monitor: ResourceMonitor = None, + initially_suspended: bool = False, + resource_monitor: Union[ResourceMonitor, str, None] = None, comment: str = None, enable_query_acceleration: bool = False, - query_acceleration_max_scale_factor: int = None, + query_acceleration_max_scale_factor: int = 8, max_concurrency_level: int = 8, statement_queued_timeout_in_seconds: int = 0, statement_timeout_in_seconds: int = 172800, diff --git a/tools/reset_test_account.py b/tools/reset_test_account.py index 7d7a7e4..4758a00 100644 --- a/tools/reset_test_account.py +++ b/tools/reset_test_account.py @@ -114,6 +114,7 @@ def get_connection(env_vars): def configure_test_accounts(): for account in ["aws.standard", "aws.enterprise"]: + print(">>>>>>>>>>>>>>>>", account) env_vars = dotenv_values(f"env/.env.{account}") conn = get_connection(env_vars) try: diff --git a/tools/test_account.yml b/tools/test_account.yml index ea464de..6d1306b 100644 --- a/tools/test_account.yml +++ b/tools/test_account.yml @@ -14,6 +14,7 @@ allowlist: - "role grant" - "role" - "schema" + - "secret" - "security integration" - "share" - "stage" @@ -258,4 +259,12 @@ external_volumes: storage_provider: S3 storage_base_url: "{{ var.storage_base_url }}" storage_aws_role_arn: "{{ var.storage_role_arn }}" - storage_aws_external_id: iceberg_table_external_id \ No newline at end of file + storage_aws_external_id: iceberg_table_external_id + +secrets: + - name: static_secret + secret_type: PASSWORD + username: someuser + password: somepass + database: static_database + schema: public \ No newline at end of file diff --git a/tools/test_account_enterprise.yml b/tools/test_account_enterprise.yml index e979aaf..30ae7e9 100644 --- a/tools/test_account_enterprise.yml +++ b/tools/test_account_enterprise.yml @@ -1,7 +1,6 @@ name: reset-test-account run_mode: SYNC allowlist: - - "secret" - "tag" - "tag reference" @@ -19,12 +18,3 @@ tags: comment: This is a static tag allowed_values: - STATIC_TAG_VALUE - - -secrets: - - name: static_secret - secret_type: PASSWORD - username: someuser - password: somepass - database: static_database - schema: public \ No newline at end of file