Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ADream-ki committed Sep 4, 2024
1 parent f408904 commit 35cc0c3
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 82 deletions.
82 changes: 58 additions & 24 deletions examples/geofno/catheter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.nn.initializer as Initializer

import paddle.optimizer
from ppsci.arch import base


################################################################
Expand All @@ -15,28 +27,36 @@ def __init__(self, in_channels, out_channels, modes1):
super(SpectralConv1d, self).__init__()

"""
1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
"""

self.in_channels = in_channels
self.out_channels = out_channels
# Number of Fourier modes to multiply, at most floor(N/2) + 1
self.modes1 = modes1

self.scale = (1 / (in_channels*out_channels))
self.scale = 1 / (in_channels * out_channels)

real = paddle.rand(shape=[in_channels, out_channels, modes1])
real.stop_gradient = False
img = paddle.rand(shape=[in_channels, out_channels, modes1])
img.stop_gradient = False
self.weights1_real = self.create_parameter([in_channels, out_channels, self.modes1], attr=Initializer.Assign(self.scale*real))
self.weights1_imag = self.create_parameter([in_channels, out_channels, self.modes1], attr=Initializer.Assign(self.scale*img))
self.weights1 = paddle.complex(self.weights1_real,self.weights1_imag)
self.weights1_real = self.create_parameter(
[in_channels, out_channels, self.modes1],
attr=Initializer.Assign(self.scale * real),
)
self.weights1_imag = self.create_parameter(
[in_channels, out_channels, self.modes1],
attr=Initializer.Assign(self.scale * img),
)
self.weights1 = paddle.complex(self.weights1_real, self.weights1_imag)

tmp = paddle.ParamAttr(
initializer=Initializer.Normal(mean=0.+0.j, std=self.scale))
initializer=Initializer.Normal(mean=0.0 + 0.0j, std=self.scale)
)
self.weights1 = self.create_parameter(
[in_channels, out_channels, self.modes1], dtype="complex64", attr=tmp)
[in_channels, out_channels, self.modes1], dtype="complex64", attr=tmp
)

# Complex multiplication
def compl_mul1d(self, input, weights):
Expand All @@ -50,16 +70,19 @@ def forward(self, x, output_size=None):

# Multiply relevant Fourier modes
out_ft_real = paddle.zeros(
[batchsize, self.out_channels, x.shape[-1]//2 + 1], dtype='float32')
[batchsize, self.out_channels, x.shape[-1] // 2 + 1], dtype="float32"
)
out_ft_img = paddle.zeros(
[batchsize, self.out_channels, x.shape[-1]//2 + 1], dtype='float32')
[batchsize, self.out_channels, x.shape[-1] // 2 + 1], dtype="float32"
)
out_ft = paddle.complex(out_ft_real, out_ft_img)

out_ft[:, :, :self.modes1] = self.compl_mul1d(
x_ft[:, :, :self.modes1], self.weights1)
out_ft[:, :, : self.modes1] = self.compl_mul1d(
x_ft[:, :, : self.modes1], self.weights1
)

# Return to physical space
if output_size == None:
if output_size is None:
x = paddle.fft.irfft(out_ft, n=x.shape[-1])
else:
x = paddle.fft.irfft(out_ft, n=output_size)
Expand All @@ -68,10 +91,17 @@ def forward(self, x, output_size=None):


class FNO1d(nn.Layer):
def __init__(self, input_key="input", output_key="output", modes=64, width=64, padding=100, input_channel=2,
output_np=2001):
super().__init__(input_keys=input_key,
output_keys=output_key)
def __init__(
self,
input_key="input",
output_key="output",
modes=64,
width=64,
padding=100,
input_channel=2,
output_np=2001,
):
super().__init__(input_keys=input_key, output_keys=output_key)
"""
The overall network. It contains 4 layers of the Fourier layer.
1. Lift the input to the desire channel dimension by self.fc0 .
Expand Down Expand Up @@ -104,10 +134,15 @@ def __init__(self, input_key="input", output_key="output", modes=64, width=64, p
self.fc1 = nn.Linear(self.width, 128)
self.fc2 = nn.Linear(128, 1)

def _FUNCTIONAL_PAD(self, x, pad, mode='constant', value=0.0, data_format='NCL'):
def _FUNCTIONAL_PAD(self, x, pad, mode="constant", value=0.0, data_format="NCL"):
if len(x.shape) * 2 == len(pad) and mode == "constant":
pad = paddle.to_tensor(pad, dtype="float32").reshape(
(-1, 2)).flip([0]).flatten().tolist()
pad = (
paddle.to_tensor(pad, dtype="float32")
.reshape((-1, 2))
.flip([0])
.flatten()
.tolist()
)
return F.pad(x, pad, mode, value, data_format)

def forward(self, x):
Expand Down Expand Up @@ -138,10 +173,9 @@ def forward(self, x):
x = x1 + x2
x = F.gelu(x, approximate=False)

x = x[..., :-self.padding]
x = x[..., : -self.padding]
x1 = self.conv4(x, self.output_np)
x2 = F.interpolate(x, size=[self.output_np],
mode='linear', align_corners=True)
x2 = F.interpolate(x, size=[self.output_np], mode="linear", align_corners=True)
x = x1 + x2
# x(batch, channel, 2001)
x = x.transpose(perm=[0, 2, 1])
Expand Down
103 changes: 60 additions & 43 deletions examples/geofno/geofno.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import hydra
from os import path as osp
from omegaconf import DictConfig

from utilities3 import *
from catheter import *
import numpy as np
import hydra
import matplotlib.pyplot as plt
import numpy as np
import paddle
from catheter import FNO1d
from omegaconf import DictConfig
from utilities3 import LpLoss

import ppsci
from ppsci.optimizer import Adam
from ppsci.optimizer import lr_scheduler
from ppsci.utils import logger
from ppsci.optimizer import Adam, lr_scheduler


# build data
def getdata(x_path, y_path, para_path, output_path, n_data, n, s, is_train=True):

# load data
inputX_raw = np.load(x_path)[:, 0:n_data]
inputY_raw = np.load(y_path)[:, 0:n_data]
Expand All @@ -37,17 +41,13 @@ def getdata(x_path, y_path, para_path, output_path, n_data, n, s, is_train=True)
inputX = inputX_raw[:, 0::3]
inputY = inputY_raw[:, 0::3]
inputPara = inputPara_raw[:, 0::3]
output = (output_raw[:, 0::3] + output_raw[:, 1::3] + output_raw[:, 2::3]
) / 3.0
output = (output_raw[:, 0::3] + output_raw[:, 1::3] + output_raw[:, 2::3]) / 3.0

inputX = paddle.to_tensor(
data=inputX, dtype='float32').transpose(perm=[1, 0])
inputY = paddle.to_tensor(
data=inputY, dtype='float32').transpose(perm=[1, 0])
inputX = paddle.to_tensor(data=inputX, dtype="float32").transpose(perm=[1, 0])
inputY = paddle.to_tensor(data=inputY, dtype="float32").transpose(perm=[1, 0])
input = paddle.stack(x=[inputX, inputY], axis=-1)
output = paddle.to_tensor(
data=output, dtype='float32').transpose(perm=[1, 0])
if (is_train):
output = paddle.to_tensor(data=output, dtype="float32").transpose(perm=[1, 0])
if is_train:
index = paddle.randperm(n=n)
index = index[:n]

Expand All @@ -65,13 +65,10 @@ def train(cfg: DictConfig):
# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(cfg.seed)
# initialize logger
logger.init_logger("ppsci", osp.join(
cfg.output_dir, f"{cfg.mode}.log"), "info")
logger.init_logger("ppsci", osp.join(cfg.output_dir, f"{cfg.mode}.log"), "info")

# generate training dataset
inputs_train, labels_train, _ = getdata(
**cfg.TRAIN_DATA, is_train=True
)
inputs_train, labels_train, _ = getdata(**cfg.TRAIN_DATA, is_train=True)

# set constraints
sup_constraint = ppsci.constraint.SupervisedConstraint(
Expand All @@ -88,8 +85,7 @@ def train(cfg: DictConfig):
"shuffle": True,
},
},
ppsci.loss.FunctionalLoss(
LpLoss(**cfg.TRAIN.LOSS, size_average=False)),
ppsci.loss.FunctionalLoss(LpLoss(**cfg.TRAIN.LOSS, size_average=False)),
name="sup_constraint",
)
constraint = {sup_constraint.name: sup_constraint}
Expand All @@ -101,13 +97,12 @@ def train(cfg: DictConfig):
# set optimizer
ITERS_PER_EPOCH = int(cfg.TRAIN_DATA.n / cfg.TRAIN.batch_size)
scheduler = lr_scheduler.Step(
**cfg.TRAIN.lr_scheduler, iters_per_epoch=ITERS_PER_EPOCH)
**cfg.TRAIN.lr_scheduler, iters_per_epoch=ITERS_PER_EPOCH
)
optimizer = Adam(scheduler, weight_decay=cfg.TRAIN.weight_decay)(model)

# generate test dataset
inputs_test, labels_test, _ = getdata(
**cfg.TEST_DATA, is_train=False
)
inputs_test, labels_test, _ = getdata(**cfg.TEST_DATA, is_train=False)

# set validator
l2rel_validator = ppsci.validate.SupervisedValidator(
Expand Down Expand Up @@ -149,29 +144,50 @@ def evaluate(cfg: DictConfig):
model.set_state_dict(paddle.load(cfg.TRAIN.model_path))

# set data
x_test, y_test, para = getdata(
**cfg.TEST_DATA, is_train=False
)
x_test, y_test, para = getdata(**cfg.TEST_DATA, is_train=False)
y_test = y_test.detach().cpu().numpy().flatten()

for sample_id in [0, 8]:
sample, uf, L_p, x1, x2, x3, h = para[:, sample_id]
mesh = x_test[sample_id, :, :]

y_test_pred = paddle.exp(model({"input": x_test[sample_id:sample_id+1, :, :]})[
'output']).detach().cpu().numpy().flatten()
print("rel. error is ", np.linalg.norm(y_test_pred -
y_test[sample_id, :].numpy())/np.linalg.norm(y_test[sample_id, :]))
y_test_pred = (
paddle.exp(
model({"input": x_test[sample_id : sample_id + 1, :, :]})["output"]
)
.detach()
.cpu()
.numpy()
.flatten()
)
print(
"rel. error is ",
np.linalg.norm(y_test_pred - y_test[sample_id, :].numpy())
/ np.linalg.norm(y_test[sample_id, :]),
)
xx = np.linspace(-500, 0, 2001)
plt.figure(figsize=(5, 4))

plt.plot(mesh[:, 0], mesh[:, 1], color="C1", label="Channel geometry")
plt.plot(mesh[:, 0], 100-mesh[:, 1], color="C1")

plt.plot(xx, y_test[sample_id, :], "--o", color="red",
markevery=len(xx)//10, label="Reference")
plt.plot(xx, y_test_pred, "--*", color="C2", fillstyle='none',
markevery=len(xx)//10, label="Predicted bacteria distribution")
plt.plot(mesh[:, 0], 100 - mesh[:, 1], color="C1")

plt.plot(
xx,
y_test[sample_id, :],
"--o",
color="red",
markevery=len(xx) // 10,
label="Reference",
)
plt.plot(
xx,
y_test_pred,
"--*",
color="C2",
fillstyle="none",
markevery=len(xx) // 10,
label="Predicted bacteria distribution",
)

plt.xlabel(r"x")

Expand All @@ -193,8 +209,10 @@ def export(cfg: DictConfig):
from paddle.static import InputSpec

input_spec = [
{key: InputSpec([None, 2001, 2], "float32", name=key)
for key in model.input_keys},
{
key: InputSpec([None, 2001, 2], "float32", name=key)
for key in model.input_keys
},
]
solver.export(input_spec, cfg.INFER.export_path)

Expand All @@ -209,8 +227,7 @@ def main(cfg: DictConfig):
export(cfg)
else:
raise ValueError(
f"cfg.mode should in ['train', 'eval', 'export'], but got '{
cfg.mode}'"
f"cfg.mode should in ['train', 'eval', 'export'], but got '{cfg.mode}'"
)


Expand Down
Loading

0 comments on commit 35cc0c3

Please sign in to comment.