Skip to content

Commit

Permalink
Let multiple CLI args update the same dict-valued option in the Rust …
Browse files Browse the repository at this point in the history
…options parser (#20735)

This works in Python, and now works in the Rust options parser too.

The issue was noticed by @huonw while reviewing #20698.
  • Loading branch information
benjyw authored Apr 2, 2024
1 parent 5374772 commit c90f798
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 56 deletions.
23 changes: 14 additions & 9 deletions src/rust/engine/options/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,22 @@ impl OptionsSource for Args {
self.get_list::<String>(id)
}

fn get_dict(&self, id: &OptionId) -> Result<Option<DictEdit>, String> {
// We iterate in reverse so that the rightmost arg wins in case an option
// is specified multiple times.
for arg in self.args.iter().rev() {
fn get_dict(&self, id: &OptionId) -> Result<Option<Vec<DictEdit>>, String> {
let mut edits = vec![];
for arg in self.args.iter() {
if arg.matches(id) {
return expand_to_dict(arg.value.clone().ok_or_else(|| {
format!("Expected list option {} to have a value.", self.display(id))
})?)
.map_err(|e| e.render(&arg.flag));
let value = arg.value.clone().ok_or_else(|| {
format!("Expected dict option {} to have a value.", self.display(id))
})?;
if let Some(es) = expand_to_dict(value).map_err(|e| e.render(&arg.flag))? {
edits.extend(es);
}
}
}
Ok(None)
if edits.is_empty() {
Ok(None)
} else {
Ok(Some(edits))
}
}
}
64 changes: 52 additions & 12 deletions src/rust/engine/options/src/args_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,22 @@ fn test_list_fromfile() {
}],
"fromfile.txt",
);
do_test(
"+[-42]",
&[ListEdit {
action: ListEditAction::Add,
items: vec![-42],
}],
"fromfile.txt",
);
do_test(
"[-42]",
&[ListEdit {
action: ListEditAction::Replace,
items: vec![-42],
}],
"fromfile.txt",
);
do_test(
"[10, 12]",
&[ListEdit {
Expand All @@ -258,20 +274,31 @@ fn test_list_fromfile() {
#[test]
fn test_dict_fromfile() {
fn do_test(content: &str, filename: &str) {
let expected = DictEdit {
action: DictEditAction::Replace,
items: hashmap! {
"FOO".to_string() => Val::Dict(hashmap! {
"BAR".to_string() => Val::Float(3.14),
"BAZ".to_string() => Val::Dict(hashmap! {
"QUX".to_string() => Val::Bool(true),
"QUUX".to_string() => Val::List(vec![ Val::Int(1), Val::Int(2)])
})
}),},
};
let expected = vec![
DictEdit {
action: DictEditAction::Replace,
items: hashmap! {
"FOO".to_string() => Val::Dict(hashmap! {
"BAR".to_string() => Val::Float(3.14),
"BAZ".to_string() => Val::Dict(hashmap! {
"QUX".to_string() => Val::Bool(true),
"QUUX".to_string() => Val::List(vec![ Val::Int(1), Val::Int(2)])
})
}),},
},
DictEdit {
action: DictEditAction::Add,
items: hashmap! {
"KEY".to_string() => Val::String("VALUE".to_string()),
},
},
];

let (_tmpdir, fromfile_path) = write_fromfile(filename, content);
let args = Args::new(vec![format!("--foo=@{}", &fromfile_path.display())]);
let args = Args::new(vec![
format!("--foo=@{}", &fromfile_path.display()),
"--foo=+{'KEY':'VALUE'}".to_string(),
]);
let actual = args.get_dict(&option_id!("foo")).unwrap().unwrap();
assert_eq!(expected, actual)
}
Expand All @@ -296,6 +323,19 @@ fn test_dict_fromfile() {
"#,
"fromfile.yaml",
);

// Test adding, rather than replacing, from a raw text fromfile.
let expected_add = vec![DictEdit {
action: DictEditAction::Add,
items: hashmap! {"FOO".to_string() => Val::Int(42)},
}];

let (_tmpdir, fromfile_path) = write_fromfile("fromfile.txt", "+{'FOO':42}");
let args = Args::new(vec![format!("--foo=@{}", &fromfile_path.display())]);
assert_eq!(
expected_add,
args.get_dict(&option_id!("foo")).unwrap().unwrap()
)
}

#[test]
Expand Down
10 changes: 5 additions & 5 deletions src/rust/engine/options/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,24 +414,24 @@ impl OptionsSource for Config {
self.get_list::<String>(id)
}

fn get_dict(&self, id: &OptionId) -> Result<Option<DictEdit>, String> {
fn get_dict(&self, id: &OptionId) -> Result<Option<Vec<DictEdit>>, String> {
if let Some(table) = self.value.get(id.scope.name()) {
let option_name = Self::option_name(id);
if let Some(value) = table.get(&option_name) {
match value {
Value::Table(sub_table) => {
if let Some(add) = sub_table.get("add") {
if sub_table.len() == 1 && add.is_table() {
return Ok(Some(DictEdit {
return Ok(Some(vec![DictEdit {
action: DictEditAction::Add,
items: toml_table_to_dict(add),
}));
}]));
}
}
return Ok(Some(DictEdit {
return Ok(Some(vec![DictEdit {
action: DictEditAction::Replace,
items: toml_table_to_dict(value),
}));
}]));
}
Value::String(v) => {
return expand_to_dict(v.to_owned())
Expand Down
8 changes: 4 additions & 4 deletions src/rust/engine/options/src/config_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ fn test_interpolate_config() {
);

assert_eq!(
DictEdit {
vec![DictEdit {
action: DictEditAction::Replace,
items: HashMap::from([
("fruit".to_string(), Val::String("strawberry".to_string())),
("spice".to_string(), Val::String("black pepper".to_string()))
])
},
}],
conf.get_dict(&option_id!(["groceries"], "inline_table"))
.unwrap()
.unwrap()
Expand Down Expand Up @@ -230,7 +230,7 @@ fn test_list_fromfile() {
#[test]
fn test_dict_fromfile() {
fn do_test(content: &str, filename: &str) {
let expected = DictEdit {
let expected = vec![DictEdit {
action: DictEditAction::Replace,
items: hashmap! {
"FOO".to_string() => Val::Dict(hashmap! {
Expand All @@ -240,7 +240,7 @@ fn test_dict_fromfile() {
"QUUX".to_string() => Val::List(vec![ Val::Int(1), Val::Int(2)])
})
}),},
};
}];

let (_tmpdir, fromfile_path) = write_fromfile(filename, content);
let conf = config(format!("[GLOBAL]\nfoo = '@{}'\n", fromfile_path.display()).as_str());
Expand Down
2 changes: 1 addition & 1 deletion src/rust/engine/options/src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl OptionsSource for Env {
self.get_list::<String>(id)
}

fn get_dict(&self, id: &OptionId) -> Result<Option<DictEdit>, String> {
fn get_dict(&self, id: &OptionId) -> Result<Option<Vec<DictEdit>>, String> {
for env_var_name in &Self::env_var_names(id) {
if let Some(value) = self.env.get(env_var_name) {
return expand_to_dict(value.to_owned()).map_err(|e| e.render(self.display(id)));
Expand Down
4 changes: 2 additions & 2 deletions src/rust/engine/options/src/env_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ fn test_list_fromfile() {
#[test]
fn test_dict_fromfile() {
fn do_test(content: &str, filename: &str) {
let expected = DictEdit {
let expected = vec![DictEdit {
action: DictEditAction::Replace,
items: hashmap! {
"FOO".to_string() => Val::Dict(hashmap! {
Expand All @@ -276,7 +276,7 @@ fn test_dict_fromfile() {
"QUUX".to_string() => Val::List(vec![ Val::Int(1), Val::Int(2)])
})
}),},
};
}];

let (_tmpdir, fromfile_path) = write_fromfile(filename, content);
let env = env([(
Expand Down
22 changes: 12 additions & 10 deletions src/rust/engine/options/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ pub(crate) trait OptionsSource {
/// Get the dict option identified by `id` from this source.
/// Errors when this source has an option value for `id` but that value is not a dict.
///
fn get_dict(&self, id: &OptionId) -> Result<Option<DictEdit>, String>;
fn get_dict(&self, id: &OptionId) -> Result<Option<Vec<DictEdit>>, String>;
}

#[derive(Clone, Debug, Ord, PartialOrd, Eq, PartialEq)]
Expand Down Expand Up @@ -227,7 +227,7 @@ pub struct ListOptionValue<T> {

#[derive(Debug)]
pub struct DictOptionValue {
pub derivation: Option<Vec<(Source, DictEdit)>>,
pub derivation: Option<Vec<(Source, Vec<DictEdit>)>>,
// The highest-priority source that provided edits for this value.
pub source: Source,
pub value: HashMap<String, Val>,
Expand Down Expand Up @@ -582,25 +582,27 @@ impl OptionParser {
if self.include_derivation {
let mut derivations = vec![(
Source::Default,
DictEdit {
vec![DictEdit {
action: DictEditAction::Replace,
items: dict.clone(),
},
}],
)];
for (source_type, source) in self.sources.iter() {
if let Some(dict_edit) = source.get_dict(id)? {
derivations.push((source_type.clone(), dict_edit));
if let Some(dict_edits) = source.get_dict(id)? {
derivations.push((source_type.clone(), dict_edits));
}
}
derivation = Some(derivations);
}
let mut highest_priority_source = Source::Default;
for (source_type, source) in self.sources.iter() {
if let Some(dict_edit) = source.get_dict(id)? {
if let Some(dict_edits) = source.get_dict(id)? {
highest_priority_source = source_type.clone();
match dict_edit.action {
DictEditAction::Replace => dict = dict_edit.items,
DictEditAction::Add => dict.extend(dict_edit.items),
for dict_edit in dict_edits {
match dict_edit.action {
DictEditAction::Replace => dict = dict_edit.items,
DictEditAction::Add => dict.extend(dict_edit.items),
}
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/rust/engine/options/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,16 +435,16 @@ pub(crate) fn expand_to_list<T: Parseable>(
}
}

pub(crate) fn expand_to_dict(value: String) -> Result<Option<DictEdit>, ParseError> {
pub(crate) fn expand_to_dict(value: String) -> Result<Option<Vec<DictEdit>>, ParseError> {
let (path_opt, value_opt) = maybe_expand(value)?;
if let Some(value) = value_opt {
if let Some(items) = try_deserialize(&value, path_opt)? {
Ok(Some(DictEdit {
Ok(Some(vec![DictEdit {
action: DictEditAction::Replace,
items,
}))
}]))
} else {
parse_dict(&value).map(Some)
parse_dict(&value).map(|x| Some(vec![x]))
}
} else {
Ok(None)
Expand Down
7 changes: 7 additions & 0 deletions src/rust/engine/options/src/parse_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,13 @@ fn test_expand_fromfile_to_dict() {
"{prefix}{}",
_tmpdir.path().join(filename).display()
))
.map(|x| {
if let Some(des) = x {
des.into_iter().next()
} else {
None
}
})
}

fn do_test(content: &str, expected: &DictEdit, filename: &str) {
Expand Down
38 changes: 29 additions & 9 deletions src/rust/engine/options/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ fn test_parse_dict_options() {

fn check(
expected: HashMap<&str, Val>,
expected_derivation: Vec<(Source, DictEdit)>,
expected_derivation: Vec<(Source, Vec<DictEdit>)>,
args: Vec<&'static str>,
env: Vec<(&'static str, &'static str)>,
config: &'static str,
Expand All @@ -359,18 +359,31 @@ fn test_parse_dict_options() {
});
}

fn replace(items: HashMap<&str, Val>) -> DictEdit {
DictEdit {
fn replace(items: HashMap<&str, Val>) -> Vec<DictEdit> {
vec![DictEdit {
action: DictEditAction::Replace,
items: with_owned_keys(items),
}
}]
}

fn add(items: HashMap<&str, Val>) -> DictEdit {
DictEdit {
fn add(items: HashMap<&str, Val>) -> Vec<DictEdit> {
vec![DictEdit {
action: DictEditAction::Add,
items: with_owned_keys(items),
}
}]
}

fn add2(items0: HashMap<&str, Val>, items1: HashMap<&str, Val>) -> Vec<DictEdit> {
vec![
DictEdit {
action: DictEditAction::Add,
items: with_owned_keys(items0),
},
DictEdit {
action: DictEditAction::Add,
items: with_owned_keys(items1),
},
]
}

let default_derivation = (
Expand All @@ -383,6 +396,7 @@ fn test_parse_dict_options() {
"key1" => Val::Int(1),
"key2" => Val::String("val2".to_string()),
"key3" => Val::Int(3),
"key3a" => Val::String("3a".to_string()),
"key4" => Val::Float(4.0),
"key5" => Val::Bool(true),
"key6" => Val::Int(6),
Expand All @@ -392,9 +406,15 @@ fn test_parse_dict_options() {
(config_source(), add(hashmap! {"key5" => Val::Bool(true)})),
(extra_config_source(), add(hashmap! {"key6" => Val::Int(6)})),
(Source::Env, add(hashmap! {"key4" => Val::Float(4.0)})),
(Source::Flag, add(hashmap! {"key3" => Val::Int(3)})),
(
Source::Flag,
add2(
hashmap! {"key3" => Val::Int(3)},
hashmap! {"key3a" => Val::String("3a".to_string())},
),
),
],
vec!["--scope-foo=+{'key3': 3}"],
vec!["--scope-foo=+{'key3': 3}", "--scope-foo=+{'key3a': '3a'}"],
vec![("PANTS_SCOPE_FOO", "+{'key4': 4.0}")],
"[scope]\nfoo = \"+{ 'key5': true }\"",
"[scope]\nfoo = \"+{ 'key6': 6 }\"",
Expand Down

0 comments on commit c90f798

Please sign in to comment.