Skip to content

Commit

Permalink
introduce Anti-corruption Layer for yaml library (#913)
Browse files Browse the repository at this point in the history
Signed-off-by: Ching Yi, Chan <qrtt1@infuseai.io>
  • Loading branch information
qrtt1 authored Oct 24, 2023
1 parent a5d77ec commit c35b2b7
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 94 deletions.
28 changes: 3 additions & 25 deletions piperider_cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from dateutil import tz
from rich.console import Console
from ruamel import yaml
from piperider_cli import yaml as pyml

PIPERIDER_USER_HOME = os.path.expanduser('~/.piperider')
if os.access(os.path.expanduser('~/'), os.W_OK) is False:
Expand Down Expand Up @@ -79,7 +79,7 @@ def get_sentry_dns():

def get_user_id():
with open(PIPERIDER_USER_PROFILE, 'r') as f:
user_profile = yaml.YAML().load(f)
user_profile = pyml.load(f)
return user_profile.get('user_id')


Expand Down Expand Up @@ -159,28 +159,6 @@ def raise_exception_when_directory_not_writable(output):
raise Exception(f'The path "{output}" is not writable')


def safe_load_yaml(file_path):
try:
with open(file_path, 'r') as f:
payload = yaml.safe_load(f)
except yaml.YAMLError as e:
print(e)
return None
except FileNotFoundError:
return None
return payload


def round_trip_load_yaml(file_path):
with open(file_path, 'r') as f:
try:
payload = yaml.round_trip_load(f)
except yaml.YAMLError as e:
print(e)
return None
return payload


def load_json(file_path):
with open(file_path, 'r') as f:
try:
Expand Down Expand Up @@ -264,7 +242,7 @@ def open_report_in_browser(report_path='', is_cloud_path=False):
protocol_prefix = "" if is_cloud_path else "file://"
try:
webbrowser.open(f"{protocol_prefix}{report_path}")
except yaml.YAMLError as e:
except pyml.YAMLError as e:
print(e)
return None

Expand Down
9 changes: 4 additions & 5 deletions piperider_cli/assertion_engine/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from typing import List, Dict

from deepmerge import always_merger
from ruamel import yaml
from ruamel.yaml.comments import CommentedMap
from sqlalchemy import inspect
from sqlalchemy.engine import Engine

from piperider_cli import safe_load_yaml, round_trip_load_yaml
from piperider_cli import yaml as pyml
from piperider_cli.yaml import safe_load_yaml, round_trip_load_yaml
from piperider_cli.configuration import FileSystem
from piperider_cli.error import \
AssertionError, \
Expand Down Expand Up @@ -532,7 +531,7 @@ def _update_existing_recommended_assertions(self, recommended_assertions):
def merge_assertions(target: str, existed_items: List, new_generating_items: List):
if new_generating_items.get(target) is None:
# Column or table doesn't exist in the existing assertions
new_generating_items[target] = CommentedMap(existed_items[target])
new_generating_items[target] = pyml.CommentedMap(existed_items[target])
is_generated_by_us = False
for assertion in new_generating_items[target].get('tests', []):
is_generated_by_us = self._is_recommended_assertion(assertion)
Expand Down Expand Up @@ -605,7 +604,7 @@ def _dump_assertions_files(self, assertions, prefix=''):
if assertion.get('skip'): # skip if it already exists user-defined assertions
continue
with open(file_path, 'w') as f:
yaml.YAML().dump(assertion, f)
pyml.dump(assertion, f)
paths.append(file_path)
return paths

Expand Down
34 changes: 16 additions & 18 deletions piperider_cli/assertion_engine/recommender.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import inspect
from typing import Callable, List, Dict

from ruamel.yaml.comments import CommentedMap, CommentedSeq
from ruamel.yaml.error import CommentMark
from ruamel.yaml.tokens import CommentToken
from piperider_cli import yaml as pyml

from .recommended_rules import RecommendedAssertion, RecommendedRules

Expand All @@ -14,28 +12,28 @@

class AssertionRecommender:
def __init__(self):
self.assertions: Dict[CommentedMap] = {}
self.assertions: Dict[pyml.CommentedMap] = {}
self.recommended_rule_callbacks = []
self.load_recommended_rules()
self.generated_assertions: List[RecommendedAssertion] = []

def prepare_assertion_template(self, profiling_result):
for name, table in profiling_result.get('tables', {}).items():
# Generate template of assertions
table_assertions = CommentedSeq()
columns = CommentedMap()
table_assertions = pyml.CommentedSeq()
columns = pyml.CommentedMap()

# Generate assertions for columns
for col in table.get('columns', {}).keys():
column_name = str(col)
column_assertions = CommentedSeq()
columns[column_name] = CommentedMap({
column_assertions = pyml.CommentedSeq()
columns[column_name] = pyml.CommentedMap({
'tests': column_assertions,
})

# Generate assertions for table
recommended_assertion = CommentedMap({
name: CommentedMap({
recommended_assertion = pyml.CommentedMap({
name: pyml.CommentedMap({
'tests': table_assertions,
'columns': columns,
})})
Expand All @@ -54,20 +52,20 @@ def run(self, profiling_result) -> List[RecommendedAssertion]:
self.prepare_assertion_template(profiling_result)

for table, ta in self.assertions.items():
table_assertions: CommentedSeq = ta[table]['tests']
table_assertions: pyml.CommentedSeq = ta[table]['tests']
for callback in self.recommended_rule_callbacks:
assertion: RecommendedAssertion = callback(table, None, profiling_result)
if assertion:
if assertion.name:
table_assertions.append(CommentedMap({
table_assertions.append(pyml.CommentedMap({
'name': assertion.name,
'assert': CommentedMap(assertion.asserts),
'assert': pyml.CommentedMap(assertion.asserts),
'tags': [RECOMMENDED_ASSERTION_TAG]
}))
else:
table_assertions.append(CommentedMap({
table_assertions.append(pyml.CommentedMap({
'metric': assertion.metric,
'assert': CommentedMap(assertion.asserts),
'assert': pyml.CommentedMap(assertion.asserts),
'tags': [RECOMMENDED_ASSERTION_TAG]
}))
assertion.table = table
Expand All @@ -82,13 +80,13 @@ def run(self, profiling_result) -> List[RecommendedAssertion]:
assertion.table = table
assertion.column = column
if assertion.asserts:
column_assertions.append(CommentedMap({
column_assertions.append(pyml.CommentedMap({
'name': assertion.name,
'assert': CommentedMap(assertion.asserts),
'assert': pyml.CommentedMap(assertion.asserts),
'tags': [RECOMMENDED_ASSERTION_TAG]
}))
else:
column_assertions.append(CommentedMap({
column_assertions.append(pyml.CommentedMap({
'name': assertion.name,
'tags': [RECOMMENDED_ASSERTION_TAG]
}))
Expand Down
3 changes: 0 additions & 3 deletions piperider_cli/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import requests
from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor
from rich.progress import Progress, TextColumn, BarColumn, DownloadColumn, TimeElapsedColumn
from ruamel import yaml

from piperider_cli import __version__
from piperider_cli.configuration import Configuration
Expand All @@ -17,8 +16,6 @@
SERVICE_ENV_API_KEY = 'PIPERIDER_API_TOKEN'
SERVICE_ENV_SERVICE_KEY = 'PIPERIDER_API_SERVICE'

yml = yaml.YAML()


class PipeRiderProject(object):
def __init__(self, project: dict):
Expand Down
30 changes: 14 additions & 16 deletions piperider_cli/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@

import inquirer
from rich.console import Console
from ruamel import yaml
from ruamel.yaml import CommentedMap, CommentedSeq

from piperider_cli import raise_exception_when_directory_not_writable, round_trip_load_yaml, safe_load_yaml
from piperider_cli import raise_exception_when_directory_not_writable
from piperider_cli.yaml import round_trip_load_yaml, safe_load_yaml
from piperider_cli.cli_utils import DbtUtil
from piperider_cli.datasource import DATASOURCE_PROVIDERS, DataSource
from piperider_cli.datasource.unsupported import UnsupportedDataSource
Expand All @@ -24,6 +23,7 @@
PipeRiderInvalidDataSourceError, \
DbtProjectNotFoundError, \
DbtProfileNotFoundError
from piperider_cli import yaml as pyml

# ref: https://docs.getdbt.com/dbt-cli/configure-your-profile
DBT_PROFILES_DIR_DEFAULT = '~/.dbt/'
Expand Down Expand Up @@ -367,14 +367,12 @@ def activate_report_directory(self, report_dir: str = None) -> ReportDirectory:

@staticmethod
def update_config(key: str, update_values: Union[dict, str]):
_yml = yaml.YAML()

with open(FileSystem.PIPERIDER_CONFIG_PATH, 'r') as f:
config = _yml.load(f) or {}
config = pyml.load(f) or {}

config[key] = update_values
with open(FileSystem.PIPERIDER_CONFIG_PATH, 'w+', encoding='utf-8') as f:
_yml.dump(config, f)
pyml.dump(config, f)

@classmethod
def from_dbt_project(cls, dbt_project_path, dbt_profiles_dir=None, dbt_profile: str = None, dbt_target: str = None):
Expand Down Expand Up @@ -462,7 +460,7 @@ def instance(cls, piperider_config_path=None, dbt_profile: str = None, dbt_targe
if configuration_instance:
dbt = configuration_instance.dbt
if reload is False or dbt is None or (
dbt.get('profile') == dbt_profile and dbt.get('target') == dbt_target):
dbt.get('profile') == dbt_profile and dbt.get('target') == dbt_target):
return configuration_instance
configuration_instance = cls._load(piperider_config_path, dbt_profile=dbt_profile, dbt_target=dbt_target)
return configuration_instance
Expand Down Expand Up @@ -530,7 +528,7 @@ def _load(cls, piperider_config_path=None, dbt_profile: str = None, dbt_target:
else:
try:
with open(FileSystem.PIPERIDER_CREDENTIALS_PATH, 'r') as fd:
credentials = yaml.safe_load(fd)
credentials = pyml.safe_load(fd)
credential.update(credentials.get(ds.get('name'), {}))
except FileNotFoundError:
pass
Expand Down Expand Up @@ -623,9 +621,9 @@ def _generate_datasource_config(d):
# non-dbt project
if d.credential_source == 'config':
datasource.update(**d.credential)
return CommentedMap(datasource)
return pyml.CommentedMap(datasource)

flush_data_sources = CommentedSeq()
flush_data_sources = pyml.CommentedSeq()

for ds in self.dataSources:
exist_ds = _get_exist_datasource(ds, config)
Expand All @@ -641,7 +639,7 @@ def _generate_datasource_config(d):
config.yaml_set_comment_before_after_key('profiler', before='\n')

with open(path, 'w', encoding='utf-8') as fd:
yaml.YAML().dump(config, fd)
pyml.dump(config, fd)
pass

def dump(self, path):
Expand All @@ -660,7 +658,7 @@ def dump(self, path):
datasource = dict(name=d.name, type=d.type_name)
if d.args.get('dbt'):
# dbt project
config['dbt'] = CommentedMap(d.args.get('dbt'))
config['dbt'] = pyml.CommentedMap(d.args.get('dbt'))
else:
# non-dbt project
if d.credential_source == 'config':
Expand All @@ -671,10 +669,10 @@ def dump(self, path):
if self.cloud_config:
config['cloud_config'] = self.cloud_config

config_yaml = CommentedMap(config)
config_yaml = pyml.CommentedMap(config)

with open(path, 'w', encoding='utf-8') as fd:
yaml.YAML().dump(config_yaml, fd)
pyml.dump(config_yaml, fd)

def dump_credentials(self, path, after_init_config=False):
"""
Expand All @@ -693,7 +691,7 @@ def dump_credentials(self, path, after_init_config=False):

if creds:
with open(path, 'w', encoding='utf-8') as fd:
yaml.round_trip_dump(creds, fd)
pyml.round_trip_dump(creds, fd)

def to_sqlalchemy_config(self, datasource_name):
# TODO we will convert a data source to a sqlalchemy parameters
Expand Down
12 changes: 5 additions & 7 deletions piperider_cli/dbtutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jinja2 import UndefinedError
from rich.console import Console
from rich.table import Table
from ruamel import yaml
from piperider_cli import yaml as pyml

from piperider_cli import load_jinja_template, load_jinja_string_template
from piperider_cli.dbt.list_task import load_manifest, list_resources_unique_id_from_manifest, load_full_manifest
Expand Down Expand Up @@ -689,9 +689,8 @@ def load_dbt_project(path: str):

with open(path, 'r') as fd:
try:
yml = yaml.YAML()
yml.allow_duplicate_keys = True
dbt_project = yml.load(fd)
loader = pyml.allow_duplicate_keys_loader()
dbt_project = loader(fd)

content = {}
for key, val in dbt_project.items():
Expand All @@ -710,9 +709,8 @@ def load_dbt_profile(path):
template = load_jinja_template(path)
profile = None
try:
yml = yaml.YAML()
yml.allow_duplicate_keys = True
profile = yml.load(template.render())
loader = pyml.allow_duplicate_keys_loader()
profile = loader(template.render())
except Exception as e:
raise DbtProfileInvalidError(path, e)
if profile is None:
Expand Down
13 changes: 6 additions & 7 deletions piperider_cli/event/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import sentry_sdk
from rich.console import Console
from ruamel import yaml
from piperider_cli import yaml as pyml

from piperider_cli import PIPERIDER_USER_HOME, PIPERIDER_USER_PROFILE, is_executed_manually
from piperider_cli.event.collector import Collector
Expand All @@ -15,7 +15,6 @@
PIPERIDER_FLUSH_EVENTS_WHITELIST = ['init', 'run', 'generate-report', 'compare-reports', 'compare']

_collector = Collector()
_yml = yaml.YAML()
user_profile_lock = threading.Lock()


Expand All @@ -35,7 +34,7 @@ def load_user_profile():
user_profile = _generate_user_profile()
else:
with open(PIPERIDER_USER_PROFILE, 'r') as f:
user_profile = _yml.load(f)
user_profile = pyml.load(f)
if user_profile.get('user_id') is None:
user_profile = _generate_user_profile()

Expand All @@ -46,7 +45,7 @@ def update_user_profile(update_values):
original = load_user_profile()
original.update(update_values)
with open(PIPERIDER_USER_PROFILE, 'w+') as f:
_yml.dump(original, f)
pyml.dump(original, f)
return original


Expand All @@ -55,7 +54,7 @@ def _get_api_key():
config_file = os.path.abspath(os.path.join(os.path.dirname(data.__file__), 'CONFIG'))
try:
with open(config_file) as fh:
config = _yml.load(fh)
config = pyml.load(fh)
return config.get('event_api_key')
except Exception:
return None
Expand All @@ -71,7 +70,7 @@ def _generate_user_profile():

user_id = uuid.uuid4().hex
with open(PIPERIDER_USER_PROFILE, 'w+') as f:
_yml.dump({'user_id': user_id, 'anonymous_tracking': True}, f)
pyml.dump({'user_id': user_id, 'anonymous_tracking': True}, f)
return dict(user_id=user_id, anonymous_tracking=True)


Expand Down Expand Up @@ -106,7 +105,7 @@ def flush_events(command=None):

def log_event(prop, event_type, **kwargs):
with open(PIPERIDER_USER_PROFILE, 'r') as f:
user_profile = _yml.load(f)
user_profile = pyml.load(f)
# TODO: default anonymous_tracking to false if field is not present
tracking = user_profile.get('anonymous_tracking', False)
tracking = tracking and isinstance(tracking, bool)
Expand Down
Loading

0 comments on commit c35b2b7

Please sign in to comment.