Skip to content

Commit

Permalink
Use smarter inversion for comparison checks
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Feb 12, 2023
1 parent 8b35b05 commit 8193689
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 1 deletion.
16 changes: 16 additions & 0 deletions crates/ruff/resources/test/fixtures/flake8_simplify/SIM111.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,19 @@ def f():
if check(x):
return False
return True


def f():
# SIM111
for x in iterable:
if x not in y:
return False
return True


def f():
# SIM111
for x in iterable:
if x > y:
return False
return True
32 changes: 31 additions & 1 deletion crates/ruff/src/rules/flake8_simplify/rules/ast_for.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ruff_macros::{define_violation, derive_message_formats};
use rustpython_parser::ast::{
Comprehension, Constant, Expr, ExprContext, ExprKind, Location, Stmt, StmtKind, Unaryop,
Cmpop, Comprehension, Constant, Expr, ExprContext, ExprKind, Location, Stmt, StmtKind, Unaryop,
};

use crate::ast::helpers::{create_expr, create_stmt, unparse_stmt};
Expand Down Expand Up @@ -260,6 +260,36 @@ pub fn convert_for_loop_to_any_all(checker: &mut Checker, stmt: &Stmt, sibling:
} = &loop_info.test.node
{
*operand.clone()
} else if let ExprKind::Compare {
left,
ops,
comparators,
} = &loop_info.test.node
{
if ops.len() == 1 && comparators.len() == 1 {
let op = match ops[0] {
Cmpop::Eq => Cmpop::NotEq,
Cmpop::NotEq => Cmpop::Eq,
Cmpop::Lt => Cmpop::GtE,
Cmpop::LtE => Cmpop::Gt,
Cmpop::Gt => Cmpop::LtE,
Cmpop::GtE => Cmpop::Lt,
Cmpop::Is => Cmpop::IsNot,
Cmpop::IsNot => Cmpop::Is,
Cmpop::In => Cmpop::NotIn,
Cmpop::NotIn => Cmpop::In,
};
create_expr(ExprKind::Compare {
left: left.clone(),
ops: vec![op],
comparators: vec![comparators[0].clone()],
})
} else {
create_expr(ExprKind::UnaryOp {
op: Unaryop::Not,
operand: Box::new(loop_info.test.clone()),
})
}
} else {
create_expr(ExprKind::UnaryOp {
op: Unaryop::Not,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
---
source: crates/ruff/src/rules/flake8_simplify/mod.rs
assertion_line: 47
expression: diagnostics
---
- kind:
ConvertLoopToAll:
all: return all(not check(x) for x in iterable)
location:
row: 25
column: 4
end_location:
row: 27
column: 24
fix:
content:
- return all(not check(x) for x in iterable)
location:
row: 25
column: 4
end_location:
row: 28
column: 15
parent: ~
- kind:
ConvertLoopToAll:
all: return all(x.is_empty() for x in iterable)
location:
row: 33
column: 4
end_location:
row: 35
column: 24
fix:
content:
- return all(x.is_empty() for x in iterable)
location:
row: 33
column: 4
end_location:
row: 36
column: 15
parent: ~
- kind:
ConvertLoopToAll:
all: return all(not check(x) for x in iterable)
location:
row: 64
column: 4
end_location:
row: 68
column: 19
fix:
content:
- return all(not check(x) for x in iterable)
location:
row: 64
column: 4
end_location:
row: 68
column: 19
parent: ~
- kind:
ConvertLoopToAll:
all: return all(not check(x) for x in iterable)
location:
row: 83
column: 4
end_location:
row: 87
column: 19
fix:
content:
- return all(not check(x) for x in iterable)
location:
row: 83
column: 4
end_location:
row: 87
column: 19
parent: ~
- kind:
ConvertLoopToAll:
all: return all(not check(x) for x in iterable)
location:
row: 134
column: 4
end_location:
row: 136
column: 24
fix: ~
parent: ~
- kind:
ConvertLoopToAll:
all: return all(not check(x) for x in iterable)
location:
row: 154
column: 4
end_location:
row: 156
column: 24
fix:
content:
- return all(not check(x) for x in iterable)
location:
row: 154
column: 4
end_location:
row: 157
column: 15
parent: ~
- kind:
ConvertLoopToAll:
all: return all(x in y for x in iterable)
location:
row: 162
column: 4
end_location:
row: 164
column: 24
fix:
content:
- return all(x in y for x in iterable)
location:
row: 162
column: 4
end_location:
row: 165
column: 15
parent: ~
- kind:
ConvertLoopToAll:
all: return all(x <= y for x in iterable)
location:
row: 170
column: 4
end_location:
row: 172
column: 24
fix:
content:
- return all(x <= y for x in iterable)
location:
row: 170
column: 4
end_location:
row: 173
column: 15
parent: ~

0 comments on commit 8193689

Please sign in to comment.