Skip to content

Commit

Permalink
Merge pull request #1788 from fishtown-analytics/feature/warehouse-mo…
Browse files Browse the repository at this point in the history
…del-config

Add runtime per-model warehouse config on snowflake models (#1358)
  • Loading branch information
beckjake authored Sep 27, 2019
2 parents 286753b + 52f6243 commit 812c549
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 10 deletions.
28 changes: 27 additions & 1 deletion core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from contextlib import contextmanager
from datetime import datetime
from typing import (
Optional, Tuple, Callable, Container, FrozenSet, Type, Dict, Any, List
Optional, Tuple, Callable, Container, FrozenSet, Type, Dict, Any, List,
Mapping
)

import agate
Expand Down Expand Up @@ -1010,3 +1011,28 @@ def calculate_freshness(
'snapshotted_at': snapshotted_at,
'age': age,
}

def pre_model_hook(self, config: Mapping[str, Any]) -> Any:
"""A hook for running some operation before the model materialization
runs. The hook can assume it has a connection available.
The only parameter is a configuration dictionary (the same one
available in the materialization context). It should be considered
read-only.
The pre-model hook may return anything as a context, which will be
passed to the post-model hook.
"""
pass

def post_model_hook(self, config: Mapping[str, Any], context: Any) -> None:
"""A hook for running some operation after the model materialization
runs. The hook can assume it has a connection available.
The first parameter is a configuration dictionary (the same one
available in the materialization context). It should be considered
read-only.
The second parameter is the value returned by pre_mdoel_hook.
"""
pass
12 changes: 6 additions & 6 deletions core/dbt/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,12 +445,12 @@ def emit(self, record: logbook.LogRecord):


# we still need to use logging to suppress these or pytest captures them
logging.getLogger('botocore').setLevel(logging.INFO)
logging.getLogger('requests').setLevel(logging.INFO)
logging.getLogger('urllib3').setLevel(logging.INFO)
logging.getLogger('google').setLevel(logging.INFO)
logging.getLogger('snowflake.connector').setLevel(logging.INFO)
logging.getLogger('parsedatetime').setLevel(logging.INFO)
logging.getLogger('botocore').setLevel(logging.ERROR)
logging.getLogger('requests').setLevel(logging.ERROR)
logging.getLogger('urllib3').setLevel(logging.ERROR)
logging.getLogger('google').setLevel(logging.ERROR)
logging.getLogger('snowflake.connector').setLevel(logging.ERROR)
logging.getLogger('parsedatetime').setLevel(logging.ERROR)
# want to see werkzeug logs about errors
logging.getLogger('werkzeug').setLevel(logging.ERROR)

Expand Down
13 changes: 12 additions & 1 deletion core/dbt/node_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,18 @@ def execute(self, model, manifest):
if materialization_macro is None:
missing_materialization(model, self.adapter.type())

result = materialization_macro.generator(context)()
if 'config' not in context:
raise InternalException(
'Invalid materialization context generated, missing config: {}'
.format(context)
)
context_config = context['config']

hook_ctx = self.adapter.pre_model_hook(context_config)
try:
result = materialization_macro.generator(context)()
finally:
self.adapter.post_model_hook(context_config, hook_ctx)

for relation in self._materialization_relations(result, model):
self.adapter.cache_added(relation.incorporate(dbt_created=True))
Expand Down
1 change: 1 addition & 0 deletions plugins/snowflake/dbt/adapters/snowflake/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def _rollback_handle(cls, connection):
"""On snowflake, rolling back the handle of an aborted session raises
an exception.
"""
logger.debug('initiating rollback')
try:
connection.handle.rollback()
except snowflake.connector.errors.ProgrammingError as e:
Expand Down
36 changes: 35 additions & 1 deletion plugins/snowflake/dbt/adapters/snowflake/impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Mapping, Any, Optional

from dbt.adapters.sql import SQLAdapter
from dbt.adapters.snowflake import SnowflakeConnectionManager
from dbt.adapters.snowflake import SnowflakeRelation
from dbt.utils import filter_null_values
from dbt.exceptions import RuntimeException


class SnowflakeAdapter(SQLAdapter):
Expand All @@ -10,7 +13,7 @@ class SnowflakeAdapter(SQLAdapter):

AdapterSpecificConfigs = frozenset(
{"transient", "cluster_by", "automatic_clustering", "secure",
"copy_grants"}
"copy_grants", "warehouse"}
)

@classmethod
Expand Down Expand Up @@ -40,3 +43,34 @@ def _make_match_kwargs(self, database, schema, identifier):
return filter_null_values(
{"identifier": identifier, "schema": schema, "database": database}
)

def _get_warehouse(self) -> str:
_, table = self.execute(
'select current_warehouse() as warehouse',
fetch=True
)
if len(table) == 0 or len(table[0]) == 0:
# can this happen?
raise RuntimeException(
'Could not get current warehouse: no results'
)
return str(table[0][0])

def _use_warehouse(self, warehouse: str):
"""Use the given warehouse. Quotes are never applied."""
self.execute('use warehouse {}'.format(warehouse))

def pre_model_hook(self, config: Mapping[str, Any]) -> Optional[str]:
default_warehouse = self.config.credentials.warehouse
warehouse = config.get('warehouse', default_warehouse)
if warehouse == default_warehouse or warehouse is None:
return None
previous = self._get_warehouse()
self._use_warehouse(warehouse)
return previous

def post_model_hook(
self, config: Mapping[str, Any], context: Optional[str]
) -> None:
if context is not None:
self._use_warehouse(context)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{{ config(materialized='table') }}
select 'DBT_TEST_ALT' as warehouse
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{{ config(warehouse='DBT_TEST_DOES_NOT_EXIST') }}
select current_warehouse() as warehouse
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{{ config(warehouse='DBT_TEST_ALT', materialized='table') }}
select current_warehouse() as warehouse
28 changes: 28 additions & 0 deletions test/integration/050_warehouse_test/test_warehouses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from test.integration.base import DBTIntegrationTest, use_profile
import os


class TestDebug(DBTIntegrationTest):
@property
def schema(self):
return 'dbt_warehouse_050'

@staticmethod
def dir(value):
return os.path.normpath(value)

@property
def models(self):
return self.dir('models')

@use_profile('snowflake')
def test_snowflake_override_ok(self):
self.run_dbt([
'run',
'--models', 'override_warehouse', 'expected_warehouse',
])
self.assertManyRelationsEqual([['OVERRIDE_WAREHOUSE'], ['EXPECTED_WAREHOUSE']])

@use_profile('snowflake')
def test_snowflake_override_noexist(self):
self.run_dbt(['run', '--models', 'invalid_warehouse'], expect_pass=False)
63 changes: 62 additions & 1 deletion test/unit/test_snowflake_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from contextlib import contextmanager
from unittest import mock

import dbt.flags as flags
Expand Down Expand Up @@ -46,7 +47,8 @@ def setUp(self):
self.cursor = self.handle.cursor.return_value
self.mock_execute = self.cursor.execute
self.patcher = mock.patch(
'dbt.adapters.snowflake.connections.snowflake.connector.connect')
'dbt.adapters.snowflake.connections.snowflake.connector.connect'
)
self.snowflake = self.patcher.start()

self.load_patch = mock.patch('dbt.loader.make_parse_result')
Expand Down Expand Up @@ -133,6 +135,65 @@ def test_quoting_on_rename(self):
)
])

@contextmanager
def current_warehouse(self, response):
# there is probably some elegant way built into mock.patch to do this
fetchall_return = self.cursor.fetchall.return_value
execute_side_effect = self.mock_execute.side_effect

def execute_effect(sql, *args, **kwargs):
if sql == 'select current_warehouse() as warehouse':
self.cursor.description = [['name']]
self.cursor.fetchall.return_value = [[response]]
else:
self.cursor.description = None
self.cursor.fetchall.return_value = fetchall_return
return self.mock_execute.return_value

self.mock_execute.side_effect = execute_effect
try:
yield
finally:
self.cursor.fetchall.return_value = fetchall_return
self.mock_execute.side_effect = execute_side_effect

def _strip_transactions(self):
result = []
for call_args in self.mock_execute.call_args_list:
args, kwargs = tuple(call_args)
is_transactional = (
len(kwargs) == 0 and
len(args) == 2 and
args[1] is None and
args[0] in {'BEGIN', 'COMMIT'}
)
if not is_transactional:
result.append(call_args)
return result

def test_pre_post_hooks_warehouse(self):
with self.current_warehouse('warehouse'):
config = {'warehouse': 'other_warehouse'}
result = self.adapter.pre_model_hook(config)
self.assertIsNotNone(result)
calls = [
mock.call('select current_warehouse() as warehouse', None),
mock.call('use warehouse other_warehouse', None)
]
self.mock_execute.assert_has_calls(calls)
self.adapter.post_model_hook(config, result)
calls.append(mock.call('use warehouse warehouse', None))
self.mock_execute.assert_has_calls(calls)

def test_pre_post_hooks_no_warehouse(self):
with self.current_warehouse('warehouse'):
config = {}
result = self.adapter.pre_model_hook(config)
self.assertIsNone(result)
self.mock_execute.assert_not_called()
self.adapter.post_model_hook(config, result)
self.mock_execute.assert_not_called()

def test_cancel_open_connections_empty(self):
self.assertEqual(len(list(self.adapter.cancel_open_connections())), 0)

Expand Down

0 comments on commit 812c549

Please sign in to comment.