From 1b75ee913596091293973cd9e0deda61c1e9f918 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Mon, 13 Nov 2023 12:33:42 -0500 Subject: [PATCH] Use a visitor --- crates/ruff_python_ast/src/visitor.rs | 5 +- .../src/visitor/transformer.rs | 732 ++++++++ .../ruff_python_formatter/tests/fixtures.rs | 15 +- .../tests/normalized_ast.rs | 1544 ----------------- .../ruff_python_formatter/tests/normalizer.rs | 83 + 5 files changed, 828 insertions(+), 1551 deletions(-) create mode 100644 crates/ruff_python_ast/src/visitor/transformer.rs delete mode 100644 crates/ruff_python_formatter/tests/normalized_ast.rs create mode 100644 crates/ruff_python_formatter/tests/normalizer.rs diff --git a/crates/ruff_python_ast/src/visitor.rs b/crates/ruff_python_ast/src/visitor.rs index 8084f030c8f55..f740044fb5ffc 100644 --- a/crates/ruff_python_ast/src/visitor.rs +++ b/crates/ruff_python_ast/src/visitor.rs @@ -1,6 +1,7 @@ //! AST visitor trait and walk functions. pub mod preorder; +pub mod transformer; use crate::{ self as ast, Alias, Arguments, BoolOp, CmpOp, Comprehension, Decorator, ElifElseClause, @@ -14,8 +15,10 @@ use crate::{ /// Prefer [`crate::statement_visitor::StatementVisitor`] for visitors that only need to visit /// statements. /// -/// Use the [`PreorderVisitor`](self::preorder::PreorderVisitor) if you want to visit the nodes +/// Use the [`PreorderVisitor`](preorder::PreorderVisitor) if you want to visit the nodes /// in pre-order rather than evaluation order. +/// +/// Use the [`Transformer`](transformer::Transformer) if you want to modify the nodes. pub trait Visitor<'a> { fn visit_stmt(&mut self, stmt: &'a Stmt) { walk_stmt(self, stmt); diff --git a/crates/ruff_python_ast/src/visitor/transformer.rs b/crates/ruff_python_ast/src/visitor/transformer.rs new file mode 100644 index 0000000000000..b90ab0a1b61e9 --- /dev/null +++ b/crates/ruff_python_ast/src/visitor/transformer.rs @@ -0,0 +1,732 @@ +use crate::{ + self as ast, Alias, Arguments, BoolOp, CmpOp, Comprehension, Decorator, ElifElseClause, + ExceptHandler, Expr, ExprContext, Keyword, MatchCase, Operator, Parameter, Parameters, Pattern, + PatternArguments, PatternKeyword, Stmt, TypeParam, TypeParamTypeVar, TypeParams, UnaryOp, + WithItem, +}; + +/// A trait for transforming ASTs. Visits all nodes in the AST recursively in evaluation-order. +pub trait Transformer { + fn visit_stmt(&self, stmt: &mut Stmt) { + walk_stmt(self, stmt); + } + fn visit_annotation(&self, expr: &mut Expr) { + walk_annotation(self, expr); + } + fn visit_decorator(&self, decorator: &mut Decorator) { + walk_decorator(self, decorator); + } + fn visit_expr(&self, expr: &mut Expr) { + walk_expr(self, expr); + } + fn visit_expr_context(&self, expr_context: &mut ExprContext) { + walk_expr_context(self, expr_context); + } + fn visit_bool_op(&self, bool_op: &mut BoolOp) { + walk_bool_op(self, bool_op); + } + fn visit_operator(&self, operator: &mut Operator) { + walk_operator(self, operator); + } + fn visit_unary_op(&self, unary_op: &mut UnaryOp) { + walk_unary_op(self, unary_op); + } + fn visit_cmp_op(&self, cmp_op: &mut CmpOp) { + walk_cmp_op(self, cmp_op); + } + fn visit_comprehension(&self, comprehension: &mut Comprehension) { + walk_comprehension(self, comprehension); + } + fn visit_except_handler(&self, except_handler: &mut ExceptHandler) { + walk_except_handler(self, except_handler); + } + fn visit_format_spec(&self, format_spec: &mut Expr) { + walk_format_spec(self, format_spec); + } + fn visit_arguments(&self, arguments: &mut Arguments) { + walk_arguments(self, arguments); + } + fn visit_parameters(&self, parameters: &mut Parameters) { + walk_parameters(self, parameters); + } + fn visit_parameter(&self, parameter: &mut Parameter) { + walk_parameter(self, parameter); + } + fn visit_keyword(&self, keyword: &mut Keyword) { + walk_keyword(self, keyword); + } + fn visit_alias(&self, alias: &mut Alias) { + walk_alias(self, alias); + } + fn visit_with_item(&self, with_item: &mut WithItem) { + walk_with_item(self, with_item); + } + fn visit_type_params(&self, type_params: &mut TypeParams) { + walk_type_params(self, type_params); + } + fn visit_type_param(&self, type_param: &mut TypeParam) { + walk_type_param(self, type_param); + } + fn visit_match_case(&self, match_case: &mut MatchCase) { + walk_match_case(self, match_case); + } + fn visit_pattern(&self, pattern: &mut Pattern) { + walk_pattern(self, pattern); + } + fn visit_pattern_arguments(&self, pattern_arguments: &mut PatternArguments) { + walk_pattern_arguments(self, pattern_arguments); + } + fn visit_pattern_keyword(&self, pattern_keyword: &mut PatternKeyword) { + walk_pattern_keyword(self, pattern_keyword); + } + fn visit_body(&self, body: &mut [Stmt]) { + walk_body(self, body); + } + fn visit_elif_else_clause(&self, elif_else_clause: &mut ElifElseClause) { + walk_elif_else_clause(self, elif_else_clause); + } +} + +pub fn walk_body(visitor: &V, body: &mut [Stmt]) { + for stmt in body { + visitor.visit_stmt(stmt); + } +} + +pub fn walk_elif_else_clause( + visitor: &V, + elif_else_clause: &mut ElifElseClause, +) { + if let Some(test) = &mut elif_else_clause.test { + visitor.visit_expr(test); + } + visitor.visit_body(&mut elif_else_clause.body); +} + +pub fn walk_stmt(visitor: &V, stmt: &mut Stmt) { + match stmt { + Stmt::FunctionDef(ast::StmtFunctionDef { + parameters, + body, + decorator_list, + returns, + type_params, + .. + }) => { + for decorator in decorator_list { + visitor.visit_decorator(decorator); + } + if let Some(type_params) = type_params { + visitor.visit_type_params(type_params); + } + visitor.visit_parameters(parameters); + for expr in returns { + visitor.visit_annotation(expr); + } + visitor.visit_body(body); + } + Stmt::ClassDef(ast::StmtClassDef { + arguments, + body, + decorator_list, + type_params, + .. + }) => { + for decorator in decorator_list { + visitor.visit_decorator(decorator); + } + if let Some(type_params) = type_params { + visitor.visit_type_params(type_params); + } + if let Some(arguments) = arguments { + visitor.visit_arguments(arguments); + } + visitor.visit_body(body); + } + Stmt::Return(ast::StmtReturn { value, range: _ }) => { + if let Some(expr) = value { + visitor.visit_expr(expr); + } + } + Stmt::Delete(ast::StmtDelete { targets, range: _ }) => { + for expr in targets { + visitor.visit_expr(expr); + } + } + Stmt::TypeAlias(ast::StmtTypeAlias { + range: _, + name, + type_params, + value, + }) => { + visitor.visit_expr(value); + if let Some(type_params) = type_params { + visitor.visit_type_params(type_params); + } + visitor.visit_expr(name); + } + Stmt::Assign(ast::StmtAssign { targets, value, .. }) => { + visitor.visit_expr(value); + for expr in targets { + visitor.visit_expr(expr); + } + } + Stmt::AugAssign(ast::StmtAugAssign { + target, + op, + value, + range: _, + }) => { + visitor.visit_expr(value); + visitor.visit_operator(op); + visitor.visit_expr(target); + } + Stmt::AnnAssign(ast::StmtAnnAssign { + target, + annotation, + value, + .. + }) => { + if let Some(expr) = value { + visitor.visit_expr(expr); + } + visitor.visit_annotation(annotation); + visitor.visit_expr(target); + } + Stmt::For(ast::StmtFor { + target, + iter, + body, + orelse, + .. + }) => { + visitor.visit_expr(iter); + visitor.visit_expr(target); + visitor.visit_body(body); + visitor.visit_body(orelse); + } + Stmt::While(ast::StmtWhile { + test, + body, + orelse, + range: _, + }) => { + visitor.visit_expr(test); + visitor.visit_body(body); + visitor.visit_body(orelse); + } + Stmt::If(ast::StmtIf { + test, + body, + elif_else_clauses, + range: _, + }) => { + visitor.visit_expr(test); + visitor.visit_body(body); + for clause in elif_else_clauses { + if let Some(test) = &mut clause.test { + visitor.visit_expr(test); + } + walk_elif_else_clause(visitor, clause); + } + } + Stmt::With(ast::StmtWith { items, body, .. }) => { + for with_item in items { + visitor.visit_with_item(with_item); + } + visitor.visit_body(body); + } + Stmt::Match(ast::StmtMatch { + subject, + cases, + range: _, + }) => { + visitor.visit_expr(subject); + for match_case in cases { + visitor.visit_match_case(match_case); + } + } + Stmt::Raise(ast::StmtRaise { + exc, + cause, + range: _, + }) => { + if let Some(expr) = exc { + visitor.visit_expr(expr); + }; + if let Some(expr) = cause { + visitor.visit_expr(expr); + }; + } + Stmt::Try(ast::StmtTry { + body, + handlers, + orelse, + finalbody, + is_star: _, + range: _, + }) => { + visitor.visit_body(body); + for except_handler in handlers { + visitor.visit_except_handler(except_handler); + } + visitor.visit_body(orelse); + visitor.visit_body(finalbody); + } + Stmt::Assert(ast::StmtAssert { + test, + msg, + range: _, + }) => { + visitor.visit_expr(test); + if let Some(expr) = msg { + visitor.visit_expr(expr); + } + } + Stmt::Import(ast::StmtImport { names, range: _ }) => { + for alias in names { + visitor.visit_alias(alias); + } + } + Stmt::ImportFrom(ast::StmtImportFrom { names, .. }) => { + for alias in names { + visitor.visit_alias(alias); + } + } + Stmt::Global(_) => {} + Stmt::Nonlocal(_) => {} + Stmt::Expr(ast::StmtExpr { value, range: _ }) => visitor.visit_expr(value), + Stmt::Pass(_) | Stmt::Break(_) | Stmt::Continue(_) | Stmt::IpyEscapeCommand(_) => {} + } +} + +pub fn walk_annotation(visitor: &V, expr: &mut Expr) { + visitor.visit_expr(expr); +} + +pub fn walk_decorator(visitor: &V, decorator: &mut Decorator) { + visitor.visit_expr(&mut decorator.expression); +} + +pub fn walk_expr(visitor: &V, expr: &mut Expr) { + match expr { + Expr::BoolOp(ast::ExprBoolOp { + op, + values, + range: _, + }) => { + visitor.visit_bool_op(op); + for expr in values { + visitor.visit_expr(expr); + } + } + Expr::NamedExpr(ast::ExprNamedExpr { + target, + value, + range: _, + }) => { + visitor.visit_expr(value); + visitor.visit_expr(target); + } + Expr::BinOp(ast::ExprBinOp { + left, + op, + right, + range: _, + }) => { + visitor.visit_expr(left); + visitor.visit_operator(op); + visitor.visit_expr(right); + } + Expr::UnaryOp(ast::ExprUnaryOp { + op, + operand, + range: _, + }) => { + visitor.visit_unary_op(op); + visitor.visit_expr(operand); + } + Expr::Lambda(ast::ExprLambda { + parameters, + body, + range: _, + }) => { + if let Some(parameters) = parameters { + visitor.visit_parameters(parameters); + } + visitor.visit_expr(body); + } + Expr::IfExp(ast::ExprIfExp { + test, + body, + orelse, + range: _, + }) => { + visitor.visit_expr(test); + visitor.visit_expr(body); + visitor.visit_expr(orelse); + } + Expr::Dict(ast::ExprDict { + keys, + values, + range: _, + }) => { + for expr in keys.iter_mut().flatten() { + visitor.visit_expr(expr); + } + for expr in values { + visitor.visit_expr(expr); + } + } + Expr::Set(ast::ExprSet { elts, range: _ }) => { + for expr in elts { + visitor.visit_expr(expr); + } + } + Expr::ListComp(ast::ExprListComp { + elt, + generators, + range: _, + }) => { + for comprehension in generators { + visitor.visit_comprehension(comprehension); + } + visitor.visit_expr(elt); + } + Expr::SetComp(ast::ExprSetComp { + elt, + generators, + range: _, + }) => { + for comprehension in generators { + visitor.visit_comprehension(comprehension); + } + visitor.visit_expr(elt); + } + Expr::DictComp(ast::ExprDictComp { + key, + value, + generators, + range: _, + }) => { + for comprehension in generators { + visitor.visit_comprehension(comprehension); + } + visitor.visit_expr(key); + visitor.visit_expr(value); + } + Expr::GeneratorExp(ast::ExprGeneratorExp { + elt, + generators, + range: _, + }) => { + for comprehension in generators { + visitor.visit_comprehension(comprehension); + } + visitor.visit_expr(elt); + } + Expr::Await(ast::ExprAwait { value, range: _ }) => visitor.visit_expr(value), + Expr::Yield(ast::ExprYield { value, range: _ }) => { + if let Some(expr) = value { + visitor.visit_expr(expr); + } + } + Expr::YieldFrom(ast::ExprYieldFrom { value, range: _ }) => visitor.visit_expr(value), + Expr::Compare(ast::ExprCompare { + left, + ops, + comparators, + range: _, + }) => { + visitor.visit_expr(left); + for cmp_op in ops { + visitor.visit_cmp_op(cmp_op); + } + for expr in comparators { + visitor.visit_expr(expr); + } + } + Expr::Call(ast::ExprCall { + func, + arguments, + range: _, + }) => { + visitor.visit_expr(func); + visitor.visit_arguments(arguments); + } + Expr::FormattedValue(ast::ExprFormattedValue { + value, format_spec, .. + }) => { + visitor.visit_expr(value); + if let Some(expr) = format_spec { + visitor.visit_format_spec(expr); + } + } + Expr::FString(ast::ExprFString { values, .. }) => { + for expr in values { + visitor.visit_expr(expr); + } + } + Expr::StringLiteral(_) + | Expr::BytesLiteral(_) + | Expr::NumberLiteral(_) + | Expr::BooleanLiteral(_) + | Expr::NoneLiteral(_) + | Expr::EllipsisLiteral(_) => {} + Expr::Attribute(ast::ExprAttribute { value, ctx, .. }) => { + visitor.visit_expr(value); + visitor.visit_expr_context(ctx); + } + Expr::Subscript(ast::ExprSubscript { + value, + slice, + ctx, + range: _, + }) => { + visitor.visit_expr(value); + visitor.visit_expr(slice); + visitor.visit_expr_context(ctx); + } + Expr::Starred(ast::ExprStarred { + value, + ctx, + range: _, + }) => { + visitor.visit_expr(value); + visitor.visit_expr_context(ctx); + } + Expr::Name(ast::ExprName { ctx, .. }) => { + visitor.visit_expr_context(ctx); + } + Expr::List(ast::ExprList { + elts, + ctx, + range: _, + }) => { + for expr in elts { + visitor.visit_expr(expr); + } + visitor.visit_expr_context(ctx); + } + Expr::Tuple(ast::ExprTuple { + elts, + ctx, + range: _, + }) => { + for expr in elts { + visitor.visit_expr(expr); + } + visitor.visit_expr_context(ctx); + } + Expr::Slice(ast::ExprSlice { + lower, + upper, + step, + range: _, + }) => { + if let Some(expr) = lower { + visitor.visit_expr(expr); + } + if let Some(expr) = upper { + visitor.visit_expr(expr); + } + if let Some(expr) = step { + visitor.visit_expr(expr); + } + } + Expr::IpyEscapeCommand(_) => {} + } +} + +pub fn walk_comprehension(visitor: &V, comprehension: &mut Comprehension) { + visitor.visit_expr(&mut comprehension.iter); + visitor.visit_expr(&mut comprehension.target); + for expr in &mut comprehension.ifs { + visitor.visit_expr(expr); + } +} + +pub fn walk_except_handler( + visitor: &V, + except_handler: &mut ExceptHandler, +) { + match except_handler { + ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { type_, body, .. }) => { + if let Some(expr) = type_ { + visitor.visit_expr(expr); + } + visitor.visit_body(body); + } + } +} + +pub fn walk_format_spec(visitor: &V, format_spec: &mut Expr) { + visitor.visit_expr(format_spec); +} + +pub fn walk_arguments(visitor: &V, arguments: &mut Arguments) { + // Note that the there might be keywords before the last arg, e.g. in + // f(*args, a=2, *args2, **kwargs)`, but we follow Python in evaluating first `args` and then + // `keywords`. See also [Arguments::arguments_source_order`]. + for arg in &mut arguments.args { + visitor.visit_expr(arg); + } + for keyword in &mut arguments.keywords { + visitor.visit_keyword(keyword); + } +} + +pub fn walk_parameters(visitor: &V, parameters: &mut Parameters) { + // Defaults are evaluated before annotations. + for arg in &mut parameters.posonlyargs { + if let Some(default) = &mut arg.default { + visitor.visit_expr(default); + } + } + for arg in &mut parameters.args { + if let Some(default) = &mut arg.default { + visitor.visit_expr(default); + } + } + for arg in &mut parameters.kwonlyargs { + if let Some(default) = &mut arg.default { + visitor.visit_expr(default); + } + } + + for arg in &mut parameters.posonlyargs { + visitor.visit_parameter(&mut arg.parameter); + } + for arg in &mut parameters.args { + visitor.visit_parameter(&mut arg.parameter); + } + if let Some(arg) = &mut parameters.vararg { + visitor.visit_parameter(arg); + } + for arg in &mut parameters.kwonlyargs { + visitor.visit_parameter(&mut arg.parameter); + } + if let Some(arg) = &mut parameters.kwarg { + visitor.visit_parameter(arg); + } +} + +pub fn walk_parameter(visitor: &V, parameter: &mut Parameter) { + if let Some(expr) = &mut parameter.annotation { + visitor.visit_annotation(expr); + } +} + +pub fn walk_keyword(visitor: &V, keyword: &mut Keyword) { + visitor.visit_expr(&mut keyword.value); +} + +pub fn walk_with_item(visitor: &V, with_item: &mut WithItem) { + visitor.visit_expr(&mut with_item.context_expr); + if let Some(expr) = &mut with_item.optional_vars { + visitor.visit_expr(expr); + } +} + +pub fn walk_type_params(visitor: &V, type_params: &mut TypeParams) { + for type_param in &mut type_params.type_params { + visitor.visit_type_param(type_param); + } +} + +pub fn walk_type_param(visitor: &V, type_param: &mut TypeParam) { + match type_param { + TypeParam::TypeVar(TypeParamTypeVar { + bound, + name: _, + range: _, + }) => { + if let Some(expr) = bound { + visitor.visit_expr(expr); + } + } + TypeParam::TypeVarTuple(_) | TypeParam::ParamSpec(_) => {} + } +} + +pub fn walk_match_case(visitor: &V, match_case: &mut MatchCase) { + visitor.visit_pattern(&mut match_case.pattern); + if let Some(expr) = &mut match_case.guard { + visitor.visit_expr(expr); + } + visitor.visit_body(&mut match_case.body); +} + +pub fn walk_pattern(visitor: &V, pattern: &mut Pattern) { + match pattern { + Pattern::MatchValue(ast::PatternMatchValue { value, .. }) => { + visitor.visit_expr(value); + } + Pattern::MatchSingleton(_) => {} + Pattern::MatchSequence(ast::PatternMatchSequence { patterns, .. }) => { + for pattern in patterns { + visitor.visit_pattern(pattern); + } + } + Pattern::MatchMapping(ast::PatternMatchMapping { keys, patterns, .. }) => { + for expr in keys { + visitor.visit_expr(expr); + } + for pattern in patterns { + visitor.visit_pattern(pattern); + } + } + Pattern::MatchClass(ast::PatternMatchClass { cls, arguments, .. }) => { + visitor.visit_expr(cls); + visitor.visit_pattern_arguments(arguments); + } + Pattern::MatchStar(_) => {} + Pattern::MatchAs(ast::PatternMatchAs { pattern, .. }) => { + if let Some(pattern) = pattern { + visitor.visit_pattern(pattern); + } + } + Pattern::MatchOr(ast::PatternMatchOr { patterns, .. }) => { + for pattern in patterns { + visitor.visit_pattern(pattern); + } + } + } +} + +pub fn walk_pattern_arguments( + visitor: &V, + pattern_arguments: &mut PatternArguments, +) { + for pattern in &mut pattern_arguments.patterns { + visitor.visit_pattern(pattern); + } + for keyword in &mut pattern_arguments.keywords { + visitor.visit_pattern_keyword(keyword); + } +} + +pub fn walk_pattern_keyword( + visitor: &V, + pattern_keyword: &mut PatternKeyword, +) { + visitor.visit_pattern(&mut pattern_keyword.pattern); +} + +#[allow(unused_variables)] +pub fn walk_expr_context(visitor: &V, expr_context: &mut ExprContext) {} + +#[allow(unused_variables)] +pub fn walk_bool_op(visitor: &V, bool_op: &mut BoolOp) {} + +#[allow(unused_variables)] +pub fn walk_operator(visitor: &V, operator: &mut Operator) {} + +#[allow(unused_variables)] +pub fn walk_unary_op(visitor: &V, unary_op: &mut UnaryOp) {} + +#[allow(unused_variables)] +pub fn walk_cmp_op(visitor: &V, cmp_op: &mut CmpOp) {} + +#[allow(unused_variables)] +pub fn walk_alias(visitor: &V, alias: &mut Alias) {} diff --git a/crates/ruff_python_formatter/tests/fixtures.rs b/crates/ruff_python_formatter/tests/fixtures.rs index 5a6ec79c03cba..c3fd2b20707e8 100644 --- a/crates/ruff_python_formatter/tests/fixtures.rs +++ b/crates/ruff_python_formatter/tests/fixtures.rs @@ -5,12 +5,13 @@ use std::{fmt, fs}; use similar::TextDiff; -use normalized_ast::NormalizedMod; +use crate::normalizer::Normalizer; use ruff_formatter::FormatOptions; +use ruff_python_ast::comparable::ComparableMod; use ruff_python_formatter::{format_module_source, PreviewMode, PyFormatOptions}; use ruff_python_parser::{parse, AsMode}; -mod normalized_ast; +mod normalizer; #[test] fn black_compatibility() { @@ -253,22 +254,24 @@ fn ensure_unchanged_ast( let source_type = options.source_type(); // Parse the unformatted code. - let unformatted_ast = parse( + let mut unformatted_ast = parse( unformatted_code, source_type.as_mode(), &input_path.to_string_lossy(), ) .expect("Unformatted code to be valid syntax"); - let unformatted_ast = NormalizedMod::from(&unformatted_ast); + Normalizer.visit_module(&mut unformatted_ast); + let unformatted_ast = ComparableMod::from(&unformatted_ast); // Parse the formatted code. - let formatted_ast = parse( + let mut formatted_ast = parse( formatted_code, source_type.as_mode(), &input_path.to_string_lossy(), ) .expect("Formatted code to be valid syntax"); - let formatted_ast = NormalizedMod::from(&formatted_ast); + Normalizer.visit_module(&mut formatted_ast); + let formatted_ast = ComparableMod::from(&formatted_ast); if formatted_ast != unformatted_ast { let diff = TextDiff::from_lines( diff --git a/crates/ruff_python_formatter/tests/normalized_ast.rs b/crates/ruff_python_formatter/tests/normalized_ast.rs deleted file mode 100644 index 1176124c728fc..0000000000000 --- a/crates/ruff_python_formatter/tests/normalized_ast.rs +++ /dev/null @@ -1,1544 +0,0 @@ -//! An equivalent object hierarchy to the `RustPython` AST hierarchy, but with the -//! ability to compare nodes for equality after formatting. -//! -//! Vis-à-vis comparing ASTs, comparing these normalized representations does the following: -//! - Removes all locations from the AST. -//! - Ignores non-abstraction information that we've encoded into the AST, e.g., the difference -//! between `class C: ...` and `class C(): ...`, which is part of our AST but not `CPython`'s. -//! - Normalize strings. The formatter can re-indent docstrings, so we need to compare string -//! contents ignoring whitespace. (Black does the same.) -//! - Ignores nested tuples in deletions. (Black does the same.) - -use itertools::Either::{Left, Right}; - -use ruff_python_ast as ast; - -#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] -enum NormalizedBoolOp { - And, - Or, -} - -impl From for NormalizedBoolOp { - fn from(op: ast::BoolOp) -> Self { - match op { - ast::BoolOp::And => Self::And, - ast::BoolOp::Or => Self::Or, - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] -enum NormalizedOperator { - Add, - Sub, - Mult, - MatMult, - Div, - Mod, - Pow, - LShift, - RShift, - BitOr, - BitXor, - BitAnd, - FloorDiv, -} - -impl From for NormalizedOperator { - fn from(op: ast::Operator) -> Self { - match op { - ast::Operator::Add => Self::Add, - ast::Operator::Sub => Self::Sub, - ast::Operator::Mult => Self::Mult, - ast::Operator::MatMult => Self::MatMult, - ast::Operator::Div => Self::Div, - ast::Operator::Mod => Self::Mod, - ast::Operator::Pow => Self::Pow, - ast::Operator::LShift => Self::LShift, - ast::Operator::RShift => Self::RShift, - ast::Operator::BitOr => Self::BitOr, - ast::Operator::BitXor => Self::BitXor, - ast::Operator::BitAnd => Self::BitAnd, - ast::Operator::FloorDiv => Self::FloorDiv, - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] -enum NormalizedUnaryOp { - Invert, - Not, - UAdd, - USub, -} - -impl From for NormalizedUnaryOp { - fn from(op: ast::UnaryOp) -> Self { - match op { - ast::UnaryOp::Invert => Self::Invert, - ast::UnaryOp::Not => Self::Not, - ast::UnaryOp::UAdd => Self::UAdd, - ast::UnaryOp::USub => Self::USub, - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] -enum NormalizedCmpOp { - Eq, - NotEq, - Lt, - LtE, - Gt, - GtE, - Is, - IsNot, - In, - NotIn, -} - -impl From for NormalizedCmpOp { - fn from(op: ast::CmpOp) -> Self { - match op { - ast::CmpOp::Eq => Self::Eq, - ast::CmpOp::NotEq => Self::NotEq, - ast::CmpOp::Lt => Self::Lt, - ast::CmpOp::LtE => Self::LtE, - ast::CmpOp::Gt => Self::Gt, - ast::CmpOp::GtE => Self::GtE, - ast::CmpOp::Is => Self::Is, - ast::CmpOp::IsNot => Self::IsNot, - ast::CmpOp::In => Self::In, - ast::CmpOp::NotIn => Self::NotIn, - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedAlias<'a> { - name: &'a str, - asname: Option<&'a str>, -} - -impl<'a> From<&'a ast::Alias> for NormalizedAlias<'a> { - fn from(alias: &'a ast::Alias) -> Self { - Self { - name: alias.name.as_str(), - asname: alias.asname.as_deref(), - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedWithItem<'a> { - context_expr: NormalizedExpr<'a>, - optional_vars: Option>, -} - -impl<'a> From<&'a ast::WithItem> for NormalizedWithItem<'a> { - fn from(with_item: &'a ast::WithItem) -> Self { - Self { - context_expr: (&with_item.context_expr).into(), - optional_vars: with_item.optional_vars.as_ref().map(Into::into), - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedPatternArguments<'a> { - patterns: Vec>, - keywords: Vec>, -} - -impl<'a> From<&'a ast::PatternArguments> for NormalizedPatternArguments<'a> { - fn from(parameters: &'a ast::PatternArguments) -> Self { - Self { - patterns: parameters.patterns.iter().map(Into::into).collect(), - keywords: parameters.keywords.iter().map(Into::into).collect(), - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedPatternKeyword<'a> { - attr: &'a str, - pattern: NormalizedPattern<'a>, -} - -impl<'a> From<&'a ast::PatternKeyword> for NormalizedPatternKeyword<'a> { - fn from(keyword: &'a ast::PatternKeyword) -> Self { - Self { - attr: keyword.attr.as_str(), - pattern: (&keyword.pattern).into(), - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct PatternMatchValue<'a> { - value: NormalizedExpr<'a>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct PatternMatchSingleton { - value: NormalizedSingleton, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct PatternMatchSequence<'a> { - patterns: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct PatternMatchMapping<'a> { - keys: Vec>, - patterns: Vec>, - rest: Option<&'a str>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct PatternMatchClass<'a> { - cls: NormalizedExpr<'a>, - arguments: NormalizedPatternArguments<'a>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct PatternMatchStar<'a> { - name: Option<&'a str>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct PatternMatchAs<'a> { - pattern: Option>>, - name: Option<&'a str>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct PatternMatchOr<'a> { - patterns: Vec>, -} - -#[allow(clippy::enum_variant_names)] -#[derive(Debug, PartialEq, Eq, Hash)] -enum NormalizedPattern<'a> { - MatchValue(PatternMatchValue<'a>), - MatchSingleton(PatternMatchSingleton), - MatchSequence(PatternMatchSequence<'a>), - MatchMapping(PatternMatchMapping<'a>), - MatchClass(PatternMatchClass<'a>), - MatchStar(PatternMatchStar<'a>), - MatchAs(PatternMatchAs<'a>), - MatchOr(PatternMatchOr<'a>), -} - -impl<'a> From<&'a ast::Pattern> for NormalizedPattern<'a> { - fn from(pattern: &'a ast::Pattern) -> Self { - match pattern { - ast::Pattern::MatchValue(ast::PatternMatchValue { value, .. }) => { - Self::MatchValue(PatternMatchValue { - value: value.into(), - }) - } - ast::Pattern::MatchSingleton(ast::PatternMatchSingleton { value, .. }) => { - Self::MatchSingleton(PatternMatchSingleton { - value: value.into(), - }) - } - ast::Pattern::MatchSequence(ast::PatternMatchSequence { patterns, .. }) => { - Self::MatchSequence(PatternMatchSequence { - patterns: patterns.iter().map(Into::into).collect(), - }) - } - ast::Pattern::MatchMapping(ast::PatternMatchMapping { - keys, - patterns, - rest, - .. - }) => Self::MatchMapping(PatternMatchMapping { - keys: keys.iter().map(Into::into).collect(), - patterns: patterns.iter().map(Into::into).collect(), - rest: rest.as_deref(), - }), - ast::Pattern::MatchClass(ast::PatternMatchClass { cls, arguments, .. }) => { - Self::MatchClass(PatternMatchClass { - cls: cls.into(), - arguments: arguments.into(), - }) - } - ast::Pattern::MatchStar(ast::PatternMatchStar { name, .. }) => { - Self::MatchStar(PatternMatchStar { - name: name.as_deref(), - }) - } - ast::Pattern::MatchAs(ast::PatternMatchAs { pattern, name, .. }) => { - Self::MatchAs(PatternMatchAs { - pattern: pattern.as_ref().map(Into::into), - name: name.as_deref(), - }) - } - ast::Pattern::MatchOr(ast::PatternMatchOr { patterns, .. }) => { - Self::MatchOr(PatternMatchOr { - patterns: patterns.iter().map(Into::into).collect(), - }) - } - } - } -} - -impl<'a> From<&'a Box> for Box> { - fn from(pattern: &'a Box) -> Self { - Box::new((pattern.as_ref()).into()) - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedMatchCase<'a> { - pattern: NormalizedPattern<'a>, - guard: Option>, - body: Vec>, -} - -impl<'a> From<&'a ast::MatchCase> for NormalizedMatchCase<'a> { - fn from(match_case: &'a ast::MatchCase) -> Self { - Self { - pattern: (&match_case.pattern).into(), - guard: match_case.guard.as_ref().map(Into::into), - body: match_case.body.iter().map(Into::into).collect(), - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedDecorator<'a> { - expression: NormalizedExpr<'a>, -} - -impl<'a> From<&'a ast::Decorator> for NormalizedDecorator<'a> { - fn from(decorator: &'a ast::Decorator) -> Self { - Self { - expression: (&decorator.expression).into(), - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -enum NormalizedSingleton { - None, - True, - False, -} - -impl From<&ast::Singleton> for NormalizedSingleton { - fn from(singleton: &ast::Singleton) -> Self { - match singleton { - ast::Singleton::None => Self::None, - ast::Singleton::True => Self::True, - ast::Singleton::False => Self::False, - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -enum NormalizedNumber<'a> { - Int(&'a ast::Int), - Float(u64), - Complex { real: u64, imag: u64 }, -} - -impl<'a> From<&'a ast::Number> for NormalizedNumber<'a> { - fn from(number: &'a ast::Number) -> Self { - match number { - ast::Number::Int(value) => Self::Int(value), - ast::Number::Float(value) => Self::Float(value.to_bits()), - ast::Number::Complex { real, imag } => Self::Complex { - real: real.to_bits(), - imag: imag.to_bits(), - }, - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash, Default)] -struct NormalizedArguments<'a> { - args: Vec>, - keywords: Vec>, -} - -impl<'a> From<&'a ast::Arguments> for NormalizedArguments<'a> { - fn from(arguments: &'a ast::Arguments) -> Self { - Self { - args: arguments.args.iter().map(Into::into).collect(), - keywords: arguments.keywords.iter().map(Into::into).collect(), - } - } -} - -impl<'a> From<&'a Box> for NormalizedArguments<'a> { - fn from(arguments: &'a Box) -> Self { - (arguments.as_ref()).into() - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedParameters<'a> { - posonlyargs: Vec>, - args: Vec>, - vararg: Option>, - kwonlyargs: Vec>, - kwarg: Option>, -} - -impl<'a> From<&'a ast::Parameters> for NormalizedParameters<'a> { - fn from(parameters: &'a ast::Parameters) -> Self { - Self { - posonlyargs: parameters.posonlyargs.iter().map(Into::into).collect(), - args: parameters.args.iter().map(Into::into).collect(), - vararg: parameters.vararg.as_ref().map(Into::into), - kwonlyargs: parameters.kwonlyargs.iter().map(Into::into).collect(), - kwarg: parameters.kwarg.as_ref().map(Into::into), - } - } -} - -impl<'a> From<&'a Box> for NormalizedParameters<'a> { - fn from(parameters: &'a Box) -> Self { - (parameters.as_ref()).into() - } -} - -impl<'a> From<&'a Box> for NormalizedParameter<'a> { - fn from(arg: &'a Box) -> Self { - (arg.as_ref()).into() - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedParameter<'a> { - arg: &'a str, - annotation: Option>>, -} - -impl<'a> From<&'a ast::Parameter> for NormalizedParameter<'a> { - fn from(arg: &'a ast::Parameter) -> Self { - Self { - arg: arg.name.as_str(), - annotation: arg.annotation.as_ref().map(Into::into), - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedParameterWithDefault<'a> { - def: NormalizedParameter<'a>, - default: Option>, -} - -impl<'a> From<&'a ast::ParameterWithDefault> for NormalizedParameterWithDefault<'a> { - fn from(arg: &'a ast::ParameterWithDefault) -> Self { - Self { - def: (&arg.parameter).into(), - default: arg.default.as_ref().map(Into::into), - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedKeyword<'a> { - arg: Option<&'a str>, - value: NormalizedExpr<'a>, -} - -impl<'a> From<&'a ast::Keyword> for NormalizedKeyword<'a> { - fn from(keyword: &'a ast::Keyword) -> Self { - Self { - arg: keyword.arg.as_ref().map(ast::Identifier::as_str), - value: (&keyword.value).into(), - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedComprehension<'a> { - target: NormalizedExpr<'a>, - iter: NormalizedExpr<'a>, - ifs: Vec>, - is_async: bool, -} - -impl<'a> From<&'a ast::Comprehension> for NormalizedComprehension<'a> { - fn from(comprehension: &'a ast::Comprehension) -> Self { - Self { - target: (&comprehension.target).into(), - iter: (&comprehension.iter).into(), - ifs: comprehension.ifs.iter().map(Into::into).collect(), - is_async: comprehension.is_async, - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExceptHandlerExceptHandler<'a> { - type_: Option>>, - name: Option<&'a str>, - body: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -enum NormalizedExceptHandler<'a> { - ExceptHandler(ExceptHandlerExceptHandler<'a>), -} - -impl<'a> From<&'a ast::ExceptHandler> for NormalizedExceptHandler<'a> { - fn from(except_handler: &'a ast::ExceptHandler) -> Self { - let ast::ExceptHandler::ExceptHandler(ast::ExceptHandlerExceptHandler { - type_, - name, - body, - .. - }) = except_handler; - Self::ExceptHandler(ExceptHandlerExceptHandler { - type_: type_.as_ref().map(Into::into), - name: name.as_deref(), - body: body.iter().map(Into::into).collect(), - }) - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedElifElseClause<'a> { - test: Option>, - body: Vec>, -} - -impl<'a> From<&'a ast::ElifElseClause> for NormalizedElifElseClause<'a> { - fn from(elif_else_clause: &'a ast::ElifElseClause) -> Self { - let ast::ElifElseClause { - range: _, - test, - body, - } = elif_else_clause; - Self { - test: test.as_ref().map(Into::into), - body: body.iter().map(Into::into).collect(), - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprBoolOp<'a> { - op: NormalizedBoolOp, - values: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprNamedExpr<'a> { - target: Box>, - value: Box>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprBinOp<'a> { - left: Box>, - op: NormalizedOperator, - right: Box>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprUnaryOp<'a> { - op: NormalizedUnaryOp, - operand: Box>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprLambda<'a> { - parameters: Option>, - body: Box>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprIfExp<'a> { - test: Box>, - body: Box>, - orelse: Box>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprDict<'a> { - keys: Vec>>, - values: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprSet<'a> { - elts: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprListComp<'a> { - elt: Box>, - generators: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprSetComp<'a> { - elt: Box>, - generators: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprDictComp<'a> { - key: Box>, - value: Box>, - generators: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprGeneratorExp<'a> { - elt: Box>, - generators: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprAwait<'a> { - value: Box>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprYield<'a> { - value: Option>>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprYieldFrom<'a> { - value: Box>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprCompare<'a> { - left: Box>, - ops: Vec, - comparators: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprCall<'a> { - func: Box>, - arguments: NormalizedArguments<'a>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprFormattedValue<'a> { - value: Box>, - debug_text: Option<&'a ast::DebugText>, - conversion: ast::ConversionFlag, - format_spec: Option>>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprFString<'a> { - values: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -enum NormalizedLiteral<'a> { - None, - Ellipsis, - Bool(&'a bool), - Str(String), - Bytes(&'a [u8]), - Number(NormalizedNumber<'a>), -} - -impl<'a> From> for NormalizedLiteral<'a> { - fn from(literal: ast::LiteralExpressionRef<'a>) -> Self { - match literal { - ast::LiteralExpressionRef::NoneLiteral(_) => Self::None, - ast::LiteralExpressionRef::EllipsisLiteral(_) => Self::Ellipsis, - ast::LiteralExpressionRef::BooleanLiteral(ast::ExprBooleanLiteral { - value, .. - }) => Self::Bool(value), - ast::LiteralExpressionRef::StringLiteral(ast::ExprStringLiteral { value, .. }) => { - Self::Str(normalize(value)) - } - ast::LiteralExpressionRef::BytesLiteral(ast::ExprBytesLiteral { value, .. }) => { - Self::Bytes(value) - } - ast::LiteralExpressionRef::NumberLiteral(ast::ExprNumberLiteral { value, .. }) => { - Self::Number(value.into()) - } - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprStringLiteral { - value: String, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprBytesLiteral<'a> { - value: &'a [u8], -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprNumberLiteral<'a> { - value: NormalizedNumber<'a>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprBoolLiteral<'a> { - value: &'a bool, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprAttribute<'a> { - value: Box>, - attr: &'a str, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprSubscript<'a> { - value: Box>, - slice: Box>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprStarred<'a> { - value: Box>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprName<'a> { - id: &'a str, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprList<'a> { - elts: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprTuple<'a> { - elts: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprSlice<'a> { - lower: Option>>, - upper: Option>>, - step: Option>>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct ExprIpyEscapeCommand<'a> { - kind: ast::IpyEscapeKind, - value: &'a str, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -enum NormalizedExpr<'a> { - BoolOp(ExprBoolOp<'a>), - NamedExpr(ExprNamedExpr<'a>), - BinOp(ExprBinOp<'a>), - UnaryOp(ExprUnaryOp<'a>), - Lambda(ExprLambda<'a>), - IfExp(ExprIfExp<'a>), - Dict(ExprDict<'a>), - Set(ExprSet<'a>), - ListComp(ExprListComp<'a>), - SetComp(ExprSetComp<'a>), - DictComp(ExprDictComp<'a>), - GeneratorExp(ExprGeneratorExp<'a>), - Await(ExprAwait<'a>), - Yield(ExprYield<'a>), - YieldFrom(ExprYieldFrom<'a>), - Compare(ExprCompare<'a>), - Call(ExprCall<'a>), - NormalizedValue(ExprFormattedValue<'a>), - FString(ExprFString<'a>), - StringLiteral(ExprStringLiteral), - BytesLiteral(ExprBytesLiteral<'a>), - NumberLiteral(ExprNumberLiteral<'a>), - BoolLiteral(ExprBoolLiteral<'a>), - NoneLiteral, - EllispsisLiteral, - Attribute(ExprAttribute<'a>), - Subscript(ExprSubscript<'a>), - Starred(ExprStarred<'a>), - Name(ExprName<'a>), - List(ExprList<'a>), - Tuple(ExprTuple<'a>), - Slice(ExprSlice<'a>), - IpyEscapeCommand(ExprIpyEscapeCommand<'a>), -} - -impl<'a> From<&'a Box> for Box> { - fn from(expr: &'a Box) -> Self { - Box::new((expr.as_ref()).into()) - } -} - -impl<'a> From<&'a Box> for NormalizedExpr<'a> { - fn from(expr: &'a Box) -> Self { - (expr.as_ref()).into() - } -} - -impl<'a> From<&'a ast::Expr> for NormalizedExpr<'a> { - fn from(expr: &'a ast::Expr) -> Self { - match expr { - ast::Expr::BoolOp(ast::ExprBoolOp { - op, - values, - range: _, - }) => Self::BoolOp(ExprBoolOp { - op: (*op).into(), - values: values.iter().map(Into::into).collect(), - }), - ast::Expr::NamedExpr(ast::ExprNamedExpr { - target, - value, - range: _, - }) => Self::NamedExpr(ExprNamedExpr { - target: target.into(), - value: value.into(), - }), - ast::Expr::BinOp(ast::ExprBinOp { - left, - op, - right, - range: _, - }) => Self::BinOp(ExprBinOp { - left: left.into(), - op: (*op).into(), - right: right.into(), - }), - ast::Expr::UnaryOp(ast::ExprUnaryOp { - op, - operand, - range: _, - }) => Self::UnaryOp(ExprUnaryOp { - op: (*op).into(), - operand: operand.into(), - }), - ast::Expr::Lambda(ast::ExprLambda { - parameters, - body, - range: _, - }) => Self::Lambda(ExprLambda { - parameters: parameters.as_ref().map(Into::into), - body: body.into(), - }), - ast::Expr::IfExp(ast::ExprIfExp { - test, - body, - orelse, - range: _, - }) => Self::IfExp(ExprIfExp { - test: test.into(), - body: body.into(), - orelse: orelse.into(), - }), - ast::Expr::Dict(ast::ExprDict { - keys, - values, - range: _, - }) => Self::Dict(ExprDict { - keys: keys - .iter() - .map(|expr| expr.as_ref().map(Into::into)) - .collect(), - values: values.iter().map(Into::into).collect(), - }), - ast::Expr::Set(ast::ExprSet { elts, range: _ }) => Self::Set(ExprSet { - elts: elts.iter().map(Into::into).collect(), - }), - ast::Expr::ListComp(ast::ExprListComp { - elt, - generators, - range: _, - }) => Self::ListComp(ExprListComp { - elt: elt.into(), - generators: generators.iter().map(Into::into).collect(), - }), - ast::Expr::SetComp(ast::ExprSetComp { - elt, - generators, - range: _, - }) => Self::SetComp(ExprSetComp { - elt: elt.into(), - generators: generators.iter().map(Into::into).collect(), - }), - ast::Expr::DictComp(ast::ExprDictComp { - key, - value, - generators, - range: _, - }) => Self::DictComp(ExprDictComp { - key: key.into(), - value: value.into(), - generators: generators.iter().map(Into::into).collect(), - }), - ast::Expr::GeneratorExp(ast::ExprGeneratorExp { - elt, - generators, - range: _, - }) => Self::GeneratorExp(ExprGeneratorExp { - elt: elt.into(), - generators: generators.iter().map(Into::into).collect(), - }), - ast::Expr::Await(ast::ExprAwait { value, range: _ }) => Self::Await(ExprAwait { - value: value.into(), - }), - ast::Expr::Yield(ast::ExprYield { value, range: _ }) => Self::Yield(ExprYield { - value: value.as_ref().map(Into::into), - }), - ast::Expr::YieldFrom(ast::ExprYieldFrom { value, range: _ }) => { - Self::YieldFrom(ExprYieldFrom { - value: value.into(), - }) - } - ast::Expr::Compare(ast::ExprCompare { - left, - ops, - comparators, - range: _, - }) => Self::Compare(ExprCompare { - left: left.into(), - ops: ops.iter().copied().map(Into::into).collect(), - comparators: comparators.iter().map(Into::into).collect(), - }), - ast::Expr::Call(ast::ExprCall { - func, - arguments, - range: _, - }) => Self::Call(ExprCall { - func: func.into(), - arguments: arguments.into(), - }), - ast::Expr::FormattedValue(ast::ExprFormattedValue { - value, - conversion, - debug_text, - format_spec, - range: _, - }) => Self::NormalizedValue(ExprFormattedValue { - value: value.into(), - conversion: *conversion, - debug_text: debug_text.as_ref(), - format_spec: format_spec.as_ref().map(Into::into), - }), - ast::Expr::FString(ast::ExprFString { - values, - implicit_concatenated: _, - range: _, - }) => Self::FString(ExprFString { - values: values.iter().map(Into::into).collect(), - }), - ast::Expr::StringLiteral(ast::ExprStringLiteral { - value, - // Compare strings based on resolved value, not representation (i.e., ignore whether - // the string was implicitly concatenated). - implicit_concatenated: _, - unicode: _, - range: _, - }) => Self::StringLiteral(ExprStringLiteral { - value: normalize(value), - }), - ast::Expr::BytesLiteral(ast::ExprBytesLiteral { - value, - // Compare bytes based on resolved value, not representation (i.e., ignore whether - // the bytes was implicitly concatenated). - implicit_concatenated: _, - range: _, - }) => Self::BytesLiteral(ExprBytesLiteral { value }), - ast::Expr::NumberLiteral(ast::ExprNumberLiteral { value, range: _ }) => { - Self::NumberLiteral(ExprNumberLiteral { - value: value.into(), - }) - } - ast::Expr::BooleanLiteral(ast::ExprBooleanLiteral { value, range: _ }) => { - Self::BoolLiteral(ExprBoolLiteral { value }) - } - ast::Expr::NoneLiteral(_) => Self::NoneLiteral, - ast::Expr::EllipsisLiteral(_) => Self::EllispsisLiteral, - ast::Expr::Attribute(ast::ExprAttribute { - value, - attr, - ctx: _, - range: _, - }) => Self::Attribute(ExprAttribute { - value: value.into(), - attr: attr.as_str(), - }), - ast::Expr::Subscript(ast::ExprSubscript { - value, - slice, - ctx: _, - range: _, - }) => Self::Subscript(ExprSubscript { - value: value.into(), - slice: slice.into(), - }), - ast::Expr::Starred(ast::ExprStarred { - value, - ctx: _, - range: _, - }) => Self::Starred(ExprStarred { - value: value.into(), - }), - ast::Expr::Name(name) => name.into(), - ast::Expr::List(ast::ExprList { - elts, - ctx: _, - range: _, - }) => Self::List(ExprList { - elts: elts.iter().map(Into::into).collect(), - }), - ast::Expr::Tuple(ast::ExprTuple { - elts, - ctx: _, - range: _, - }) => Self::Tuple(ExprTuple { - elts: elts.iter().map(Into::into).collect(), - }), - ast::Expr::Slice(ast::ExprSlice { - lower, - upper, - step, - range: _, - }) => Self::Slice(ExprSlice { - lower: lower.as_ref().map(Into::into), - upper: upper.as_ref().map(Into::into), - step: step.as_ref().map(Into::into), - }), - ast::Expr::IpyEscapeCommand(ast::ExprIpyEscapeCommand { - kind, - value, - range: _, - }) => Self::IpyEscapeCommand(ExprIpyEscapeCommand { - kind: *kind, - value: value.as_str(), - }), - } - } -} - -impl<'a> From<&'a ast::ExprName> for NormalizedExpr<'a> { - fn from(expr: &'a ast::ExprName) -> Self { - Self::Name(ExprName { - id: expr.id.as_str(), - }) - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtFunctionDef<'a> { - is_async: bool, - decorator_list: Vec>, - name: &'a str, - type_params: Option>, - parameters: NormalizedParameters<'a>, - returns: Option>, - body: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtClassDef<'a> { - decorator_list: Vec>, - name: &'a str, - type_params: Option>, - arguments: NormalizedArguments<'a>, - body: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtReturn<'a> { - value: Option>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtDelete<'a> { - targets: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtTypeAlias<'a> { - name: Box>, - type_params: Option>, - value: Box>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct NormalizedTypeParams<'a> { - type_params: Vec>, -} - -impl<'a> From<&'a ast::TypeParams> for NormalizedTypeParams<'a> { - fn from(type_params: &'a ast::TypeParams) -> Self { - Self { - type_params: type_params.iter().map(Into::into).collect(), - } - } -} - -impl<'a> From<&'a Box> for NormalizedTypeParams<'a> { - fn from(type_params: &'a Box) -> Self { - type_params.as_ref().into() - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -enum NormalizedTypeParam<'a> { - TypeVar(TypeParamTypeVar<'a>), - ParamSpec(TypeParamParamSpec<'a>), - TypeVarTuple(TypeParamTypeVarTuple<'a>), -} - -impl<'a> From<&'a ast::TypeParam> for NormalizedTypeParam<'a> { - fn from(type_param: &'a ast::TypeParam) -> Self { - match type_param { - ast::TypeParam::TypeVar(ast::TypeParamTypeVar { - name, - bound, - range: _, - }) => Self::TypeVar(TypeParamTypeVar { - name: name.as_str(), - bound: bound.as_ref().map(Into::into), - }), - ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, range: _ }) => { - Self::TypeVarTuple(TypeParamTypeVarTuple { - name: name.as_str(), - }) - } - ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, range: _ }) => { - Self::ParamSpec(TypeParamParamSpec { - name: name.as_str(), - }) - } - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct TypeParamTypeVar<'a> { - name: &'a str, - bound: Option>>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct TypeParamParamSpec<'a> { - name: &'a str, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct TypeParamTypeVarTuple<'a> { - name: &'a str, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtAssign<'a> { - targets: Vec>, - value: NormalizedExpr<'a>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtAugAssign<'a> { - target: NormalizedExpr<'a>, - op: NormalizedOperator, - value: NormalizedExpr<'a>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtAnnAssign<'a> { - target: NormalizedExpr<'a>, - annotation: NormalizedExpr<'a>, - value: Option>, - simple: bool, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtFor<'a> { - is_async: bool, - target: NormalizedExpr<'a>, - iter: NormalizedExpr<'a>, - body: Vec>, - orelse: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtWhile<'a> { - test: NormalizedExpr<'a>, - body: Vec>, - orelse: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtIf<'a> { - test: NormalizedExpr<'a>, - body: Vec>, - elif_else_clauses: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtWith<'a> { - is_async: bool, - items: Vec>, - body: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtMatch<'a> { - subject: NormalizedExpr<'a>, - cases: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtRaise<'a> { - exc: Option>, - cause: Option>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtTry<'a> { - body: Vec>, - handlers: Vec>, - orelse: Vec>, - finalbody: Vec>, - is_star: bool, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtAssert<'a> { - test: NormalizedExpr<'a>, - msg: Option>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtImport<'a> { - names: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtImportFrom<'a> { - module: Option<&'a str>, - names: Vec>, - level: Option, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtGlobal<'a> { - names: Vec<&'a str>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtNonlocal<'a> { - names: Vec<&'a str>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtExpr<'a> { - value: NormalizedExpr<'a>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -struct StmtIpyEscapeCommand<'a> { - kind: ast::IpyEscapeKind, - value: &'a str, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -enum NormalizedStmt<'a> { - FunctionDef(StmtFunctionDef<'a>), - ClassDef(StmtClassDef<'a>), - Return(StmtReturn<'a>), - Delete(StmtDelete<'a>), - Assign(StmtAssign<'a>), - AugAssign(StmtAugAssign<'a>), - AnnAssign(StmtAnnAssign<'a>), - For(StmtFor<'a>), - While(StmtWhile<'a>), - If(StmtIf<'a>), - With(StmtWith<'a>), - Match(StmtMatch<'a>), - Raise(StmtRaise<'a>), - Try(StmtTry<'a>), - TypeAlias(StmtTypeAlias<'a>), - Assert(StmtAssert<'a>), - Import(StmtImport<'a>), - ImportFrom(StmtImportFrom<'a>), - Global(StmtGlobal<'a>), - Nonlocal(StmtNonlocal<'a>), - IpyEscapeCommand(StmtIpyEscapeCommand<'a>), - Expr(StmtExpr<'a>), - Pass, - Break, - Continue, -} - -impl<'a> From<&'a ast::Stmt> for NormalizedStmt<'a> { - fn from(stmt: &'a ast::Stmt) -> Self { - match stmt { - ast::Stmt::FunctionDef(ast::StmtFunctionDef { - is_async, - name, - parameters, - body, - decorator_list, - returns, - type_params, - range: _, - }) => Self::FunctionDef(StmtFunctionDef { - is_async: *is_async, - name: name.as_str(), - parameters: parameters.into(), - body: body.iter().map(Into::into).collect(), - decorator_list: decorator_list.iter().map(Into::into).collect(), - returns: returns.as_ref().map(Into::into), - type_params: type_params.as_ref().map(Into::into), - }), - ast::Stmt::ClassDef(ast::StmtClassDef { - name, - arguments, - body, - decorator_list, - type_params, - range: _, - }) => Self::ClassDef(StmtClassDef { - name: name.as_str(), - arguments: arguments.as_ref().map(Into::into).unwrap_or_default(), - body: body.iter().map(Into::into).collect(), - decorator_list: decorator_list.iter().map(Into::into).collect(), - type_params: type_params.as_ref().map(Into::into), - }), - ast::Stmt::Return(ast::StmtReturn { value, range: _ }) => Self::Return(StmtReturn { - value: value.as_ref().map(Into::into), - }), - ast::Stmt::Delete(ast::StmtDelete { targets, range: _ }) => Self::Delete(StmtDelete { - // Like Black, flatten all tuples, as we may insert parentheses, which changes the - // AST but not the semantics. - targets: targets - .iter() - .flat_map(|target| { - if let ast::Expr::Tuple(tuple) = target { - Left(tuple.elts.iter()) - } else { - Right(std::iter::once(target)) - } - }) - .map(Into::into) - .collect(), - }), - ast::Stmt::TypeAlias(ast::StmtTypeAlias { - range: _, - name, - type_params, - value, - }) => Self::TypeAlias(StmtTypeAlias { - name: name.into(), - type_params: type_params.as_ref().map(Into::into), - value: value.into(), - }), - ast::Stmt::Assign(ast::StmtAssign { - targets, - value, - range: _, - }) => Self::Assign(StmtAssign { - targets: targets.iter().map(Into::into).collect(), - value: value.into(), - }), - ast::Stmt::AugAssign(ast::StmtAugAssign { - target, - op, - value, - range: _, - }) => Self::AugAssign(StmtAugAssign { - target: target.into(), - op: (*op).into(), - value: value.into(), - }), - ast::Stmt::AnnAssign(ast::StmtAnnAssign { - target, - annotation, - value, - simple, - range: _, - }) => Self::AnnAssign(StmtAnnAssign { - target: target.into(), - annotation: annotation.into(), - value: value.as_ref().map(Into::into), - simple: *simple, - }), - ast::Stmt::For(ast::StmtFor { - is_async, - target, - iter, - body, - orelse, - range: _, - }) => Self::For(StmtFor { - is_async: *is_async, - target: target.into(), - iter: iter.into(), - body: body.iter().map(Into::into).collect(), - orelse: orelse.iter().map(Into::into).collect(), - }), - ast::Stmt::While(ast::StmtWhile { - test, - body, - orelse, - range: _, - }) => Self::While(StmtWhile { - test: test.into(), - body: body.iter().map(Into::into).collect(), - orelse: orelse.iter().map(Into::into).collect(), - }), - ast::Stmt::If(ast::StmtIf { - test, - body, - elif_else_clauses, - range: _, - }) => Self::If(StmtIf { - test: test.into(), - body: body.iter().map(Into::into).collect(), - elif_else_clauses: elif_else_clauses.iter().map(Into::into).collect(), - }), - ast::Stmt::With(ast::StmtWith { - is_async, - items, - body, - range: _, - }) => Self::With(StmtWith { - is_async: *is_async, - items: items.iter().map(Into::into).collect(), - body: body.iter().map(Into::into).collect(), - }), - ast::Stmt::Match(ast::StmtMatch { - subject, - cases, - range: _, - }) => Self::Match(StmtMatch { - subject: subject.into(), - cases: cases.iter().map(Into::into).collect(), - }), - ast::Stmt::Raise(ast::StmtRaise { - exc, - cause, - range: _, - }) => Self::Raise(StmtRaise { - exc: exc.as_ref().map(Into::into), - cause: cause.as_ref().map(Into::into), - }), - ast::Stmt::Try(ast::StmtTry { - body, - handlers, - orelse, - finalbody, - is_star, - range: _, - }) => Self::Try(StmtTry { - body: body.iter().map(Into::into).collect(), - handlers: handlers.iter().map(Into::into).collect(), - orelse: orelse.iter().map(Into::into).collect(), - finalbody: finalbody.iter().map(Into::into).collect(), - is_star: *is_star, - }), - ast::Stmt::Assert(ast::StmtAssert { - test, - msg, - range: _, - }) => Self::Assert(StmtAssert { - test: test.into(), - msg: msg.as_ref().map(Into::into), - }), - ast::Stmt::Import(ast::StmtImport { names, range: _ }) => Self::Import(StmtImport { - names: names.iter().map(Into::into).collect(), - }), - ast::Stmt::ImportFrom(ast::StmtImportFrom { - module, - names, - level, - range: _, - }) => Self::ImportFrom(StmtImportFrom { - module: module.as_deref(), - names: names.iter().map(Into::into).collect(), - level: *level, - }), - ast::Stmt::Global(ast::StmtGlobal { names, range: _ }) => Self::Global(StmtGlobal { - names: names.iter().map(ast::Identifier::as_str).collect(), - }), - ast::Stmt::Nonlocal(ast::StmtNonlocal { names, range: _ }) => { - Self::Nonlocal(StmtNonlocal { - names: names.iter().map(ast::Identifier::as_str).collect(), - }) - } - ast::Stmt::IpyEscapeCommand(ast::StmtIpyEscapeCommand { - kind, - value, - range: _, - }) => Self::IpyEscapeCommand(StmtIpyEscapeCommand { - kind: *kind, - value: value.as_str(), - }), - ast::Stmt::Expr(ast::StmtExpr { value, range: _ }) => Self::Expr(StmtExpr { - value: value.into(), - }), - ast::Stmt::Pass(_) => Self::Pass, - ast::Stmt::Break(_) => Self::Break, - ast::Stmt::Continue(_) => Self::Continue, - } - } -} - -#[derive(Debug, PartialEq, Eq, Hash)] -pub(crate) enum NormalizedMod<'a> { - Module(NormalizedModModule<'a>), - Expression(NormalizedModExpression<'a>), -} - -#[derive(Debug, PartialEq, Eq, Hash)] -pub(crate) struct NormalizedModModule<'a> { - body: Vec>, -} - -#[derive(Debug, PartialEq, Eq, Hash)] -pub(crate) struct NormalizedModExpression<'a> { - body: Box>, -} - -impl<'a> From<&'a ast::Mod> for NormalizedMod<'a> { - fn from(mod_: &'a ast::Mod) -> Self { - match mod_ { - ast::Mod::Module(module) => Self::Module(module.into()), - ast::Mod::Expression(expr) => Self::Expression(expr.into()), - } - } -} - -impl<'a> From<&'a ast::ModModule> for NormalizedModModule<'a> { - fn from(module: &'a ast::ModModule) -> Self { - Self { - body: module.body.iter().map(Into::into).collect(), - } - } -} - -impl<'a> From<&'a ast::ModExpression> for NormalizedModExpression<'a> { - fn from(expr: &'a ast::ModExpression) -> Self { - Self { - body: (&expr.body).into(), - } - } -} - -/// Normalize a string by (1) stripping any leading and trailing space from each line, and -/// (2) removing any blank lines from the start and end of the string. -fn normalize(s: &str) -> String { - s.lines() - .map(str::trim) - .collect::>() - .join("\n") - .trim() - .to_owned() -} diff --git a/crates/ruff_python_formatter/tests/normalizer.rs b/crates/ruff_python_formatter/tests/normalizer.rs new file mode 100644 index 0000000000000..5aab798d69333 --- /dev/null +++ b/crates/ruff_python_formatter/tests/normalizer.rs @@ -0,0 +1,83 @@ +use itertools::Either::{Left, Right}; + +use ruff_python_ast::visitor::transformer; +use ruff_python_ast::visitor::transformer::Transformer; +use ruff_python_ast::{self as ast, Expr, Stmt}; + +/// A struct to normalize AST nodes for the purpose of comparing formatted representations for +/// semantic equivalence. +/// +/// Vis-à-vis comparing ASTs, comparing these normalized representations does the following: +/// - Ignores non-abstraction information that we've encoded into the AST, e.g., the difference +/// between `class C: ...` and `class C(): ...`, which is part of our AST but not `CPython`'s. +/// - Normalize strings. The formatter can re-indent docstrings, so we need to compare string +/// contents ignoring whitespace. (Black does the same.) +/// - Ignores nested tuples in deletions. (Black does the same.) +pub(crate) struct Normalizer; + +impl Normalizer { + /// Transform an AST module into a normalized representation. + #[allow(dead_code)] + pub(crate) fn visit_module(&self, module: &mut ast::Mod) { + match module { + ast::Mod::Module(module) => { + self.visit_body(&mut module.body); + } + ast::Mod::Expression(expression) => { + self.visit_expr(&mut expression.body); + } + } + } +} + +impl Transformer for Normalizer { + fn visit_stmt(&self, stmt: &mut Stmt) { + match stmt { + Stmt::ClassDef(class_def) => { + // Treat `class C: ...` and `class C(): ...` equivalently. + if class_def + .arguments + .as_ref() + .is_some_and(|arguments| arguments.is_empty()) + { + class_def.arguments = None; + } + } + Stmt::Delete(delete) => { + // Treat `del a, b` and `del (a, b)` equivalently. + delete.targets = delete + .targets + .clone() + .into_iter() + .flat_map(|target| { + if let Expr::Tuple(tuple) = target { + Left(tuple.elts.into_iter()) + } else { + Right(std::iter::once(target)) + } + }) + .collect(); + } + _ => {} + } + + transformer::walk_stmt(self, stmt); + } + + fn visit_expr(&self, expr: &mut Expr) { + if let Expr::StringLiteral(string_literal) = expr { + // Normalize a string by (1) stripping any leading and trailing space from each + // line, and (2) removing any blank lines from the start and end of the string. + string_literal.value = string_literal + .value + .lines() + .map(str::trim) + .collect::>() + .join("\n") + .trim() + .to_owned(); + } + + transformer::walk_expr(self, expr); + } +}