Skip to content

Commit

Permalink
[Hexa] fix decimal type to arrow type(shit code), Rename execution pl…
Browse files Browse the repository at this point in the history
…an in fe
  • Loading branch information
suibianwanwank committed May 7, 2024
1 parent 3cc6c6a commit 56f47f2
Show file tree
Hide file tree
Showing 14 changed files with 150 additions and 40 deletions.
4 changes: 4 additions & 0 deletions be/src/core/src/datasource/common/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use datafusion::arrow::datatypes::DataType::{Decimal128};
use snafu::ResultExt;
use sqlx::Row;
use std::sync::Arc;
use tracing::info;

for_primitive_array_variants! {impl_sqlx_rows_to_primitive_array_data}

Expand Down Expand Up @@ -81,10 +82,13 @@ where
arr.append_null();
}
}
info!("Row transform Decimal{}, {}", p, s);
let arr = arr
.finish()
.with_precision_and_scale(p, s)
.context(RecordBatchCreateSnafu {})?;

info!("Row transform Decimal array:{}, {}", arr.precision(), arr.scale());
Ok(Arc::new(arr))
}
_ => Err(ArrayCreate {
Expand Down
6 changes: 4 additions & 2 deletions be/src/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::str::FromStr;
use std::sync::Arc;

use crate::protobuf::{
Expand Down Expand Up @@ -749,9 +750,10 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
null_type.try_into().map_err(Error::DataFusionError)?
}
Value::Decimal128Value(val) => {
let array = vec_to_array(val.value.clone());
// let array = vec_to_array(val.value.clone());
let value = String::from_utf8(val.value.clone()).unwrap();
Self::Decimal128(
Some(i128::from_be_bytes(array)),
Some(i128::from_str(value.as_str()).unwrap()),
val.p as u8,
val.s as i8,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
import com.ccsu.meta.type.arrow.ArrowTypeEnum;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeFactoryImpl;
import org.apache.calcite.rel.type.RelDataTypeSystemImpl;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.checkerframework.checker.nullness.qual.Nullable;

public class ArrowDataTypeSystem extends RelDataTypeSystemImpl {

public static final int MAX_NUMERIC_PRECISION = 38;
public static final int MAX_NUMERIC_SCALE = 38;

@Override
public RelDataType deriveSumType(RelDataTypeFactory typeFactory, RelDataType argumentType) {
Expand Down Expand Up @@ -38,11 +42,75 @@ public RelDataType deriveSumType(RelDataTypeFactory typeFactory, RelDataType arg
case DECIMAL:
return typeFactory.createTypeWithNullability(((ArrowTypeFactory) typeFactory)
.createArrowType(ArrowTypeEnum.DECIMAL,
SqlTypeName.DECIMAL, MAX_NUMERIC_PRECISION, argumentType.getScale()),
SqlTypeName.DECIMAL, argumentType.getPrecision() + 10, argumentType.getScale()),
argumentType.isNullable());
}
return argumentType;
}
return super.deriveSumType(typeFactory, argumentType);
}

@Override
public @Nullable RelDataType deriveDecimalMultiplyType(RelDataTypeFactory typeFactory, RelDataType type1, RelDataType type2) {
if (SqlTypeUtil.isExactNumeric(type1)
&& SqlTypeUtil.isExactNumeric(type2)) {
if (SqlTypeUtil.isDecimal(type1)
|| SqlTypeUtil.isDecimal(type2)) {
// Java numeric will always have invalid precision/scale,
// use its default decimal precision/scale instead.
type1 = RelDataTypeFactoryImpl.isJavaType(type1)
? typeFactory.decimalOf(type1)
: type1;
type2 = RelDataTypeFactoryImpl.isJavaType(type2)
? typeFactory.decimalOf(type2)
: type2;
int p1 = type1.getPrecision();
int p2 = type2.getPrecision();
int s1 = type1.getScale();
int s2 = type2.getScale();

int scale = s1 + s2;
scale = Math.min(scale, getMaxPrecision(SqlTypeName.DECIMAL));
int precision = p1 + p2 + 1;
precision = Math.min(precision, getMaxScale(SqlTypeName.DECIMAL));

RelDataType ret;
ret = ((ArrowTypeFactory) typeFactory).createArrowType(ArrowTypeEnum.DECIMAL, SqlTypeName.DECIMAL, precision, scale);

return ret;
}
}

return null;
}

@Override
public RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory, RelDataType argumentType) {
if (argumentType.getSqlTypeName() == SqlTypeName.DECIMAL) {
return ((ArrowTypeFactory) typeFactory).createArrowType(ArrowTypeEnum.DECIMAL,
SqlTypeName.DECIMAL, argumentType.getPrecision() + 4, argumentType.getScale() + 4);
}
return super.deriveAvgAggType(typeFactory, argumentType);
}

@Override
public int getMaxScale(SqlTypeName typeName) {
if (typeName == SqlTypeName.DECIMAL) {
return MAX_NUMERIC_SCALE;
}
return super.getMaxScale(typeName);
}

@Override
public int getMaxPrecision(SqlTypeName typeName) {
if (typeName == SqlTypeName.DECIMAL) {
return MAX_NUMERIC_PRECISION;
}
return super.getMaxPrecision(typeName);
}

@Override
public int getMaxNumericPrecision() {
return super.getMaxNumericPrecision();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import java.util.List;

import static program.physical.rel.PhysicalPlanTransformUtil.*;
import static program.physical.rel.ExecutionPlanToDataFusionPlanUtil.*;

public class AggregateExecutionPlan
extends EnumerableAggregateBase
Expand Down Expand Up @@ -66,7 +66,9 @@ public proto.datafusion.PhysicalPlanNode transformToDataFusionNode() {
aggregateNode.setMode(AggregateMode.SINGLE);
aggregateNode.setInputSchema(buildRelNodeSchema(inputFields));

return proto.datafusion.PhysicalPlanNode.newBuilder().setAggregate(aggregateNode).build();
return proto.datafusion.PhysicalPlanNode.newBuilder()
.setAggregate(aggregateNode)
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

import java.util.Set;

import static program.physical.rel.PhysicalPlanTransformUtil.transformJoinType;
import static program.physical.rel.PhysicalPlanTransformUtil.transformRexNodeToJoinFilter;
import static program.physical.rel.ExecutionPlanToDataFusionPlanUtil.transformJoinType;
import static program.physical.rel.ExecutionPlanToDataFusionPlanUtil.transformRexNodeToJoinFilter;

public class BatchNestedLoopJoinExecutionPlan
extends EnumerableBatchNestedLoopJoin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,36 @@
import com.ccsu.meta.type.ArrowDataType;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.protobuf.ByteString;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.rel2sql.SqlImplementor;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.*;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.RangeSets;
import org.apache.calcite.util.Sarg;
import org.checkerframework.checker.nullness.qual.Nullable;
import proto.datafusion.InListNode;
import proto.datafusion.PhysicalExprNode;
import proto.datafusion.PhysicalInListNode;
import proto.datafusion.PhysicalIsNotNull;
import proto.datafusion.*;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static com.ccsu.error.CommonErrorCode.PLAN_TRANSFORM_ERROR;

public class PhysicalPlanTransformUtil {
/**
* Util for to datafusion physical plan.
*
* Most of the code is a piece of shit and will need a lot of rewriting in the future for this section.
*/
public class ExecutionPlanToDataFusionPlanUtil {

public static final proto.datafusion.EmptyMessage EMPTY_MESSAGE = proto.datafusion.EmptyMessage.getDefaultInstance();

private PhysicalPlanTransformUtil() {
private ExecutionPlanToDataFusionPlanUtil() {
}

public static proto.datafusion.JoinType transformJoinType(JoinRelType type) {
Expand Down Expand Up @@ -180,11 +177,45 @@ public static proto.datafusion.PhysicalExprNode transformRexNodeToExprNode(RexNo
public static proto.datafusion.PhysicalExprNode transformBinaryExpr(RexCall rexCall) {
proto.datafusion.PhysicalBinaryExprNode.Builder binary = proto.datafusion.PhysicalBinaryExprNode.newBuilder();

proto.datafusion.PhysicalExprNode l = transformRexNodeToExprNode(rexCall.getOperands().get(0));
proto.datafusion.PhysicalExprNode r = transformRexNodeToExprNode(rexCall.getOperands().get(1));

binary.setL(l);
binary.setR(r);
RexNode leftOp = rexCall.getOperands().get(0);
RexNode rightOp = rexCall.getOperands().get(1);

RelDataType leftOpType = leftOp.getType();
RelDataType rightOpType = rightOp.getType();

PhysicalExprNode leftExprNode = transformRexNodeToExprNode(leftOp);
PhysicalExprNode rightExprNode = transformRexNodeToExprNode(rightOp);


if (leftOpType.getSqlTypeName() == SqlTypeName.DECIMAL
&& rightOpType.getSqlTypeName() == SqlTypeName.DECIMAL
&& rexCall.getKind().belongsTo(SqlKind.BINARY_COMPARISON)) {
int leftPrecision = leftOpType.getPrecision();
int rightPrecision = rightOpType.getPrecision();

if (leftPrecision > rightPrecision) {

proto.datafusion.Decimal.Builder builder = proto.datafusion.Decimal.newBuilder();
rightExprNode = PhysicalExprNode.newBuilder().setCast(PhysicalCastNode.newBuilder()
.setExpr(rightExprNode)
.setArrowType(ArrowType.newBuilder().setDECIMAL(
builder.setPrecision(leftPrecision).setScale(leftOpType.getScale())
))).build();
}
if (rightPrecision > leftPrecision) {
proto.datafusion.Decimal.Builder builder = proto.datafusion.Decimal.newBuilder();
leftExprNode = PhysicalExprNode.newBuilder().setCast(PhysicalCastNode.newBuilder()
.setExpr(leftExprNode)
.setArrowType(ArrowType.newBuilder().setDECIMAL(
builder.setPrecision(rightPrecision).setScale(rightOpType.getScale())
))).build();
}

}

binary.setL(leftExprNode);
binary.setR(rightExprNode);

if (!BINARY_OP_MAP.containsKey(rexCall.getKind())) {
String errMsg = String.format("RexNode:%s can not be transformed", rexCall.getType());
Expand All @@ -210,7 +241,7 @@ public static proto.datafusion.PhysicalExprNode transformBinaryExpr(RexCall rexC
public static proto.datafusion.PhysicalExprNode transformAggFunction(AggregateCall call, List<RelDataTypeField> fields) {
proto.datafusion.PhysicalAggregateExprNode.Builder builder = proto.datafusion.PhysicalAggregateExprNode.newBuilder();
String aggName = call.getAggregation().getName();
proto.datafusion.AggregateFunction aggregateFunction;
proto.datafusion.AggregateFunction aggregateFunction = null;
try {
aggregateFunction = proto.datafusion.AggregateFunction.valueOf(aggName);
builder.setAggrFunction(aggregateFunction);
Expand All @@ -236,6 +267,14 @@ public static proto.datafusion.ScalarValue transformLiteral(RexLiteral literal)
Comparable comparable = Objects.requireNonNull(literal.getValue());
if (!(literal.getType() instanceof ArrowDataType)) {
String value = literal.getValue2().toString();
if (literal.getType().getSqlTypeName()
== SqlTypeName.DECIMAL) {
return builder.setDecimal128Value(Decimal128.newBuilder()
.setValue(ByteString.copyFrom(value.getBytes(StandardCharsets.UTF_8)))
.setP(literal.getType().getPrecision())
.setS(literal.getType().getScale()))
.build();
}
return builder.setUtf8Value(value)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rex.RexNode;

import static program.physical.rel.PhysicalPlanTransformUtil.transformRexNodeToExprNode;
import static program.physical.rel.ExecutionPlanToDataFusionPlanUtil.transformRexNodeToExprNode;

public class FilterExecutionPlan
extends EnumerableFilter
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package program.physical.rel;

import com.ccsu.error.CommonException;
import org.apache.calcite.adapter.enumerable.EnumerableHashJoin;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptUtil;
Expand All @@ -10,20 +9,16 @@
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import program.util.RexUtils;
import proto.datafusion.HashJoinExecNode;
import proto.datafusion.JoinFilter;
import proto.datafusion.PhysicalExprNode;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

import static com.ccsu.error.CommonErrorCode.PLAN_TRANSFORM_ERROR;
import static program.physical.rel.PhysicalPlanTransformUtil.*;
import static program.physical.rel.ExecutionPlanToDataFusionPlanUtil.*;

public class HashJoinExecutionPlan extends EnumerableHashJoin implements ExecutionPlan {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import java.util.Set;

import static program.physical.rel.PhysicalPlanTransformUtil.transformJoinOn;
import static program.physical.rel.PhysicalPlanTransformUtil.transformJoinType;
import static program.physical.rel.ExecutionPlanToDataFusionPlanUtil.transformJoinOn;
import static program.physical.rel.ExecutionPlanToDataFusionPlanUtil.transformJoinType;

public class MergeJoinExecutionPlan
extends EnumerableMergeJoin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import java.util.Set;

import static program.physical.rel.PhysicalPlanTransformUtil.transformJoinType;
import static program.physical.rel.PhysicalPlanTransformUtil.transformRexNodeToJoinFilter;
import static program.physical.rel.ExecutionPlanToDataFusionPlanUtil.transformJoinType;
import static program.physical.rel.ExecutionPlanToDataFusionPlanUtil.transformRexNodeToJoinFilter;

public class NestedLoopJoinExecutionPlan
extends EnumerableNestedLoopJoin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public proto.datafusion.PhysicalPlanNode transformToDataFusionNode() {
builder.setInput(input);

for (RexNode project : getProjects()) {
builder.addExpr(PhysicalPlanTransformUtil.transformRexNodeToExprNode(project));
builder.addExpr(ExecutionPlanToDataFusionPlanUtil.transformRexNodeToExprNode(project));
}

for (RelDataTypeField relDataTypeField : rowType.getFieldList()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import proto.datafusion.PhysicalExprNode;
import proto.datafusion.PhysicalSortExprNode;

import static program.physical.rel.PhysicalPlanTransformUtil.transformRexNodeToExprNode;
import static program.physical.rel.ExecutionPlanToDataFusionPlanUtil.transformRexNodeToExprNode;

public class SortExecutionPlan
extends EnumerableSort
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import java.util.List;

import static com.ccsu.pojo.DatasourceType.transformToProtoSourceType;
import static program.physical.rel.PhysicalPlanTransformUtil.buildRelNodeSchema;
import static program.physical.rel.ExecutionPlanToDataFusionPlanUtil.buildRelNodeSchema;

public class SourceScanExecutionPlan
extends EnumerableTableScan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public PhysicalPlanNode transformToDataFusionNode() {
throw new CommonException(CommonErrorCode.PLAN_TRANSFORM_ERROR, "Not support value tuples size > 1");
}
for (RexLiteral rexLiteral : tuples.get(0)) {
ScalarValue scalarValue = PhysicalPlanTransformUtil.transformLiteral(rexLiteral);
ScalarValue scalarValue = ExecutionPlanToDataFusionPlanUtil.transformLiteral(rexLiteral);
builder.addExpr(proto.datafusion.PhysicalExprNode.newBuilder()
.setLiteral(scalarValue));
}
Expand Down

0 comments on commit 56f47f2

Please sign in to comment.