Skip to content

Commit

Permalink
fix imports for inprocess executor; update TF network
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed May 22, 2024
1 parent 5648ecb commit eb40e32
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 32 deletions.
2 changes: 1 addition & 1 deletion examples/hello-world/job_api/pt/src/cifar10_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from src.net import Net
from net import Net

# (1) import nvflare client API
import nvflare.client as flare
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import torch
import torchvision
import torchvision.transforms as transforms
from lit_net import LitNet
from pytorch_lightning import LightningDataModule, Trainer, seed_everything
from src.lit_net import LitNet
from torch.utils.data import DataLoader, random_split

# (1) import nvflare lightning client API
Expand Down
23 changes: 22 additions & 1 deletion examples/hello-world/job_api/pt/src/lit_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,37 @@

from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule
from src.net import Net
from torchmetrics import Accuracy

NUM_CLASSES = 10
criterion = nn.CrossEntropyLoss()


class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x


class LitNet(LightningModule):
def __init__(self):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion examples/hello-world/job_api/pt/src/train_eval_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from src.net import Net
from net import Net

# (1) import nvflare client API
import nvflare.client as flare
Expand Down
10 changes: 7 additions & 3 deletions examples/hello-world/job_api/tf/src/cifar10_tf_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from src.tf_net import TFNet
import tensorflow as tf
from tensorflow.keras import datasets
from tf_net import TFNet

# (1) import nvflare client API
import nvflare.client as flare

PATH = "./tf_model.ckpt"
PATH = "./tf_model.weights.h5"


def main():
Expand All @@ -27,7 +28,10 @@ def main():
# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0

model = TFNet(input_shape=(None, 32, 32, 3))
model = TFNet()
model.compile(
optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"]
)
model.summary()

# (2) initializes NVFlare client API
Expand Down
35 changes: 12 additions & 23 deletions examples/hello-world/job_api/tf/src/tf_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from tensorflow.keras import Model, layers, losses
from tensorflow.keras import layers, models


class TFNet(Model):
def __init__(self, input_shape):
class TFNet(models.Sequential):
def __init__(self):
super().__init__()
self._input_shape = input_shape # Required to get constructor arguments in job config
self.conv1 = layers.Conv2D(6, 5, activation="relu")
self.pool = layers.MaxPooling2D((2, 2), 2)
self.conv2 = layers.Conv2D(16, 5, activation="relu")
self.flatten = layers.Flatten()
self.fc1 = layers.Dense(120, activation="relu")
self.fc2 = layers.Dense(84, activation="relu")
self.fc3 = layers.Dense(10)
loss_fn = losses.SparseCategoricalCrossentropy(from_logits=True)
self.compile(optimizer="sgd", loss=loss_fn, metrics=["accuracy"])
self.build(input_shape)

def call(self, x):
x = self.pool(self.conv1(x))
x = self.pool(self.conv2(x))
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
self.add(layers.Input(shape=(32, 32, 3)))
self.add(layers.Conv2D(32, (3, 3), activation="relu"))
self.add(layers.MaxPooling2D((2, 2)))
self.add(layers.Conv2D(64, (3, 3), activation="relu"))
self.add(layers.MaxPooling2D((2, 2)))
self.add(layers.Conv2D(64, (3, 3), activation="relu"))
self.add(layers.Flatten())
self.add(layers.Dense(64, activation="relu"))
self.add(layers.Dense(10))
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
job.to(controller, "server")

# Define the initial global model and send to server
job.to(TFNet(input_shape=(None, 32, 32, 3)), "server")
job.to(TFNet(), "server")

# Add clients
for i in range(n_clients):
Expand Down
3 changes: 2 additions & 1 deletion examples/hello-world/ml-to-fl/tf/code/tf_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
class TFNet(models.Sequential):
def __init__(self):
super().__init__()
self.add(layers.Conv2D(32, (3, 3), activation="relu", input_shape=(32, 32, 3)))
self.add(layers.Input(shape=(32, 32, 3)))
self.add(layers.Conv2D(32, (3, 3), activation="relu"))
self.add(layers.MaxPooling2D((2, 2)))
self.add(layers.Conv2D(64, (3, 3), activation="relu"))
self.add(layers.MaxPooling2D((2, 2)))
Expand Down

0 comments on commit eb40e32

Please sign in to comment.