Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] Other example cuda versions #6958

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
33 changes: 25 additions & 8 deletions examples/sampling/graphbolt/lightning/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,12 @@ def __init__(self, in_feats, n_hidden, n_classes):
self.dropout = nn.Dropout(0.5)
self.n_hidden = n_hidden
self.n_classes = n_classes
self.train_acc = Accuracy(task="multiclass", num_classes=n_classes)
self.val_acc = Accuracy(task="multiclass", num_classes=n_classes)
self.train_acc = Accuracy(
task="multiclass", num_classes=n_classes, top_k=1
)
self.val_acc = Accuracy(
task="multiclass", num_classes=n_classes, top_k=1
)

def forward(self, blocks, x):
h = x
Expand Down Expand Up @@ -133,13 +137,14 @@ def configure_optimizers(self):


class DataModule(LightningDataModule):
def __init__(self, dataset, fanouts, batch_size, num_workers):
def __init__(self, dataset, fanouts, batch_size, num_workers, device):
super().__init__()
self.fanouts = fanouts
self.batch_size = batch_size
self.num_workers = num_workers
self.feature_store = dataset.feature
self.graph = dataset.graph
self.feature_store = dataset.feature.to(device)
self.graph = dataset.graph.to(device)
self.device = "cuda" if device != "cpu" else "cpu"
self.train_set = dataset.tasks[0].train_set
self.valid_set = dataset.tasks[0].validation_set
self.num_classes = dataset.tasks[0].metadata["num_classes"]
Expand All @@ -148,9 +153,10 @@ def create_dataloader(self, node_set, is_train):
datapipe = gb.ItemSampler(
node_set,
batch_size=self.batch_size,
shuffle=True,
drop_last=True,
shuffle=is_train,
drop_last=is_train,
)
datapipe = datapipe.copy_to(self.device, ["seed_nodes"])
sampler = (
datapipe.sample_layer_neighbor
if is_train
Expand Down Expand Up @@ -203,14 +209,25 @@ def val_dataloader(self):
default=0,
help="number of workers (default: 0)",
)
parser.add_argument(
"--storage_device",
default="pinned",
choices=["cpu", "pinned", "cuda"],
help="Moves the dataset into the selected storage",
)
args = parser.parse_args()

if not torch.cuda.is_available():
args.num_gpus = 0
args.storage_device = "cpu"

dataset = gb.BuiltinDataset("ogbn-products").load()
datamodule = DataModule(
dataset,
[10, 10, 10],
args.batch_size,
args.num_workers,
args.storage_device,
)
in_size = dataset.feature.size("node", None, "feat")[0]
model = SAGE(in_size, 256, datamodule.num_classes)
Expand All @@ -225,7 +242,7 @@ def val_dataloader(self):
# https://lightning.ai/docs/pytorch/stable/common/trainer.html.
########################################################################
trainer = Trainer(
accelerator="gpu",
accelerator="gpu" if args.num_gpus > 0 else "cpu",
devices=args.num_gpus,
max_epochs=args.epochs,
callbacks=[checkpoint_callback, early_stopping_callback],
Expand Down
12 changes: 6 additions & 6 deletions examples/sampling/pyg/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def create_dataloader(dataset_set, graph, feature, device, is_train):
# (HIGHLIGHT) Create a data loader for efficiently loading graph data.
#
# - 'ItemSampler' samples mini-batches of node IDs from the dataset.
# - 'CopyTo' copies the fetched data to the specified device.
# - 'sample_neighbor' performs neighbor sampling on the graph.
# - 'FeatureFetcher' fetches node features based on the sampled subgraph.
# - 'CopyTo' copies the fetched data to the specified device.

#####################################################################
# Create a datapipe for mini-batch sampling with a specific neighbor fanout.
Expand All @@ -108,12 +108,12 @@ def create_dataloader(dataset_set, graph, feature, device, is_train):
datapipe = gb.ItemSampler(
dataset_set, batch_size=1024, shuffle=is_train, drop_last=is_train
)
# Copy the data to the specified device.
datapipe = datapipe.copy_to(device=device, extra_attrs=["seed_nodes"])
# Sample neighbors for each node in the mini-batch.
datapipe = datapipe.sample_neighbor(graph, [10, 10, 10])
# Fetch node features for the sampled subgraph.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
# Copy the data to the specified device.
datapipe = datapipe.copy_to(device=device)
# Create and return a DataLoader to handle data loading.
dataloader = gb.DataLoader(datapipe, num_workers=0)

Expand Down Expand Up @@ -195,13 +195,13 @@ def main():
args = parser.parse_args()
dataset_name = args.dataset
dataset = gb.BuiltinDataset(dataset_name).load()
graph = dataset.graph
feature = dataset.feature
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
graph = dataset.graph.to(device)
feature = dataset.feature.to(device)
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
num_classes = dataset.tasks[0].metadata["num_classes"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataloader = create_dataloader(
train_set, graph, feature, device, is_train=True
Expand Down
Loading