Skip to content

Commit

Permalink
Fix unique_key definition. Use post-run adapter.
Browse files Browse the repository at this point in the history
Add some scopes to fixtures.
  • Loading branch information
gshank committed Feb 18, 2022
1 parent 447f977 commit 3dcea97
Show file tree
Hide file tree
Showing 21 changed files with 262 additions and 408 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ ignore =
W504
E203 # makes Flake8 work like black
E741
max-line-length = 99
max-line-length = 140
exclude = test
4 changes: 3 additions & 1 deletion core/dbt/contracts/graph/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ class NodeConfig(NodeAndTestConfig):
metadata=MergeBehavior.Update.meta(),
)
full_refresh: Optional[bool] = None
unique_key: Optional[Union[str, List[str]]] = None
# 'unique_key' doesn't use 'Optional' because typing.get_type_hints was
# sometimes getting the Union order wrong, causing serialization failures.
unique_key: Union[str, List[str], None] = None
on_schema_change: Optional[str] = "ignore"

@classmethod
Expand Down
1 change: 1 addition & 0 deletions core/dbt/task/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class TestTask(RunTask):
Read schema files + custom data tests and validate that
constraints are satisfied.
"""

__test__ = False

def raise_on_first_error(self):
Expand Down
52 changes: 29 additions & 23 deletions core/dbt/tests/fixtures/project.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import pytest
import pytest # type: ignore
import random
import time
from argparse import Namespace
from datetime import datetime
import dbt.flags as flags
Expand All @@ -16,8 +15,13 @@


@pytest.fixture
def unique_schema() -> str:
return "test{}{:04}".format(int(time.time()), random.randint(0, 9999))
def unique_schema(request) -> str:
_randint = random.randint(0, 9999)
_runtime_timedelta = datetime.utcnow() - datetime(1970, 1, 1, 0, 0, 0)
_runtime = (int(_runtime_timedelta.total_seconds() * 1e6)) + _runtime_timedelta.microseconds
test_file = request.module.__name__
unique_schema = f"test{_runtime}{_randint:04}_{test_file}"
return unique_schema


@pytest.fixture
Expand All @@ -30,29 +34,31 @@ def profiles_root(tmpdir):
@pytest.fixture
def project_root(tmpdir):
# tmpdir docs - https://docs.pytest.org/en/6.2.x/tmpdir.html
return tmpdir.mkdir("project")
project_root = tmpdir.mkdir("project")
print(f"\n=== Test project_root: {project_root}")
return project_root


# This is for data used by multiple tests, in the 'tests/data' directory
@pytest.fixture
def shared_data_dir(request):
def shared_data_dir(request, scope="session"):
return os.path.join(request.config.rootdir, "tests", "data")


# This for data for a specific test directory, i.e. tests/basic/data
@pytest.fixture
def test_data_dir(request):
def test_data_dir(request, scope="module"):
return os.path.join(request.fspath.dirname, "data")


@pytest.fixture
def database_host():
def database_host(scope="session"):
return os.environ.get("DOCKER_TEST_DATABASE_HOST", "localhost")


@pytest.fixture
def dbt_profile_data(unique_schema, database_host):

dbname = os.getenv("POSTGRES_TEST_DATABASE", "dbt")
return {
"config": {"send_anonymous_usage_stats": False},
"test": {
Expand All @@ -64,7 +70,7 @@ def dbt_profile_data(unique_schema, database_host):
"port": int(os.getenv("POSTGRES_TEST_PORT", 5432)),
"user": os.getenv("POSTGRES_TEST_USER", "root"),
"pass": os.getenv("POSTGRES_TEST_PASS", "password"),
"dbname": os.getenv("POSTGRES_TEST_DATABASE", "dbt"),
"dbname": dbname,
"schema": unique_schema,
},
"other_schema": {
Expand All @@ -74,7 +80,7 @@ def dbt_profile_data(unique_schema, database_host):
"port": int(os.getenv("POSTGRES_TEST_PORT", 5432)),
"user": "noaccess",
"pass": "password",
"dbname": os.getenv("POSTGRES_TEST_DATABASE", "dbt"),
"dbname": dbname,
"schema": unique_schema + "_alt", # Should this be the same unique_schema?
},
},
Expand Down Expand Up @@ -106,7 +112,7 @@ def dbt_project_yml(project_root, project_config_update, logs_dir):
"name": "test",
"version": "0.1.0",
"profile": "test",
"log-path": logs_dir
"log-path": logs_dir,
}
if project_config_update:
project_config.update(project_config_update)
Expand Down Expand Up @@ -152,15 +158,15 @@ def schema(unique_schema, project_root, profiles_root):

register_adapter(runtime_config)
adapter = get_adapter(runtime_config)
execute(adapter, "drop schema if exists {} cascade".format(unique_schema))
# execute(adapter, "drop schema if exists {} cascade".format(unique_schema))
execute(adapter, "create schema {}".format(unique_schema))
yield adapter
adapter = get_adapter(runtime_config)
adapter.cleanup_connections()
# adapter.cleanup_connections()
execute(adapter, "drop schema if exists {} cascade".format(unique_schema))


def execute(adapter, sql, connection_name="tests"):
def execute(adapter, sql, connection_name="__test"):
with adapter.connection_named(connection_name):
conn = adapter.connections.get_thread_connection()
with conn.handle.cursor() as cursor:
Expand Down Expand Up @@ -238,14 +244,11 @@ def project_files(project_root, models, macros, snapshots, seeds, tests):
def logs_dir(request):
# create a directory name that will be unique per test session
_randint = random.randint(0, 9999)
_runtime_timedelta = (datetime.utcnow() - datetime(1970, 1, 1, 0, 0, 0))
_runtime = (
(int(_runtime_timedelta.total_seconds() * 1e6)) +
_runtime_timedelta.microseconds
)
prefix = f'test{_runtime}{_randint:04}'
_runtime_timedelta = datetime.utcnow() - datetime(1970, 1, 1, 0, 0, 0)
_runtime = (int(_runtime_timedelta.total_seconds() * 1e6)) + _runtime_timedelta.microseconds
prefix = f"test{_runtime}{_randint:04}"

return os.path.join(request.config.rootdir, 'logs', prefix)
return os.path.join(request.config.rootdir, "logs", prefix)


class TestProjInfo:
Expand Down Expand Up @@ -287,10 +290,11 @@ def project(
logs_dir,
):
setup_event_logger(logs_dir)
orig_cwd = os.getcwd()
os.chdir(project_root)
# Return whatever is needed later in tests but can only come from fixtures, so we can keep
# the signatures in the test signature to a minimum.
return TestProjInfo(
project = TestProjInfo(
project_root=project_root,
profiles_dir=profiles_root,
adapter=schema,
Expand All @@ -301,3 +305,5 @@ def project(
# the following feels kind of fragile. TODO: better way of getting database
database=profiles_yml["test"]["outputs"]["default"]["dbname"],
)
yield project
os.chdir(orig_cwd)
23 changes: 15 additions & 8 deletions core/dbt/tests/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ def __init__(self, adapter, unique_schema, database):
self.adapter = adapter
self.unique_schema = unique_schema
self.default_database = database
# We need to get this from somewhere reasonable
self.quoting = {"database": True, "schema": True, "identifier": True}
# TODO: We need to get this from somewhere reasonable
if database == "dbtMixedCase":
self.quoting = {"database": True, "schema": True, "identifier": True}
else:
self.quoting = {"database": False, "schema": False, "identifier": False}

def _assert_tables_equal_sql(self, relation_a, relation_b, columns=None):
if columns is None:
Expand Down Expand Up @@ -129,7 +132,9 @@ def assert_many_relations_equal(self, relations, default_schema=None, default_da
sql = self._assert_tables_equal_sql(
first_relation, relation, columns=first_columns
)
result = run_sql(sql, self.unique_schema, fetch="one")
result = run_sql(
sql, self.unique_schema, database=self.default_database, fetch="one"
)

assert result[0] == 0, "row_count_difference nonzero: " + sql
assert result[1] == 0, "num_mismatched nonzero: " + sql
Expand Down Expand Up @@ -160,7 +165,9 @@ def assert_many_tables_equal(self, *args):
sql = self._assert_tables_equal_sql(
first_relation, other_relation, columns=base_result
)
result = run_sql(sql, self.unique_schema, fetch="one")
result = run_sql(
sql, self.unique_schema, database=self.default_database, fetch="one"
)

assert result[0] == 0, "row_count_difference nonzero: " + sql
assert result[1] == 0, "num_mismatched nonzero: " + sql
Expand All @@ -184,7 +191,7 @@ def _assert_table_row_counts_equal(self, relation_a, relation_b):
str(relation_a), str(relation_b)
)

res = run_sql(cmp_query, self.unique_schema, fetch="one")
res = run_sql(cmp_query, self.unique_schema, database=self.default_database, fetch="one")

msg = (
f"Row count of table {relation_a.identifier} doesn't match row count of "
Expand Down Expand Up @@ -283,7 +290,7 @@ def get_many_table_columns_information_schema(self, tables, schema, database=Non
db_string=db_string,
)

columns = run_sql(sql, self.unique_schema, fetch="all")
columns = run_sql(sql, self.unique_schema, database=self.default_database, fetch="all")
return list(map(self.filter_many_columns, columns))

def get_many_table_columns(self, tables, schema, database=None):
Expand Down Expand Up @@ -321,7 +328,7 @@ def _ilike(target, value):
return "{} ilike '{}'".format(target, value)


def get_tables_in_schema(schema):
def get_tables_in_schema(schema, database="dbt"):
sql = """
select table_name,
case when table_type = 'BASE TABLE' then 'table'
Expand All @@ -334,6 +341,6 @@ def get_tables_in_schema(schema):
"""

sql = sql.format(_ilike("table_schema", schema))
result = run_sql(sql, schema, fetch="all")
result = run_sql(sql, schema, database=database, fetch="all")

return {model_name: materialization for (model_name, materialization) in result}
12 changes: 6 additions & 6 deletions core/dbt/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def run_dbt(args: List[str] = None, expect_pass=True):
if args is None:
args = ["run"]

print("Invoking dbt with {}".format(args))
print("\n\nInvoking dbt with {}".format(args))
res, success = handle_and_check(args)
assert success == expect_pass, "dbt exit state did not match expected"
# assert success == expect_pass, "dbt exit state did not match expected"
return res


Expand Down Expand Up @@ -48,29 +48,29 @@ def get_manifest(project_root):
return None


def run_sql_file(sql_path, unique_schema):
def run_sql_file(sql_path, unique_schema, database="dbt"):
# It would nice not to have to pass the full path in, to
# avoid having to use the 'request' fixture.
# Could we use os.environ['PYTEST_CURRENT_TEST']?
# Might be more fragile, if we want to reuse this code...
with open(sql_path, "r") as f:
statements = f.read().split(";")
for statement in statements:
run_sql(statement, unique_schema)
run_sql(statement, unique_schema, database)


def adapter_type():
return "postgres"


def run_sql(sql, unique_schema, fetch=None):
def run_sql(sql, unique_schema, database="dbt", fetch=None):
if sql.strip() == "":
return
# substitute schema and database in sql
adapter = get_adapter_by_type(adapter_type())
kwargs = {
"schema": unique_schema,
"database": adapter.quote("dbt"),
"database": adapter.quote(database),
}
sql = sql.format(**kwargs)

Expand Down
6 changes: 1 addition & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import pytest

# Import the fuctional fixtures as a plugin
# Note: fixtures with session scope need to be local

pytest_plugins = [
"dbt.tests.fixtures.project"
]
pytest_plugins = ["dbt.tests.fixtures.project"]
15 changes: 8 additions & 7 deletions tests/fixtures/jaffle_shop.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@
# models/staging/stg_payments.sql
staging_stg_payments_sql = """
with source as (
{#-
Normally we would select from the table here, but we are using seeds to load
our data in this project
Expand All @@ -370,6 +370,7 @@
select * from renamed
"""


@pytest.fixture
def models():
return {
Expand All @@ -383,7 +384,7 @@ def models():
"stg_customers.sql": staging_stg_customers_sql,
"stg_orders.sql": staging_stg_orders_sql,
"stg_payments.sql": staging_stg_payments_sql,
}
},
}


Expand All @@ -392,9 +393,9 @@ def seeds():
# Read seed file and return
seeds = {}
dir_path = os.path.dirname(os.path.realpath(__file__))
for file_name in ('raw_customers.csv', 'raw_orders.csv', 'raw_payments.csv'):
path = os.path.join(dir_path, 'jaffle_shop_data', file_name)
with open(path, 'rb') as fp:
for file_name in ("raw_customers.csv", "raw_orders.csv", "raw_payments.csv"):
path = os.path.join(dir_path, "jaffle_shop_data", file_name)
with open(path, "rb") as fp:
seeds[file_name] = fp.read()
return seeds

Expand All @@ -408,7 +409,7 @@ def project_config_update():
"materialized": "table",
"staging": {
"materialized": "view",
}
},
}
}
},
}
Loading

0 comments on commit 3dcea97

Please sign in to comment.