From 3369068d994f462354dd021adbcf5b0c4008f3ae Mon Sep 17 00:00:00 2001 From: Jan Nedbal Date: Tue, 25 Jun 2024 10:06:30 +0200 Subject: [PATCH] Proper aggregate function detection --- ...eryAggregateFunctionDetectorTreeWalker.php | 321 ++++++++++++++++++ .../Doctrine/Query/QueryResultTypeWalker.php | 30 +- .../Query/QueryResultTypeWalkerTest.php | 18 + 3 files changed, 342 insertions(+), 27 deletions(-) create mode 100644 src/Type/Doctrine/Query/QueryAggregateFunctionDetectorTreeWalker.php diff --git a/src/Type/Doctrine/Query/QueryAggregateFunctionDetectorTreeWalker.php b/src/Type/Doctrine/Query/QueryAggregateFunctionDetectorTreeWalker.php new file mode 100644 index 00000000..11af2086 --- /dev/null +++ b/src/Type/Doctrine/Query/QueryAggregateFunctionDetectorTreeWalker.php @@ -0,0 +1,321 @@ +doWalkSelectClause($selectStatement->selectClause); + } + + /** + * @param AST\SelectClause $selectClause + */ + public function doWalkSelectClause($selectClause): void + { + foreach ($selectClause->selectExpressions as $selectExpression) { + $this->doWalkSelectExpression($selectExpression); + } + } + + /** + * @param AST\SelectExpression $selectExpression + */ + public function doWalkSelectExpression($selectExpression): void + { + $this->doWalkNode($selectExpression->expression); + } + + /** + * @param mixed $expr + */ + private function doWalkNode($expr): void + { + if ($expr instanceof AST\AggregateExpression) { + $this->markAggregateFunctionFound(); + + } elseif ($expr instanceof AST\Functions\FunctionNode) { + if ($this->isAggregateFunction($expr)) { + $this->markAggregateFunctionFound(); + } + + } elseif ($expr instanceof AST\SimpleArithmeticExpression) { + foreach ($expr->arithmeticTerms as $term) { + $this->doWalkArithmeticTerm($term); + } + + } elseif ($expr instanceof AST\ArithmeticTerm) { + $this->doWalkArithmeticTerm($expr); + + } elseif ($expr instanceof AST\ArithmeticFactor) { + $this->doWalkArithmeticFactor($expr); + + } elseif ($expr instanceof AST\ParenthesisExpression) { + $this->doWalkArithmeticPrimary($expr->expression); + + } elseif ($expr instanceof AST\NullIfExpression) { + $this->doWalkNullIfExpression($expr); + + } elseif ($expr instanceof AST\CoalesceExpression) { + $this->doWalkCoalesceExpression($expr); + + } elseif ($expr instanceof AST\GeneralCaseExpression) { + $this->doWalkGeneralCaseExpression($expr); + + } elseif ($expr instanceof AST\SimpleCaseExpression) { + $this->doWalkSimpleCaseExpression($expr); + + } elseif ($expr instanceof AST\ArithmeticExpression) { + $this->doWalkArithmeticExpression($expr); + + } elseif ($expr instanceof AST\ComparisonExpression) { + $this->doWalkComparisonExpression($expr); + + } elseif ($expr instanceof AST\BetweenExpression) { + $this->doWalkBetweenExpression($expr); + } + } + + public function doWalkCoalesceExpression(AST\CoalesceExpression $coalesceExpression): void + { + foreach ($coalesceExpression->scalarExpressions as $scalarExpression) { + $this->doWalkSimpleArithmeticExpression($scalarExpression); + } + } + + public function doWalkNullIfExpression(AST\NullIfExpression $nullIfExpression): void + { + if (!is_string($nullIfExpression->firstExpression)) { + $this->doWalkSimpleArithmeticExpression($nullIfExpression->firstExpression); + } + + if (is_string($nullIfExpression->secondExpression)) { + return; + } + + $this->doWalkSimpleArithmeticExpression($nullIfExpression->secondExpression); + } + + public function doWalkGeneralCaseExpression(AST\GeneralCaseExpression $generalCaseExpression): void + { + foreach ($generalCaseExpression->whenClauses as $whenClause) { + $this->doWalkConditionalExpression($whenClause->caseConditionExpression); + $this->doWalkSimpleArithmeticExpression($whenClause->thenScalarExpression); + } + + $this->doWalkSimpleArithmeticExpression($generalCaseExpression->elseScalarExpression); + } + + public function doWalkSimpleCaseExpression(AST\SimpleCaseExpression $simpleCaseExpression): void + { + foreach ($simpleCaseExpression->simpleWhenClauses as $simpleWhenClause) { + $this->doWalkSimpleArithmeticExpression($simpleWhenClause->caseScalarExpression); + $this->doWalkSimpleArithmeticExpression($simpleWhenClause->thenScalarExpression); + } + + $this->doWalkSimpleArithmeticExpression($simpleCaseExpression->elseScalarExpression); + } + + /** + * @param AST\ConditionalExpression|AST\Phase2OptimizableConditional $condExpr + */ + public function doWalkConditionalExpression($condExpr): void + { + if (!$condExpr instanceof AST\ConditionalExpression) { + $this->doWalkConditionalTerm($condExpr); // @phpstan-ignore-line PHPStan do not read @psalm-inheritors of Phase2OptimizableConditional + return; + } + + foreach ($condExpr->conditionalTerms as $conditionalTerm) { + $this->doWalkConditionalTerm($conditionalTerm); + } + } + + /** + * @param AST\ConditionalTerm|AST\ConditionalPrimary|AST\ConditionalFactor $condTerm + */ + public function doWalkConditionalTerm($condTerm): void + { + if (!$condTerm instanceof AST\ConditionalTerm) { + $this->doWalkConditionalFactor($condTerm); + return; + } + + foreach ($condTerm->conditionalFactors as $conditionalFactor) { + $this->doWalkConditionalFactor($conditionalFactor); + } + } + + /** + * @param AST\ConditionalFactor|AST\ConditionalPrimary $factor + */ + public function doWalkConditionalFactor($factor): void + { + if (!$factor instanceof AST\ConditionalFactor) { + $this->doWalkConditionalPrimary($factor); + } else { + $this->doWalkConditionalPrimary($factor->conditionalPrimary); + } + } + + /** + * @param AST\ConditionalPrimary $primary + */ + public function doWalkConditionalPrimary($primary): void + { + if ($primary->isSimpleConditionalExpression()) { + if ($primary->simpleConditionalExpression instanceof AST\ComparisonExpression) { + $this->doWalkComparisonExpression($primary->simpleConditionalExpression); + return; + } + $this->doWalkNode($primary->simpleConditionalExpression); + } + + if (!$primary->isConditionalExpression()) { + return; + } + + if ($primary->conditionalExpression === null) { + return; + } + + $this->doWalkConditionalExpression($primary->conditionalExpression); + } + + /** + * @param AST\BetweenExpression $betweenExpr + */ + public function doWalkBetweenExpression($betweenExpr): void + { + $this->doWalkArithmeticExpression($betweenExpr->expression); + $this->doWalkArithmeticExpression($betweenExpr->leftBetweenExpression); + $this->doWalkArithmeticExpression($betweenExpr->rightBetweenExpression); + } + + /** + * @param AST\ComparisonExpression $compExpr + */ + public function doWalkComparisonExpression($compExpr): void + { + $leftExpr = $compExpr->leftExpression; + $rightExpr = $compExpr->rightExpression; + + if ($leftExpr instanceof AST\Node) { + $this->doWalkNode($leftExpr); + } + + if (!($rightExpr instanceof AST\Node)) { + return; + } + + $this->doWalkNode($rightExpr); + } + + /** + * @param AST\ArithmeticExpression $arithmeticExpr + */ + public function doWalkArithmeticExpression($arithmeticExpr): void + { + if (!$arithmeticExpr->isSimpleArithmeticExpression()) { + return; + } + + if ($arithmeticExpr->simpleArithmeticExpression === null) { + return; + } + + $this->doWalkSimpleArithmeticExpression($arithmeticExpr->simpleArithmeticExpression); + } + + /** + * @param AST\Node|string $simpleArithmeticExpr + */ + public function doWalkSimpleArithmeticExpression($simpleArithmeticExpr): void + { + if (!$simpleArithmeticExpr instanceof AST\SimpleArithmeticExpression) { + $this->doWalkArithmeticTerm($simpleArithmeticExpr); + return; + } + + foreach ($simpleArithmeticExpr->arithmeticTerms as $term) { + $this->doWalkArithmeticTerm($term); + } + } + + /** + * @param AST\Node|string $term + */ + public function doWalkArithmeticTerm($term): void + { + if (is_string($term)) { + return; + } + + if (!$term instanceof AST\ArithmeticTerm) { + $this->doWalkArithmeticFactor($term); + return; + } + + foreach ($term->arithmeticFactors as $factor) { + $this->doWalkArithmeticFactor($factor); + } + } + + /** + * @param AST\Node|string $factor + */ + public function doWalkArithmeticFactor($factor): void + { + if (is_string($factor)) { + return; + } + + if (!$factor instanceof AST\ArithmeticFactor) { + $this->doWalkArithmeticPrimary($factor); + return; + } + + $this->doWalkArithmeticPrimary($factor->arithmeticPrimary); + } + + /** + * @param AST\Node|string $primary + */ + public function doWalkArithmeticPrimary($primary): void + { + if ($primary instanceof AST\SimpleArithmeticExpression) { + $this->doWalkSimpleArithmeticExpression($primary); + return; + } + + if (!($primary instanceof AST\Node)) { + return; + } + + $this->doWalkNode($primary); + } + + private function isAggregateFunction(AST\Node $node): bool + { + return $node instanceof AST\Functions\AvgFunction + || $node instanceof AST\Functions\CountFunction + || $node instanceof AST\Functions\MaxFunction + || $node instanceof AST\Functions\MinFunction + || $node instanceof AST\Functions\SumFunction + || $node instanceof AST\AggregateExpression; + } + + private function markAggregateFunctionFound(): void + { + $this->_getQuery()->setHint(self::HINT_HAS_AGGREGATE_FUNCTION, true); + } + +} diff --git a/src/Type/Doctrine/Query/QueryResultTypeWalker.php b/src/Type/Doctrine/Query/QueryResultTypeWalker.php index 5e4bee91..4fb56be3 100644 --- a/src/Type/Doctrine/Query/QueryResultTypeWalker.php +++ b/src/Type/Doctrine/Query/QueryResultTypeWalker.php @@ -117,6 +117,7 @@ class QueryResultTypeWalker extends SqlWalker public static function walk(Query $query, QueryResultTypeBuilder $typeBuilder, DescriptorRegistry $descriptorRegistry): void { $query->setHint(Query::HINT_CUSTOM_OUTPUT_WALKER, self::class); + $query->setHint(Query::HINT_CUSTOM_TREE_WALKERS, [QueryAggregateFunctionDetectorTreeWalker::class]); $query->setHint(self::HINT_TYPE_MAPPING, $typeBuilder); $query->setHint(self::HINT_DESCRIPTOR_REGISTRY, $descriptorRegistry); @@ -137,7 +138,8 @@ public function __construct($query, $parserResult, array $queryComponents) $this->em = $query->getEntityManager(); $this->queryComponents = $queryComponents; $this->nullableQueryComponents = []; - $this->hasAggregateFunction = false; + $this->hasAggregateFunction = $query->hasHint(QueryAggregateFunctionDetectorTreeWalker::HINT_HAS_AGGREGATE_FUNCTION); + $this->hasGroupByClause = false; // The object is instantiated by Doctrine\ORM\Query\Parser, so receiving @@ -176,7 +178,6 @@ public function __construct($query, $parserResult, array $queryComponents) public function walkSelectStatement(AST\SelectStatement $AST): string { $this->typeBuilder->setSelectQuery(); - $this->hasAggregateFunction = $this->hasAggregateFunction($AST); $this->hasGroupByClause = $AST->groupByClause !== null; $this->walkFromClause($AST->fromClause); @@ -1432,29 +1433,4 @@ private function hasAggregateWithoutGroupBy(): bool return $this->hasAggregateFunction && !$this->hasGroupByClause; } - private function hasAggregateFunction(AST\SelectStatement $AST): bool - { - foreach ($AST->selectClause->selectExpressions as $selectExpression) { - if (!$selectExpression instanceof AST\SelectExpression) { - continue; - } - - $expression = $selectExpression->expression; - - switch (true) { - case $expression instanceof AST\Functions\AvgFunction: - case $expression instanceof AST\Functions\CountFunction: - case $expression instanceof AST\Functions\MaxFunction: - case $expression instanceof AST\Functions\MinFunction: - case $expression instanceof AST\Functions\SumFunction: - case $expression instanceof AST\AggregateExpression: - return true; - default: - break; - } - } - - return false; - } - } diff --git a/tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php b/tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php index b64099ff..a1ddb649 100644 --- a/tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php +++ b/tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php @@ -645,6 +645,24 @@ public function getTestData(): iterable ', ]; + yield 'aggregate deeper in AST' => [ + $this->constantArray([ + [ + new ConstantStringType('many'), + TypeCombinator::addNull(new ObjectType(Many::class)), + ], + [ + new ConstantStringType('max'), + $this->intStringified(), + ], + ]), + ' + SELECT m AS many, + COALESCE(MAX(m.intColumn), 0) as max + FROM QueryResult\Entities\Many m + ', + ]; + yield 'aggregate lowercase' => [ $this->constantArray([ [