diff --git a/docs/zh/api/arch.md b/docs/zh/api/arch.md index 69ff3ac12..657f26f7d 100644 --- a/docs/zh/api/arch.md +++ b/docs/zh/api/arch.md @@ -18,6 +18,7 @@ - ModelList - AFNONet - PrecipNet + - PhyCRNet - UNetEx - USCNN - NowcastNet diff --git a/docs/zh/examples/phycrnet.md b/docs/zh/examples/phycrnet.md new file mode 100644 index 000000000..a7af1259f --- /dev/null +++ b/docs/zh/examples/phycrnet.md @@ -0,0 +1,190 @@ +# PhyCRNet + +AI Studio快速体验 + +=== "模型训练命令" + + ``` sh + # linux + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyCRNet/burgers_1501x2x128x128.mat -P ./data/ + + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyCRNet/burgers_1501x2x128x128.mat --output ./data/burgers_1501x2x128x128.mat + + python main.py DATA_PATH=./data/burgers_1501x2x128x128.mat + ``` + +=== "模型评估命令" + + ``` sh + # linux + wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyCRNet/burgers_1501x2x128x128.mat -P ./data/ + # windows + # curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/PhyCRNet/burgers_1501x2x128x128.mat --output ./data/burgers_1501x2x128x128.mat + + python main.py mode=eval DATA_PATH=./data/burgers_1501x2x128x128.mat EVAL.pretrained_model_path=https://paddle-org.bj.bcebos.com/paddlescience/models/phycrnet/phycrnet_burgers.pdparams + ``` +| 预训练模型 | 指标 | +|:--| :--| +| [phycrnet_burgers_pretrained.pdparams](https://paddle-org.bj.bcebos.com/paddlescience/models/phycrnet/phycrnet_burgers.pdparams) | a-RMSE: 3.20e-3 | + +## 1. 背景简介 + +复杂时空系统通常可以通过偏微分方程(PDE)来建模,它们在许多领域都十分常见,如应用数学、物理学、生物学、化学和工程学。求解PDE系统一直是科学计算领域的一个关键组成部分。 +本文的具体目标是为了提出一种新颖的、考虑物理信息的卷积-递归学习架构(PhyCRNet)及其轻量级变体(PhyCRNet-s),用于解决没有任何标签数据的多维时间空间PDEs。本项目主要目标是使用PaddleScience复现论文所提供的代码,并与代码的精度对齐。 + +该网络有以下优势: + +1、 使用ConvLSTM(enconder-decoder Convolutional Long Short-Term Memory network) 可以充分提取低维空间上的特征以及学习其时间上的变化。 + +2、使用一个全局的残差迭代从而可以严格地执行时间上的迭代过程。 + +3、使用基于高阶有限差分格式的滤波从而能够精确求解重要的偏微分方程导数值。 + +4、使用强制边界条件是的所求解的数值解可以满足原方程所要求的初值以及边界条件。 + +## 2. 问题定义 + +在本模型中,我们考虑的是含有时间和空间的PDE模型,此类模型在推理过程中会存在时间上的误差累积的问题,因此,本文通过设计循环卷积神经网络试图减轻每一步时间迭代的误差累积。而我们所求解的问题为以高斯分布随机得到的值为初值的二维Burgers' Equation: + +$$u_t+u\cdot \nabla u -\nu u =0$$ + +二维Burgers' Equation 刻画了复杂的非线性的反应扩散相互作用的问题,因此,经常被用来当作benchmark来比较各种科学计算算法。 + +## 3. 问题求解 + +### 3.1 模型构建 +在这一部分中,我们介绍 PhyCRNet 的架构,包括编码器-解码器模块、残差连接、自回归(AR)过程和基于过滤的微分。网络架构如图所示。编码器(黄色Encoder,包含3个卷积层),用于从输入状态变量 $u(t=i),i = 0,1,2,..,T-1$ 学习低维潜在特征,其中 $T$ 表示总时间步。我们应用 ReLU 作为卷积层的激活函数。然后,我们将ConvLSTM层的输出(Encoder得到的低分辨率),潜在特征的时间传播器(绿色部分),其中,输出的LSTM的记忆单元 $C_i$ 和LSTM的隐藏变量单元 $h_i$ 会作为下一个时间步的输入。这样做的好处是对低维变量的基本动态进行建模,能够准确地捕获时间依赖性,同时有助于减轻记忆负担。 使用 LSTM 的另一个优势来自输出状态的双曲正切函数,它可以保持平滑的梯度曲线,并将值控制在 -1 和 1 之间。在建立低分辨率LSTM卷积循环方案后,我们基于上采样操作Decoder(蓝色部分)直接将低分辨率潜在空间重建为高分辨率量。特别注明,应用了子像素卷积层(即像素shuffle),因为与反卷积相比,它具有更好的效率和重建精度,且伪像更少。 最后,我们添加另一个卷积层,用于将有界潜变量空间输出,缩放回原始的物理空间。该Decoder后面没有激活函数。 此外,值得一提的是,鉴于输入变量数量有限及其对超分辨率的缺陷,我们在 PhyCRNet 中没有考虑 batch normalization。 作为替代,我们使用 batch normalization 来训练网络,以实现训练加速和更好的收敛性。受到动力学中,Forward Eular Scheme 的启发,我们在输入状态变量 $u_i$ 和输出变量 $u_{i+1}$ 之间附加全局残差连接。具体网络结构如下图所示: + +![image](https://paddle-org.bj.bcebos.com/paddlescience/docs/phycrnet/PhyCRnet.png) + +接下来,剩下的挑战是,如何进行物理嵌入,来融合N-S方程带来的精度提升。我们应用无梯度卷积滤波器,来表示离散数值微分,以近似感兴趣的导数项。 例如,我们在本文中考虑的基于 Finite Difference 有限差分的滤波器是2阶和4阶中心差分格式,来计算时间和空间导数。 + +时间差分: + +$$K_t = [-1,0,1] \times \frac{1}{2 \delta t},$$ + +空间差分: + +$$K_s = \begin{bmatrix} + 0 & 0 & -1 & 0 & 0 \\ + 0 & 0 & 16 & 0 & 0 \\ + -1 & 16 & -60 & 16 & -1 \\ + 0 & 0 & 16 & 0 & 0 \\ + 0 & 0 & -1 & 0 & 0 \\ +\end{bmatrix} \times \frac{1}{12 (\delta x)^2},$$ + +其中 $\delta t$ 和 $\delta x$ 表示时间步长和空间步长。 + +此外需要注意无法直接计算边界上的导数,丢失边界差异信息的风险可以通过接下来引入的在传统有限差分中经常使用的鬼点填充机制来减轻,其主要核心是在矩阵外围填充一层或多层鬼点(层数取决于差分格式,即,过滤器的大小),以下图为例,在迪利克雷边界条件(Dirichlet BCs)下,我们只需要把常值鬼点在原矩阵外围填充即可;在诺伊曼边界条件(Neumann BCs)下,我们需要根据其边界条件导数值确定鬼点的值。 + +![image](https://paddle-org.bj.bcebos.com/paddlescience/docs/phycrnet/Hard_IC_BC.png) + +``` py linenums="43" +--8<-- +examples/phycrnet/main.py:43:45 +--8<-- +``` + +``` yaml linenums="34" +--8<-- +examples/phycrnet/conf/burgers_equations.yaml:34:42 +--8<-- +``` + +### 3.2 数据载入 +我们使用RK4或者谱方法生成的数据(初值为使用正态分布生成),需要从.mat文件中将其读入,: +``` py linenums="54" +--8<-- +examples/phycrnet/main.py:54:72 +--8<-- +``` + +### 3.3 约束构建 + +设置约束以及相关损失函数: + +``` py linenums="74" +--8<-- +examples/phycrnet/main.py:74:90 +--8<-- +``` + +### 3.4 评估器构建 + +设置评估数据集和相关损失函数: + +``` py linenums="92" +--8<-- +examples/phycrnet/main.py:92:109 +--8<-- +``` + +### 3.6 优化器构建 + +训练过程会调用优化器来更新模型参数,此处选择 `Adam` 优化器并设定 `learning_rate` + +``` py linenums="112" +--8<-- +examples/phycrnet/main.py:112:116 +--8<-- +``` + +### 3.7 模型训练与评估 + +为了评估所有基于神经网络的求解器产生的解决方案精度,我们分两个阶段评估了全场误差传播:训练和外推。在时刻 τ 的全场误差 $\epsilon_\tau$ 的定义为给定 b 的累积均方根误差 (a-RMSE)。 + +$$ +\epsilon_\tau=\sqrt{\frac{1}{N_\tau} \sum_{k=1}^{N_\tau} \frac{\left\|\mathbf{u}^*\left(\mathbf{x}, t_k\right)-\mathbf{u}^\theta\left(\mathbf{x}, t_k\right)\right\|_2^2}{m n}} +$$ + +这一步需要通过设置外界函数来进行,因此在训练过程中,我们使用`function.transform_out`来进行训练 +``` py linenums="47" +--8<-- +examples/phycrnet/main.py:47:51 +--8<-- +``` +而在评估过程中,我们使用`function.tranform_output_val`来进行评估,并生成累计均方根误差。 +``` py linenums="142" +--8<-- +examples/phycrnet/main.py:142:142 +--8<-- +``` +完成上述设置之后,只需要将上述实例化的对象按顺序传递给 `ppsci.solver.Solver`。 + +``` py linenums="117" +--8<-- +examples/phycrnet/main.py:117:129 +--8<-- +``` + +最后启动训练、评估即可: + +``` py linenums="132" +--8<-- +examples/phycrnet/main.py:132:140 +--8<-- +``` + +## 4. 完整代码 + +``` py linenums="1" title="phycrnet" +--8<-- +examples/phycrnet/main.py +--8<-- +``` + +## 5. 结果展示 + +本文通过对Burgers' Equation进行训练,所得结果如下,根据精度和扩展能力的对比我们可以得出,我们的模型在训练集(t=1.0,2.0)以及拓展集(t=3.0,4.0)上均有良好的表现效果。pred为使用网络预测的速度的第一分量u在定义域上的contour图,truth为真实的速度第一分量u在定义域上的contour图,Error为预测值与真实值之间在整个定义域差值。 + +![image](https://paddle-org.bj.bcebos.com/paddlescience/docs/NSFNet/PhyCRNet_Burgers.jpeg) + +## 6. 结果说明 + +求解偏微分方程是在科学计算中的一个基本问题,而神经网络求解偏微分方程在求解逆问题以及数据同化问题等在传统方法上具有挑战性的问题上具有显著效果,但是,现有神经网络求解方法受限制于可扩展性,误差传导以及泛化能力等问题。因此,本论文通过提出一个新的神经网络PhyCRNet,通过将传统有限差分的思路嵌入物理信息神经网络中,针对性地解决原神经网络缺少对长时间数据的推理能力、误差累积以及缺少泛化能力的问题。与此同时,本文通过类似于有限差分的边界处理方式,将原本边界条件的软限制转为硬限制,大大提高了神经网络的准确性。新提出的网络可以有效解决上述提到的数据同化问题以及逆问题。 + +## 7. 参考资料 + +- [PhyCRNet: Physics-informed Convolutional-Recurrent Network for Solving Spatiotemporal PDEs](https://arxiv.org/abs/2106.14103) +- diff --git a/examples/phycrnet/conf/burgers_equations.yaml b/examples/phycrnet/conf/burgers_equations.yaml new file mode 100644 index 000000000..1dee3a0f9 --- /dev/null +++ b/examples/phycrnet/conf/burgers_equations.yaml @@ -0,0 +1,65 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: burgers/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchanged + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + - log_freq + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 66 +output_dir: ${hydra:run.dir} +DATA_PATH: ./data/burgers_2001x2x128x128.mat +case_name: 2D Burgers' equation +num_convlstm: 1 +# set working condition +TIME_STEPS: 1001 +DT: 0.002 +DX: [1.0, 128] +TIME_BATCH_SIZE: 1000 + +# model settings +MODEL: + input_channels: 2 + hidden_channels: [8, 32, 128, 128] + input_kernel_size: [4, 4, 4, 3] + input_stride: [2, 2, 2, 1] + input_padding: [1, 1, 1, 1] + num_layers: [3, 1] + upscale_factor: 8 + +# training settings +TRAIN: + epochs: 24000 + iters_per_epoch: 1 + lr_scheduler: + epochs: ${TRAIN.epochs} + iters_per_epoch: ${TRAIN.iters_per_epoch} + step_size: 50 + gamma: 0.99 + learning_rate: 6.0e-4 + save_freq: 50 + eval_with_no_grad: true + pretrained_model_path: null + checkpoint_path: null + +# evaluation settings +EVAL: + pretrained_model_path: null + eval_with_no_grad: true + TIME_BATCH_SIZE: 2000 + TIME_STEPS: 2001 diff --git a/examples/phycrnet/conf/fitzhugh_nagumo_RD_equation.yaml b/examples/phycrnet/conf/fitzhugh_nagumo_RD_equation.yaml new file mode 100644 index 000000000..8ed99d860 --- /dev/null +++ b/examples/phycrnet/conf/fitzhugh_nagumo_RD_equation.yaml @@ -0,0 +1,63 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: fitzhugh/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchanged + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + - log_freq + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +mode: train # running mode: train/eval +seed: 66 +output_dir: ${hydra:run.dir} +DATA_PATH: ./data/FN_1001x2x128x128.mat +case_name: fitzhugh_nagumo +num_convlstm: 1 +# set working condition +TIME_STEPS: 751 +DT: 0.006 +DX: [128.0, 128] +TIME_BATCH_SIZE: 750 + +# model settings +MODEL: + input_channels: 2 + hidden_channels: [8, 32, 128, 128] + input_kernel_size: [4, 4, 4, 3] + input_stride: [2, 2, 2, 1] + input_padding: [1, 1, 1, 1] + num_layers: [3, 1] + upscale_factor: 8 + +# training settings +TRAIN: + epochs: 30000 + iters_per_epoch: 1 + lr_scheduler: + epochs: ${TRAIN.epochs} + iters_per_epoch: ${TRAIN.iters_per_epoch} + step_size: 50 + gamma: 0.995 + learning_rate: 5.0e-5 + save_freq: 50 + eval_with_no_grad: true + pretrained_model_path: null + checkpoint_path: null + +# evaluation settings +EVAL: + pretrained_model_path: null + eval_with_no_grad: true diff --git a/examples/phycrnet/conf/lambda_omega_RD_equation.yaml b/examples/phycrnet/conf/lambda_omega_RD_equation.yaml new file mode 100644 index 000000000..19c8eb771 --- /dev/null +++ b/examples/phycrnet/conf/lambda_omega_RD_equation.yaml @@ -0,0 +1,63 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: lambda_omega/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchanged + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + - log_freq + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +# general settings +device: 'gpu:3' +mode: train # running mode: train/eval +seed: 66 +output_dir: ${hydra:run.dir} +DATA_PATH: ./data/reaction_diffusion.mat +case_name: lambda_omega +# set working condition +TIME_STEPS: 201 +DT: 0.025 +DX: [20.0, 512] +TIME_BATCH_SIZE: 200 +num_convlstm: 1 +# model settings +MODEL: + input_channels: 2 + hidden_channels: [8, 32, 128, 128] + input_kernel_size: [4, 4, 4, 3] + input_stride: [2, 2, 2, 1] + input_padding: [1, 1, 1, 1] + num_layers: [3, 1] + upscale_factor: 8 + +# training settings +TRAIN: + epochs: 24000 + iters_per_epoch: 1 + lr_scheduler: + epochs: ${TRAIN.epochs} + iters_per_epoch: ${TRAIN.iters_per_epoch} + step_size: 100 + gamma: 0.99 + learning_rate: 5.0e-4 + save_freq: 50 + eval_with_no_grad: true + pretrained_model_path: null + checkpoint_path: null + +# evaluation settings +EVAL: + pretrained_model_path: null + eval_with_no_grad: true diff --git a/examples/phycrnet/functions.py b/examples/phycrnet/functions.py new file mode 100644 index 000000000..80540e902 --- /dev/null +++ b/examples/phycrnet/functions.py @@ -0,0 +1,420 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict + +import matplotlib.pyplot as plt +import numpy as np +import paddle +import paddle.nn as nn + +from ppsci.arch import phycrnet + +dt = None +dx = None +num_time_batch = None +uv = None +time_steps = None + +# transform +def transform_in(input): + shape = input["initial_state_shape"][0] + input_transformed = { + "initial_state": input["initial_state"][0].reshape(shape.tolist()), + "input": input["input"][0], + } + return input_transformed + + +def transform_out(input, out, model): + # Stop the transform to avoid circulation + model.enable_transform = False + + loss_func = phycrnet.loss_generator(dt, dx) + batch_loss = 0 + state_detached = [] + prev_output = [] + for time_batch_id in range(num_time_batch): + # update the first input for each time batch + if time_batch_id == 0: + hidden_state = input["initial_state"] + u0 = input["input"] + else: + hidden_state = state_detached + u0 = prev_output[-2:-1].detach() # second last output + out = model({"initial_state": hidden_state, "input": u0}) + + # output is a list + output = out["outputs"] + second_last_state = out["second_last_state"] + + # [t, c, height (Y), width (X)] + output = paddle.concat(tuple(output), axis=0) + + # concatenate the initial state to the output for central diff + output = paddle.concat((u0.cuda(), output), axis=0) + + # get loss + loss = compute_loss(output, loss_func) + batch_loss += loss + + # update the state and output for next batch + prev_output = output + state_detached = [] + for i in range(len(second_last_state)): + (h, c) = second_last_state[i] + state_detached.append((h.detach(), c.detach())) # hidden state + + model.enable_transform = True + return {"loss": batch_loss} + + +def tranform_output_val(input, out, name="results.npz"): + output = out["outputs"] + input = input["input"] + + # shape: [t, c, h, w] + output = paddle.concat(tuple(output), axis=0) + output = paddle.concat((input.cuda(), output), axis=0) + + # Padding x and y axis due to periodic boundary condition + output = paddle.concat((output[:, :, :, -1:], output, output[:, :, :, 0:2]), axis=3) + output = paddle.concat((output[:, :, -1:, :], output, output[:, :, 0:2, :]), axis=2) + + # [t, c, h, w] + truth = uv[0:time_steps, :, :, :] + + # [101, 2, 131, 131] + truth = np.concatenate((truth[:, :, :, -1:], truth, truth[:, :, :, 0:2]), axis=3) + truth = np.concatenate((truth[:, :, -1:, :], truth, truth[:, :, 0:2, :]), axis=2) + truth = paddle.to_tensor(truth) + # post-process + ten_true = [] + ten_pred = [] + for i in range(0, 1001): + u_star, u_pred, v_star, v_pred = post_process( + output, + truth, + num=i, + ) + + ten_true.append(paddle.stack([u_star, v_star])) + ten_pred.append(paddle.stack([u_pred, v_pred])) + ten_true = paddle.stack(ten_true) + ten_pred = paddle.stack(ten_pred) + # compute the error + # a-RMSE + error = ( + paddle.sum((ten_pred - ten_true) ** 2, axis=(1, 2, 3)) + / ten_true.shape[2] + / ten_true.shape[3] + ) + N = error.shape[0] + M = 0 + for i in range(N): + M = M + np.eye(N, k=-i) + M = M.T / np.arange(N) + M[:, 0] = 0 + M[0, :] = 0 + M = paddle.to_tensor(M) + aRMSE = paddle.sqrt(M.T @ error) + np.savez( + name, + error=np.array(error), + ten_true=ten_true, + ten_pred=ten_pred, + aRMSE=np.array(aRMSE), + ) + error = paddle.linalg.norm(error) + return {"loss": paddle.to_tensor([error])} + + +def train_loss_func(result_dict, *args) -> paddle.Tensor: + """For model calculation of loss. + + Args: + result_dict (Dict[str, paddle.Tensor]): The result dict. + + Returns: + paddle.Tensor: Loss value. + """ + return result_dict["loss"] + + +def val_loss_func(result_dict, *args) -> paddle.Tensor: + return result_dict["loss"] + + +def metric_expr(output_dict, *args) -> Dict[str, paddle.Tensor]: + return {"dummy_loss": paddle.to_tensor(0.0)} + + +class GaussianRF(object): + def __init__(self, dim, size, alpha=2, tau=3, sigma=None, boundary="periodic"): + self.dim = dim + + if sigma is None: + sigma = tau ** (0.5 * (2 * alpha - self.dim)) + + k_max = size // 2 + + if dim == 1: + k = paddle.concat( + ( + paddle.arange(start=0, end=k_max, step=1), + paddle.arange(start=-k_max, end=0, step=1), + ), + 0, + ) + + self.sqrt_eig = ( + size + * math.sqrt(2.0) + * sigma + * ((4 * (math.pi**2) * (k**2) + tau**2) ** (-alpha / 2.0)) + ) + self.sqrt_eig[0] = 0.0 + + elif dim == 2: + wavenumers = paddle.concat( + ( + paddle.arange(start=0, end=k_max, step=1), + paddle.arange(start=-k_max, end=0, step=1), + ), + 0, + ).tile((size, 1)) + + perm = list(range(wavenumers.ndim)) + perm[1] = 0 + perm[0] = 1 + k_x = wavenumers.transpose(perm=perm) + k_y = wavenumers + + self.sqrt_eig = ( + (size**2) + * math.sqrt(2.0) + * sigma + * ( + (4 * (math.pi**2) * (k_x**2 + k_y**2) + tau**2) + ** (-alpha / 2.0) + ) + ) + self.sqrt_eig[0, 0] = 0.0 + + elif dim == 3: + wavenumers = paddle.concat( + ( + paddle.arange(start=0, end=k_max, step=1), + paddle.arange(start=-k_max, end=0, step=1), + ), + 0, + ).tile((size, size, 1)) + + perm = list(range(wavenumers.ndim)) + perm[1] = 2 + perm[2] = 1 + k_x = wavenumers.transpose(perm=perm) + k_y = wavenumers + + perm = list(range(wavenumers.ndim)) + perm[0] = 2 + perm[2] = 0 + k_z = wavenumers.transpose(perm=perm) + + self.sqrt_eig = ( + (size**3) + * math.sqrt(2.0) + * sigma + * ( + (4 * (math.pi**2) * (k_x**2 + k_y**2 + k_z**2) + tau**2) + ** (-alpha / 2.0) + ) + ) + self.sqrt_eig[0, 0, 0] = 0.0 + + self.size = [] + for j in range(self.dim): + self.size.append(size) + + self.size = tuple(self.size) + + def sample(self, N): + + coeff = paddle.randn((N, *self.size, 2)) + + coeff[..., 0] = self.sqrt_eig * coeff[..., 0] + coeff[..., 1] = self.sqrt_eig * coeff[..., 1] + + if self.dim == 2: + u = paddle.as_real(paddle.fft.ifft2(paddle.as_complex(coeff))) + else: + raise f"self.dim not in (2): {self.dim}" + + u = u[..., 0] + + return u + + +def compute_loss(output, loss_func): + """calculate the physics loss""" + + # Padding x axis due to periodic boundary condition + output = paddle.concat((output[:, :, :, -2:], output, output[:, :, :, 0:3]), axis=3) + + # Padding y axis due to periodic boundary condition + output = paddle.concat((output[:, :, -2:, :], output, output[:, :, 0:3, :]), axis=2) + + # get physics loss + mse_loss = nn.MSELoss() + f_u, f_v = loss_func.get_phy_Loss(output) + loss = mse_loss(f_u, paddle.zeros_like(f_u).cuda()) + mse_loss( + f_v, paddle.zeros_like(f_v).cuda() + ) + + return loss + + +def post_process(output, true, num): + """ + num: Number of time step + """ + u_star = true[num, 0, 1:-1, 1:-1] + u_pred = output[num, 0, 1:-1, 1:-1].detach() + + v_star = true[num, 1, 1:-1, 1:-1] + v_pred = output[num, 1, 1:-1, 1:-1].detach() + + return u_star, u_pred, v_star, v_pred + + +class Dataset: + def __init__(self, initial_state, input): + self.initial_state = initial_state + self.input = input + + def get(self, epochs=1): + input_dict_train = { + "initial_state": [], + "initial_state_shape": [], + "input": [], + } + label_dict_train = {"dummy_loss": []} + input_dict_val = { + "initial_state": [], + "initial_state_shape": [], + "input": [], + } + label_dict_val = {"dummy_loss": []} + for i in range(epochs): + shape = self.initial_state.shape + input_dict_train["initial_state"].append(self.initial_state.reshape((-1,))) + input_dict_train["initial_state_shape"].append(paddle.to_tensor(shape)) + input_dict_train["input"].append(self.input) + label_dict_train["dummy_loss"].append(paddle.to_tensor(0.0)) + + if i == epochs - 1: + shape = self.initial_state.shape + input_dict_val["initial_state"].append( + self.initial_state.reshape((-1,)) + ) + input_dict_val["initial_state_shape"].append(paddle.to_tensor(shape)) + input_dict_val["input"].append(self.input) + label_dict_val["dummy_loss"].append(paddle.to_tensor(0.0)) + + return input_dict_train, label_dict_train, input_dict_val, label_dict_val + + +def output_graph(model, input_dataset, fig_save_path, case_name): + output_dataset = model(input_dataset) + output = output_dataset["outputs"] + input = input_dataset["input"][0] + output = paddle.concat(tuple(output), axis=0) + output = paddle.concat((input.cuda(), output), axis=0) + + # Padding x and y axis due to periodic boundary condition + output = paddle.concat((output[:, :, :, -1:], output, output[:, :, :, 0:2]), axis=3) + output = paddle.concat((output[:, :, -1:, :], output, output[:, :, 0:2, :]), axis=2) + truth = uv[0:2001, :, :, :] + truth = np.concatenate((truth[:, :, :, -1:], truth, truth[:, :, :, 0:2]), axis=3) + truth = np.concatenate((truth[:, :, -1:, :], truth, truth[:, :, 0:2, :]), axis=2) + + # post-process + ten_true = [] + ten_pred = [] + + for i in range(0, 100): + u_star, u_pred, v_star, v_pred = post_process(output, truth, num=20 * i) + ten_true.append([u_star, v_star]) + ten_pred.append([u_pred, v_pred]) + + ten_true = np.stack(ten_true) + ten_pred = np.stack(ten_pred) + + # compute the error + # a-RMSE + error = ( + np.sum((ten_pred - ten_true) ** 2, axis=(1, 2, 3)) + / ten_true.shape[2] + / ten_true.shape[3] + ) + N = error.shape[0] + M = 0 + for i in range(N): + M = M + np.eye(N, k=-i) + M = M.T / np.arange(N) + M[:, 0] = 0 + M[0, :] = 0 + + M = paddle.to_tensor(M) + aRMSE = paddle.sqrt(M.T @ error) + t = np.linspace(0, 4, N) + plt.plot(t, aRMSE, color="r") + plt.yscale("log") + plt.xlabel("t") + plt.ylabel("a-RMSE") + plt.ylim((1e-4, 10)) + plt.xlim((0, 4)) + plt.legend( + [ + "PhyCRNet", + ], + loc="upper left", + ) + plt.title(case_name) + plt.savefig(fig_save_path + "/error.jpg") + + _, ax = plt.subplots(3, 4, figsize=(18, 12)) + ax[0, 0].contourf(ten_true[25, 0]) + ax[0, 0].set_title("t=1") + ax[0, 0].set_ylabel("truth") + ax[1, 0].contourf(ten_pred[25, 0]) + ax[1, 0].set_ylabel("pred") + ax[2, 0].contourf(ten_true[25, 0] - ten_pred[25, 0]) + ax[2, 0].set_ylabel("error") + ax[0, 1].contourf(ten_true[50, 0]) + ax[0, 1].set_title("t=2") + ax[1, 1].contourf(ten_pred[50, 0]) + ax[2, 1].contourf(ten_true[50, 0] - ten_pred[50, 0]) + ax[0, 2].contourf(ten_true[75, 0]) + ax[0, 2].set_title("t=3") + ax[1, 2].contourf(ten_pred[75, 0]) + ax[2, 2].contourf(ten_true[75, 0] - ten_pred[75, 0]) + ax[0, 3].contourf(ten_true[99, 0]) + ax[0, 3].set_title("t=4") + ax[1, 3].contourf(ten_pred[99, 0]) + ax[2, 3].contourf(ten_true[99, 0] - ten_pred[99, 0]) + plt.title(case_name) + plt.savefig(fig_save_path + "/Burgers.jpg") + plt.close() diff --git a/examples/phycrnet/main.py b/examples/phycrnet/main.py new file mode 100644 index 000000000..ef0f8fab8 --- /dev/null +++ b/examples/phycrnet/main.py @@ -0,0 +1,194 @@ +""" +PhyCRNet for solving spatiotemporal PDEs +Reference: https://github.com/isds-neu/PhyCRNet/ +""" +from os import path as osp + +import functions +import hydra +import paddle +import scipy.io as scio +from omegaconf import DictConfig + +import ppsci +from ppsci.utils import logger + + +def train(cfg: DictConfig): + # set random seed for reproducibility + ppsci.utils.misc.set_random_seed(cfg.seed) + # initialize logger + logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info") + + # set initial states for convlstm + NUM_CONVLSTM = cfg.num_convlstm + (h0, c0) = (paddle.randn((1, 128, 16, 16)), paddle.randn((1, 128, 16, 16))) + initial_state = [] + for _ in range(NUM_CONVLSTM): + initial_state.append((h0, c0)) + + # grid parameters + time_steps = cfg.TIME_STEPS + dx = cfg.DX[0] / cfg.DX[1] + + steps = cfg.TIME_BATCH_SIZE + 1 + effective_step = list(range(0, steps)) + num_time_batch = int(time_steps / cfg.TIME_BATCH_SIZE) + + functions.dt = cfg.DT + functions.dx = dx + functions.time_steps = cfg.TIME_STEPS + functions.num_time_batch = num_time_batch + model = ppsci.arch.PhyCRNet( + dt=cfg.DT, step=steps, effective_step=effective_step, **cfg.MODEL + ) + + def _transform_out(_in, _out): + return functions.transform_out(_in, _out, model) + + model.register_input_transform(functions.transform_in) + model.register_output_transform(_transform_out) + + # use Burgers_2d_solver_HighOrder.py to generate data + data = scio.loadmat(cfg.DATA_PATH) + uv = data["uv"] # [t,c,h,w] + functions.uv = uv + + # generate input data + ( + input_dict_train, + label_dict_train, + input_dict_val, + label_dict_val, + ) = functions.Dataset( + paddle.to_tensor(initial_state), + paddle.to_tensor( + uv[ + 0:1, + ], + dtype=paddle.get_default_dtype(), + ), + ).get() + + sup_constraint_pde = ppsci.constraint.SupervisedConstraint( + { + "dataset": { + "name": "NamedArrayDataset", + "input": input_dict_train, + "label": label_dict_train, + }, + "batch_size": 1, + "num_workers": 0, + }, + ppsci.loss.FunctionalLoss(functions.train_loss_func), + { + "loss": lambda out: out["loss"], + }, + name="sup_train", + ) + constraint_pde = {sup_constraint_pde.name: sup_constraint_pde} + + sup_validator_pde = ppsci.validate.SupervisedValidator( + { + "dataset": { + "name": "NamedArrayDataset", + "input": input_dict_val, + "label": label_dict_val, + }, + "batch_size": 1, + "num_workers": 0, + }, + ppsci.loss.FunctionalLoss(functions.val_loss_func), + { + "loss": lambda out: out["loss"], + }, + metric={"metric": ppsci.metric.FunctionalMetric(functions.metric_expr)}, + name="sup_valid", + ) + validator_pde = {sup_validator_pde.name: sup_validator_pde} + + # initialize solver + scheduler = ppsci.optimizer.lr_scheduler.Step(**cfg.TRAIN.lr_scheduler)() + + optimizer = ppsci.optimizer.Adam(scheduler)(model) + solver = ppsci.solver.Solver( + model, + constraint_pde, + cfg.output_dir, + optimizer, + scheduler, + cfg.TRAIN.epochs, + cfg.TRAIN.iters_per_epoch, + validator=validator_pde, + eval_with_no_grad=cfg.TRAIN.eval_with_no_grad, + pretrained_model_path=cfg.TRAIN.pretrained_model_path, + ) + + # train model + solver.train() + # evaluate after finished training + model.register_output_transform(functions.tranform_output_val) + solver.eval() + + +def evaluate(cfg: DictConfig): + # set random seed for reproducibility + ppsci.utils.misc.set_random_seed(cfg.seed) + # initialize logger + logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info") + + # set initial states for convlstm + NUM_CONVLSTM = cfg.num_convlstm + (h0, c0) = (paddle.randn((1, 128, 16, 16)), paddle.randn((1, 128, 16, 16))) + initial_state = [] + for _ in range(NUM_CONVLSTM): + initial_state.append((h0, c0)) + + # grid parameters + time_steps = cfg.TIME_STEPS + dx = cfg.DX[0] / cfg.DX[1] + + steps = cfg.EVAL.TIME_BATCH_SIZE + 1 + effective_step = list(range(0, steps)) + num_time_batch = int(time_steps / cfg.EVAL.TIME_BATCH_SIZE) + + functions.dt = cfg.DT + functions.dx = dx + functions.num_time_batch = num_time_batch + model = ppsci.arch.PhyCRNet( + dt=cfg.DT, step=steps, effective_step=effective_step, **cfg.MODEL + ) + + def _transform_out(_in, _out): + return functions.transform_out(_in, _out, model) + + model.register_input_transform(functions.transform_in) + model.register_output_transform(_transform_out) + + # use the generated data + data = scio.loadmat(cfg.DATA_PATH) + uv = data["uv"] # [t,c,h,w] + functions.uv = uv + _, _, input_dict_val, _ = functions.Dataset( + paddle.to_tensor(initial_state), + paddle.to_tensor(uv[0:1, ...], dtype=paddle.get_default_dtype()), + ).get() + ppsci.utils.load_pretrain(model, cfg.EVAL.pretrained_model_path) + model.register_output_transform(None) + functions.output_graph(model, input_dict_val, cfg.output_dir, cfg.case_name) + + +@hydra.main( + version_base=None, config_path="./conf", config_name="burgers_equations.yaml" +) +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + + +if __name__ == "__main__": + main() diff --git a/mkdocs.yml b/mkdocs.yml index 0e84072d8..0d274018d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -59,6 +59,7 @@ nav: - LDC2D_unsteady: zh/examples/ldc2d_unsteady.md - Labelfree_DNN_surrogate: zh/examples/labelfree_DNN_surrogate.md - NSFNets: zh/examples/nsfnet.md + - PhyCRNet: zh/examples/phycrnet.md - ShockWave: zh/examples/shock_wave.md - tempoGAN: zh/examples/tempoGAN.md - ViV: zh/examples/viv.md diff --git a/ppsci/arch/__init__.py b/ppsci/arch/__init__.py index 1130b35c3..7e96b2604 100644 --- a/ppsci/arch/__init__.py +++ b/ppsci/arch/__init__.py @@ -16,6 +16,8 @@ import copy +from ppsci.arch.phycrnet import PhyCRNet + from ppsci.arch.base import Arch # isort:skip from ppsci.arch.amgnet import AMGNet # isort:skip from ppsci.arch.mlp import MLP # isort:skip @@ -56,6 +58,7 @@ "PrecipNet", "UNetEx", "Epnn", + "PhyCRNet", "NowcastNet", "USCNN", "HEDeepONets", diff --git a/ppsci/arch/phycrnet.py b/ppsci/arch/phycrnet.py new file mode 100644 index 000000000..744a374c3 --- /dev/null +++ b/ppsci/arch/phycrnet.py @@ -0,0 +1,526 @@ +from typing import Tuple + +import numpy as np +import paddle +import paddle.nn as nn +from paddle.nn import utils + +from ppsci.arch import base + +# define the high-order finite difference kernels +LALP_OP = [ + [ + [ + [0, 0, -1 / 12, 0, 0], + [0, 0, 4 / 3, 0, 0], + [-1 / 12, 4 / 3, -5, 4 / 3, -1 / 12], + [0, 0, 4 / 3, 0, 0], + [0, 0, -1 / 12, 0, 0], + ] + ] +] + +PARTIAL_Y = [ + [ + [ + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1 / 12, -8 / 12, 0, 8 / 12, -1 / 12], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + ] +] + +PARTIAL_X = [ + [ + [ + [0, 0, 1 / 12, 0, 0], + [0, 0, -8 / 12, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 8 / 12, 0, 0], + [0, 0, -1 / 12, 0, 0], + ] + ] +] + + +# specific parameters for burgers equation +def _initialize_weights(module): + if isinstance(module, nn.Conv2D): + c = 1.0 # 0.5 + initializer = nn.initializer.Uniform( + -c * np.sqrt(1 / (3 * 3 * 320)), c * np.sqrt(1 / (3 * 3 * 320)) + ) + initializer(module.weight) + elif isinstance(module, nn.Linear): + initializer = nn.initializer.Constant(0.0) + initializer(module.bias) + + +class PhyCRNet(base.Arch): + """Physics-informed convolutional-recurrent neural networks. + + Args: + input_channels (int): The input channels. + hidden_channels (Tuple[int]): The hidden channels. + input_kernel_size (Tuple[int]): The input kernel size. + input_stride (Tuple[int]): The input stride. + input_padding (Tuple[int]): The input padding. + dt (float): The dt parameter. + num_layers (Tuple[int]): The number of layers. + upscale_factor (int): The upscale factor. + step (int, optional): The step. Defaults to 1. + effective_step (Tuple[int], optional): The effective step. Defaults to (1). + + Examples: + >>> import ppsci + >>> model = ppsci.arch.PhyCRNet( + ... input_channels=2, + ... hidden_channels=[8, 32, 128, 128], + ... input_kernel_size=[4, 4, 4, 3], + ... input_stride=[2, 2, 2, 1], + ... input_padding=[1, 1, 1, 1], + ... dt=0.002, + ... num_layers=[3, 1], + ... upscale_factor=8 + ... ) + """ + + def __init__( + self, + input_channels: int, + hidden_channels: Tuple[int], + input_kernel_size: Tuple[int], + input_stride: Tuple[int], + input_padding: Tuple[int], + dt: float, + num_layers: Tuple[int], + upscale_factor: int, + step: int = 1, + effective_step: Tuple[int] = (1), + ): + super(PhyCRNet, self).__init__() + + # input channels of layer includes input_channels and hidden_channels of cells + self.input_channels = [input_channels] + hidden_channels + self.hidden_channels = hidden_channels + self.input_kernel_size = input_kernel_size + self.input_stride = input_stride + self.input_padding = input_padding + self.step = step + self.effective_step = effective_step + self._all_layers = [] + self.dt = dt + self.upscale_factor = upscale_factor + + # number of layers + self.num_encoder = num_layers[0] + self.num_convlstm = num_layers[1] + + # encoder - downsampling + self.encoder = paddle.nn.LayerList( + [ + encoder_block( + input_channels=self.input_channels[i], + hidden_channels=self.hidden_channels[i], + input_kernel_size=self.input_kernel_size[i], + input_stride=self.input_stride[i], + input_padding=self.input_padding[i], + ) + for i in range(self.num_encoder) + ] + ) + + # ConvLSTM + self.ConvLSTM = paddle.nn.LayerList( + [ + ConvLSTMCell( + input_channels=self.input_channels[i], + hidden_channels=self.hidden_channels[i], + input_kernel_size=self.input_kernel_size[i], + input_stride=self.input_stride[i], + input_padding=self.input_padding[i], + ) + for i in range(self.num_encoder, self.num_encoder + self.num_convlstm) + ] + ) + + # output layer + self.output_layer = nn.Conv2D( + 2, 2, kernel_size=5, stride=1, padding=2, padding_mode="circular" + ) + + # pixelshuffle - upscale + self.pixelshuffle = nn.PixelShuffle(self.upscale_factor) + + # initialize weights + self.apply(_initialize_weights) + initializer_0 = paddle.nn.initializer.Constant(0.0) + initializer_0(self.output_layer.bias) + self.enable_transform = True + + def forward(self, x): + if self.enable_transform: + if self._input_transform is not None: + x = self._input_transform(x) + output_x = x + + self.initial_state = x["initial_state"] + x = x["input"] + internal_state = [] + outputs = [] + second_last_state = [] + + for step in range(self.step): + xt = x + + # encoder + for encoder in self.encoder: + x = encoder(x) + + # convlstm + for i, LSTM in enumerate(self.ConvLSTM): + if step == 0: + (h, c) = LSTM.init_hidden_tensor( + prev_state=self.initial_state[i - self.num_encoder] + ) + internal_state.append((h, c)) + + # one-step forward + (h, c) = internal_state[i - self.num_encoder] + x, new_c = LSTM(x, h, c) + internal_state[i - self.num_encoder] = (x, new_c) + + # output + x = self.pixelshuffle(x) + x = self.output_layer(x) + + # residual connection + x = xt + self.dt * x + + if step == (self.step - 2): + second_last_state = internal_state.copy() + + if step in self.effective_step: + outputs.append(x) + + result_dict = {"outputs": outputs, "second_last_state": second_last_state} + if self.enable_transform: + if self._output_transform is not None: + result_dict = self._output_transform(output_x, result_dict) + return result_dict + + +class ConvLSTMCell(nn.Layer): + """Convolutional LSTM""" + + def __init__( + self, + input_channels, + hidden_channels, + input_kernel_size, + input_stride, + input_padding, + hidden_kernel_size=3, + num_features=4, + ): + super(ConvLSTMCell, self).__init__() + + self.input_channels = input_channels + self.hidden_channels = hidden_channels + self.hidden_kernel_size = hidden_kernel_size # Page 9, The convolutional operations in ConvLSTM have 3x3 kernels. + self.input_kernel_size = input_kernel_size + self.input_stride = input_stride + self.input_padding = input_padding + self.num_features = ( + num_features # Page 10, block of different dense layers {4, 3, 4} + ) + + # padding for hidden state + self.padding = int((self.hidden_kernel_size - 1) / 2) + + self.Wxi = nn.Conv2D( + self.input_channels, + self.hidden_channels, + self.input_kernel_size, + self.input_stride, + self.input_padding, + bias_attr=None, + padding_mode="circular", + ) + + self.Whi = nn.Conv2D( + self.hidden_channels, + self.hidden_channels, + self.hidden_kernel_size, + 1, + padding=1, + bias_attr=False, + padding_mode="circular", + ) + + self.Wxf = nn.Conv2D( + self.input_channels, + self.hidden_channels, + self.input_kernel_size, + self.input_stride, + self.input_padding, + bias_attr=None, + padding_mode="circular", + ) + + self.Whf = nn.Conv2D( + self.hidden_channels, + self.hidden_channels, + self.hidden_kernel_size, + 1, + padding=1, + bias_attr=False, + padding_mode="circular", + ) + + self.Wxc = nn.Conv2D( + self.input_channels, + self.hidden_channels, + self.input_kernel_size, + self.input_stride, + self.input_padding, + bias_attr=None, + padding_mode="circular", + ) + + self.Whc = nn.Conv2D( + self.hidden_channels, + self.hidden_channels, + self.hidden_kernel_size, + 1, + padding=1, + bias_attr=False, + padding_mode="circular", + ) + + self.Wxo = nn.Conv2D( + self.input_channels, + self.hidden_channels, + self.input_kernel_size, + self.input_stride, + self.input_padding, + bias_attr=None, + padding_mode="circular", + ) + + self.Who = nn.Conv2D( + self.hidden_channels, + self.hidden_channels, + self.hidden_kernel_size, + 1, + padding=1, + bias_attr=False, + padding_mode="circular", + ) + + initializer_0 = paddle.nn.initializer.Constant(0.0) + initializer_1 = paddle.nn.initializer.Constant(1.0) + + initializer_0(self.Wxi.bias) + initializer_0(self.Wxf.bias) + initializer_0(self.Wxc.bias) + initializer_1(self.Wxo.bias) + + def forward(self, x, h, c): + ci = paddle.nn.functional.sigmoid(self.Wxi(x) + self.Whi(h)) + cf = paddle.nn.functional.sigmoid(self.Wxf(x) + self.Whf(h)) + cc = cf * c + ci * paddle.tanh(self.Wxc(x) + self.Whc(h)) + co = paddle.nn.functional.sigmoid(self.Wxo(x) + self.Who(h)) + ch = co * paddle.tanh(cc) + return ch, cc + + def init_hidden_tensor(self, prev_state): + return ((prev_state[0]).cuda(), (prev_state[1]).cuda()) + + +class encoder_block(nn.Layer): + """encoder with CNN""" + + def __init__( + self, + input_channels, + hidden_channels, + input_kernel_size, + input_stride, + input_padding, + ): + super(encoder_block, self).__init__() + + self.input_channels = input_channels + self.hidden_channels = hidden_channels + self.input_kernel_size = input_kernel_size + self.input_stride = input_stride + self.input_padding = input_padding + + self.conv = utils.weight_norm( + nn.Conv2D( + self.input_channels, + self.hidden_channels, + self.input_kernel_size, + self.input_stride, + self.input_padding, + bias_attr=None, + padding_mode="circular", + ) + ) + + self.act = nn.ReLU() + + initializer_0 = paddle.nn.initializer.Constant(0.0) + initializer_0(self.conv.bias) + + def forward(self, x): + return self.act(self.conv(x)) + + +class Conv2DDerivative(nn.Layer): + def __init__(self, der_filter, resol, kernel_size=3, name=""): + super(Conv2DDerivative, self).__init__() + + self.resol = resol # constant in the finite difference + self.name = name + self.input_channels = 1 + self.output_channels = 1 + self.kernel_size = kernel_size + + self.padding = int((kernel_size - 1) / 2) + self.filter = nn.Conv2D( + self.input_channels, + self.output_channels, + self.kernel_size, + 1, + padding=0, + bias_attr=False, + ) + + # Fixed gradient operator + self.filter.weight = self.create_parameter( + shape=self.filter.weight.shape, + dtype=self.filter.weight.dtype, + default_initializer=paddle.nn.initializer.Assign( + paddle.to_tensor( + der_filter, dtype=paddle.get_default_dtype(), stop_gradient=True + ) + ), + ) + self.filter.weight.stop_gradient = True + + def forward(self, input): + derivative = self.filter(input) + return derivative / self.resol + + +class Conv1DDerivative(nn.Layer): + def __init__(self, der_filter, resol, kernel_size=3, name=""): + super(Conv1DDerivative, self).__init__() + + self.resol = resol # $\delta$*constant in the finite difference + self.name = name + self.input_channels = 1 + self.output_channels = 1 + self.kernel_size = kernel_size + + self.padding = int((kernel_size - 1) / 2) + self.filter = nn.Conv1D( + self.input_channels, + self.output_channels, + self.kernel_size, + 1, + padding=0, + bias_attr=False, + ) + + # Fixed gradient operator + self.filter.weight = self.create_parameter( + shape=self.filter.weight.shape, + dtype=self.filter.weight.dtype, + default_initializer=paddle.nn.initializer.Assign( + paddle.to_tensor( + der_filter, dtype=paddle.get_default_dtype(), stop_gradient=True + ) + ), + ) + self.filter.weight.stop_gradient = True + + def forward(self, input): + derivative = self.filter(input) + return derivative / self.resol + + +class loss_generator(nn.Layer): + """Loss generator for physics loss""" + + def __init__(self, dt, dx): + """Construct the derivatives, X = Width, Y = Height""" + super(loss_generator, self).__init__() + + # spatial derivative operator + self.laplace = Conv2DDerivative( + der_filter=LALP_OP, resol=(dx**2), kernel_size=5, name="laplace_operator" + ) + + self.dx = Conv2DDerivative( + der_filter=PARTIAL_X, resol=(dx * 1), kernel_size=5, name="dx_operator" + ) + + self.dy = Conv2DDerivative( + der_filter=PARTIAL_Y, resol=(dx * 1), kernel_size=5, name="dy_operator" + ) + + # temporal derivative operator + self.dt = Conv1DDerivative( + der_filter=[[[-1, 0, 1]]], resol=(dt * 2), kernel_size=3, name="partial_t" + ) + + def get_phy_Loss(self, output): + # spatial derivatives + laplace_u = self.laplace(output[1:-1, 0:1, :, :]) # [t,c,h,w] + laplace_v = self.laplace(output[1:-1, 1:2, :, :]) + + u_x = self.dx(output[1:-1, 0:1, :, :]) + u_y = self.dy(output[1:-1, 0:1, :, :]) + v_x = self.dx(output[1:-1, 1:2, :, :]) + v_y = self.dy(output[1:-1, 1:2, :, :]) + + # temporal derivative - u + u = output[:, 0:1, 2:-2, 2:-2] + lent = u.shape[0] + lenx = u.shape[3] + leny = u.shape[2] + u_conv1d = u.transpose((2, 3, 1, 0)) # [height(Y), width(X), c, step] + u_conv1d = u_conv1d.reshape((lenx * leny, 1, lent)) + u_t = self.dt(u_conv1d) # lent-2 due to no-padding + u_t = u_t.reshape((leny, lenx, 1, lent - 2)) + u_t = u_t.transpose((3, 2, 0, 1)) # [step-2, c, height(Y), width(X)] + + # temporal derivative - v + v = output[:, 1:2, 2:-2, 2:-2] + v_conv1d = v.transpose((2, 3, 1, 0)) # [height(Y), width(X), c, step] + v_conv1d = v_conv1d.reshape((lenx * leny, 1, lent)) + v_t = self.dt(v_conv1d) # lent-2 due to no-padding + v_t = v_t.reshape((leny, lenx, 1, lent - 2)) + v_t = v_t.transpose((3, 2, 0, 1)) # [step-2, c, height(Y), width(X)] + + u = output[1:-1, 0:1, 2:-2, 2:-2] # [t, c, height(Y), width(X)] + v = output[1:-1, 1:2, 2:-2, 2:-2] # [t, c, height(Y), width(X)] + + assert laplace_u.shape == u_t.shape + assert u_t.shape == v_t.shape + assert laplace_u.shape == u.shape + assert laplace_v.shape == v.shape + + # Reynolds number + R = 200.0 + + # 2D burgers eqn + f_u = u_t + u * u_x + v * u_y - (1 / R) * laplace_u + f_v = v_t + u * v_x + v * v_y - (1 / R) * laplace_v + + return f_u, f_v