-
-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update to Splink v4.0 #3834
Update to Splink v4.0 #3834
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,10 +5,9 @@ | |
model. | ||
""" | ||
|
||
import splink.duckdb.comparison_level_library as cll | ||
import splink.duckdb.comparison_library as cl | ||
import splink.duckdb.comparison_template_library as ctl | ||
from splink.duckdb.blocking_rule_library import block_on | ||
import splink.comparison_level_library as cll | ||
import splink.comparison_library as cl | ||
from splink import block_on | ||
|
||
blocking_rule_1 = "l.report_year = r.report_year and substr(l.plant_name_mphone,1,3) = substr(r.plant_name_mphone,1,3)" | ||
blocking_rule_2 = "l.report_year = r.report_year and substr(l.utility_name_mphone,1,2) = substr(r.utility_name_mphone,1,2) and substr(l.plant_name_mphone,1,2) = substr(r.plant_name_mphone,1,2)" | ||
|
@@ -19,7 +18,7 @@ | |
blocking_rule_7 = "l.report_year = r.report_year and l.capacity_mw = r.capacity_mw and substr(l.plant_name_mphone,1,2) = substr(r.plant_name_mphone,1,2)" | ||
blocking_rule_8 = "l.report_year = r.report_year and l.installation_year = r.installation_year and substr(l.plant_name_mphone,1,2) = substr(r.plant_name_mphone,1,2)" | ||
blocking_rule_9 = "l.report_year = r.report_year and l.construction_year = r.construction_year and substr(l.plant_name_mphone,1,2) = substr(r.plant_name_mphone,1,2)" | ||
blocking_rule_10 = block_on(["report_year", "net_generation_mwh"]) | ||
blocking_rule_10 = block_on("report_year", "net_generation_mwh") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change in the function signature. Takes an arbitrary number of positional args now. |
||
BLOCKING_RULES = [ | ||
blocking_rule_1, | ||
blocking_rule_2, | ||
|
@@ -33,59 +32,57 @@ | |
blocking_rule_10, | ||
] | ||
|
||
plant_name_comparison = ctl.name_comparison( | ||
plant_name_comparison = cl.NameComparison( | ||
Comment on lines
-36
to
+35
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A lot of things that were previously functions have been replaced with classes. |
||
"plant_name", | ||
damerau_levenshtein_thresholds=[], | ||
jaro_winkler_thresholds=[0.9, 0.8, 0.7], | ||
) | ||
utility_name_comparison = ctl.name_comparison( | ||
utility_name_comparison = cl.NameComparison( | ||
"utility_name", | ||
damerau_levenshtein_thresholds=[], | ||
jaro_winkler_thresholds=[0.9, 0.8, 0.7], | ||
term_frequency_adjustments=True, | ||
) | ||
fuel_type_code_pudl_comparison = cl.exact_match( | ||
"fuel_type_code_pudl", term_frequency_adjustments=True | ||
) | ||
utility_name_comparison.configure(term_frequency_adjustments=True) | ||
fuel_type_code_pudl_comparison = cl.ExactMatch("fuel_type_code_pudl") | ||
fuel_type_code_pudl_comparison.configure(term_frequency_adjustments=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like a funny change to me, but the |
||
|
||
capacity_comparison = { | ||
"output_column_name": "capacity_mw", | ||
"comparison_levels": [ | ||
cll.null_level("capacity_mw"), | ||
cll.percentage_difference_level( | ||
cll.NullLevel("capacity_mw"), | ||
cll.PercentageDifferenceLevel( | ||
"capacity_mw", | ||
0.0 + 1e-4, | ||
), | ||
cll.percentage_difference_level("capacity_mw", 0.05), | ||
cll.percentage_difference_level("capacity_mw", 0.1), | ||
cll.percentage_difference_level("capacity_mw", 0.2), | ||
cll.else_level(), | ||
cll.PercentageDifferenceLevel("capacity_mw", 0.05), | ||
cll.PercentageDifferenceLevel("capacity_mw", 0.1), | ||
cll.PercentageDifferenceLevel("capacity_mw", 0.2), | ||
cll.ElseLevel(), | ||
], | ||
"comparison_description": "0% different vs. 5% different vs. 10% different vs. 20% different vs. anything else", | ||
} | ||
|
||
net_gen_comparison = { | ||
"output_column_name": "net_generation_mwh", | ||
"comparison_levels": [ | ||
cll.null_level("net_generation_mwh"), | ||
cll.percentage_difference_level( | ||
cll.NullLevel("net_generation_mwh"), | ||
cll.PercentageDifferenceLevel( | ||
"net_generation_mwh", 0.0 + 1e-4 | ||
), # could add an exact match level too | ||
cll.percentage_difference_level("net_generation_mwh", 0.01), | ||
cll.percentage_difference_level("net_generation_mwh", 0.1), | ||
cll.percentage_difference_level("net_generation_mwh", 0.2), | ||
cll.else_level(), | ||
cll.PercentageDifferenceLevel("net_generation_mwh", 0.01), | ||
cll.PercentageDifferenceLevel("net_generation_mwh", 0.1), | ||
cll.PercentageDifferenceLevel("net_generation_mwh", 0.2), | ||
cll.ElseLevel(), | ||
], | ||
"comparison_description": "0% different vs. 1% different vs. 10% different vs. 20% different vs. anything else", | ||
} | ||
|
||
|
||
def get_date_comparison(column_name): | ||
"""Get date comparison template for column.""" | ||
return ctl.date_comparison( | ||
return cl.DateOfBirthComparison( | ||
column_name, | ||
damerau_levenshtein_thresholds=[], | ||
datediff_thresholds=[1, 2], | ||
datediff_metrics=["year", "year"], | ||
input_is_string=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This |
||
datetime_thresholds=[1, 2], | ||
datetime_metrics=["year", "year"], | ||
Comment on lines
-84
to
+85
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Weirdly they went to calling this |
||
) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
that links together several thousand EIA and FERC plant records. This trained model is | ||
used to predict matches on the full dataset (see :func:`get_model_predictions`) using a | ||
threshold match probability to predict if records are a match or not. | ||
|
||
The model can return multiple EIA match options for each FERC1 record, so we rank the | ||
matches and choose the one with the highest score. Any matches identified by the model | ||
which are in conflict with our training data are overwritten with the manually | ||
|
@@ -36,7 +37,7 @@ | |
import numpy as np | ||
import pandas as pd | ||
from dagster import Out, graph, op | ||
from splink.duckdb.linker import DuckDBLinker | ||
from splink import DuckDBAPI, Linker, SettingsCreator | ||
|
||
import pudl | ||
from pudl.analysis.ml_tools import experiment_tracking, models | ||
|
@@ -211,29 +212,30 @@ def get_training_data_df(inputs): | |
@op | ||
def get_model_predictions(eia_df, ferc_df, train_df, experiment_tracker): | ||
"""Train splink model and output predicted matches.""" | ||
settings_dict = { | ||
"link_type": "link_only", | ||
"unique_id_column_name": "record_id", | ||
"additional_columns_to_retain": ["plant_id_pudl", "utility_id_pudl"], | ||
"comparisons": COMPARISONS, | ||
"blocking_rules_to_generate_predictions": BLOCKING_RULES, | ||
"retain_matching_columns": True, | ||
"retain_intermediate_calculation_columns": True, | ||
"probability_two_random_records_match": 1 / len(eia_df), | ||
} | ||
linker = DuckDBLinker( | ||
settings = SettingsCreator( | ||
link_type="link_only", | ||
unique_id_column_name="record_id", | ||
additional_columns_to_retain=["plant_id_pudl", "utility_id_pudl"], | ||
comparisons=COMPARISONS, | ||
blocking_rules_to_generate_predictions=BLOCKING_RULES, | ||
retain_matching_columns=True, | ||
retain_intermediate_calculation_columns=True, | ||
probability_two_random_records_match=(1.0 / len(eia_df)), | ||
) | ||
linker = Linker( | ||
[eia_df, ferc_df], | ||
settings=settings, | ||
input_table_aliases=["eia_df", "ferc_df"], | ||
settings_dict=settings_dict, | ||
db_api=DuckDBAPI(), | ||
Comment on lines
+225
to
+229
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
) | ||
linker.register_table(train_df, "training_labels", overwrite=True) | ||
linker.estimate_u_using_random_sampling(max_pairs=1e7) | ||
linker.estimate_m_from_pairwise_labels("training_labels") | ||
linker.table_management.register_table(train_df, "training_labels", overwrite=True) | ||
linker.training.estimate_u_using_random_sampling(max_pairs=1e7) | ||
linker.training.estimate_m_from_pairwise_labels("training_labels") | ||
threshold_prob = 0.9 | ||
experiment_tracker.execute_logging( | ||
lambda: mlflow.log_params({"threshold match probability": threshold_prob}) | ||
) | ||
preds_df = linker.predict(threshold_match_probability=threshold_prob) | ||
preds_df = linker.inference.predict(threshold_match_probability=threshold_prob) | ||
return preds_df.as_pandas_dataframe() | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a big namespace reorganization in v4. Same modules and functions, but imported from different places.