diff --git a/ppsci/utils/save_load.py b/ppsci/utils/save_load.py index d1616725a..b882c515d 100644 --- a/ppsci/utils/save_load.py +++ b/ppsci/utils/save_load.py @@ -80,13 +80,22 @@ def _load_pretrain_from_path( def load_pretrain( model: nn.Layer, path: str, equation: Optional[Dict[str, equation.PDE]] = None ): - """Load pretrained model from given path or url. + """ + Load pretrained model from given path or url. Args: model (nn.Layer): Model with parameters. path (str): File path or url of pretrained model, i.e. `/path/to/model.pdparams` or `http://xxx.com/model.pdparams`. equation (Optional[Dict[str, equation.PDE]]): Equations. Defaults to None. + + Examples: + >>> import ppsci + >>> from ppsci.utils import save_load + >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 9, 50, "tanh") + >>> save_load.load_pretrain( + ... model=model, + ... path="path/to/pretrain_model") # doctest: +SKIP """ if path.startswith("http"): path = download.get_weights_path_from_url(path) @@ -159,7 +168,8 @@ def save_checkpoint( equation: Optional[Dict[str, equation.PDE]] = None, print_log: bool = True, ): - """Save checkpoint, including model params, optimizer params, metric information. + """ + Save checkpoint, including model params, optimizer params, metric information. Args: model (nn.Layer): Model with parameters. @@ -172,6 +182,14 @@ def save_checkpoint( print_log (bool, optional): Whether print saving log information, mainly for keeping log tidy without duplicate 'Finish saving checkpoint ...' log strings. Defaults to True. + + Examples: + >>> import ppsci + >>> import paddle + >>> from ppsci.utils import save_load + >>> model = ppsci.arch.MLP(("x", "y", "z"), ("u", "v", "w"), 5, 64, "tanh") + >>> optimizer = ppsci.optimizer.Adam(0.001)(model) + >>> save_load.save_checkpoint(model, optimizer, {"RMSE": 0.1}, output_dir="path/to/output/dir") # doctest: +SKIP """ if paddle.distributed.get_rank() != 0: return