diff --git a/nds/nds_validate.py b/nds/nds_validate.py index ea478d0..e78152d 100644 --- a/nds/nds_validate.py +++ b/nds/nds_validate.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # -# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -43,7 +43,7 @@ from pyspark.sql.types import * from pyspark.sql.functions import col -from nds_power import gen_sql_from_stream +from nds_power import gen_sql_from_stream, get_query_subset def compare_results(spark_session: SparkSession, input1: str, @@ -335,8 +335,17 @@ def update_summary(prefix, unmatch_queries): ' checks when the input data is float for some queries.') parser.add_argument('--json_summary_folder', help='path of a folder that contains json summary file for each query.') + parser.add_argument('--sub_queries', + type=lambda s: [x.strip() for x in s.split(',')], + help='comma separated list of queries to compare. If not specified, all queries ' + + 'in the stream file will be compared. e.g. "query1,query2,query3". Note, use ' + + '"_part1" and "_part2" suffix for the following query names: ' + + 'query14, query23, query24, query39. e.g. query14_part1, query39_part2') args = parser.parse_args() query_dict = gen_sql_from_stream(args.query_stream_file) + # if set sub_queries, only compare the specified queries + if args.sub_queries: + query_dict = get_query_subset(query_dict, args.sub_queries) session_builder = SparkSession.builder.appName("Validate Query Output").getOrCreate() unmatch_queries = iterate_queries(session_builder, args.input1,