Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to Splink v4.0 #3834

Merged
merged 3 commits into from
Sep 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions environments/conda-linux-64.lock.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

60 changes: 30 additions & 30 deletions environments/conda-lock.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions environments/conda-osx-64.lock.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions environments/conda-osx-arm64.lock.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ dependencies = [
"scikit-learn>=1.5",
"scipy>=1.14",
"Shapely>=2",
"splink>=3.9.14,<4", # Need to update PUDL to use new Splink v4 API. See issue #3735
"splink>=4",
"sphinx>=7.4.4",
"sphinx-autoapi>=3",
"sphinx-issues>=1.2",
Expand Down
55 changes: 26 additions & 29 deletions src/pudl/analysis/record_linkage/eia_ferc1_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
model.
"""

import splink.duckdb.comparison_level_library as cll
import splink.duckdb.comparison_library as cl
import splink.duckdb.comparison_template_library as ctl
from splink.duckdb.blocking_rule_library import block_on
import splink.comparison_level_library as cll
import splink.comparison_library as cl
from splink import block_on
Comment on lines -8 to +10
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a big namespace reorganization in v4. Same modules and functions, but imported from different places.


blocking_rule_1 = "l.report_year = r.report_year and substr(l.plant_name_mphone,1,3) = substr(r.plant_name_mphone,1,3)"
blocking_rule_2 = "l.report_year = r.report_year and substr(l.utility_name_mphone,1,2) = substr(r.utility_name_mphone,1,2) and substr(l.plant_name_mphone,1,2) = substr(r.plant_name_mphone,1,2)"
Expand All @@ -19,7 +18,7 @@
blocking_rule_7 = "l.report_year = r.report_year and l.capacity_mw = r.capacity_mw and substr(l.plant_name_mphone,1,2) = substr(r.plant_name_mphone,1,2)"
blocking_rule_8 = "l.report_year = r.report_year and l.installation_year = r.installation_year and substr(l.plant_name_mphone,1,2) = substr(r.plant_name_mphone,1,2)"
blocking_rule_9 = "l.report_year = r.report_year and l.construction_year = r.construction_year and substr(l.plant_name_mphone,1,2) = substr(r.plant_name_mphone,1,2)"
blocking_rule_10 = block_on(["report_year", "net_generation_mwh"])
blocking_rule_10 = block_on("report_year", "net_generation_mwh")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change in the function signature. Takes an arbitrary number of positional args now.

BLOCKING_RULES = [
blocking_rule_1,
blocking_rule_2,
Expand All @@ -33,59 +32,57 @@
blocking_rule_10,
]

plant_name_comparison = ctl.name_comparison(
plant_name_comparison = cl.NameComparison(
Comment on lines -36 to +35
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of things that were previously functions have been replaced with classes.

"plant_name",
damerau_levenshtein_thresholds=[],
jaro_winkler_thresholds=[0.9, 0.8, 0.7],
)
utility_name_comparison = ctl.name_comparison(
utility_name_comparison = cl.NameComparison(
"utility_name",
damerau_levenshtein_thresholds=[],
jaro_winkler_thresholds=[0.9, 0.8, 0.7],
term_frequency_adjustments=True,
)
fuel_type_code_pudl_comparison = cl.exact_match(
"fuel_type_code_pudl", term_frequency_adjustments=True
)
utility_name_comparison.configure(term_frequency_adjustments=True)
fuel_type_code_pudl_comparison = cl.ExactMatch("fuel_type_code_pudl")
fuel_type_code_pudl_comparison.configure(term_frequency_adjustments=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a funny change to me, but the .configure() method is used to set attributes which are common to all the different kinds of match classes now.


capacity_comparison = {
"output_column_name": "capacity_mw",
"comparison_levels": [
cll.null_level("capacity_mw"),
cll.percentage_difference_level(
cll.NullLevel("capacity_mw"),
cll.PercentageDifferenceLevel(
"capacity_mw",
0.0 + 1e-4,
),
cll.percentage_difference_level("capacity_mw", 0.05),
cll.percentage_difference_level("capacity_mw", 0.1),
cll.percentage_difference_level("capacity_mw", 0.2),
cll.else_level(),
cll.PercentageDifferenceLevel("capacity_mw", 0.05),
cll.PercentageDifferenceLevel("capacity_mw", 0.1),
cll.PercentageDifferenceLevel("capacity_mw", 0.2),
cll.ElseLevel(),
],
"comparison_description": "0% different vs. 5% different vs. 10% different vs. 20% different vs. anything else",
}

net_gen_comparison = {
"output_column_name": "net_generation_mwh",
"comparison_levels": [
cll.null_level("net_generation_mwh"),
cll.percentage_difference_level(
cll.NullLevel("net_generation_mwh"),
cll.PercentageDifferenceLevel(
"net_generation_mwh", 0.0 + 1e-4
), # could add an exact match level too
cll.percentage_difference_level("net_generation_mwh", 0.01),
cll.percentage_difference_level("net_generation_mwh", 0.1),
cll.percentage_difference_level("net_generation_mwh", 0.2),
cll.else_level(),
cll.PercentageDifferenceLevel("net_generation_mwh", 0.01),
cll.PercentageDifferenceLevel("net_generation_mwh", 0.1),
cll.PercentageDifferenceLevel("net_generation_mwh", 0.2),
cll.ElseLevel(),
],
"comparison_description": "0% different vs. 1% different vs. 10% different vs. 20% different vs. anything else",
}


def get_date_comparison(column_name):
"""Get date comparison template for column."""
return ctl.date_comparison(
return cl.DateOfBirthComparison(
column_name,
damerau_levenshtein_thresholds=[],
datediff_thresholds=[1, 2],
datediff_metrics=["year", "year"],
input_is_string=False,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This input_is_string=False parameter is the only thing I was worried about being more than a rename -- it means that the date field that's being matched is a Date or DateTime type, which I wasn't sure about. But the ETL worked, so I'm assuming it was correct.

datetime_thresholds=[1, 2],
datetime_metrics=["year", "year"],
Comment on lines -84 to +85
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weirdly they went to calling this DateOfBirthComparison as if that's the only kind of date you might want to compare.

)


Expand Down
36 changes: 19 additions & 17 deletions src/pudl/analysis/record_linkage/eia_ferc1_record_linkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
that links together several thousand EIA and FERC plant records. This trained model is
used to predict matches on the full dataset (see :func:`get_model_predictions`) using a
threshold match probability to predict if records are a match or not.

The model can return multiple EIA match options for each FERC1 record, so we rank the
matches and choose the one with the highest score. Any matches identified by the model
which are in conflict with our training data are overwritten with the manually
Expand All @@ -36,7 +37,7 @@
import numpy as np
import pandas as pd
from dagster import Out, graph, op
from splink.duckdb.linker import DuckDBLinker
from splink import DuckDBAPI, Linker, SettingsCreator

import pudl
from pudl.analysis.ml_tools import experiment_tracking, models
Expand Down Expand Up @@ -211,29 +212,30 @@ def get_training_data_df(inputs):
@op
def get_model_predictions(eia_df, ferc_df, train_df, experiment_tracker):
"""Train splink model and output predicted matches."""
settings_dict = {
"link_type": "link_only",
"unique_id_column_name": "record_id",
"additional_columns_to_retain": ["plant_id_pudl", "utility_id_pudl"],
"comparisons": COMPARISONS,
"blocking_rules_to_generate_predictions": BLOCKING_RULES,
"retain_matching_columns": True,
"retain_intermediate_calculation_columns": True,
"probability_two_random_records_match": 1 / len(eia_df),
}
linker = DuckDBLinker(
settings = SettingsCreator(
link_type="link_only",
unique_id_column_name="record_id",
additional_columns_to_retain=["plant_id_pudl", "utility_id_pudl"],
comparisons=COMPARISONS,
blocking_rules_to_generate_predictions=BLOCKING_RULES,
retain_matching_columns=True,
retain_intermediate_calculation_columns=True,
probability_two_random_records_match=(1.0 / len(eia_df)),
)
linker = Linker(
[eia_df, ferc_df],
settings=settings,
input_table_aliases=["eia_df", "ferc_df"],
settings_dict=settings_dict,
db_api=DuckDBAPI(),
Comment on lines +225 to +229
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Linker class is agnostic as to what database backend it uses -- you pass in a class that is used to interface with it.

)
linker.register_table(train_df, "training_labels", overwrite=True)
linker.estimate_u_using_random_sampling(max_pairs=1e7)
linker.estimate_m_from_pairwise_labels("training_labels")
linker.table_management.register_table(train_df, "training_labels", overwrite=True)
linker.training.estimate_u_using_random_sampling(max_pairs=1e7)
linker.training.estimate_m_from_pairwise_labels("training_labels")
threshold_prob = 0.9
experiment_tracker.execute_logging(
lambda: mlflow.log_params({"threshold match probability": threshold_prob})
)
preds_df = linker.predict(threshold_match_probability=threshold_prob)
preds_df = linker.inference.predict(threshold_match_probability=threshold_prob)
return preds_df.as_pandas_dataframe()


Expand Down
Loading