diff --git a/docs/docs/ProgrammingGuide/pytorch.md b/docs/docs/ProgrammingGuide/pytorch.md index 8e2599bde45..0d9f2b79ea3 100644 --- a/docs/docs/ProgrammingGuide/pytorch.md +++ b/docs/docs/ProgrammingGuide/pytorch.md @@ -14,24 +14,24 @@ Two wrappers are defined in Analytics Zoo for Pytorch: 1. TorchModel: TorchModel is a wrapper class for Pytorch model. User may create a TorchModel by providing a Pytorch model, e.g. -```python + ```python from zoo.pipeline.api.torch import TorchModel TorchModel.from_pytorch(torchvision.models.resnet18(pretrained=True)) -``` + ``` The above line creates TorchModel wrapping a ResNet model, and user can use the TorchModel for training or inference with Analytics Zoo. 2. TorchLoss: TorchLoss is a wrapper for loss functions defined by Pytorch. User may create a TorchLoss from a Pytorch Criterion, -```python + ```python from torch import nn from zoo.pipeline.api.torch import TorchLoss - az_criterion = TorchLoss.from_pytorch(loss=nn.MSELoss()) -``` -or from a custom loss function, which takes input and label as parameters + az_criterion = TorchLoss.from_pytorch(nn.MSELoss()) + ``` + or from a custom loss function, which takes input and label as parameters -```python + ```python from torch import nn from zoo.pipeline.api.torch import TorchLoss @@ -44,8 +44,8 @@ or from a custom loss function, which takes input and label as parameters loss = loss1 + 0.4 * loss2 return loss - az_criterion = TorchLoss.from_pytorch(loss=lossFunc) -``` + az_criterion = TorchLoss.from_pytorch(lossFunc) + ``` # Examples Here we provide a simple end to end example, where we use TorchModel and TorchLoss to