From ff34fa00f39f75aa9961e5c214294e3dd857926f Mon Sep 17 00:00:00 2001 From: ooo oo <106524776+ooooo-create@users.noreply.github.com> Date: Fri, 19 Jan 2024 12:54:05 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PPSCI=20Doc=20No.104-105=E3=80=91=20(#?= =?UTF-8?q?759)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update docstring * remove spaces in args * update * remove extra blank line --- ppsci/utils/save_load.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/ppsci/utils/save_load.py b/ppsci/utils/save_load.py index d1616725a..b882c515d 100644 --- a/ppsci/utils/save_load.py +++ b/ppsci/utils/save_load.py @@ -80,13 +80,22 @@ def _load_pretrain_from_path( def load_pretrain( model: nn.Layer, path: str, equation: Optional[Dict[str, equation.PDE]] = None ): - """Load pretrained model from given path or url. + """ + Load pretrained model from given path or url. Args: model (nn.Layer): Model with parameters. path (str): File path or url of pretrained model, i.e. `/path/to/model.pdparams` or `http://xxx.com/model.pdparams`. equation (Optional[Dict[str, equation.PDE]]): Equations. Defaults to None. + + Examples: + >>> import ppsci + >>> from ppsci.utils import save_load + >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 9, 50, "tanh") + >>> save_load.load_pretrain( + ... model=model, + ... path="path/to/pretrain_model") # doctest: +SKIP """ if path.startswith("http"): path = download.get_weights_path_from_url(path) @@ -159,7 +168,8 @@ def save_checkpoint( equation: Optional[Dict[str, equation.PDE]] = None, print_log: bool = True, ): - """Save checkpoint, including model params, optimizer params, metric information. + """ + Save checkpoint, including model params, optimizer params, metric information. Args: model (nn.Layer): Model with parameters. @@ -172,6 +182,14 @@ def save_checkpoint( print_log (bool, optional): Whether print saving log information, mainly for keeping log tidy without duplicate 'Finish saving checkpoint ...' log strings. Defaults to True. + + Examples: + >>> import ppsci + >>> import paddle + >>> from ppsci.utils import save_load + >>> model = ppsci.arch.MLP(("x", "y", "z"), ("u", "v", "w"), 5, 64, "tanh") + >>> optimizer = ppsci.optimizer.Adam(0.001)(model) + >>> save_load.save_checkpoint(model, optimizer, {"RMSE": 0.1}, output_dir="path/to/output/dir") # doctest: +SKIP """ if paddle.distributed.get_rank() != 0: return