Skip to content

Commit

Permalink
fixup: make timing work for iterables
Browse files Browse the repository at this point in the history
- don't time quick functions that get called loads of times
- time the traversal of an iterable, not the time to a return a value
  • Loading branch information
MatMoore committed Aug 7, 2024
1 parent acb2d7b commit eb2645d
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 12 deletions.
5 changes: 2 additions & 3 deletions ingestion/create_cadet_databases_source/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_cadet_manifest,
validate_fqn,
)
from ingestion.utils import report_time
from ingestion.utils import report_time, report_time_of_iterable

logging.basicConfig(level=logging.DEBUG)

Expand All @@ -39,7 +39,7 @@ def create(cls, config_dict, ctx):
config = CreateCadetDatabasesConfig.parse_obj(config_dict)
return cls(config, ctx)

@report_time
@report_time_of_iterable
def get_workunits(self) -> Iterable[MetadataWorkUnit]:
manifest = get_cadet_manifest(self.source_config.manifest_s3_uri)

Expand Down Expand Up @@ -86,7 +86,6 @@ def get_workunits(self) -> Iterable[MetadataWorkUnit]:
extra_properties=None,
)

@report_time
def _get_domains(self, manifest) -> set[str]:
"""Only models are arranged by domain in CaDeT"""
return set(
Expand Down
2 changes: 2 additions & 0 deletions ingestion/ingestion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph

from ingestion.config import ENV, INSTANCE, PLATFORM
from ingestion.utils import report_time

logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)


@report_time
def get_cadet_manifest(manifest_s3_uri: str) -> Dict:
try:
s3 = boto3.client("s3")
Expand Down
6 changes: 2 additions & 4 deletions ingestion/justice_data_source/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
)
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
from datahub.metadata.schema_classes import (
BrowsePathsV2Class,
ChangeTypeClass,
ChartInfoClass,
CorpGroupInfoClass,
Expand All @@ -39,7 +38,7 @@
)

from ingestion.ingestion_utils import list_datahub_domains
from ingestion.utils import report_time
from ingestion.utils import report_time_of_iterable

from .api_client import JusticeDataAPIClient
from .config import JusticeDataAPIConfig
Expand Down Expand Up @@ -76,7 +75,7 @@ def create(cls, config_dict, ctx):
config = JusticeDataAPIConfig.parse_obj(config_dict)
return cls(ctx, config)

@report_time
@report_time_of_iterable
def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
all_chart_data = self.client.list_all(self.config.exclude_id_list)

Expand Down Expand Up @@ -140,7 +139,6 @@ def _make_dashboard(self, chart_urns):
dashboard_mce = MetadataChangeEvent(proposedSnapshot=dashboard_snapshot)
return dashboard_mce

@report_time
def _make_chart(self, chart_data) -> MetadataChangeEvent:
chart_urn = builder.make_chart_urn(self.platform_name, chart_data["id"])
chart_snapshot = ChartSnapshot(
Expand Down
1 change: 1 addition & 0 deletions ingestion/transformers/assign_cadet_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def handle_end_of_stream(

return mcps

@report_time
def _get_table_database_mappings(self, manifest) -> Dict[str, str]:
mappings = {}
for node in manifest["nodes"]:
Expand Down
34 changes: 29 additions & 5 deletions ingestion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def report_time(func):
def wrapped_func(*args, **kwargs):
arg_types = [type(arg) for arg in args]
stopwatch = Stopwatch(
function={func.__name__}, arg_types=arg_types, kwargs=kwargs
function=func.__name__, arg_types=arg_types, kwargs=kwargs
)

stopwatch.start()
Expand All @@ -28,6 +28,30 @@ def wrapped_func(*args, **kwargs):
return wrapped_func


def report_time_of_iterable(func):
"""
Decorator to report the total time of an iterable
"""

def wrapped_func(*args, **kwargs):
arg_types = [type(arg) for arg in args]
stopwatch = Stopwatch(
function=func.__name__, arg_types=arg_types, kwargs=kwargs
)

stopwatch.start()

r = func(*args, **kwargs)
yield from r

stopwatch.stop()
stopwatch.report()

return r

return wrapped_func


class Stopwatch:
"""
Wrapper around the time module for timing code execution
Expand All @@ -38,7 +62,8 @@ def __init__(self, **meta):
self.start_time = None
self.stop_time = None
self.elapsed = 0
self.meta_string = ", ".join(f"{k} = {v}" for k, v in meta.items())
joined_meta = ", ".join(f"{k}={v}" for k, v in meta.items())
self.prefix = f"{joined_meta}, " if joined_meta else ""

def start(self):
self.time = time.time()
Expand All @@ -54,10 +79,9 @@ def stop(self):
self.stop_time = now
self.elapsed += elapsed

def report(self, **meta):
prefix = f"{self.meta_string}, " if self.meta_string else ""
def report(self):
logging.info(
f"{prefix}"
f"{self.prefix}"
f"start_time={time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(self.start_time))}, "
f"end_time={time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(self.stop_time))}, "
f"elapsed_time={str(timedelta(seconds=self.elapsed))}"
Expand Down
64 changes: 64 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging
import re

from ingestion.utils import Stopwatch, report_time, report_time_of_iterable

REPORT_REGEX = re.compile(
r"start_time=\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}, "
r"end_time=\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}, "
r"elapsed_time=0:00:\d\d",
)


def test_stopwatch_generates_a_report(caplog):
caplog.set_level(logging.INFO)
s = Stopwatch()
s.start()
s.stop()
s.report()

messages = [r.message for r in caplog.records]
assert len(messages) == 1
assert re.match(
REPORT_REGEX,
messages[0],
)


def test_report_time_generates_a_report(caplog):
caplog.set_level(logging.INFO)

@report_time
def foo():
return 1 + 1

assert foo() == 2

messages = [r.message for r in caplog.records]
assert len(messages) == 1
assert re.search(
REPORT_REGEX,
messages[0],
)
assert "function=foo, " in messages[0]


def test_report_time_of_iterator(caplog):
caplog.set_level(logging.INFO)

@report_time_of_iterable
def foo():
yield 1
yield 2

generator = foo()
values = list(generator)
assert values == [1, 2]

messages = [r.message for r in caplog.records]
assert len(messages) == 1
assert re.search(
REPORT_REGEX,
messages[0],
)
assert "function=foo, " in messages[0]

0 comments on commit eb2645d

Please sign in to comment.