Skip to content

Commit

Permalink
【PPSCI Doc No.104-105】 (#759)
Browse files Browse the repository at this point in the history
* update docstring

* remove spaces in args

* update

* remove extra blank line
  • Loading branch information
ooooo-create authored Jan 19, 2024
1 parent 26a0a85 commit ff34fa0
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions ppsci/utils/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,22 @@ def _load_pretrain_from_path(
def load_pretrain(
model: nn.Layer, path: str, equation: Optional[Dict[str, equation.PDE]] = None
):
"""Load pretrained model from given path or url.
"""
Load pretrained model from given path or url.
Args:
model (nn.Layer): Model with parameters.
path (str): File path or url of pretrained model, i.e. `/path/to/model.pdparams`
or `http://xxx.com/model.pdparams`.
equation (Optional[Dict[str, equation.PDE]]): Equations. Defaults to None.
Examples:
>>> import ppsci
>>> from ppsci.utils import save_load
>>> model = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 9, 50, "tanh")
>>> save_load.load_pretrain(
... model=model,
... path="path/to/pretrain_model") # doctest: +SKIP
"""
if path.startswith("http"):
path = download.get_weights_path_from_url(path)
Expand Down Expand Up @@ -159,7 +168,8 @@ def save_checkpoint(
equation: Optional[Dict[str, equation.PDE]] = None,
print_log: bool = True,
):
"""Save checkpoint, including model params, optimizer params, metric information.
"""
Save checkpoint, including model params, optimizer params, metric information.
Args:
model (nn.Layer): Model with parameters.
Expand All @@ -172,6 +182,14 @@ def save_checkpoint(
print_log (bool, optional): Whether print saving log information, mainly for
keeping log tidy without duplicate 'Finish saving checkpoint ...' log strings.
Defaults to True.
Examples:
>>> import ppsci
>>> import paddle
>>> from ppsci.utils import save_load
>>> model = ppsci.arch.MLP(("x", "y", "z"), ("u", "v", "w"), 5, 64, "tanh")
>>> optimizer = ppsci.optimizer.Adam(0.001)(model)
>>> save_load.save_checkpoint(model, optimizer, {"RMSE": 0.1}, output_dir="path/to/output/dir") # doctest: +SKIP
"""
if paddle.distributed.get_rank() != 0:
return
Expand Down

0 comments on commit ff34fa0

Please sign in to comment.