Skip to content

Commit

Permalink
Fix untyped functions in core/dbt/context/base.py (#8525)
Browse files Browse the repository at this point in the history
* Improve typing of `ContextMember` functions

* Improve typing of `Var` functions

* Improve typing of `ContextMeta.__new__`

* Improve typing `BaseContext` and functions

In addition to just adding parameter typing and return typing to
`BaseContext` functions. We also declared `_context_members_` and
`_context_attrs_` as properites of `BaseContext` this was necessary
because they're being accessed in the classes functions. However,
because they were being indirectly instantiated by the metaclass
`ContextMeta`, the properties weren't actually known to exist. By
adding declaring the properties on the `BaseContext`, we let mypy
know they exist.

* Remove bare `invocations` of `@contextmember` and `@contextproperty`, and add typing to them

Previously `contextmember` and `contextproperty` were 2-in-1 decorators.
This meant they could be invoked either as `@contextmember` or
`@contextmember('some_string')`. This was fine until we wanted to return
typing to the functions. In the instance where the bare decorator was used
(i.e. no `(...)` were present) an object was expected to be returned. However
in the instance where parameters were passed on the invocation, a callable
was expected to be returned. Putting a union of both in the return type
made the invocations complain about each others' return type. To get around this
we've dropped the bare invocation as acceptable. The parenthesis are now always
required, but passing a string in them is optional.
  • Loading branch information
QMalcolm authored Aug 31, 2023
1 parent d8e8a78 commit e5e1a27
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 95 deletions.
97 changes: 50 additions & 47 deletions core/dbt/context/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import json
import os
from typing import Any, Dict, NoReturn, Optional, Mapping, Iterable, Set, List
from typing import Any, Callable, Dict, NoReturn, Optional, Mapping, Iterable, Set, List
import threading

from dbt.flags import get_flags
Expand Down Expand Up @@ -86,33 +88,29 @@ def get_context_modules() -> Dict[str, Dict[str, Any]]:


class ContextMember:
def __init__(self, value, name=None):
def __init__(self, value: Any, name: Optional[str] = None) -> None:
self.name = name
self.inner = value

def key(self, default):
def key(self, default: str) -> str:
if self.name is None:
return default
return self.name


def contextmember(value):
if isinstance(value, str):
return lambda v: ContextMember(v, name=value)
return ContextMember(value)
def contextmember(value: Optional[str] = None) -> Callable:
return lambda v: ContextMember(v, name=value)


def contextproperty(value):
if isinstance(value, str):
return lambda v: ContextMember(property(v), name=value)
return ContextMember(property(value))
def contextproperty(value: Optional[str] = None) -> Callable:
return lambda v: ContextMember(property(v), name=value)


class ContextMeta(type):
def __new__(mcls, name, bases, dct):
context_members = {}
context_attrs = {}
new_dct = {}
def __new__(mcls, name, bases, dct: Dict[str, Any]) -> ContextMeta:
context_members: Dict[str, Any] = {}
context_attrs: Dict[str, Any] = {}
new_dct: Dict[str, Any] = {}

for base in bases:
context_members.update(getattr(base, "_context_members_", {}))
Expand Down Expand Up @@ -148,27 +146,28 @@ def _generate_merged(self) -> Mapping[str, Any]:
return self._cli_vars

@property
def node_name(self):
def node_name(self) -> str:
if self._node is not None:
return self._node.name
else:
return "<Configuration>"

def get_missing_var(self, var_name):
raise RequiredVarNotFoundError(var_name, self._merged, self._node)
def get_missing_var(self, var_name: str) -> NoReturn:
# TODO function name implies a non exception resolution
raise RequiredVarNotFoundError(var_name, dict(self._merged), self._node)

def has_var(self, var_name: str):
def has_var(self, var_name: str) -> bool:
return var_name in self._merged

def get_rendered_var(self, var_name):
def get_rendered_var(self, var_name: str) -> Any:
raw = self._merged[var_name]
# if bool/int/float/etc are passed in, don't compile anything
if not isinstance(raw, str):
return raw

return get_rendered(raw, self._context)
return get_rendered(raw, dict(self._context))

def __call__(self, var_name, default=_VAR_NOTSET):
def __call__(self, var_name: str, default: Any = _VAR_NOTSET) -> Any:
if self.has_var(var_name):
return self.get_rendered_var(var_name)
elif default is not self._VAR_NOTSET:
Expand All @@ -178,13 +177,17 @@ def __call__(self, var_name, default=_VAR_NOTSET):


class BaseContext(metaclass=ContextMeta):
# Set by ContextMeta
_context_members_: Dict[str, Any]
_context_attrs_: Dict[str, Any]

# subclass is TargetContext
def __init__(self, cli_vars):
self._ctx = {}
self.cli_vars = cli_vars
self.env_vars = {}
def __init__(self, cli_vars: Dict[str, Any]) -> None:
self._ctx: Dict[str, Any] = {}
self.cli_vars: Dict[str, Any] = cli_vars
self.env_vars: Dict[str, Any] = {}

def generate_builtins(self):
def generate_builtins(self) -> Dict[str, Any]:
builtins: Dict[str, Any] = {}
for key, value in self._context_members_.items():
if hasattr(value, "__get__"):
Expand All @@ -194,14 +197,14 @@ def generate_builtins(self):
return builtins

# no dbtClassMixin so this is not an actual override
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
self._ctx["context"] = self._ctx
builtins = self.generate_builtins()
self._ctx["builtins"] = builtins
self._ctx.update(builtins)
return self._ctx

@contextproperty
@contextproperty()
def dbt_version(self) -> str:
"""The `dbt_version` variable returns the installed version of dbt that
is currently running. It can be used for debugging or auditing
Expand All @@ -221,7 +224,7 @@ def dbt_version(self) -> str:
"""
return dbt_version

@contextproperty
@contextproperty()
def var(self) -> Var:
"""Variables can be passed from your `dbt_project.yml` file into models
during compilation. These variables are useful for configuring packages
Expand Down Expand Up @@ -290,7 +293,7 @@ def var(self) -> Var:
"""
return Var(self._ctx, self.cli_vars)

@contextmember
@contextmember()
def env_var(self, var: str, default: Optional[str] = None) -> str:
"""The env_var() function. Return the environment variable named 'var'.
If there is no such environment variable set, return the default.
Expand Down Expand Up @@ -318,7 +321,7 @@ def env_var(self, var: str, default: Optional[str] = None) -> str:

if os.environ.get("DBT_MACRO_DEBUGGING"):

@contextmember
@contextmember()
@staticmethod
def debug():
"""Enter a debugger at this line in the compiled jinja code."""
Expand Down Expand Up @@ -357,7 +360,7 @@ def _return(data: Any) -> NoReturn:
"""
raise MacroReturn(data)

@contextmember
@contextmember()
@staticmethod
def fromjson(string: str, default: Any = None) -> Any:
"""The `fromjson` context method can be used to deserialize a json
Expand All @@ -378,7 +381,7 @@ def fromjson(string: str, default: Any = None) -> Any:
except ValueError:
return default

@contextmember
@contextmember()
@staticmethod
def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any:
"""The `tojson` context method can be used to serialize a Python
Expand All @@ -401,7 +404,7 @@ def tojson(value: Any, default: Any = None, sort_keys: bool = False) -> Any:
except ValueError:
return default

@contextmember
@contextmember()
@staticmethod
def fromyaml(value: str, default: Any = None) -> Any:
"""The fromyaml context method can be used to deserialize a yaml string
Expand Down Expand Up @@ -432,7 +435,7 @@ def fromyaml(value: str, default: Any = None) -> Any:

# safe_dump defaults to sort_keys=True, but we act like json.dumps (the
# opposite)
@contextmember
@contextmember()
@staticmethod
def toyaml(
value: Any, default: Optional[str] = None, sort_keys: bool = False
Expand Down Expand Up @@ -477,7 +480,7 @@ def _set(value: Iterable[Any], default: Any = None) -> Optional[Set[Any]]:
except TypeError:
return default

@contextmember
@contextmember()
@staticmethod
def set_strict(value: Iterable[Any]) -> Set[Any]:
"""The `set_strict` context method can be used to convert any iterable
Expand Down Expand Up @@ -519,7 +522,7 @@ def _zip(*args: Iterable[Any], default: Any = None) -> Optional[Iterable[Any]]:
except TypeError:
return default

@contextmember
@contextmember()
@staticmethod
def zip_strict(*args: Iterable[Any]) -> Iterable[Any]:
"""The `zip_strict` context method can be used to used to return
Expand All @@ -541,7 +544,7 @@ def zip_strict(*args: Iterable[Any]) -> Iterable[Any]:
except TypeError as e:
raise ZipStrictWrongTypeError(e)

@contextmember
@contextmember()
@staticmethod
def log(msg: str, info: bool = False) -> str:
"""Logs a line to either the log file or stdout.
Expand All @@ -562,7 +565,7 @@ def log(msg: str, info: bool = False) -> str:
fire_event(JinjaLogDebug(msg=msg, node_info=get_node_info()))
return ""

@contextproperty
@contextproperty()
def run_started_at(self) -> Optional[datetime.datetime]:
"""`run_started_at` outputs the timestamp that this run started, e.g.
`2017-04-21 01:23:45.678`. The `run_started_at` variable is a Python
Expand Down Expand Up @@ -590,19 +593,19 @@ def run_started_at(self) -> Optional[datetime.datetime]:
else:
return None

@contextproperty
@contextproperty()
def invocation_id(self) -> Optional[str]:
"""invocation_id outputs a UUID generated for this dbt run (useful for
auditing)
"""
return get_invocation_id()

@contextproperty
@contextproperty()
def thread_id(self) -> str:
"""thread_id outputs an ID for the current thread (useful for auditing)"""
return threading.current_thread().name

@contextproperty
@contextproperty()
def modules(self) -> Dict[str, Any]:
"""The `modules` variable in the Jinja context contains useful Python
modules for operating on data.
Expand All @@ -627,7 +630,7 @@ def modules(self) -> Dict[str, Any]:
""" # noqa
return get_context_modules()

@contextproperty
@contextproperty()
def flags(self) -> Any:
"""The `flags` variable contains true/false values for flags provided
on the command line.
Expand All @@ -644,7 +647,7 @@ def flags(self) -> Any:
"""
return flags_module.get_flag_obj()

@contextmember
@contextmember()
@staticmethod
def print(msg: str) -> str:
"""Prints a line to stdout.
Expand All @@ -662,7 +665,7 @@ def print(msg: str) -> str:
print(msg)
return ""

@contextmember
@contextmember()
@staticmethod
def diff_of_two_dicts(
dict_a: Dict[str, List[str]], dict_b: Dict[str, List[str]]
Expand Down Expand Up @@ -691,7 +694,7 @@ def diff_of_two_dicts(
dict_diff.update({k: dict_a[k]})
return dict_diff

@contextmember
@contextmember()
@staticmethod
def local_md5(value: str) -> str:
"""Calculates an MD5 hash of the given string.
Expand Down
8 changes: 4 additions & 4 deletions core/dbt/context/configured.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, config: AdapterRequiredConfig) -> None:
super().__init__(config.to_target_dict(), config.cli_vars)
self.config = config

@contextproperty
@contextproperty()
def project_name(self) -> str:
return self.config.project_name

Expand Down Expand Up @@ -80,11 +80,11 @@ def __init__(self, config, project_name: str, schema_yaml_vars: Optional[SchemaY
self._project_name = project_name
self.schema_yaml_vars = schema_yaml_vars

@contextproperty
@contextproperty()
def var(self) -> ConfiguredVar:
return ConfiguredVar(self._ctx, self.config, self._project_name)

@contextmember
@contextmember()
def env_var(self, var: str, default: Optional[str] = None) -> str:
return_value = None
if var.startswith(SECRET_ENV_PREFIX):
Expand Down Expand Up @@ -113,7 +113,7 @@ class MacroResolvingContext(ConfiguredContext):
def __init__(self, config):
super().__init__(config)

@contextproperty
@contextproperty()
def var(self) -> ConfiguredVar:
return ConfiguredVar(self._ctx, self.config, self.config.project_name)

Expand Down
2 changes: 1 addition & 1 deletion core/dbt/context/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self.node = node
self.manifest = manifest

@contextmember
@contextmember()
def doc(self, *args: str) -> str:
"""The `doc` function is used to reference docs blocks in schema.yml
files. It is analogous to the `ref` function. For more information,
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/context/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def to_dict(self):
dct.update(self.namespace)
return dct

@contextproperty
@contextproperty()
def context_macro_stack(self):
return self.macro_stack

Expand Down
Loading

0 comments on commit e5e1a27

Please sign in to comment.