Skip to content

Commit

Permalink
fix code block
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster committed Jul 30, 2024
1 parent 2ae136d commit 5c50363
Showing 1 changed file with 23 additions and 21 deletions.
44 changes: 23 additions & 21 deletions docs/programming_guide/controllers/model_controller.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,27 +159,29 @@ For example we can use PyTorch's save and load functions for the model parameter

.. code-block:: python
import torch
from nvflare.fuel.utils import fobs
def save_model(self, model, filepath=""):
params = model.params
# PyTorch save
torch.save(params, filepath)
# save FLModel metadata
model.params = {}
fobs.dumpf(model, filepath + ".metadata")
model.params = params
def load_model(self, filepath=""):
# PyTorch load
params = torch.load(filepath)
# load FLModel metadata
model = fobs.loadf(filepath + ".metadata")
model.params = params
return model
import torch
from nvflare.fuel.utils import fobs
class MyController(ModelController):
...
def save_model(self, model, filepath=""):
params = model.params
# PyTorch save
torch.save(params, filepath)
# save FLModel metadata
model.params = {}
fobs.dumpf(model, filepath + ".metadata")
model.params = params
def load_model(self, filepath=""):
# PyTorch load
params = torch.load(filepath)
# load FLModel metadata
model = fobs.loadf(filepath + ".metadata")
model.params = params
return model
Note: for non-primitive data types such as ``torch.nn.Module`` (used for the initial PyTorch model), we must configure a corresponding FOBS decomposer for serialization and deserialization.
Expand Down

0 comments on commit 5c50363

Please sign in to comment.