From 2a76d98fc66ac20a264af373a9f4e5feaaa5b28f Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 10 Jan 2024 15:02:07 +0800 Subject: [PATCH] [Fea] Support derivative nodes fusing for any order (#745) * support parallel grad to optimize grad computation with reducing common pathes in backward graph * fix * clean code * update grad merge code for any common prefix numerator * clean code * update final code * refine docstrings and variable & function name * add graph_filename * change heat_exchanger to sympy equation --- docs/zh/examples/biharmonic2d.md | 26 +- docs/zh/examples/bracket.md | 15 +- examples/biharmonic2d/biharmonic2d.py | 3 +- ppsci/autodiff/ad.py | 42 +- ppsci/equation/pde/heat_exchanger.py | 47 +-- ppsci/equation/pde/linear_elasticity.py | 9 +- ppsci/solver/solver.py | 26 +- ppsci/utils/symbolic.py | 529 ++++++++++++++++++------ 8 files changed, 505 insertions(+), 192 deletions(-) diff --git a/docs/zh/examples/biharmonic2d.md b/docs/zh/examples/biharmonic2d.md index 86c5de74b..67489c3b0 100644 --- a/docs/zh/examples/biharmonic2d.md +++ b/docs/zh/examples/biharmonic2d.md @@ -114,7 +114,7 @@ examples/biharmonic2d/biharmonic2d.py:93:95 ``` py linenums="97" --8<-- -examples/biharmonic2d/biharmonic2d.py:97:107 +examples/biharmonic2d/biharmonic2d.py:97:108 --8<-- ``` @@ -122,9 +122,9 @@ examples/biharmonic2d/biharmonic2d.py:97:107 以作用在背板内部点的 `InteriorConstraint` 为例,代码如下: -``` py linenums="206" +``` py linenums="207" --8<-- -examples/biharmonic2d/biharmonic2d.py:206:215 +examples/biharmonic2d/biharmonic2d.py:207:216 --8<-- ``` @@ -160,17 +160,17 @@ examples/biharmonic2d/conf/biharmonic2d.yaml:60:62 如 [2 问题定义](#2) 中所述,$x=0$ 处的挠度 $w$ 为 0,有如下边界条件,其他 7 个边界条件也与之类似: -``` py linenums="110" +``` py linenums="111" --8<-- -examples/biharmonic2d/biharmonic2d.py:110:119 +examples/biharmonic2d/biharmonic2d.py:111:120 --8<-- ``` 在方程约束、边界约束构建完毕之后,以刚才的命名为关键字,封装到一个字典中,方便后续访问。 -``` py linenums="216" +``` py linenums="217" --8<-- -examples/biharmonic2d/biharmonic2d.py:216:227 +examples/biharmonic2d/biharmonic2d.py:217:228 --8<-- ``` @@ -178,9 +178,9 @@ examples/biharmonic2d/biharmonic2d.py:216:227 训练过程会调用优化器来更新模型参数,此处选择使用 `Adam` 先进行少量训练后,再使用 `LBFGS` 优化器精调。 -``` py linenums="80" +``` py linenums="81" --8<-- -examples/biharmonic2d/biharmonic2d.py:80:82 +examples/biharmonic2d/biharmonic2d.py:81:83 --8<-- ``` @@ -198,9 +198,9 @@ examples/biharmonic2d/conf/biharmonic2d.yaml:46:56 完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`,然后启动训练,注意两个优化过程需要分别构建 `Solver`。 -``` py linenums="229" +``` py linenums="230" --8<-- -examples/biharmonic2d/biharmonic2d.py:229:268 +examples/biharmonic2d/biharmonic2d.py:230:269 --8<-- ``` @@ -208,9 +208,9 @@ examples/biharmonic2d/biharmonic2d.py:229:268 训练完成后,可以在 `eval` 模式中对训练好的模型进行评估和可视化。由于案例的特殊性,不需构建评估器和可视化器,而是使用自定义代码。 -``` py linenums="271" +``` py linenums="272" --8<-- -examples/biharmonic2d/biharmonic2d.py:271:351 +examples/biharmonic2d/biharmonic2d.py:272:352 --8<-- ``` diff --git a/docs/zh/examples/bracket.md b/docs/zh/examples/bracket.md index cfffc6552..14c6bba14 100644 --- a/docs/zh/examples/bracket.md +++ b/docs/zh/examples/bracket.md @@ -85,16 +85,11 @@ examples/bracket/bracket.py:15:19 Bracket 案例涉及到以下线弹性方程,使用 PaddleScience 内置的 `LinearElasticity` 即可。 -$$ -\begin{cases} - stress\_disp_{xx} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial u}{\partial x} - \sigma_{xx} \\ - stress\_disp_{yy} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial v}{\partial y} - \sigma_{yy} \\ - stress\_disp_{zz} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial w}{\partial z} - \sigma_{zz} \\ - traction_{x} = n_x \sigma_{xx} + n_y \sigma_{xy} + n_z \sigma_{xz} \\ - traction_{y} = n_y \sigma_{yx} + n_y \sigma_{yy} + n_z \sigma_{yz} \\ - traction_{z} = n_z \sigma_{zx} + n_y \sigma_{zy} + n_z \sigma_{zz} \\ -\end{cases} -$$ +--8<-- +ppsci/equation/pde/linear_elasticity.py:30:42 +--8<-- + +对应的方程实例化代码如下: ``` py linenums="32" --8<-- diff --git a/examples/biharmonic2d/biharmonic2d.py b/examples/biharmonic2d/biharmonic2d.py index 3c80f9b32..3cccfd673 100644 --- a/examples/biharmonic2d/biharmonic2d.py +++ b/examples/biharmonic2d/biharmonic2d.py @@ -103,7 +103,8 @@ def train(cfg: DictConfig): "drop_last": True, "shuffle": True, }, - "num_workers": 1, + "num_workers": 0, + "auto_collation": False, } # set constraint diff --git a/ppsci/autodiff/ad.py b/ppsci/autodiff/ad.py index 14ce6b926..c31274433 100644 --- a/ppsci/autodiff/ad.py +++ b/ppsci/autodiff/ad.py @@ -19,7 +19,9 @@ from __future__ import annotations from typing import Dict +from typing import List from typing import Optional +from typing import Union import paddle @@ -36,14 +38,19 @@ class _Jacobian: xs (paddle.Tensor): Input Tensor of shape [batch_size, dim_x]. """ - def __init__(self, ys: "paddle.Tensor", xs: "paddle.Tensor"): + def __init__( + self, + ys: "paddle.Tensor", + xs: "paddle.Tensor", + J: Optional[Dict[int, paddle.Tensor]] = None, + ): self.ys = ys self.xs = xs self.dim_y = ys.shape[1] self.dim_x = xs.shape[1] - self.J: Dict[str, paddle.Tensor] = {} + self.J: Dict[int, paddle.Tensor] = {} if J is None else J def __call__( self, @@ -87,12 +94,12 @@ def __init__(self): def __call__( self, ys: "paddle.Tensor", - xs: "paddle.Tensor", + xs: Union["paddle.Tensor", List["paddle.Tensor"]], i: int = 0, j: Optional[int] = None, retain_graph: Optional[bool] = None, create_graph: bool = True, - ) -> "paddle.Tensor": + ) -> Union["paddle.Tensor", List["paddle.Tensor"]]: """Compute jacobians for given ys and xs. Args: @@ -121,10 +128,29 @@ def __call__( >>> y = x * x >>> dy_dx = ppsci.autodiff.jacobian(y, x) """ - key = (ys, xs) - if key not in self.Js: - self.Js[key] = _Jacobian(ys, xs) - return self.Js[key](i, j, retain_graph, create_graph) + if not isinstance(xs, (list, tuple)): + key = (ys, xs) + if key not in self.Js: + self.Js[key] = _Jacobian(ys, xs) + return self.Js[key](i, j, retain_graph, create_graph) + else: + grads = paddle.grad( + ys, + xs, + create_graph=create_graph, + retain_graph=retain_graph, + ) + Js_list = [] + for k, xs_ in enumerate(xs): + key = (ys, xs_) + assert xs_.shape[-1] == 1, ( + f"The last dim of each xs should be 1, but xs[{k}] has shape " + f"{xs_.shape}" + ) + if key not in self.Js: + self.Js[key] = _Jacobian(ys, xs_, {0: grads[k]}) + Js_list.append(self.Js[key](i, j, retain_graph, create_graph)) + return Js_list def _clear(self): """Clear cached Jacobians.""" diff --git a/ppsci/equation/pde/heat_exchanger.py b/ppsci/equation/pde/heat_exchanger.py index 959b0ad41..b107419aa 100644 --- a/ppsci/equation/pde/heat_exchanger.py +++ b/ppsci/equation/pde/heat_exchanger.py @@ -16,7 +16,6 @@ from typing import Union -from ppsci.autodiff import jacobian from ppsci.equation.pde import base @@ -69,37 +68,25 @@ def __init__( w_c: Union[float, str], ): super().__init__() + x, t, qm_h, qm_c, qm_w = self.create_symbols("x t qm_h qm_c qm_w") - def heat_boundary_fun(out): - x, t, qm_h = out["x"], out["t"], out["qm_h"] - T_h, T_w = out["T_h"], out["T_w"] - T_h_x = jacobian(T_h, x) - T_h_t = jacobian(T_h, t) + T_h = self.create_function("T_h", (x, t, qm_h)) + T_c = self.create_function("T_c", (x, t, qm_c)) + T_w = self.create_function("T_w", (x, t, qm_w)) - beta_h = (alpha_h * v_h) / qm_h - heat_boundary = T_h_t + v_h * T_h_x - beta_h * (T_w - T_h) - return heat_boundary + T_h_x = T_h.diff(x) + T_h_t = T_h.diff(t) + T_c_x = T_c.diff(x) + T_c_t = T_c.diff(t) + T_w_t = T_w.diff(t) - self.add_equation("heat_boundary", heat_boundary_fun) + beta_h = (alpha_h * v_h) / qm_h + beta_c = (alpha_c * v_c) / qm_c - def cold_boundary_fun(out): - x, t, qm_c = out["x"], out["t"], out["qm_c"] - T_c, T_w = out["T_c"], out["T_w"] - T_c_x = jacobian(T_c, x) - T_c_t = jacobian(T_c, t) + heat_boundary = T_h_t + v_h * T_h_x - beta_h * (T_w - T_h) + cold_boundary = T_c_t - v_c * T_c_x - beta_c * (T_w - T_c) + wall = T_w_t - w_h * (T_h - T_w) - w_c * (T_c - T_w) - beta_c = (alpha_c * v_c) / qm_c - cold_boundary = T_c_t - v_c * T_c_x - beta_c * (T_w - T_c) - return cold_boundary - - self.add_equation("cold_boundary", cold_boundary_fun) - - def wall_fun(out): - t = out["t"] - T_c, T_h, T_w = out["T_c"], out["T_h"], out["T_w"] - T_w_t = jacobian(T_w, t) - - wall = T_w_t - w_h * (T_h - T_w) - w_c * (T_c - T_w) - return wall - - self.add_equation("wall", wall_fun) + self.add_equation("heat_boundary", heat_boundary) + self.add_equation("cold_boundary", cold_boundary) + self.add_equation("wall", wall) diff --git a/ppsci/equation/pde/linear_elasticity.py b/ppsci/equation/pde/linear_elasticity.py index a788fa49c..9120c6d21 100644 --- a/ppsci/equation/pde/linear_elasticity.py +++ b/ppsci/equation/pde/linear_elasticity.py @@ -32,9 +32,12 @@ class LinearElasticity(base.PDE): stress\_disp_{xx} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial u}{\partial x} - \sigma_{xx} \\ stress\_disp_{yy} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial v}{\partial y} - \sigma_{yy} \\ stress\_disp_{zz} = \lambda(\dfrac{\partial u}{\partial x} + \dfrac{\partial v}{\partial y} + \dfrac{\partial w}{\partial z}) + 2\mu \dfrac{\partial w}{\partial z} - \sigma_{zz} \\ - traction_{x} = n_x \sigma_{xx} + n_y \sigma_{xy} + n_z \sigma_{xz} \\ - traction_{y} = n_y \sigma_{yx} + n_y \sigma_{yy} + n_z \sigma_{yz} \\ - traction_{z} = n_z \sigma_{zx} + n_y \sigma_{zy} + n_z \sigma_{zz} \\ + stress\_disp_{xy} = \mu(\dfrac{\partial u}{\partial y} + \dfrac{\partial v}{\partial x}) - \sigma_{xy} \\ + stress\_disp_{xz} = \mu(\dfrac{\partial u}{\partial z} + \dfrac{\partial w}{\partial x}) - \sigma_{xz} \\ + stress\_disp_{yz} = \mu(\dfrac{\partial v}{\partial z} + \dfrac{\partial w}{\partial y}) - \sigma_{yz} \\ + equilibrium_{x} = \rho \dfrac{\partial^2 u}{\partial t^2} - (\dfrac{\partial \sigma_{xx}}{\partial x} + \dfrac{\partial \sigma_{xy}}{\partial y} + \dfrac{\partial \sigma_{xz}}{\partial z}) \\ + equilibrium_{y} = \rho \dfrac{\partial^2 u}{\partial t^2} - (\dfrac{\partial \sigma_{xy}}{\partial x} + \dfrac{\partial \sigma_{yy}}{\partial y} + \dfrac{\partial \sigma_{yz}}{\partial z}) \\ + equilibrium_{z} = \rho \dfrac{\partial^2 u}{\partial t^2} - (\dfrac{\partial \sigma_{xz}}{\partial x} + \dfrac{\partial \sigma_{yz}}{\partial y} + \dfrac{\partial \sigma_{zz}}{\partial z}) \\ \end{cases} $$ diff --git a/ppsci/solver/solver.py b/ppsci/solver/solver.py index 3f3fbcd29..d73feb406 100644 --- a/ppsci/solver/solver.py +++ b/ppsci/solver/solver.py @@ -367,14 +367,24 @@ def convert_expr( ] ) -> None: for container in container_dict.values(): - for name, expr in container.output_expr.items(): - if isinstance(expr, sp.Basic): - container.output_expr[name] = ppsci.lambdify( - expr, - self.model, - extra_parameters, - # osp.join(self.output_dir, "symbolic_graph_visual", container.name, name), # HACK: Activate it for DEBUG. - ) + exprs = [ + expr + for expr in container.output_expr.values() + if isinstance(expr, sp.Basic) + ] + if len(exprs) > 0: + funcs = ppsci.lambdify( + exprs, + self.model, + extra_parameters=extra_parameters, + fuse_derivative=True, + # graph_filename=osp.join(self.output_dir, "symbolic_graph_visual") # HACK: Activate it for DEBUG. + ) + ind = 0 + for name in container.output_expr: + if isinstance(container.output_expr[name], sp.Basic): + container.output_expr[name] = funcs[ind] + ind += 1 if self.constraint: convert_expr(self.constraint) diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py index f11b06a5d..88b9e38ea 100644 --- a/ppsci/utils/symbolic.py +++ b/ppsci/utils/symbolic.py @@ -36,6 +36,7 @@ from ppsci import equation from ppsci.autodiff import hessian from ppsci.autodiff import jacobian +from ppsci.utils import logger __all__ = [ "lambdify", @@ -106,6 +107,19 @@ } +def _numerator_of_derivative(expr: sp.Basic) -> sp.Basic: + if not isinstance(expr, sp.Derivative): + raise TypeError( + f"expr({expr}) should be of type sp.Derivative, but got {type(expr)}" + ) + if len(expr.args) <= 2: + if expr.args[1][1] == 1: + return expr.args[0] + return sp.Derivative(expr.args[0], (expr.args[1][0], expr.args[1][1] - 1)) + else: + return sp.Derivative(*expr.args[:-1]) + + def _cvt_to_key(expr: sp.Basic) -> str: """Convert sympy expression to a string key, mainly as retrieval key in dict. @@ -182,41 +196,21 @@ class OperatorNode(Node): Args: expr (SYMPY_BUILTIN_FUNC): Sympy expression. - create_graph (bool, optional): Whether to create the gradient graphs of - the computing process. When it is True, higher order derivatives are - supported to compute; when it is False, the gradient graphs of the - computing process would be discarded. Defaults to True. - retain_graph (Optional[bool]): Whether to retain the forward graph which - is used to calculate the gradient. When it is True, the graph would - be retained, in which way users can calculate backward twice for the - same graph. When it is False, the graph would be freed. Defaults to None, - which means it is equal to `create_graph`. """ def __init__( self, expr: SYMPY_BUILTIN_FUNC, - create_graph: bool = True, - retain_graph: Optional[bool] = None, ): super().__init__(expr) # preprocess children's key instead of processing at run-time in forward # which can reduce considerable overhead of time for calling "_cvt_to_key" - if self.expr.func == sp.Derivative: - self.childs = [_cvt_to_key(self.expr.args[0])] + [ - (_cvt_to_key(arg), int(order)) for (arg, order) in self.expr.args[1:] - ] - self.create_graph = create_graph - self.retain_graph = retain_graph - else: - self.childs = [_cvt_to_key(arg) for arg in self.expr.args] + self.childs = [_cvt_to_key(arg) for arg in self.expr.args] if self.expr.func == sp.Add: self._apply_func = self._add_operator_func elif self.expr.func == sp.Mul: self._apply_func = self._mul_operator_func - elif self.expr.func == sp.Derivative: - self._apply_func = self._derivate_operator_func elif self.expr.func == sp.Heaviside: self._apply_func = self._heaviside_operator_func self._auxiliary_func = SYMPY_TO_PADDLE[sp.Heaviside] @@ -247,31 +241,6 @@ def _mul_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: data_dict[self.key] *= data_dict[child] return data_dict - def _derivate_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: - # NOTE: Derivative of 'sdf' function will not be executed here, which is already - # generated in 'data_dict' during points sampling using discrete difference - # method(see also: ppsci/geometry/geometry.py: Geometry.sdf_derivatives), - # such as 'sdf__x', 'sdf__y'. - data_dict[self.key] = data_dict[self.childs[0]] - for child, order in self.childs[1:]: - if order & 1: - data_dict[self.key] = jacobian( - data_dict[self.key], - data_dict[child], - create_graph=self.create_graph, - retain_graph=self.retain_graph, - ) - order -= 1 - for _ in range(0, order, 2): - data_dict[self.key] = hessian( - data_dict[self.key], - data_dict[child], - create_graph=self.create_graph, - retain_graph=self.retain_graph, - ) - order -= 2 - return data_dict - def _heaviside_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: data_dict[self.key] = self._auxiliary_func(data_dict[self.childs[0]]) return data_dict @@ -305,6 +274,141 @@ def _vanilla_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: return data_dict +class DerivativeNode(Node): + """Class for operator node in converted expression tree. + + Args: + expr (sp.Derivative): Sympy derivative expression. + create_graph (bool, optional): Whether to create the gradient graphs of + the computing process. When it is True, higher order derivatives are + supported to compute; when it is False, the gradient graphs of the + computing process would be discarded. Defaults to True. + retain_graph (Optional[bool]): Whether to retain the forward graph which + is used to calculate the gradient. When it is True, the graph would + be retained, in which way users can calculate backward twice for the + same graph. When it is False, the graph would be freed. Defaults to None, + which means it is equal to `create_graph`. + """ + + def __init__( + self, + expr: sp.Derivative, + create_graph: bool = True, + retain_graph: Optional[bool] = None, + ): + super().__init__(expr) + # preprocess children's key instead of processing at run-time in forward + # which can reduce considerable overhead of time for calling "_cvt_to_key" + self.childs = [_cvt_to_key(self.expr.args[0])] + [ + (_cvt_to_key(arg), int(order)) for (arg, order) in self.expr.args[1:] + ] + self.create_graph = create_graph + self.retain_graph = retain_graph + self._apply_func = self._derivate_operator_func + + def forward(self, data_dict: DATA_DICT): + # use cache + if self.key in data_dict: + return data_dict + + return self._apply_func(data_dict) + + def _derivate_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: + # NOTE: Derivative of 'sdf' function will not be executed here, which is already + # generated in 'data_dict' during points sampling using discrete difference + # method(see also: ppsci/geometry/geometry.py: Geometry.sdf_derivatives), + # such as 'sdf__x', 'sdf__y'. + data_dict[self.key] = data_dict[self.childs[0]] + for child, order in self.childs[1:]: + if order & 1: + data_dict[self.key] = jacobian( + data_dict[self.key], + data_dict[child], + create_graph=self.create_graph, + retain_graph=self.retain_graph, + ) + order -= 1 + for _ in range(0, order, 2): + data_dict[self.key] = hessian( + data_dict[self.key], + data_dict[child], + create_graph=self.create_graph, + retain_graph=self.retain_graph, + ) + order -= 2 + return data_dict + + +class FusedDerivativeNode(nn.Layer): + """Class for fused DerivativeNode. + + Args: + f_x_tuples (List[Tuple[Union[sp.Function, sp.Derivative], sp.Symbol]]): + indicate all derivatives of a function in list of tuples. e.g. + [(func1, var1), (func1, var2), (func1, var3), ...]. + create_graph (bool, optional): Whether to create the gradient graphs of + the computing process. When it is True, higher order derivatives are + supported to compute; when it is False, the gradient graphs of the + computing process would be discarded. Defaults to True. + retain_graph (Optional[bool]): Whether to retain the forward graph which + is used to calculate the gradient. When it is True, the graph would + be retained, in which way users can calculate backward twice for the + same graph. When it is False, the graph would be freed. Defaults to None, + which means it is equal to `create_graph`. + """ + + def __init__( + self, + f_x_tuples: List[Tuple[Union[sp.Function, sp.Derivative], sp.Symbol]], + create_graph: bool = True, + retain_graph: Optional[bool] = None, + ): + super().__init__() + self.expr: List[sp.Derivative] = [f.diff(x) for f, x in f_x_tuples] + self.key: List[str] = [_cvt_to_key(expr) for expr in self.expr] + self.create_graph = create_graph + self.retain_graph = retain_graph + + # preprocess children's key instead of processing at run-time in forward + # which can reduce considerable overhead of time for calling "_cvt_to_key" + self.y_key: str = _cvt_to_key(f_x_tuples[0][0]) + self.childs: List[str] = [_cvt_to_key(x) for _, x in f_x_tuples] + self._apply_func = self._parallel_derivate_operator_func + + def forward(self, data_dict: DATA_DICT): + # use cache + if all([key in data_dict for key in self.key]): + return data_dict + + return self._apply_func(data_dict) + + def _parallel_derivate_operator_func(self, data_dict: DATA_DICT) -> DATA_DICT: + # NOTE: Derivative of 'sdf' function will not be executed here, which is already + # generated in 'data_dict' during points sampling using discrete difference + # method(see also: ppsci/geometry/geometry.py: Geometry.sdf_derivatives), + # such as 'sdf__x', 'sdf__y'. + y_data: paddle.Tensor = data_dict[self.y_key] + xs_data: List[paddle.Tensor] = [data_dict[x_key] for x_key in self.childs] + y_wrt_xs_grad: List[paddle.Tensor] = jacobian( + y_data, + xs_data, + create_graph=self.create_graph, + retain_graph=self.retain_graph, + ) + for i, key in enumerate(self.key): + data_dict[key] = y_wrt_xs_grad[i] + return data_dict + + def __str__(self): + return ( + f"{self.__class__.__name__}(expr: {self.expr}, " + f"expr_type: {type(self.expr)})" + ) + + def __repr__(self): + return f"{self.__class__.__name__}(expr: {self.expr})" + + class LayerNode(Node): """Class for layer node in converted expression tree. @@ -363,6 +467,12 @@ def forward(self, data_dict: DATA_DICT) -> DATA_DICT: data_dict[self.key] = self.expr return data_dict + def __str__(self): + return ( + f"{self.__class__.__name__}(expr: {float(self.expr)}, " + f"expr_type: {type(self.expr)})" + ) + class ParameterNode(Node): """Class for constant variable node in converted expression tree. @@ -388,11 +498,12 @@ class ComposedNode(nn.Layer): def __init__(self, callable_nodes: List[Node]): super().__init__() + assert len(callable_nodes) self.callable_nodes = callable_nodes def forward(self, data_dict: DATA_DICT) -> paddle.Tensor: # call all callable_nodes in order - for func in self.callable_nodes: + for i, func in enumerate(self.callable_nodes): data_dict = func(data_dict) # return result of last node(root node) for target @@ -510,8 +621,6 @@ def add_edge(u: str, v: str, u_color: str = C_DATA, v_color: str = C_DATA): add_edge(_cvt_to_key(arg[0]), str(node), v_color=C_FUNC) # export graph to image - from ppsci.utils import logger - graph.layout() image_path = f"{graph_filename}.png" dot_path = f"{graph_filename}.dot" @@ -524,18 +633,71 @@ def add_edge(u: str, v: str, u_color: str = C_DATA, v_color: str = C_DATA): ) +def _fuse_derivative_nodes( + derivative_exprs: List[sp.Derivative], +) -> List[FusedDerivativeNode]: + """Merge derivative nodes and return in list of FusedDerivativeNode after merger. + + Args: + derivative_exprs (List[sp.Derivative]): Derivatives sympy expression of same + function, e.g. [Derivative(u(x,y), x), Derivative(u(x,y), y)] + + Returns: + List[FusedDerivativeNode]: List of FusedDerivativeNode converting from mergable + derivatives. + """ + + class DerivativeTrie: + """Trie for unrolling derivative.""" + + def __init__(self, expr: sp.Basic): + self.expr: sp.Basic = expr + self.next: Dict["sp.Symbol", "DerivativeTrie"] = {} + + # unroll derivative expressions into a trie structure + trie_root = DerivativeTrie(derivative_exprs[0].args[0]) + for derivative_expr in derivative_exprs: + cur_node = trie_root + for (child, order) in derivative_expr.args[1:]: + for _ in range(order): + if child not in cur_node.next: + cur_node.next[child] = DerivativeTrie(cur_node.expr.diff(child)) + cur_node = cur_node.next[child] + + def dfs_trie( + node: DerivativeTrie, fused_derivative_nodes: List[FusedDerivativeNode] + ) -> None: + if node.next: + fused_derivative_nodes.append( + FusedDerivativeNode( + [(node.expr, name) for name in node.next], + ) + ) + for child in node.next: + dfs_trie(node.next[child], fused_derivative_nodes) + + # walk on derivative trie in pre-order and log fusable nodes + fused_derivative_nodes: List[FusedDerivativeNode] = [] + dfs_trie(trie_root, fused_derivative_nodes) + + return fused_derivative_nodes + + def lambdify( - expr: sp.Basic, + expr: Union[sp.Basic, List[sp.Basic]], models: Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]] = None, extra_parameters: Optional[Sequence[paddle.Tensor]] = None, graph_filename: Optional[str] = None, create_graph: bool = True, retain_graph: Optional[bool] = None, -) -> ComposedNode: + fuse_derivative: bool = False, +) -> Union[ComposedNode, List[ComposedNode]]: """Convert sympy expression to callable function. Args: - expr (sp.Basic): Sympy expression to be converted. + expr (Union[sp.Basic, List[sp.Basic]]): Sympy expression(s) to be converted. + will return callable functions in list if multiple expressions are given. + else will return one single callable function. models (Optional[Union[arch.Arch, Tuple[arch.Arch, ...]]]): Model(s) for computing forward result in `LayerNode`. extra_parameters (Optional[nn.ParameterList]): Extra learnable parameters. @@ -552,6 +714,12 @@ def lambdify( be retained, in which way users can calculate backward twice for the same graph. When it is False, the graph would be freed. Defaults to None, which means it is equal to `create_graph`. + fuse_derivative (bool, optional): Whether to fuse the derivative nodes. + for example, if `expr` is 'Derivative(u, x) + Derivative(u, y)' + It will compute grad(u, x) + grad(u, y) if fuse_derivative=False, + else will compute sum(grad(u, [x, y])) if fuse_derivative=True as is more + efficient in backward-graph. Defaults to False, as it is experimental so not + enabled by default if used independently. Returns: ComposedNode: Callable object for computing expr with necessary input(s) data @@ -596,86 +764,209 @@ def lambdify( >>> paddle.allclose(z_tensor_manually, z_tensor_sympy).item() True """ - - # NOTE: Those simplify methods may complicate given expr instead, so not use here - # simplify expression to reduce nodes in tree - # expr = sp.nsimplify(expr) - # expr = sp.expand(expr) - # expr = sp.simplify(expr) - - # remove 1.0 from sympy expression tree - expr = expr.subs(1.0, 1) - - # convert sympy expression tree to list of nodes in post-order - sympy_nodes: List[sp.Basic] = [] - sympy_nodes = _post_traverse(expr, sympy_nodes) - - # remove unnecessary symbol nodes already in input dict(except for parameter symbol) if not extra_parameters: extra_parameters = () - _parameter_names = tuple(param.name for param in extra_parameters) - sympy_nodes = [ - node - for node in sympy_nodes - if (not node.is_Symbol) or (_cvt_to_key(node) in _parameter_names) - ] - - # remove duplicates with topological order kept - sympy_nodes = list(dict.fromkeys(sympy_nodes)) if isinstance(models, arch.ModelList): models = tuple(models.model_list[i] for i in range(len(models.model_list))) if not isinstance(models, (tuple, list)): models = (models,) - # convert sympy node to callable node - callable_nodes = [] - for i, node in enumerate(sympy_nodes): - if isinstance( - node, tuple(SYMPY_TO_PADDLE.keys()) + (sp.Add, sp.Mul, sp.Derivative) - ): - callable_nodes.append(OperatorNode(node, create_graph, retain_graph)) - elif isinstance(node, sp.Function): - if node.name == equation.DETACH_FUNC_NAME: - callable_nodes.append(DetachNode(node)) - else: - match_index = None - for j, model in enumerate(models): - if str(node.func.name) in model.output_keys: - callable_nodes.append( - LayerNode( - node, - model, + def _expr_to_callable_nodes( + single_expr: sp.Basic, graph_filename_: Optional[str] = None + ) -> List[Node]: + """Convert sympy expression to a sequence of nodes in topologic order. + + Args: + single_expr (sp.Basic): Single sympy expression, such as "a+b*c". + graph_filename_ (Optional[str]): Save computational graph to + `/path/to/graph_filename.png` for given `expr`, if `graph_filename` is not + None and a valid string, such as 'momentum_x'. Defaults to None. + + Returns: + List[Node]: Sequence of callable nodes. + """ + # NOTE: Those simplify methods may complicate given expr instead, so not use here + # simplify expression to reduce nodes in tree + # expr = sp.nsimplify(expr) + # expr = sp.expand(expr) + # expr = sp.simplify(expr) + + # remove 1.0 from sympy expression tree + single_expr = single_expr.subs(1.0, 1) + + # convert sympy expression tree to list of nodes in post-order + sympy_nodes: List[sp.Basic] = [] + sympy_nodes = _post_traverse(single_expr, sympy_nodes) + + # remove unnecessary symbol nodes already in input dict(except for parameter symbol) + _parameter_names = tuple(param.name for param in extra_parameters) + sympy_nodes = [ + node + for node in sympy_nodes + if (not node.is_Symbol) or (_cvt_to_key(node) in _parameter_names) + ] + + # remove duplicated node(s) with topological order kept + sympy_nodes = list(dict.fromkeys(sympy_nodes)) + + # convert sympy node to callable node + callable_nodes = [] + for i, node in enumerate(sympy_nodes): + if isinstance( + node, tuple(SYMPY_TO_PADDLE.keys()) + (sp.Add, sp.Mul, sp.Derivative) + ): + if isinstance(node, sp.Derivative): + callable_nodes.append( + DerivativeNode(node, create_graph, retain_graph) + ) + else: + callable_nodes.append(OperatorNode(node)) + elif isinstance(node, sp.Function): + if node.name == equation.DETACH_FUNC_NAME: + callable_nodes.append(DetachNode(node)) + else: + match_index = None + for j, model in enumerate(models): + if str(node.func.name) in model.output_keys: + callable_nodes.append( + LayerNode( + node, + model, + ) ) + if match_index is not None: + raise ValueError( + f"Name of function: '{node}' should be unique along given" + f" models, but got same output_key: '{node.func.name}' " + f"in given models[{match_index}] and models[{j}]." + ) + match_index = j + # NOTE: Skip 'sdf' function, which should be already generated in + # given data_dict + if match_index is None and node.name != "sdf": + raise ValueError( + f"Node {node} can not match any model in given model(s)." ) - if match_index is not None: - raise ValueError( - f"Name of function: '{node}' should be unique along given" - f" models, but got same output_key: '{node.func.name}' " - f"in given models[{match_index}] and models[{j}]." - ) - match_index = j - # NOTE: Skip 'sdf' function, which should be already generated in - # given data_dict - if match_index is None and node.name != "sdf": - raise ValueError( - f"Node {node} can not match any model in given model(s)." + elif node.is_Number or node.is_NumberSymbol: + callable_nodes.append(ConstantNode(node)) + elif isinstance(node, sp.Symbol): + callable_nodes.append( + ParameterNode( + node, + *[ + param + for param in extra_parameters + if param.name == node.name + ], ) - elif node.is_Number or node.is_NumberSymbol: - callable_nodes.append(ConstantNode(node)) - elif isinstance(node, sp.Symbol): - callable_nodes.append( - ParameterNode( - node, - *[param for param in extra_parameters if param.name == node.name], ) + else: + raise NotImplementedError( + f"The node {node} is not supported in lambdify." + ) + + # NOTE: visualize computational graph using 'pygraphviz' + if isinstance(graph_filename, str): + _visualize_graph(sympy_nodes, os.path.join(graph_filename, graph_filename_)) + + return callable_nodes + + if isinstance(expr, sp.Basic): + callable_nodes_group = [_expr_to_callable_nodes(expr, "expr")] + else: + callable_nodes_group = [ + _expr_to_callable_nodes(expr_i, f"expr_{i}") + for i, expr_i in enumerate(expr) + ] + + # [Optional] Fused derivatives nodes that with same function to be differentiated + while fuse_derivative: + candidate_pos: List[Tuple[int, int]] = [] # [(group_id, node_id), ...] + + # use 4-nested for-loop to find all potential mergable derivative nodes + for i in range(len(callable_nodes_group)): + for j in range(len(callable_nodes_group[i])): + # skip non-derivative node + if not isinstance(callable_nodes_group[i][j], DerivativeNode): + continue + # skip sdf function since it is always already given in data_dict + if callable_nodes_group[i][j].expr.args[0].name == "sdf": + continue + + candidate_pos = [[i, j]] + for ii in range(len(callable_nodes_group)): + for jj in range(len(callable_nodes_group[ii])): + # skip non-derivative node + if not isinstance(callable_nodes_group[ii][jj], DerivativeNode): + continue + + # skip same node + if i == ii and j == jj: + continue + + # has same function item + if ( + callable_nodes_group[i][j].expr.args[0] + == callable_nodes_group[ii][jj].expr.args[0] + ): + candidate_pos.append([ii, jj]) + + if len(candidate_pos) > 1: + break + if len(candidate_pos) > 1: + break + + # merge all candidate nodes into one or more FusedDerivativeNode node + if len(candidate_pos) > 1: + fused_node_seq = _fuse_derivative_nodes( + [callable_nodes_group[gid][nid].expr for gid, nid in candidate_pos] + ) + assert isinstance( + fused_node_seq, list + ), "'fused_node_seq' should be list of 'FusedDerivativeNode'" + gid0, nid0 = candidate_pos[0] + logger.debug( + f"Fused {len(candidate_pos)} derivatives nodes: " + f"{[callable_nodes_group[i][j].expr for i, j in candidate_pos]} into" + f" fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])" ) - else: - raise NotImplementedError(f"The node {node} is not supported in lambdify.") - # NOTE: Visualize computational graph using 'pygraphviz' - if isinstance(graph_filename, str): - _visualize_graph(sympy_nodes, graph_filename) + # replace first mergable node with fused node sequence(packed in list) + # then mask the rest merged node to None(except [gid0, nid0]) + for i, (gid, nid) in enumerate(candidate_pos): + if i == 0: + callable_nodes_group[gid0][nid0] = fused_node_seq + else: + # keep the end node of each group to avoid generating empty callable + # node sequence, this will not effect performance since cache strategy + # in Node.forward + if nid != len(callable_nodes_group[gid]) - 1: + callable_nodes_group[gid][nid] = None + + # re-organize callable_nodes_group, remove None element and unpack list + for i in range(len(callable_nodes_group)): + tmp = [] + for j in range(len(callable_nodes_group[i])): + if isinstance( + callable_nodes_group[i][j], (Node, FusedDerivativeNode) + ): + tmp.append(callable_nodes_group[i][j]) + elif isinstance(callable_nodes_group[i][j], list) and isinstance( + callable_nodes_group[i][j][0], FusedDerivativeNode + ): + tmp.extend(callable_nodes_group[i][j]) + else: + assert ( + callable_nodes_group[i][j] is None + ), f"Unexpected element: {callable_nodes_group[i][j]}" + callable_nodes_group[i] = tmp + else: + # exit while loop if no more fused + break # Compose callable nodes into one callable object - return ComposedNode(callable_nodes) + if isinstance(expr, sp.Basic): + return ComposedNode(callable_nodes_group[0]) + else: + return [ComposedNode(callable_nodes) for callable_nodes in callable_nodes_group]