Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick 3.3 (1020-1021) #4772

Merged
merged 4 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions src/common/expression/ContainerExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#ifndef COMMON_EXPRESSION_CONTAINEREXPRESSION_H_
#define COMMON_EXPRESSION_CONTAINEREXPRESSION_H_

#include "common/base/ObjectPool.h"
#include "common/expression/ConstantExpression.h"
#include "common/expression/ContainerExpression.h"
#include "common/expression/Expression.h"

namespace nebula {
Expand Down Expand Up @@ -63,7 +66,16 @@ class MapItemList final {
std::vector<Pair> items_;
};

class ListExpression final : public Expression {
class ContainerExpression : public Expression {
public:
virtual const std::vector<Expression *> getKeys() const = 0;
virtual size_t size() const = 0;

protected:
ContainerExpression(ObjectPool *pool, Expression::Kind kind) : Expression(pool, kind) {}
};

class ListExpression final : public ContainerExpression {
public:
ListExpression &operator=(const ListExpression &rhs) = delete;
ListExpression &operator=(ListExpression &&) = delete;
Expand All @@ -84,15 +96,15 @@ class ListExpression final : public Expression {
items_[index] = item;
}

std::vector<Expression *> get() {
const std::vector<Expression *> getKeys() const override {
return items_;
}

void setItems(std::vector<Expression *> items) {
items_ = items;
}

size_t size() const {
size_t size() const override {
return items_.size();
}

Expand All @@ -116,9 +128,9 @@ class ListExpression final : public Expression {

private:
friend ObjectPool;
explicit ListExpression(ObjectPool *pool) : Expression(pool, Kind::kList) {}
explicit ListExpression(ObjectPool *pool) : ContainerExpression(pool, Kind::kList) {}

ListExpression(ObjectPool *pool, ExpressionList *items) : Expression(pool, Kind::kList) {
ListExpression(ObjectPool *pool, ExpressionList *items) : ContainerExpression(pool, Kind::kList) {
items_ = items->get();
}

Expand All @@ -131,7 +143,7 @@ class ListExpression final : public Expression {
Value result_;
};

class SetExpression final : public Expression {
class SetExpression final : public ContainerExpression {
public:
SetExpression &operator=(const SetExpression &rhs) = delete;
SetExpression &operator=(SetExpression &&) = delete;
Expand All @@ -152,15 +164,15 @@ class SetExpression final : public Expression {
items_[index] = item;
}

std::vector<Expression *> get() {
const std::vector<Expression *> getKeys() const override {
return items_;
}

void setItems(std::vector<Expression *> items) {
items_ = items;
}

size_t size() const {
size_t size() const override {
return items_.size();
}

Expand All @@ -184,9 +196,9 @@ class SetExpression final : public Expression {

private:
friend ObjectPool;
explicit SetExpression(ObjectPool *pool) : Expression(pool, Kind::kSet) {}
explicit SetExpression(ObjectPool *pool) : ContainerExpression(pool, Kind::kSet) {}

SetExpression(ObjectPool *pool, ExpressionList *items) : Expression(pool, Kind::kSet) {
SetExpression(ObjectPool *pool, ExpressionList *items) : ContainerExpression(pool, Kind::kSet) {
items_ = items->get();
}

Expand All @@ -199,7 +211,7 @@ class SetExpression final : public Expression {
Value result_;
};

class MapExpression final : public Expression {
class MapExpression final : public ContainerExpression {
public:
MapExpression &operator=(const MapExpression &rhs) = delete;
MapExpression &operator=(MapExpression &&) = delete;
Expand Down Expand Up @@ -230,7 +242,15 @@ class MapExpression final : public Expression {
return items_;
}

size_t size() const {
const std::vector<Expression *> getKeys() const override {
std::vector<Expression *> keys;
for (const auto &item : items_) {
keys.emplace_back(ConstantExpression::make(pool_, item.first));
}
return keys;
}

size_t size() const override {
return items_.size();
}

Expand All @@ -254,9 +274,9 @@ class MapExpression final : public Expression {

private:
friend ObjectPool;
explicit MapExpression(ObjectPool *pool) : Expression(pool, Kind::kMap) {}
explicit MapExpression(ObjectPool *pool) : ContainerExpression(pool, Kind::kMap) {}

MapExpression(ObjectPool *pool, MapItemList *items) : Expression(pool, Kind::kMap) {
MapExpression(ObjectPool *pool, MapItemList *items) : ContainerExpression(pool, Kind::kMap) {
items_ = items->get();
}

Expand Down
2 changes: 2 additions & 0 deletions src/graph/planner/ngql/GoPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ PlanNode* GoPlanner::lastStep(PlanNode* dep, PlanNode* join) {
gd->setInputVar(goCtx_->vidsVar);
gd->setColNames({goCtx_->dstIdColName});
auto* dedup = Dedup::make(qctx, gd);
dedup->setColNames(gd->colNames());
cur = dedup;

if (goCtx_->joinDst) {
Expand Down Expand Up @@ -486,6 +487,7 @@ SubPlan GoPlanner::nStepsPlan(SubPlan& startVidPlan) {
gd->setColNames({goCtx_->dstIdColName});
auto* dedup = Dedup::make(qctx, gd);
dedup->setOutputVar(goCtx_->vidsVar);
dedup->setColNames(gd->colNames());
getDst = dedup;

loopBody = getDst;
Expand Down
75 changes: 73 additions & 2 deletions src/graph/util/ExpressionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

#include "graph/util/ExpressionUtils.h"

#include "ExpressionUtils.h"
#include "common/base/ObjectPool.h"
#include "common/expression/ArithmeticExpression.h"
#include "common/expression/ConstantExpression.h"
#include "common/expression/ContainerExpression.h"
#include "common/expression/Expression.h"
#include "common/expression/PropertyExpression.h"
#include "common/function/AggFunctionManager.h"
Expand Down Expand Up @@ -50,6 +53,30 @@ const Expression *ExpressionUtils::findAny(const Expression *self,
return nullptr;
}

bool ExpressionUtils::findEdgeDstExpr(const Expression *expr) {
auto finder = [](const Expression *e) -> bool {
if (e->kind() == Expression::Kind::kEdgeDst) {
return true;
} else {
auto name = e->toString();
std::transform(name.begin(), name.end(), name.begin(), ::tolower);
if (name == "id($$)") {
return true;
}
}
return false;
};
if (finder(expr)) {
return true;
}
FindVisitor visitor(finder);
const_cast<Expression *>(expr)->accept(&visitor);
if (!visitor.results().empty()) {
return true;
}
return false;
}

// Finds all expressions fit the exprected list
// Returns an empty vector if no expression found
std::vector<const Expression *> ExpressionUtils::collectAll(
Expand Down Expand Up @@ -177,6 +204,50 @@ Expression *ExpressionUtils::rewriteParameter(const Expression *expr, QueryConte
return graph::RewriteVisitor::transform(expr, matcher, rewriter);
}

Expression *ExpressionUtils::rewriteInnerInExpr(const Expression *expr) {
auto matcher = [](const Expression *e) -> bool {
if (e->kind() != Expression::Kind::kRelIn) {
return false;
}
auto rhs = static_cast<const RelationalExpression *>(e)->right();
if (rhs->kind() != Expression::Kind::kList && rhs->kind() != Expression::Kind::kSet) {
return false;
}
auto items = static_cast<const ContainerExpression *>(rhs)->getKeys();
for (const auto *item : items) {
if (!ExpressionUtils::isEvaluableExpr(item)) {
return false;
}
}
return true;
};
auto rewriter = [](const Expression *e) -> Expression * {
DCHECK_EQ(e->kind(), Expression::Kind::kRelIn);
const auto re = static_cast<const RelationalExpression *>(e);
auto lhs = re->left();
auto rhs = re->right();
DCHECK(rhs->kind() == Expression::Kind::kList || rhs->kind() == Expression::Kind::kSet);
auto ce = static_cast<const ContainerExpression *>(rhs);
auto pool = e->getObjPool();
auto *rewrittenExpr = LogicalExpression::makeOr(pool);
// Pointer to a single-level expression
Expression *singleExpr = nullptr;
auto items = ce->getKeys();
for (auto i = 0u; i < items.size(); ++i) {
auto *ee = RelationalExpression::makeEQ(pool, lhs->clone(), items[i]->clone());
rewrittenExpr->addOperand(ee);
if (i == 0) {
singleExpr = ee;
} else {
singleExpr = nullptr;
}
}
return singleExpr ? singleExpr : rewrittenExpr;
};

return graph::RewriteVisitor::transform(expr, matcher, rewriter);
}

Expression *ExpressionUtils::rewriteAgg2VarProp(const Expression *expr) {
ObjectPool *pool = expr->getObjPool();
auto matcher = [](const Expression *e) -> bool {
Expand Down Expand Up @@ -344,10 +415,10 @@ std::vector<Expression *> ExpressionUtils::getContainerExprOperands(const Expres
std::vector<Expression *> containerOperands;
switch (containerExpr->kind()) {
case Expression::Kind::kList:
containerOperands = static_cast<ListExpression *>(containerExpr)->get();
containerOperands = static_cast<ListExpression *>(containerExpr)->getKeys();
break;
case Expression::Kind::kSet: {
containerOperands = static_cast<SetExpression *>(containerExpr)->get();
containerOperands = static_cast<SetExpression *>(containerExpr)->getKeys();
break;
}
case Expression::Kind::kMap: {
Expand Down
6 changes: 6 additions & 0 deletions src/graph/util/ExpressionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class ExpressionUtils {
// Rewrites ParameterExpression to ConstantExpression
static Expression* rewriteParameter(const Expression* expr, QueryContext* qctx);

// Rewrite RelInExpr with only one operand in expression tree
static Expression* rewriteInnerInExpr(const Expression* expr);

// Rewrites relational expression, gather all evaluable expressions in the left operands and move
// them to the right
static Expression* rewriteRelExpr(const Expression* expr);
Expand Down Expand Up @@ -193,6 +196,9 @@ class ExpressionUtils {
// Checks if expr contains function call expression that generate a random value
static bool findInnerRandFunction(const Expression* expr);

// Checks if expr contains function EdgeDst expr or id($$) expr
static bool findEdgeDstExpr(const Expression* expr);

// ** Loop node condition **
// Generates an expression that is used as the condition of loop node:
// ++loopSteps <= steps
Expand Down
11 changes: 10 additions & 1 deletion src/graph/validator/GoValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ Status GoValidator::validateWhere(WhereClause* where) {
return Status::SemanticError(ss.str());
}

goCtx_->filter = rewriteVertexEdge2EdgeProp(filter);
NG_RETURN_IF_ERROR(deduceProps(filter, goCtx_->exprProps, &tagIds_, &goCtx_->over.edgeTypes));
goCtx_->filter = filter;
return Status::OK();
}

Expand Down Expand Up @@ -311,6 +311,15 @@ bool GoValidator::isSimpleCase() {
}
if (exprProps.hasInputVarProperty()) return false;

// Check filter clause
// Because GetDstBySrc doesn't support filter push down,
// so we don't optimize such case.
if (goCtx_->filter) {
if (ExpressionUtils::findEdgeDstExpr(goCtx_->filter)) {
return false;
}
}

// Check yield clause
if (!goCtx_->distinct) return false;
bool atLeastOneDstId = false;
Expand Down
3 changes: 2 additions & 1 deletion src/graph/validator/MatchValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ Status MatchValidator::buildEdgeInfo(const MatchPath *path,
// Rewrite expression to fit semantic, check type and check used aliases.
Status MatchValidator::validateFilter(const Expression *filter,
WhereClauseContext &whereClauseCtx) {
auto transformRes = ExpressionUtils::filterTransform(filter);
auto *newFilter = graph::ExpressionUtils::rewriteInnerInExpr(filter);
auto transformRes = ExpressionUtils::filterTransform(newFilter);
NG_RETURN_IF_ERROR(transformRes);
// rewrite Attribute to LabelTagProperty
whereClauseCtx.filter = ExpressionUtils::rewriteAttr2LabelTagProp(
Expand Down
2 changes: 1 addition & 1 deletion src/meta/processors/job/StorageJobExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class StorageJobExecutor : public JobExecutor {
}

JobDescription getJobDescription() override {
JobDescription ret;
JobDescription ret{space_, jobId_, cpp2::JobType::UNKNOWN};
return ret;
}

Expand Down
48 changes: 26 additions & 22 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,31 +50,36 @@ def pytest_runtest_logreport(report):

def pytest_addoption(parser):
for config in all_configs:
parser.addoption(config,
dest=all_configs[config][0],
default=all_configs[config][1],
help=all_configs[config][2])
parser.addoption(
config,
dest=all_configs[config][0],
default=all_configs[config][1],
help=all_configs[config][2],
)

parser.addoption("--build_dir",
dest="build_dir",
default=BUILD_DIR,
help="NebulaGraph CMake build directory")
parser.addoption("--src_dir",
dest="src_dir",
default=NEBULA_HOME,
help="NebulaGraph workspace")
parser.addoption(
"--build_dir",
dest="build_dir",
default=BUILD_DIR,
help="NebulaGraph CMake build directory",
)
parser.addoption(
"--src_dir", dest="src_dir", default=NEBULA_HOME, help="NebulaGraph workspace"
)


def pytest_bdd_step_error(request, feature, scenario, step, step_func, step_func_args):
logging.info("=== more error information ===")
logging.info("feature is {}".format(feature.filename))
logging.info("step line number is {}".format(step.line_number))
logging.info("step name is {}".format(step.name))
if step_func_args.get("graph_spaces") is not None:
logging.error("Location: {}:{}".format(feature.filename, step.line_number))
logging.error("Step: {}".format(step.name))
graph_spaces = None
if graph_spaces is None and step_func_args.get("graph_spaces") is not None:
graph_spaces = step_func_args.get("graph_spaces")
if graph_spaces.get("space_desc") is not None:
logging.info("error space is {}".format(
graph_spaces.get("space_desc")))

if graph_spaces is None and step_func_args.get("exec_ctx") is not None:
graph_spaces = step_func_args.get("exec_ctx")

if graph_spaces is not None and graph_spaces.get("space_desc") is not None:
logging.error("Space: {}".format(graph_spaces.get("space_desc")))


def pytest_configure(config):
Expand Down Expand Up @@ -125,8 +130,7 @@ def get_ssl_config_from_tmp():

@pytest.fixture(scope="class")
def class_fixture_variables():
"""save class scope fixture, used for session update.
"""
"""save class scope fixture, used for session update."""
# cluster is the instance of NebulaService
# current_session is the session currently using
# sessions is a list of all sessions in the cluster
Expand Down
Loading