Skip to content

Commit

Permalink
Merge pull request apache#33 from heliqi/paddle
Browse files Browse the repository at this point in the history
add clip where strided_slice addn argsort op
  • Loading branch information
jiangjiajun committed Sep 14, 2021
2 parents 00f837d + 9647a89 commit 9c536e6
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 75 deletions.
182 changes: 150 additions & 32 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ def convert_addmm(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_addn(g, op, block):
"""Operator converter for sum(add_n)."""

inputs = op.input("X")
out = g.get_node(inputs[0])
for i in range(1, len(inputs)):
out += g.get_node(inputs[i])
g.add_node(op.output("Out")[0], out)


def convert_arg_max(g, op, block):
"""Operator converter for arg_max."""

Expand Down Expand Up @@ -180,6 +190,16 @@ def convert_arg_min(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_argsort(g, op, block):
"""Operator converter for argsort."""

x = g.get_node(op.inputs("X")[0])
axis = op.attr("axis")
descending = op.attr("descending")
out = _op.argsort(x, axis, not descending, dtype="int64")
g.add_node(op.output("Indices")[0], out)


def convert_assign(g, op, block):
"""Operator converter for assign."""

Expand Down Expand Up @@ -339,6 +359,45 @@ def convert_cast(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_clip(g, op, block):
"""Operator converter for clip."""

x = g.get_node(op.input("X")[0])
dtype = infer_type(x).checked_type.dtype
is_dynamic = False
if op.input("Min"):
min_value = g.get_node(op.input("Min")[0])
min_value = _infer_value(min_value, g.get_params())
if isinstance(min_value, _expr.Expr):
is_dynamic = True
else:
min_value = min_value[0]
else:
min_value = op.attr("min")
if op.input("Max"):
max_value = g.get_node(op.input("Max")[0])
max_value = _infer_value(max_value, g.get_params())
if isinstance(max_value, _expr.Expr):
if not is_dynamic:
is_dynamic = True
min_value = _op.const(min_value, dtype)
else:
max_value = max_value[0]
if is_dynamic:
max_value = _op.const(max_value, dtype)
else:
max_value = op.attr("max")
if is_dynamic:
max_value = _op.const(max_value, dtype)

if not is_dynamic:
out = _op.clip(x, min_value, max_value)
else:
out = _op.maximum(x, min_value)
out = _op.minimum(out, max_value)
g.add_node(op.output("Out")[0], out)


def convert_concat(g, op, block):
"""Operator converter for concat."""

Expand Down Expand Up @@ -1418,60 +1477,104 @@ def convert_slice(g, op, block):

data = g.get_node(op.input("Input")[0])
dims = len(infer_shape(data))
dtype = "int64"

axes = op.attr("axes")
axes = _op.const(axes)
indices = _expr.const(axes, dtype="int64")

decrease_axis = op.attr("decrease_axis")
if isinstance(decrease_axis, int):
decrease_axis = [decrease_axis]

starts = op.input("StartsTensor")
if starts:
starts = g.get_node(starts[0])
if op.input("StartsTensor"):
starts = g.get_node(op.input("StartsTensor")[0])
starts = _infer_value(starts, g.get_params())
elif op.input("StartsTensorList"):
starts = []
for start_index in op.input("StartsTensorList"):
start_index = g.get_node(start_index)
if not isinstance(start_index, _expr.Expr):
start_index = _expr.const(start_index, dtype=dtype)
else:
start_index = start_index.astype(dtype)
start_index = g.get_node(start_index).astype("int64")
starts.append(start_index)
starts = _op.concatenate(starts, axis=0)
starts = _infer_value(starts, g.get_params())
else:
starts = op.attr("starts")
starts = _expr.const(starts)
start_dtype = infer_type(starts).checked_type.dtype
if isinstance(starts, _expr.Expr):
starts = _op.scatter(
_op.const([0] * dims, dtype=start_dtype),
axes,
starts,
axis=0,
)

ends = op.input("EndsTensor")
if ends:
ends = g.get_node(ends[0])
if len(axes) < dims:
if isinstance(starts, _expr.Expr):
starts = _op.scatter(
_op.const([0] * dims, dtype=infer_type(starts).checked_type.dtype),
indices,
starts,
axis=0,
)
else:
base = [0] * dims
for i, axis in enumerate(axes):
base[axis] = starts[i]
starts = base

if op.input("EndsTensor"):
ends = g.get_node(op.input("EndsTensor")[0])
ends = _infer_value(ends, g.get_params())
elif op.input("EndsTensorList"):
ends = []
for end_index in op.input("EndsTensorList"):
end_index = g.get_node(end_index)
if not isinstance(end_index, _expr.Expr):
end_index = _expr.const(end_index, dtype=dtype)
else:
end_index = end_index.astype(dtype)
end_index = g.get_node(end_index).astype("int64")
ends.append(end_index)
ends = _op.concatenate(ends, axis=0)
ends = _infer_value(ends, g.get_params())
else:
ends = op.attr("ends")
ends = _expr.const(ends)
if isinstance(ends, _expr.Expr):
data_shape = shape_of(data, infer_type(ends).checked_type.dtype)
ends = _op.scatter(data_shape, axes, ends, axis=0)

strides = _op.const([1] * dims, dtype=start_dtype)
if len(axes) < dims:
if isinstance(ends, _expr.Expr):
ends = _op.scatter(
_expr.const(
np.array([np.iinfo(np.int32).max] * dims),
dtype=infer_type(ends).checked_type.dtype,
),
indices,
ends,
axis=0,
)
else:
base = [np.iinfo(np.int32).max] * dims
for i, axis in enumerate(axes):
base[axis] = ends[i]
ends = base

strides = None
if "StridesTensor" in op.input_names and op.input("StridesTensor"):
strides = g.get_node(op.input("StridesTensor")[0])
strides = _infer_value(strides, g.get_params())
elif "StridesTensorList" in op.input_names and op.input("StridesTensorList"):
strides = []
for strides_index in op.input("StridesTensorList"):
strides_index = g.get_node(strides_index).astype("int64")
strides.append(strides_index)
strides = _op.concatenate(strides, axis=0)
strides = _infer_value(strides, g.get_params())
elif op.has_attr("strides"):
strides = op.attr("strides")

if len(axes) < dims:
if isinstance(strides, _expr.Expr):
strides = _op.scatter(
_expr.const(
np.array([1] * dims),
dtype=infer_type(strides).checked_type.dtype,
),
indices,
strides,
axis=0,
)
elif strides:
base = [1] * dims
for i, axis in enumerate(axes):
base[axis] = strides[i]
strides = base
if not strides:
strides = _op.const([1] * dims, dtype="int64")

out = _op.strided_slice(data, begin=starts, end=ends, strides=strides)
if decrease_axis:
out = _op.squeeze(out, axis=decrease_axis)
Expand Down Expand Up @@ -1623,12 +1726,23 @@ def convert_unsqueeze(g, op, block):
g.add_node(op.output("Out")[0], x)


def convert_where(g, op, block):
"""Operator converter for where."""

condition = g.get_node(op.input("Condition")[0])
x = g.get_node(op.input("X")[0])
y = g.get_node(op.input("Y")[0])
out = _op.where(condition, x, y)
g.add_node(op.output("Out")[0], out)


_convert_map = {
"abs": convert_unary_op,
"acos": convert_unary_op,
"addmm": convert_addmm,
"arg_max": convert_arg_max,
"arg_min": convert_arg_min,
"argsort": convert_argsort,
"asin": convert_unary_op,
"assign": convert_assign,
"assign_value": convert_assign_value,
Expand All @@ -1639,6 +1753,7 @@ def convert_unsqueeze(g, op, block):
"bmm": convert_bmm,
"cast": convert_cast,
"ceil": convert_unary_op,
"clip": convert_clip,
"concat": convert_concat,
"conv2d": convert_conv2d,
"conv2d_transpose": convert_conv2d_transpose,
Expand Down Expand Up @@ -1728,12 +1843,15 @@ def convert_unsqueeze(g, op, block):
"square": convert_square,
"squeeze2": convert_squeeze,
"stack": convert_stack,
"strided_slice": convert_slice,
"sum": convert_addn,
"tan": convert_unary_op,
"tanh": convert_unary_op,
"top_k_v2": convert_topk,
"tile": convert_tile,
"transpose2": convert_transpose,
"unsqueeze2": convert_unsqueeze,
"where": convert_where,
"where_index": convert_nonzero,
}

Expand Down
Loading

0 comments on commit 9c536e6

Please sign in to comment.