diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index c9112037676..25373787e86 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import json import os -from typing import Any, Dict, NoReturn, Optional, Mapping, Iterable, Set, List +from typing import Any, Callable, Dict, NoReturn, Optional, Mapping, Iterable, Set, List import threading from dbt.flags import get_flags @@ -86,33 +88,29 @@ def get_context_modules() -> Dict[str, Dict[str, Any]]: class ContextMember: - def __init__(self, value, name=None): + def __init__(self, value: Any, name: Optional[str] = None) -> None: self.name = name self.inner = value - def key(self, default): + def key(self, default: str) -> str: if self.name is None: return default return self.name -def contextmember(value): - if isinstance(value, str): - return lambda v: ContextMember(v, name=value) - return ContextMember(value) +def contextmember(value: Optional[str] = None) -> Callable: + return lambda v: ContextMember(v, name=value) -def contextproperty(value): - if isinstance(value, str): - return lambda v: ContextMember(property(v), name=value) - return ContextMember(property(value)) +def contextproperty(value: Optional[str] = None) -> Callable: + return lambda v: ContextMember(property(v), name=value) class ContextMeta(type): - def __new__(mcls, name, bases, dct): - context_members = {} - context_attrs = {} - new_dct = {} + def __new__(mcls, name, bases, dct: Dict[str, Any]) -> ContextMeta: + context_members: Dict[str, Any] = {} + context_attrs: Dict[str, Any] = {} + new_dct: Dict[str, Any] = {} for base in bases: context_members.update(getattr(base, "_context_members_", {})) @@ -148,27 +146,28 @@ def _generate_merged(self) -> Mapping[str, Any]: return self._cli_vars @property - def node_name(self): + def node_name(self) -> str: if self._node is not None: return self._node.name else: return "" - def get_missing_var(self, var_name): - raise RequiredVarNotFoundError(var_name, self._merged, self._node) + def get_missing_var(self, var_name: str) -> NoReturn: + # TODO function name implies a non exception resolution + raise RequiredVarNotFoundError(var_name, dict(self._merged), self._node) - def has_var(self, var_name: str): + def has_var(self, var_name: str) -> bool: return var_name in self._merged - def get_rendered_var(self, var_name): + def get_rendered_var(self, var_name: str) -> Any: raw = self._merged[var_name] # if bool/int/float/etc are passed in, don't compile anything if not isinstance(raw, str): return raw - return get_rendered(raw, self._context) + return get_rendered(raw, dict(self._context)) - def __call__(self, var_name, default=_VAR_NOTSET): + def __call__(self, var_name: str, default: Any = _VAR_NOTSET) -> Any: if self.has_var(var_name): return self.get_rendered_var(var_name) elif default is not self._VAR_NOTSET: @@ -178,13 +177,17 @@ def __call__(self, var_name, default=_VAR_NOTSET): class BaseContext(metaclass=ContextMeta): + # Set by ContextMeta + _context_members_: Dict[str, Any] + _context_attrs_: Dict[str, Any] + # subclass is TargetContext - def __init__(self, cli_vars): - self._ctx = {} - self.cli_vars = cli_vars - self.env_vars = {} + def __init__(self, cli_vars: Dict[str, Any]) -> None: + self._ctx: Dict[str, Any] = {} + self.cli_vars: Dict[str, Any] = cli_vars + self.env_vars: Dict[str, Any] = {} - def generate_builtins(self): + def generate_builtins(self) -> Dict[str, Any]: builtins: Dict[str, Any] = {} for key, value in self._context_members_.items(): if hasattr(value, "__get__"): @@ -194,14 +197,14 @@ def generate_builtins(self): return builtins # no dbtClassMixin so this is not an actual override - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: self._ctx["context"] = self._ctx builtins = self.generate_builtins() self._ctx["builtins"] = builtins self._ctx.update(builtins) return self._ctx - @contextproperty + @contextproperty() def dbt_version(self) -> str: """The `dbt_version` variable returns the installed version of dbt that is currently running. It can be used for debugging or auditing @@ -221,7 +224,7 @@ def dbt_version(self) -> str: """ return dbt_version - @contextproperty + @contextproperty() def var(self) -> Var: """Variables can be passed from your `dbt_project.yml` file into models during compilation. These variables are useful for configuring packages @@ -290,7 +293,7 @@ def var(self) -> Var: """ return Var(self._ctx, self.cli_vars) - @contextmember + @contextmember() def env_var(self, var: str, default: Optional[str] = None) -> str: """The env_var() function. Return the environment variable named 'var'. If there is no such environment variable set, return the default. @@ -318,7 +321,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: if os.environ.get("DBT_MACRO_DEBUGGING"): - @contextmember + @contextmember() @staticmethod def debug(): """Enter a debugger at this line in the compiled jinja code.""" @@ -357,7 +360,7 @@ def _return(data: Any) -> NoReturn: """ raise MacroReturn(data) - @contextmember + @contextmember() @staticmethod def fromjson(string: str, default: Any = None) -> Any: """The `fromjson` context method can be used to deserialize a json @@ -378,7 +381,7 @@ def fromjson(string: str, default: Any = None) -> Any: except ValueError: return default - @contextmember + @contextmember() @staticmethod def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any: """The `tojson` context method can be used to serialize a Python @@ -401,7 +404,7 @@ def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any: except ValueError: return default - @contextmember + @contextmember() @staticmethod def fromyaml(value: str, default: Any = None) -> Any: """The fromyaml context method can be used to deserialize a yaml string @@ -432,7 +435,7 @@ def fromyaml(value: str, default: Any = None) -> Any: # safe_dump defaults to sort_keys=True, but we act like json.dumps (the # opposite) - @contextmember + @contextmember() @staticmethod def toyaml( value: Any, default: Optional[str] = None, sort_keys: bool = False @@ -477,7 +480,7 @@ def _set(value: Iterable[Any], default: Any = None) -> Optional[Set[Any]]: except TypeError: return default - @contextmember + @contextmember() @staticmethod def set_strict(value: Iterable[Any]) -> Set[Any]: """The `set_strict` context method can be used to convert any iterable @@ -519,7 +522,7 @@ def _zip(*args: Iterable[Any], default: Any = None) -> Optional[Iterable[Any]]: except TypeError: return default - @contextmember + @contextmember() @staticmethod def zip_strict(*args: Iterable[Any]) -> Iterable[Any]: """The `zip_strict` context method can be used to used to return @@ -541,7 +544,7 @@ def zip_strict(*args: Iterable[Any]) -> Iterable[Any]: except TypeError as e: raise ZipStrictWrongTypeError(e) - @contextmember + @contextmember() @staticmethod def log(msg: str, info: bool = False) -> str: """Logs a line to either the log file or stdout. @@ -562,7 +565,7 @@ def log(msg: str, info: bool = False) -> str: fire_event(JinjaLogDebug(msg=msg, node_info=get_node_info())) return "" - @contextproperty + @contextproperty() def run_started_at(self) -> Optional[datetime.datetime]: """`run_started_at` outputs the timestamp that this run started, e.g. `2017-04-21 01:23:45.678`. The `run_started_at` variable is a Python @@ -590,19 +593,19 @@ def run_started_at(self) -> Optional[datetime.datetime]: else: return None - @contextproperty + @contextproperty() def invocation_id(self) -> Optional[str]: """invocation_id outputs a UUID generated for this dbt run (useful for auditing) """ return get_invocation_id() - @contextproperty + @contextproperty() def thread_id(self) -> str: """thread_id outputs an ID for the current thread (useful for auditing)""" return threading.current_thread().name - @contextproperty + @contextproperty() def modules(self) -> Dict[str, Any]: """The `modules` variable in the Jinja context contains useful Python modules for operating on data. @@ -627,7 +630,7 @@ def modules(self) -> Dict[str, Any]: """ # noqa return get_context_modules() - @contextproperty + @contextproperty() def flags(self) -> Any: """The `flags` variable contains true/false values for flags provided on the command line. @@ -644,7 +647,7 @@ def flags(self) -> Any: """ return flags_module.get_flag_obj() - @contextmember + @contextmember() @staticmethod def print(msg: str) -> str: """Prints a line to stdout. @@ -662,7 +665,7 @@ def print(msg: str) -> str: print(msg) return "" - @contextmember + @contextmember() @staticmethod def diff_of_two_dicts( dict_a: Dict[str, List[str]], dict_b: Dict[str, List[str]] @@ -691,7 +694,7 @@ def diff_of_two_dicts( dict_diff.update({k: dict_a[k]}) return dict_diff - @contextmember + @contextmember() @staticmethod def local_md5(value: str) -> str: """Calculates an MD5 hash of the given string. diff --git a/core/dbt/context/configured.py b/core/dbt/context/configured.py index bb292a19565..08f5bee1143 100644 --- a/core/dbt/context/configured.py +++ b/core/dbt/context/configured.py @@ -19,7 +19,7 @@ def __init__(self, config: AdapterRequiredConfig) -> None: super().__init__(config.to_target_dict(), config.cli_vars) self.config = config - @contextproperty + @contextproperty() def project_name(self) -> str: return self.config.project_name @@ -80,11 +80,11 @@ def __init__(self, config, project_name: str, schema_yaml_vars: Optional[SchemaY self._project_name = project_name self.schema_yaml_vars = schema_yaml_vars - @contextproperty + @contextproperty() def var(self) -> ConfiguredVar: return ConfiguredVar(self._ctx, self.config, self._project_name) - @contextmember + @contextmember() def env_var(self, var: str, default: Optional[str] = None) -> str: return_value = None if var.startswith(SECRET_ENV_PREFIX): @@ -113,7 +113,7 @@ class MacroResolvingContext(ConfiguredContext): def __init__(self, config): super().__init__(config) - @contextproperty + @contextproperty() def var(self) -> ConfiguredVar: return ConfiguredVar(self._ctx, self.config, self.config.project_name) diff --git a/core/dbt/context/docs.py b/core/dbt/context/docs.py index 3d5abf42e11..94f64709fc7 100644 --- a/core/dbt/context/docs.py +++ b/core/dbt/context/docs.py @@ -24,7 +24,7 @@ def __init__( self.node = node self.manifest = manifest - @contextmember + @contextmember() def doc(self, *args: str) -> str: """The `doc` function is used to reference docs blocks in schema.yml files. It is analogous to the `ref` function. For more information, diff --git a/core/dbt/context/manifest.py b/core/dbt/context/manifest.py index c6a39993d92..f2492612cc8 100644 --- a/core/dbt/context/manifest.py +++ b/core/dbt/context/manifest.py @@ -67,7 +67,7 @@ def to_dict(self): dct.update(self.namespace) return dct - @contextproperty + @contextproperty() def context_macro_stack(self): return self.macro_stack diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 6b981091682..ffc1f6d07b4 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -754,19 +754,19 @@ def _get_namespace_builder(self): self.model, ) - @contextproperty + @contextproperty() def dbt_metadata_envs(self) -> Dict[str, str]: return get_metadata_vars() - @contextproperty + @contextproperty() def invocation_args_dict(self): return args_to_dict(self.config.args) - @contextproperty + @contextproperty() def _sql_results(self) -> Dict[str, Optional[AttrDict]]: return self.sql_results - @contextmember + @contextmember() def load_result(self, name: str) -> Optional[AttrDict]: if name in self.sql_results: # handle the special case of "main" macro @@ -787,7 +787,7 @@ def load_result(self, name: str) -> Optional[AttrDict]: # Handle trying to load a result that was never stored return None - @contextmember + @contextmember() def store_result( self, name: str, response: Any, agate_table: Optional[agate.Table] = None ) -> str: @@ -803,7 +803,7 @@ def store_result( ) return "" - @contextmember + @contextmember() def store_raw_result( self, name: str, @@ -815,7 +815,7 @@ def store_raw_result( response = AdapterResponse(_message=message, code=code, rows_affected=rows_affected) return self.store_result(name, response, agate_table) - @contextproperty + @contextproperty() def validation(self): def validate_any(*args) -> Callable[[T], None]: def inner(value: T) -> None: @@ -836,7 +836,7 @@ def inner(value: T) -> None: } ) - @contextmember + @contextmember() def write(self, payload: str) -> str: # macros/source defs aren't 'writeable'. if isinstance(self.model, (Macro, SourceDefinition)): @@ -845,11 +845,11 @@ def write(self, payload: str) -> str: self.model.write_node(self.config.project_root, self.model.build_path, payload) return "" - @contextmember + @contextmember() def render(self, string: str) -> str: return get_rendered(string, self._ctx, self.model) - @contextmember + @contextmember() def try_or_compiler_error( self, message_if_exception: str, func: Callable, *args, **kwargs ) -> Any: @@ -858,7 +858,7 @@ def try_or_compiler_error( except Exception: raise CompilationError(message_if_exception, self.model) - @contextmember + @contextmember() def load_agate_table(self) -> agate.Table: if not isinstance(self.model, SeedNode): raise LoadAgateTableNotSeedError(self.model.resource_type, node=self.model) @@ -873,7 +873,7 @@ def load_agate_table(self) -> agate.Table: table.original_abspath = os.path.abspath(path) return table - @contextproperty + @contextproperty() def ref(self) -> Callable: """The most important function in dbt is `ref()`; it's impossible to build even moderately complex models without it. `ref()` is how you @@ -914,11 +914,11 @@ def ref(self) -> Callable: """ return self.provider.ref(self.db_wrapper, self.model, self.config, self.manifest) - @contextproperty + @contextproperty() def source(self) -> Callable: return self.provider.source(self.db_wrapper, self.model, self.config, self.manifest) - @contextproperty + @contextproperty() def metric(self) -> Callable: return self.provider.metric(self.db_wrapper, self.model, self.config, self.manifest) @@ -979,7 +979,7 @@ def ctx_config(self) -> Config: """ # noqa return self.provider.Config(self.model, self.context_config) - @contextproperty + @contextproperty() def execute(self) -> bool: """`execute` is a Jinja variable that returns True when dbt is in "execute" mode. @@ -1040,7 +1040,7 @@ def execute(self) -> bool: """ # noqa return self.provider.execute - @contextproperty + @contextproperty() def exceptions(self) -> Dict[str, Any]: """The exceptions namespace can be used to raise warnings and errors in dbt userspace. @@ -1078,15 +1078,15 @@ def exceptions(self) -> Dict[str, Any]: """ # noqa return wrapped_exports(self.model) - @contextproperty + @contextproperty() def database(self) -> str: return self.config.credentials.database - @contextproperty + @contextproperty() def schema(self) -> str: return self.config.credentials.schema - @contextproperty + @contextproperty() def var(self) -> ModelConfiguredVar: return self.provider.Var( context=self._ctx, @@ -1103,22 +1103,22 @@ def ctx_adapter(self) -> BaseDatabaseWrapper: """ return self.db_wrapper - @contextproperty + @contextproperty() def api(self) -> Dict[str, Any]: return { "Relation": self.db_wrapper.Relation, "Column": self.adapter.Column, } - @contextproperty + @contextproperty() def column(self) -> Type[Column]: return self.adapter.Column - @contextproperty + @contextproperty() def env(self) -> Dict[str, Any]: return self.target - @contextproperty + @contextproperty() def graph(self) -> Dict[str, Any]: """The `graph` context variable contains information about the nodes in your dbt project. Models, sources, tests, and snapshots are all @@ -1234,23 +1234,23 @@ def ctx_model(self) -> Dict[str, Any]: ret["compiled_sql"] = ret["compiled_code"] return ret - @contextproperty + @contextproperty() def pre_hooks(self) -> Optional[List[Dict[str, Any]]]: return None - @contextproperty + @contextproperty() def post_hooks(self) -> Optional[List[Dict[str, Any]]]: return None - @contextproperty + @contextproperty() def sql(self) -> Optional[str]: return None - @contextproperty + @contextproperty() def sql_now(self) -> str: return self.adapter.date_function() - @contextmember + @contextmember() def adapter_macro(self, name: str, *args, **kwargs): """This was deprecated in v0.18 in favor of adapter.dispatch""" msg = ( @@ -1262,7 +1262,7 @@ def adapter_macro(self, name: str, *args, **kwargs): ) raise CompilationError(msg) - @contextmember + @contextmember() def env_var(self, var: str, default: Optional[str] = None) -> str: """The env_var() function. Return the environment variable named 'var'. If there is no such environment variable set, return the default. @@ -1306,7 +1306,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str: else: raise EnvVarMissingError(var) - @contextproperty + @contextproperty() def selected_resources(self) -> List[str]: """The `selected_resources` variable contains a list of the resources selected based on the parameters provided to the dbt command. @@ -1315,7 +1315,7 @@ def selected_resources(self) -> List[str]: """ return selected_resources.SELECTED_RESOURCES - @contextmember + @contextmember() def submit_python_job(self, parsed_model: Dict, compiled_code: str) -> AdapterResponse: # Check macro_stack and that the unique id is for a materialization macro if not ( @@ -1358,7 +1358,7 @@ def __init__( class ModelContext(ProviderContext): model: ManifestNode - @contextproperty + @contextproperty() def pre_hooks(self) -> List[Dict[str, Any]]: if self.model.resource_type in [NodeType.Source, NodeType.Test]: return [] @@ -1367,7 +1367,7 @@ def pre_hooks(self) -> List[Dict[str, Any]]: h.to_dict(omit_none=True) for h in self.model.config.pre_hook # type: ignore[union-attr] # noqa ] - @contextproperty + @contextproperty() def post_hooks(self) -> List[Dict[str, Any]]: if self.model.resource_type in [NodeType.Source, NodeType.Test]: return [] @@ -1376,7 +1376,7 @@ def post_hooks(self) -> List[Dict[str, Any]]: h.to_dict(omit_none=True) for h in self.model.config.post_hook # type: ignore[union-attr] # noqa ] - @contextproperty + @contextproperty() def sql(self) -> Optional[str]: # only doing this in sql model for backward compatible if self.model.language == ModelLanguage.sql: # type: ignore[union-attr] @@ -1393,7 +1393,7 @@ def sql(self) -> Optional[str]: else: return None - @contextproperty + @contextproperty() def compiled_code(self) -> Optional[str]: if getattr(self.model, "defer_relation", None): # TODO https://github.com/dbt-labs/dbt-core/issues/7976 @@ -1404,15 +1404,15 @@ def compiled_code(self) -> Optional[str]: else: return None - @contextproperty + @contextproperty() def database(self) -> str: return getattr(self.model, "database", self.config.credentials.database) - @contextproperty + @contextproperty() def schema(self) -> str: return getattr(self.model, "schema", self.config.credentials.schema) - @contextproperty + @contextproperty() def this(self) -> Optional[RelationProxy]: """`this` makes available schema information about the currently executing model. It's is useful in any context in which you need to @@ -1447,7 +1447,7 @@ def this(self) -> Optional[RelationProxy]: return None return self.db_wrapper.Relation.create_from(self.config, self.model) - @contextproperty + @contextproperty() def defer_relation(self) -> Optional[RelationProxy]: """ For commands which add information about this node's corresponding @@ -1661,7 +1661,7 @@ def _build_test_namespace(self): ) self.namespace = macro_namespace - @contextmember + @contextmember() def env_var(self, var: str, default: Optional[str] = None) -> str: return_value = None if var.startswith(SECRET_ENV_PREFIX): diff --git a/core/dbt/context/secret.py b/core/dbt/context/secret.py index 4d8ff342aff..2c75546c42a 100644 --- a/core/dbt/context/secret.py +++ b/core/dbt/context/secret.py @@ -14,7 +14,7 @@ class SecretContext(BaseContext): """This context is used in profiles.yml + packages.yml. It can render secret env vars that aren't usable elsewhere""" - @contextmember + @contextmember() def env_var(self, var: str, default: Optional[str] = None) -> str: """The env_var() function. Return the environment variable named 'var'. If there is no such environment variable set, return the default. diff --git a/core/dbt/context/target.py b/core/dbt/context/target.py index a6d587269d5..39c5a30ee4e 100644 --- a/core/dbt/context/target.py +++ b/core/dbt/context/target.py @@ -9,7 +9,7 @@ def __init__(self, target_dict: Dict[str, Any], cli_vars: Dict[str, Any]): super().__init__(cli_vars=cli_vars) self.target_dict = target_dict - @contextproperty + @contextproperty() def target(self) -> Dict[str, Any]: """`target` contains information about your connection to the warehouse (specified in profiles.yml). Some configs are shared between all