Skip to content

Commit

Permalink
Merge branch 'main' into check-for-memoryview-overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Sep 6, 2024
2 parents 17d09c0 + 8a1d161 commit 80c5f7e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def load_test_data(data_path: str):


def run_server(port: int, world_size: int) -> None:
xgboost.federated.run_federated_server(port, world_size)
xgboost.federated.run_federated_server(n_workers=world_size, port=port)


def run_worker(port: int, world_size: int, rank: int, args) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from nvflare.app_common.app_constant import ModelName

# (optional) set a fixed location so we don't need to download everytime
CIFAR10_ROOT = "~/data"
MODEL_SAVE_PATH_ROOT = "~/data"
CIFAR10_ROOT = "/tmp/nvflare/data"
MODEL_SAVE_PATH_ROOT = "/tmp/nvflare/data"

# (optional) We change to use GPU to speed things up.
# if you want to use CPU, change DEVICE="cpu"
Expand All @@ -41,7 +41,6 @@ def define_parser():
parser.add_argument("--batch_size", type=int, default=4, nargs="?")
parser.add_argument("--num_workers", type=int, default=1, nargs="?")
parser.add_argument("--local_epochs", type=int, default=2, nargs="?")
parser.add_argument("--model_path", type=str, default=f"{MODEL_SAVE_PATH_ROOT}/cifar_net.pth", nargs="?")
return parser.parse_args()


Expand All @@ -53,7 +52,6 @@ def main():
batch_size = args.batch_size
num_workers = args.num_workers
local_epochs = args.local_epochs
model_path = args.model_path

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=True, transform=transform)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ tests:
"data": { "run_finished": True }
validators:
- path: tests.integration_test.src.validators.PTModelValidator
- path: tests.integration_test.src.validators.CrossValResultValidator
args: { server_model_names: [ "server" ] }
setup:
- python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='~/data/', download=True)"
- mkdir -p /tmp/nvflare/data/site-1
- mkdir -p /tmp/nvflare/data/site-2
- python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', download=True)"
teardown:
- rm -rf ~/data
- rm -rf /tmp/nvflare/data

0 comments on commit 80c5f7e

Please sign in to comment.