From 4d2ee5bf986c62e47ac5c70a9ce145e4ef8c0138 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Tue, 12 Dec 2023 20:07:33 -0500 Subject: [PATCH] Add named expression handling to `find_assigned_value` (#9109) --- .../test/fixtures/flake8_trio/TRIO115.py | 13 +- ...lake8_trio__tests__TRIO115_TRIO115.py.snap | 141 +++++++++--------- .../src/analyze/typing.rs | 82 ++++++---- crates/ruff_python_semantic/src/model.rs | 17 +++ 4 files changed, 147 insertions(+), 106 deletions(-) diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_trio/TRIO115.py b/crates/ruff_linter/resources/test/fixtures/flake8_trio/TRIO115.py index d7466beb0f5d3..764b5c1d6e9f5 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_trio/TRIO115.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_trio/TRIO115.py @@ -29,8 +29,8 @@ async def func(): trio.sleep(e) # TRIO115 m_x, m_y = 0 - trio.sleep(m_y) # TRIO115 - trio.sleep(m_x) # TRIO115 + trio.sleep(m_y) # OK + trio.sleep(m_x) # OK m_a = m_b = 0 trio.sleep(m_a) # TRIO115 @@ -43,6 +43,8 @@ async def func(): def func(): + import trio + trio.run(trio.sleep(0)) # TRIO115 @@ -55,3 +57,10 @@ def func(): async def func(): await sleep(seconds=0) # TRIO115 + + +def func(): + import trio + + if (walrus := 0) == 0: + trio.sleep(walrus) # TRIO115 diff --git a/crates/ruff_linter/src/rules/flake8_trio/snapshots/ruff_linter__rules__flake8_trio__tests__TRIO115_TRIO115.py.snap b/crates/ruff_linter/src/rules/flake8_trio/snapshots/ruff_linter__rules__flake8_trio__tests__TRIO115_TRIO115.py.snap index 1ade9f757bbaa..7710be928504a 100644 --- a/crates/ruff_linter/src/rules/flake8_trio/snapshots/ruff_linter__rules__flake8_trio__tests__TRIO115_TRIO115.py.snap +++ b/crates/ruff_linter/src/rules/flake8_trio/snapshots/ruff_linter__rules__flake8_trio__tests__TRIO115_TRIO115.py.snap @@ -143,47 +143,7 @@ TRIO115.py:29:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.s 29 |+ trio.lowlevel.checkpoint() # TRIO115 30 30 | 31 31 | m_x, m_y = 0 -32 32 | trio.sleep(m_y) # TRIO115 - -TRIO115.py:32:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` - | -31 | m_x, m_y = 0 -32 | trio.sleep(m_y) # TRIO115 - | ^^^^^^^^^^^^^^^ TRIO115 -33 | trio.sleep(m_x) # TRIO115 - | - = help: Replace with `trio.lowlevel.checkpoint()` - -ℹ Safe fix -29 29 | trio.sleep(e) # TRIO115 -30 30 | -31 31 | m_x, m_y = 0 -32 |- trio.sleep(m_y) # TRIO115 - 32 |+ trio.lowlevel.checkpoint() # TRIO115 -33 33 | trio.sleep(m_x) # TRIO115 -34 34 | -35 35 | m_a = m_b = 0 - -TRIO115.py:33:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` - | -31 | m_x, m_y = 0 -32 | trio.sleep(m_y) # TRIO115 -33 | trio.sleep(m_x) # TRIO115 - | ^^^^^^^^^^^^^^^ TRIO115 -34 | -35 | m_a = m_b = 0 - | - = help: Replace with `trio.lowlevel.checkpoint()` - -ℹ Safe fix -30 30 | -31 31 | m_x, m_y = 0 -32 32 | trio.sleep(m_y) # TRIO115 -33 |- trio.sleep(m_x) # TRIO115 - 33 |+ trio.lowlevel.checkpoint() # TRIO115 -34 34 | -35 35 | m_a = m_b = 0 -36 36 | trio.sleep(m_a) # TRIO115 +32 32 | trio.sleep(m_y) # OK TRIO115.py:36:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` | @@ -195,7 +155,7 @@ TRIO115.py:36:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.s = help: Replace with `trio.lowlevel.checkpoint()` ℹ Safe fix -33 33 | trio.sleep(m_x) # TRIO115 +33 33 | trio.sleep(m_x) # OK 34 34 | 35 35 | m_a = m_b = 0 36 |- trio.sleep(m_a) # TRIO115 @@ -264,51 +224,88 @@ TRIO115.py:42:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.s 44 44 | 45 45 | def func(): -TRIO115.py:53:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` +TRIO115.py:48:14: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` | -52 | def func(): -53 | sleep(0) # TRIO115 - | ^^^^^^^^ TRIO115 +46 | import trio +47 | +48 | trio.run(trio.sleep(0)) # TRIO115 + | ^^^^^^^^^^^^^ TRIO115 | = help: Replace with `trio.lowlevel.checkpoint()` ℹ Safe fix -46 46 | trio.run(trio.sleep(0)) # TRIO115 +45 45 | def func(): +46 46 | import trio 47 47 | -48 48 | -49 |-from trio import Event, sleep - 49 |+from trio import Event, sleep, lowlevel +48 |- trio.run(trio.sleep(0)) # TRIO115 + 48 |+ trio.run(trio.lowlevel.checkpoint()) # TRIO115 +49 49 | +50 50 | +51 51 | from trio import Event, sleep + +TRIO115.py:55:5: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` + | +54 | def func(): +55 | sleep(0) # TRIO115 + | ^^^^^^^^ TRIO115 + | + = help: Replace with `trio.lowlevel.checkpoint()` + +ℹ Safe fix +48 48 | trio.run(trio.sleep(0)) # TRIO115 +49 49 | 50 50 | -51 51 | -52 52 | def func(): -53 |- sleep(0) # TRIO115 - 53 |+ lowlevel.checkpoint() # TRIO115 -54 54 | -55 55 | -56 56 | async def func(): +51 |-from trio import Event, sleep + 51 |+from trio import Event, sleep, lowlevel +52 52 | +53 53 | +54 54 | def func(): +55 |- sleep(0) # TRIO115 + 55 |+ lowlevel.checkpoint() # TRIO115 +56 56 | +57 57 | +58 58 | async def func(): -TRIO115.py:57:11: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` +TRIO115.py:59:11: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` | -56 | async def func(): -57 | await sleep(seconds=0) # TRIO115 +58 | async def func(): +59 | await sleep(seconds=0) # TRIO115 | ^^^^^^^^^^^^^^^^ TRIO115 | = help: Replace with `trio.lowlevel.checkpoint()` ℹ Safe fix -46 46 | trio.run(trio.sleep(0)) # TRIO115 -47 47 | -48 48 | -49 |-from trio import Event, sleep - 49 |+from trio import Event, sleep, lowlevel +48 48 | trio.run(trio.sleep(0)) # TRIO115 +49 49 | 50 50 | -51 51 | -52 52 | def func(): +51 |-from trio import Event, sleep + 51 |+from trio import Event, sleep, lowlevel +52 52 | +53 53 | +54 54 | def func(): -------------------------------------------------------------------------------- -54 54 | -55 55 | -56 56 | async def func(): -57 |- await sleep(seconds=0) # TRIO115 - 57 |+ await lowlevel.checkpoint() # TRIO115 +56 56 | +57 57 | +58 58 | async def func(): +59 |- await sleep(seconds=0) # TRIO115 + 59 |+ await lowlevel.checkpoint() # TRIO115 +60 60 | +61 61 | +62 62 | def func(): + +TRIO115.py:66:9: TRIO115 [*] Use `trio.lowlevel.checkpoint()` instead of `trio.sleep(0)` + | +65 | if (walrus := 0) == 0: +66 | trio.sleep(walrus) # TRIO115 + | ^^^^^^^^^^^^^^^^^^ TRIO115 + | + = help: Replace with `trio.lowlevel.checkpoint()` + +ℹ Safe fix +63 63 | import trio +64 64 | +65 65 | if (walrus := 0) == 0: +66 |- trio.sleep(walrus) # TRIO115 + 66 |+ trio.lowlevel.checkpoint() # TRIO115 diff --git a/crates/ruff_python_semantic/src/analyze/typing.rs b/crates/ruff_python_semantic/src/analyze/typing.rs index 4ff2e27e3221c..2dd7f1003e398 100644 --- a/crates/ruff_python_semantic/src/analyze/typing.rs +++ b/crates/ruff_python_semantic/src/analyze/typing.rs @@ -582,42 +582,64 @@ pub fn resolve_assignment<'a>( pub fn find_assigned_value<'a>(symbol: &str, semantic: &'a SemanticModel<'a>) -> Option<&'a Expr> { let binding_id = semantic.lookup_symbol(symbol)?; let binding = semantic.binding(binding_id); - if binding.kind.is_assignment() || binding.kind.is_named_expr_assignment() { - let parent_id = binding.source?; - let parent = semantic.statement(parent_id); - match parent { - Stmt::Assign(ast::StmtAssign { value, targets, .. }) => match value.as_ref() { - Expr::Tuple(ast::ExprTuple { elts, .. }) - | Expr::List(ast::ExprList { elts, .. }) => { + match binding.kind { + // Ex) `x := 1` + BindingKind::NamedExprAssignment => { + let parent_id = binding.source?; + let parent = semantic + .expressions(parent_id) + .find_map(|expr| expr.as_named_expr_expr()); + if let Some(ast::ExprNamedExpr { target, value, .. }) = parent { + return match_value(symbol, target.as_ref(), value.as_ref()); + } + } + // Ex) `x = 1` + BindingKind::Assignment => { + let parent_id = binding.source?; + let parent = semantic.statement(parent_id); + match parent { + Stmt::Assign(ast::StmtAssign { value, targets, .. }) => { if let Some(target) = targets.iter().find(|target| defines(symbol, target)) { - return match target { - Expr::Tuple(ast::ExprTuple { - elts: target_elts, .. - }) - | Expr::List(ast::ExprList { - elts: target_elts, .. - }) - | Expr::Set(ast::ExprSet { - elts: target_elts, .. - }) => get_value_by_id(symbol, target_elts, elts), - _ => Some(value.as_ref()), - }; + return match_value(symbol, target, value.as_ref()); } } - _ => return Some(value.as_ref()), - }, - Stmt::AnnAssign(ast::StmtAnnAssign { - value: Some(value), .. - }) => { - return Some(value.as_ref()); + Stmt::AnnAssign(ast::StmtAnnAssign { + value: Some(value), + target, + .. + }) => { + return match_value(symbol, target, value.as_ref()); + } + _ => {} } - Stmt::AugAssign(_) => return None, - _ => return None, } + _ => {} } None } +/// Given a target and value, find the value that's assigned to the given symbol. +fn match_value<'a>(symbol: &str, target: &Expr, value: &'a Expr) -> Option<&'a Expr> { + match target { + Expr::Name(ast::ExprName { id, .. }) if id.as_str() == symbol => Some(value), + Expr::Tuple(ast::ExprTuple { elts, .. }) | Expr::List(ast::ExprList { elts, .. }) => { + match value { + Expr::Tuple(ast::ExprTuple { + elts: value_elts, .. + }) + | Expr::List(ast::ExprList { + elts: value_elts, .. + }) + | Expr::Set(ast::ExprSet { + elts: value_elts, .. + }) => get_value_by_id(symbol, elts, value_elts), + _ => None, + } + } + _ => None, + } +} + /// Returns `true` if the [`Expr`] defines the symbol. fn defines(symbol: &str, expr: &Expr) -> bool { match expr { @@ -629,11 +651,7 @@ fn defines(symbol: &str, expr: &Expr) -> bool { } } -fn get_value_by_id<'a>( - target_id: &str, - targets: &'a [Expr], - values: &'a [Expr], -) -> Option<&'a Expr> { +fn get_value_by_id<'a>(target_id: &str, targets: &[Expr], values: &'a [Expr]) -> Option<&'a Expr> { for (target, value) in targets.iter().zip(values.iter()) { match target { Expr::Tuple(ast::ExprTuple { diff --git a/crates/ruff_python_semantic/src/model.rs b/crates/ruff_python_semantic/src/model.rs index 9cb3cebaa07ff..82221f0f85dc8 100644 --- a/crates/ruff_python_semantic/src/model.rs +++ b/crates/ruff_python_semantic/src/model.rs @@ -1005,6 +1005,23 @@ impl<'a> SemanticModel<'a> { .nth(1) } + /// Return the [`Expr`] corresponding to the given [`NodeId`]. + #[inline] + pub fn expression(&self, node_id: NodeId) -> &'a Expr { + self.nodes + .ancestor_ids(node_id) + .find_map(|id| self.nodes[id].as_expression()) + .expect("No expression found") + } + + /// Returns an [`Iterator`] over the expressions, starting from the given [`NodeId`]. + /// through to any parents. + pub fn expressions(&self, node_id: NodeId) -> impl Iterator + '_ { + self.nodes + .ancestor_ids(node_id) + .filter_map(move |id| self.nodes[id].as_expression()) + } + /// Set the [`Globals`] for the current [`Scope`]. pub fn set_globals(&mut self, globals: Globals<'a>) { // If any global bindings don't already exist in the global scope, add them.