From 367ea5e9d7c10f0cdf87f041f690a2d1c32778ec Mon Sep 17 00:00:00 2001 From: Liang Geng Date: Tue, 13 Oct 2020 11:41:56 +0800 Subject: [PATCH] add vertex data ctx (#4) * add vertex data ctx * update * app extends vertex data context * revert * bugfix * fix * fix * add finalize method to save result * disable asan * add context_type method * change type * add get data method * cleanup * cleanup * optimizations * fix * fix * fix * fix * fix * fix * fix * expose ctx data type * move set_fragment to Init * remove ctx_data_t * Revert "remove ctx_data_t" This reverts commit 70280df287c23a4120a4c25ad1b79b3e418e12e9. * cleanup * add dummy data_t * add dummy data_t * cleanup * fix * refactoring ctx * fix * fix * refine * refine * refine * refactor all * fix * fix * fix * fix * make changes according the review * change return type of context_type * cleanup Co-authored-by: guanyi.gl --- .github/workflows/c-cpp.yml | 2 +- .../analytical_apps/bfs/bfs_auto_context.h | 17 ++++-- examples/analytical_apps/bfs/bfs_context.h | 22 ++++--- .../analytical_apps/cdlp/cdlp_auto_context.h | 17 ++++-- examples/analytical_apps/cdlp/cdlp_context.h | 19 +++--- examples/analytical_apps/lcc/lcc.h | 16 ++++- examples/analytical_apps/lcc/lcc_auto.h | 14 +++++ .../analytical_apps/lcc/lcc_auto_context.h | 16 +++-- examples/analytical_apps/lcc/lcc_context.h | 16 +++-- examples/analytical_apps/pagerank/pagerank.h | 16 ++++- .../analytical_apps/pagerank/pagerank_auto.h | 8 +++ .../pagerank/pagerank_auto_context.h | 26 +++++---- .../pagerank/pagerank_context.h | 26 +++++---- .../pagerank/pagerank_local_context.h | 18 ++++-- .../pagerank/pagerank_local_parallel.h | 8 +++ .../pagerank_local_parallel_context.h | 22 +++---- .../pagerank/pagerank_parallel.h | 8 +++ .../pagerank/pagerank_parallel_context.h | 26 +++++---- .../analytical_apps/sssp/sssp_auto_context.h | 16 +++-- examples/analytical_apps/sssp/sssp_context.h | 19 +++--- examples/analytical_apps/wcc/wcc_auto.h | 3 +- .../analytical_apps/wcc/wcc_auto_context.h | 25 +++++--- examples/analytical_apps/wcc/wcc_context.h | 15 +++-- .../append_only_edgecut_fragment.h | 2 + examples/gnn_sampler/sampler.h | 13 +++++ examples/gnn_sampler/sampler_context.h | 18 ++++-- grape/app/context_base.h | 13 +++-- grape/app/vertex_data_context.h | 58 +++++++++++++++++++ grape/grape.h | 1 + grape/parallel/sync_buffer.h | 13 +++-- grape/utils/default_allocator.h | 6 +- grape/utils/vertex_array.h | 2 +- grape/worker/auto_worker.h | 27 +++++---- grape/worker/batch_shuffle_worker.h | 25 +++++--- grape/worker/parallel_worker.h | 26 ++++++--- 35 files changed, 413 insertions(+), 166 deletions(-) create mode 100644 grape/app/vertex_data_context.h diff --git a/.github/workflows/c-cpp.yml b/.github/workflows/c-cpp.yml index 45cadf57..e8c49701 100644 --- a/.github/workflows/c-cpp.yml +++ b/.github/workflows/c-cpp.yml @@ -34,7 +34,7 @@ jobs: run: | mkdir build cd build - cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_ASAN=ON + cmake .. -DCMAKE_BUILD_TYPE=Debug make cpplint make - name: App Test diff --git a/examples/analytical_apps/bfs/bfs_auto_context.h b/examples/analytical_apps/bfs/bfs_auto_context.h index af2f27b3..efd25fba 100644 --- a/examples/analytical_apps/bfs/bfs_auto_context.h +++ b/examples/analytical_apps/bfs/bfs_auto_context.h @@ -27,16 +27,21 @@ namespace grape { * @tparam FRAG_T */ template -class BFSAutoContext : public ContextBase { +class BFSAutoContext : public VertexDataContext { public: using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; - void Init(const FRAG_T& frag, AutoParallelMessageManager& messages, - oid_t src_id) { - source_id = src_id; + explicit BFSAutoContext(const FRAG_T& fragment) + : VertexDataContext(fragment), + partial_result(this->data()) {} + void Init(AutoParallelMessageManager& messages, + oid_t src_id) { + auto &frag = this->fragment(); auto vertices = frag.Vertices(); + + source_id = src_id; partial_result.Init(vertices, std::numeric_limits::max(), [](int64_t* lhs, int64_t rhs) { if (*lhs > rhs) { @@ -51,8 +56,10 @@ class BFSAutoContext : public ContextBase { MessageStrategy::kSyncOnOuterVertex); } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto &frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); + for (auto v : inner_vertices) { os << frag.GetId(v) << " " << partial_result[v] << std::endl; } diff --git a/examples/analytical_apps/bfs/bfs_context.h b/examples/analytical_apps/bfs/bfs_context.h index 7b4af0a2..bbc6172b 100644 --- a/examples/analytical_apps/bfs/bfs_context.h +++ b/examples/analytical_apps/bfs/bfs_context.h @@ -27,19 +27,22 @@ namespace grape { * @tparam FRAG_T */ template -class BFSContext : public ContextBase { +class BFSContext : public VertexDataContext { public: using depth_type = int64_t; using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; - void Init(const FRAG_T& frag, ParallelMessageManager& messages, - oid_t src_id) { - source_id = src_id; + explicit BFSContext(const FRAG_T& fragment) + : VertexDataContext(fragment, true), + partial_result(this->data()) {} - auto vertices = frag.Vertices(); - partial_result.Init(vertices, std::numeric_limits::max()); + void Init(ParallelMessageManager& messages, + oid_t src_id) { + auto &frag = this->fragment(); + source_id = src_id; + partial_result.SetValue(std::numeric_limits::max()); avg_degree = static_cast(frag.GetEdgeNum()) / static_cast(frag.GetInnerVerticesNum()); @@ -50,12 +53,13 @@ class BFSContext : public ContextBase { #endif } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto &frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); + for (auto v : inner_vertices) { os << frag.GetId(v) << " " << partial_result[v] << std::endl; } - #ifdef PROFILING VLOG(2) << "preprocess_time: " << preprocess_time << "s."; VLOG(2) << "exec_time: " << exec_time << "s."; @@ -64,7 +68,7 @@ class BFSContext : public ContextBase { } oid_t source_id; - typename FRAG_T::template vertex_array_t partial_result; + typename FRAG_T::template vertex_array_t& partial_result; DenseVertexSet curr_inner_updated, next_inner_updated; depth_type current_depth = 0; diff --git a/examples/analytical_apps/cdlp/cdlp_auto_context.h b/examples/analytical_apps/cdlp/cdlp_auto_context.h index 0527a1ce..b157a4b8 100644 --- a/examples/analytical_apps/cdlp/cdlp_auto_context.h +++ b/examples/analytical_apps/cdlp/cdlp_auto_context.h @@ -25,7 +25,8 @@ namespace grape { * @tparam FRAG_T */ template -class CDLPAutoContext : public ContextBase { +class CDLPAutoContext + : public VertexDataContext { public: using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; @@ -35,12 +36,16 @@ class CDLPAutoContext : public ContextBase { using label_t = oid_t; #endif - void Init(const FRAG_T& frag, AutoParallelMessageManager& messages, - int max_round) { - this->max_round = max_round; + explicit CDLPAutoContext(const FRAG_T& fragment) + : VertexDataContext(fragment, true), + labels(this->data()) {} + void Init(AutoParallelMessageManager& messages, int max_round) { + auto& frag = this->fragment(); auto vertices = frag.Vertices(); auto inner_vertices = frag.InnerVertices(); + + this->max_round = max_round; labels.Init(vertices, 0, [](label_t* lhs, label_t rhs) { *lhs = rhs; return true; @@ -72,8 +77,10 @@ class CDLPAutoContext : public ContextBase { step = 0; } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto& frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); + for (auto v : inner_vertices) { os << frag.GetId(v) << " " << labels[v] << std::endl; } diff --git a/examples/analytical_apps/cdlp/cdlp_context.h b/examples/analytical_apps/cdlp/cdlp_context.h index a2268fef..08047a73 100644 --- a/examples/analytical_apps/cdlp/cdlp_context.h +++ b/examples/analytical_apps/cdlp/cdlp_context.h @@ -27,7 +27,7 @@ namespace grape { * @tparam FRAG_T */ template -class CDLPContext : public ContextBase { +class CDLPContext : public VertexDataContext { public: using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; @@ -37,13 +37,16 @@ class CDLPContext : public ContextBase { #else using label_t = oid_t; #endif - void Init(const FRAG_T& frag, ParallelMessageManager& messages, + explicit CDLPContext(const FRAG_T& fragment) + : VertexDataContext(fragment, true), + labels(this->data()) {} + + void Init(ParallelMessageManager& messages, int max_round) { - this->max_round = max_round; + auto &frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); - auto vertices = frag.Vertices(); - labels.Init(vertices); + this->max_round = max_round; changed.Init(inner_vertices); #ifdef PROFILING @@ -54,14 +57,16 @@ class CDLPContext : public ContextBase { step = 0; } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto &frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); + for (auto v : inner_vertices) { os << frag.GetId(v) << " " << labels[v] << std::endl; } } - typename FRAG_T::template vertex_array_t labels; + typename FRAG_T::template vertex_array_t& labels; typename FRAG_T::template vertex_array_t changed; #ifdef PROFILING diff --git a/examples/analytical_apps/lcc/lcc.h b/examples/analytical_apps/lcc/lcc.h index 3edd28a2..a508b3f4 100644 --- a/examples/analytical_apps/lcc/lcc.h +++ b/examples/analytical_apps/lcc/lcc.h @@ -206,8 +206,20 @@ class LCC : public ParallelAppBase>, ctx.preprocess_time += GetCurrentTime(); #endif } else { - messages.ParallelProcess( - thread_num(), frag, [](int tid, vertex_t u, int) {}); + auto& global_degree = ctx.global_degree; + auto& tricnt = ctx.tricnt; + auto& ctx_data = ctx.data(); + + for (auto v : inner_vertices) { + if (global_degree[v] == 0 || global_degree[v] == 1) { + ctx_data[v] = 0; + } else { + double re = 2.0 * (tricnt[v]) / + (static_cast(global_degree[v]) * + (static_cast(global_degree[v]) - 1)); + ctx_data[v] = re; + } + } } } }; diff --git a/examples/analytical_apps/lcc/lcc_auto.h b/examples/analytical_apps/lcc/lcc_auto.h index b4ba1688..080f2e1d 100644 --- a/examples/analytical_apps/lcc/lcc_auto.h +++ b/examples/analytical_apps/lcc/lcc_auto.h @@ -125,6 +125,20 @@ class LCCAuto : public AutoAppBase> { } } else if (ctx.stage == 2) { ctx.stage = 3; + auto& global_degree = ctx.global_degree; + auto& tricnt = ctx.tricnt; + auto& ctx_data = ctx.data(); + + for (auto v : inner_vertices) { + if (global_degree[v] == 0 || global_degree[v] == 1) { + ctx_data[v] = 0; + } else { + double re = 2.0 * (tricnt[v]) / + (static_cast(global_degree[v]) * + (static_cast(global_degree[v]) - 1)); + ctx_data[v] = re; + } + } } } }; diff --git a/examples/analytical_apps/lcc/lcc_auto_context.h b/examples/analytical_apps/lcc/lcc_auto_context.h index 1985bdd9..4a21b566 100644 --- a/examples/analytical_apps/lcc/lcc_auto_context.h +++ b/examples/analytical_apps/lcc/lcc_auto_context.h @@ -28,13 +28,18 @@ namespace grape { * @tparam FRAG_T */ template -class LCCAutoContext : public ContextBase { +class LCCAutoContext : public VertexDataContext { public: using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; - void Init(const FRAG_T& frag, AutoParallelMessageManager& messages) { + explicit LCCAutoContext(const FRAG_T& fragment) + : VertexDataContext(fragment) {} + + void Init(AutoParallelMessageManager& messages) { + auto &frag = this->fragment(); auto vertices = frag.Vertices(); + global_degree.Init(vertices, 0, [](int* lhs, int rhs) { *lhs = rhs; return true; @@ -59,7 +64,8 @@ class LCCAutoContext : public ContextBase { MessageStrategy::kSyncOnOuterVertex); } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto& frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); for (auto v : inner_vertices) { if (global_degree[v] == 0 || global_degree[v] == 1) { @@ -67,8 +73,8 @@ class LCCAutoContext : public ContextBase { << 0.0 << std::endl; } else { double re = 2.0 * (tricnt[v]) / - (static_cast(global_degree[v]) * - (static_cast(global_degree[v]) - 1)); + (static_cast(global_degree[v]) * + (static_cast(global_degree[v]) - 1)); os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) << re << std::endl; } diff --git a/examples/analytical_apps/lcc/lcc_context.h b/examples/analytical_apps/lcc/lcc_context.h index 2720b849..5019b70c 100644 --- a/examples/analytical_apps/lcc/lcc_context.h +++ b/examples/analytical_apps/lcc/lcc_context.h @@ -29,20 +29,26 @@ namespace grape { * @tparam FRAG_T */ template -class LCCContext : public ContextBase { +class LCCContext : public VertexDataContext { public: using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; using vertex_t = typename FRAG_T::vertex_t; - void Init(const FRAG_T& frag, ParallelMessageManager& messages) { + explicit LCCContext(const FRAG_T& fragment) + : VertexDataContext(fragment) {} + + void Init(ParallelMessageManager& messages) { + auto &frag = this->fragment(); auto vertices = frag.Vertices(); + global_degree.Init(vertices); complete_neighbor.Init(vertices); tricnt.Init(vertices, 0); } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto& frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); for (auto v : inner_vertices) { if (global_degree[v] == 0 || global_degree[v] == 1) { @@ -50,8 +56,8 @@ class LCCContext : public ContextBase { << 0.0 << std::endl; } else { double re = 2.0 * (tricnt[v]) / - (static_cast(global_degree[v]) * - (static_cast(global_degree[v]) - 1)); + (static_cast(global_degree[v]) * + (static_cast(global_degree[v]) - 1)); os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) << re << std::endl; } diff --git a/examples/analytical_apps/pagerank/pagerank.h b/examples/analytical_apps/pagerank/pagerank.h index fcca2adc..015d65fa 100644 --- a/examples/analytical_apps/pagerank/pagerank.h +++ b/examples/analytical_apps/pagerank/pagerank.h @@ -192,12 +192,22 @@ class PageRank : public BatchShuffleAppBase>, #ifdef PROFILING ctx.postprocess_time -= GetCurrentTime(); #endif + ctx.result.Swap(ctx.next_result); + if (ctx.step != ctx.max_round) { - messages.SyncInnerVertices(frag, ctx.next_result, + messages.SyncInnerVertices(frag, ctx.result, thread_num()); - } + } else { + auto& degree = ctx.degree; + auto& result = ctx.result; - ctx.result.Swap(ctx.next_result); + for (auto v : inner_vertices) { + if (degree[v] != 0) { + result[v] *= degree[v]; + } + } + return; + } #ifdef PROFILING ctx.postprocess_time += GetCurrentTime(); #endif diff --git a/examples/analytical_apps/pagerank/pagerank_auto.h b/examples/analytical_apps/pagerank/pagerank_auto.h index 23639ce1..56e90246 100644 --- a/examples/analytical_apps/pagerank/pagerank_auto.h +++ b/examples/analytical_apps/pagerank/pagerank_auto.h @@ -83,6 +83,14 @@ class PageRankAuto : public AutoAppBase>, ++ctx.step; if (ctx.step > ctx.max_round) { + auto& degree = ctx.degree; + auto& results = ctx.results; + + for (auto v : inner_vertices) { + if (degree[v] != 0) { + results[v] *= degree[v]; + } + } return; } diff --git a/examples/analytical_apps/pagerank/pagerank_auto_context.h b/examples/analytical_apps/pagerank/pagerank_auto_context.h index 07b4081c..4a9196ab 100644 --- a/examples/analytical_apps/pagerank/pagerank_auto_context.h +++ b/examples/analytical_apps/pagerank/pagerank_auto_context.h @@ -27,17 +27,23 @@ namespace grape { * @tparam FRAG_T */ template -class PageRankAutoContext : public ContextBase { +class PageRankAutoContext : public VertexDataContext { public: using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; - void Init(const FRAG_T& frag, AutoParallelMessageManager& messages, + explicit PageRankAutoContext(const FRAG_T& fragment) + : VertexDataContext(fragment, true), + results(this->data()) {} + + void Init(AutoParallelMessageManager& messages, double delta, int max_round) { - this->delta = delta; - this->max_round = max_round; + auto &frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); auto vertices = frag.Vertices(); + + this->delta = delta; + this->max_round = max_round; degree.Init(inner_vertices, 0); results.Init(vertices, 0.0, [](double* lhs, double rhs) { *lhs = rhs; @@ -49,16 +55,12 @@ class PageRankAutoContext : public ContextBase { step = 0; } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto& frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); for (auto v : inner_vertices) { - if (degree[v] == 0) { - os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) - << results[v] << std::endl; - } else { - os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) - << results[v] * degree[v] << std::endl; - } + os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) + << results[v] << std::endl; } } diff --git a/examples/analytical_apps/pagerank/pagerank_context.h b/examples/analytical_apps/pagerank/pagerank_context.h index 47257f23..418a7ca0 100644 --- a/examples/analytical_apps/pagerank/pagerank_context.h +++ b/examples/analytical_apps/pagerank/pagerank_context.h @@ -27,19 +27,25 @@ namespace grape { * @tparam FRAG_T */ template -class PageRankContext : public ContextBase { +class PageRankContext : public VertexDataContext { using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; public: - void Init(const FRAG_T& frag, BatchShuffleMessageManager& messages, + explicit PageRankContext(const FRAG_T& fragment) + : VertexDataContext(fragment, true), + result(this->data()) {} + + void Init(BatchShuffleMessageManager& messages, double delta, int max_round) { + auto &frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); auto vertices = frag.Vertices(); + this->delta = delta; this->max_round = max_round; degree.Init(inner_vertices, 0); - result.Init(vertices, 0.0); + result.SetValue(0.0); next_result.Init(vertices); step = 0; @@ -52,16 +58,12 @@ class PageRankContext : public ContextBase { #endif } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto& frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); for (auto v : inner_vertices) { - if (degree[v] == 0) { - os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) - << result[v] << std::endl; - } else { - os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) - << result[v] * degree[v] << std::endl; - } + os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) + << result[v] << std::endl; } #ifdef PROFILING VLOG(2) << "preprocess_time: " << preprocess_time << "s."; @@ -71,7 +73,7 @@ class PageRankContext : public ContextBase { } typename FRAG_T::template vertex_array_t degree; - typename FRAG_T::template vertex_array_t result; + typename FRAG_T::template vertex_array_t& result; typename FRAG_T::template vertex_array_t next_result; #ifdef PROFILING diff --git a/examples/analytical_apps/pagerank/pagerank_local_context.h b/examples/analytical_apps/pagerank/pagerank_local_context.h index 352653a7..d337e64f 100644 --- a/examples/analytical_apps/pagerank/pagerank_local_context.h +++ b/examples/analytical_apps/pagerank/pagerank_local_context.h @@ -27,18 +27,23 @@ namespace grape { * @tparam FRAG_T */ template -class PageRankLocalContext : public ContextBase { +class PageRankLocalContext : public VertexDataContext { using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; public: - void Init(const FRAG_T& frag, BatchShuffleMessageManager& messages, + explicit PageRankLocalContext(const FRAG_T& fragment) + : VertexDataContext(fragment, true), + result(this->data()) {} + + void Init(BatchShuffleMessageManager& messages, double delta, int max_round) { - auto inner_vertices = frag.InnerVertices(); + auto &frag = this->fragment(); auto vertices = frag.Vertices(); + this->delta = delta; this->max_round = max_round; - result.Init(vertices); + result.SetValue(0.0); next_result.Init(vertices); avg_degree = static_cast(frag.GetEdgeNum()) / @@ -46,7 +51,8 @@ class PageRankLocalContext : public ContextBase { step = 0; } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto& frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); for (auto v : inner_vertices) { os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) @@ -59,7 +65,7 @@ class PageRankLocalContext : public ContextBase { #endif } - typename FRAG_T::template vertex_array_t result; + typename FRAG_T::template vertex_array_t& result; typename FRAG_T::template vertex_array_t next_result; #ifdef PROFILING diff --git a/examples/analytical_apps/pagerank/pagerank_local_parallel.h b/examples/analytical_apps/pagerank/pagerank_local_parallel.h index 432b62ab..80600b7a 100644 --- a/examples/analytical_apps/pagerank/pagerank_local_parallel.h +++ b/examples/analytical_apps/pagerank/pagerank_local_parallel.h @@ -125,6 +125,14 @@ class PageRankLocalParallel ctx.next_result[u] = 1 - ctx.delta + ctx.delta * cur; }); } else { + auto& degree = ctx.degree; + auto& result = ctx.result; + + for (auto v : inner_vertices) { + if (degree[v] != 0) { + result[v] *= degree[v]; + } + } return; } diff --git a/examples/analytical_apps/pagerank/pagerank_local_parallel_context.h b/examples/analytical_apps/pagerank/pagerank_local_parallel_context.h index 55058bed..4f4fb9a4 100644 --- a/examples/analytical_apps/pagerank/pagerank_local_parallel_context.h +++ b/examples/analytical_apps/pagerank/pagerank_local_parallel_context.h @@ -27,19 +27,25 @@ namespace grape { * @tparam FRAG_T */ template -class PageRankLocalParallelContext : public ContextBase { +class PageRankLocalParallelContext : public VertexDataContext { using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; public: - void Init(const FRAG_T& frag, ParallelMessageManager& messages, double delta, + explicit PageRankLocalParallelContext(const FRAG_T& fragment) + : VertexDataContext(fragment, true), + result(this->data()) {} + + void Init(ParallelMessageManager& messages, double delta, int max_round) { + auto &frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); auto vertices = frag.Vertices(); + this->delta = delta; this->max_round = max_round; degree.Init(inner_vertices, 0); - result.Init(vertices, 0.0); + result.SetValue(0.0); next_result.Init(vertices); step = 0; #ifdef PROFILING @@ -49,16 +55,12 @@ class PageRankLocalParallelContext : public ContextBase { #endif } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto& frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); for (auto v : inner_vertices) { - if (degree[v] == 0) { os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) << result[v] << std::endl; - } else { - os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) - << result[v] * degree[v] << std::endl; - } } #ifdef PROFILING VLOG(2) << "preprocess_time: " << preprocess_time << "s."; @@ -68,7 +70,7 @@ class PageRankLocalParallelContext : public ContextBase { } typename FRAG_T::template vertex_array_t degree; - typename FRAG_T::template vertex_array_t result; + typename FRAG_T::template vertex_array_t& result; typename FRAG_T::template vertex_array_t next_result; #ifdef PROFILING diff --git a/examples/analytical_apps/pagerank/pagerank_parallel.h b/examples/analytical_apps/pagerank/pagerank_parallel.h index 4a1e1cf0..7ce2c8f6 100644 --- a/examples/analytical_apps/pagerank/pagerank_parallel.h +++ b/examples/analytical_apps/pagerank/pagerank_parallel.h @@ -107,6 +107,14 @@ class PageRankParallel ++ctx.step; if (ctx.step > ctx.max_round) { + auto& degree = ctx.degree; + auto& result = ctx.result; + + for (auto v : inner_vertices) { + if (degree[v] != 0) { + result[v] *= degree[v]; + } + } return; } diff --git a/examples/analytical_apps/pagerank/pagerank_parallel_context.h b/examples/analytical_apps/pagerank/pagerank_parallel_context.h index 7b20d250..a1f8ac3d 100644 --- a/examples/analytical_apps/pagerank/pagerank_parallel_context.h +++ b/examples/analytical_apps/pagerank/pagerank_parallel_context.h @@ -27,19 +27,25 @@ namespace grape { * @tparam FRAG_T */ template -class PageRankParallelContext : public ContextBase { +class PageRankParallelContext : public VertexDataContext { using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; public: - void Init(const FRAG_T& frag, ParallelMessageManager& messages, double delta, + explicit PageRankParallelContext(const FRAG_T& fragment) + : VertexDataContext(fragment, true), + result(this->data()) {} + + void Init(ParallelMessageManager& messages, double delta, int max_round) { + auto &frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); auto vertices = frag.Vertices(); + this->delta = delta; this->max_round = max_round; degree.Init(inner_vertices, 0); - result.Init(vertices, 0.0); + result.SetValue(0.0); next_result.Init(vertices); step = 0; #ifdef PROFILING @@ -49,16 +55,12 @@ class PageRankParallelContext : public ContextBase { #endif } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto& frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); for (auto v : inner_vertices) { - if (degree[v] == 0) { - os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) - << result[v] << std::endl; - } else { - os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) - << result[v] * degree[v] << std::endl; - } + os << frag.GetId(v) << " " << std::scientific << std::setprecision(15) + << result[v] << std::endl; } #ifdef PROFILING VLOG(2) << "preprocess_time: " << preprocess_time << "s."; @@ -68,7 +70,7 @@ class PageRankParallelContext : public ContextBase { } typename FRAG_T::template vertex_array_t degree; - typename FRAG_T::template vertex_array_t result; + typename FRAG_T::template vertex_array_t& result; typename FRAG_T::template vertex_array_t next_result; #ifdef PROFILING diff --git a/examples/analytical_apps/sssp/sssp_auto_context.h b/examples/analytical_apps/sssp/sssp_auto_context.h index 8d18971b..5a747a6e 100644 --- a/examples/analytical_apps/sssp/sssp_auto_context.h +++ b/examples/analytical_apps/sssp/sssp_auto_context.h @@ -29,15 +29,20 @@ namespace grape { * @tparam FRAG_T */ template -class SSSPAutoContext : public ContextBase { +class SSSPAutoContext : public VertexDataContext { public: using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; - void Init(const FRAG_T& frag, AutoParallelMessageManager& messages, - oid_t source_id) { - this->source_id = source_id; + explicit SSSPAutoContext(const FRAG_T& fragment) + : VertexDataContext(fragment), + partial_result(this->data()) {} + + void Init(AutoParallelMessageManager& messages, oid_t source_id) { + auto &frag = this->fragment(); auto vertices = frag.Vertices(); + + this->source_id = source_id; partial_result.Init(vertices, std::numeric_limits::max(), [](double* lhs, double rhs) { if (*lhs > rhs) { @@ -51,10 +56,11 @@ class SSSPAutoContext : public ContextBase { MessageStrategy::kSyncOnOuterVertex); } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { // If the distance is the max value for vertex_data_type // then the vertex is not connected to the source vertex. // According to specs, the output should be +inf + auto &frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); for (auto v : inner_vertices) { double d = partial_result[v]; diff --git a/examples/analytical_apps/sssp/sssp_context.h b/examples/analytical_apps/sssp/sssp_context.h index db165da5..3cd0bd4d 100644 --- a/examples/analytical_apps/sssp/sssp_context.h +++ b/examples/analytical_apps/sssp/sssp_context.h @@ -30,17 +30,21 @@ namespace grape { * @tparam FRAG_T */ template -class SSSPContext : public ContextBase { +class SSSPContext : public VertexDataContext { public: using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; - void Init(const FRAG_T& frag, ParallelMessageManager& messages, + explicit SSSPContext(const FRAG_T& fragment) + : VertexDataContext(fragment, true), + partial_result(this->data()) {} + + void Init(ParallelMessageManager& messages, oid_t source_id) { - this->source_id = source_id; - auto vertices = frag.Vertices(); - partial_result.Init(vertices, std::numeric_limits::max()); + auto &frag = this->fragment(); + this->source_id = source_id; + partial_result.SetValue(std::numeric_limits::max()); curr_modified.init(frag.GetVerticesNum()); next_modified.init(frag.GetVerticesNum()); @@ -51,10 +55,11 @@ class SSSPContext : public ContextBase { #endif } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { // If the distance is the max value for vertex_data_type // then the vertex is not connected to the source vertex. // According to specs, the output should be +inf + auto &frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); for (auto v : inner_vertices) { double d = partial_result[v]; @@ -73,7 +78,7 @@ class SSSPContext : public ContextBase { } oid_t source_id; - typename FRAG_T::template vertex_array_t partial_result; + typename FRAG_T::template vertex_array_t& partial_result; Bitset curr_modified, next_modified; diff --git a/examples/analytical_apps/wcc/wcc_auto.h b/examples/analytical_apps/wcc/wcc_auto.h index 4da3c969..7b7e6f04 100644 --- a/examples/analytical_apps/wcc/wcc_auto.h +++ b/examples/analytical_apps/wcc/wcc_auto.h @@ -42,7 +42,6 @@ class WCCAuto : public AutoAppBase> { INSTALL_AUTO_WORKER(WCCAuto, WCCAutoContext, FRAG_T) using vertex_t = typename fragment_t::vertex_t; using vid_t = typename fragment_t::vid_t; - using cid_t = typename WCCAutoContext::cid_t; void PEval(const fragment_t& frag, context_t& ctx) { auto inner_vertices = frag.InnerVertices(); @@ -156,7 +155,7 @@ class WCCAuto : public AutoAppBase> { for (auto& v : inner_vertices) { if (ctx.global_cluster_id.IsUpdated(v)) { - cid_t tag = ctx.global_cluster_id.GetValue(v); + auto tag = ctx.global_cluster_id.GetValue(v); vid_t comp_id = ctx.local_comp_id[v]; if (ctx.global_comp_id[comp_id] > tag) { ctx.global_comp_id[comp_id] = tag; diff --git a/examples/analytical_apps/wcc/wcc_auto_context.h b/examples/analytical_apps/wcc/wcc_auto_context.h index 65c0b89c..30ee1ab1 100644 --- a/examples/analytical_apps/wcc/wcc_auto_context.h +++ b/examples/analytical_apps/wcc/wcc_auto_context.h @@ -23,25 +23,33 @@ limitations under the License. namespace grape { +#ifdef WCC_USE_GID +template +using WCCAutoContextType = VertexDataContext; +#else +template +using WCCAutoContextType = VertexDataContext; +#endif + /** * @brief Context for the auto-parallel version of WCCAuto. * * @tparam FRAG_T */ template -class WCCAutoContext : public ContextBase { +class WCCAutoContext : public WCCAutoContextType { using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; using vertex_t = typename FRAG_T::vertex_t; + using cid_t = typename WCCAutoContextType::data_t; public: -#ifdef WCC_USE_GID - using cid_t = vid_t; -#else - using cid_t = oid_t; -#endif + explicit WCCAutoContext(const FRAG_T& fragment) + : WCCAutoContextType(fragment, true), + global_cluster_id(this->data()) {} - void Init(const FRAG_T& frag, AutoParallelMessageManager& messages) { + void Init(AutoParallelMessageManager& messages) { + auto& frag = this->fragment(); auto vertices = frag.Vertices(); auto inner_vertices = frag.InnerVertices(); @@ -59,7 +67,8 @@ class WCCAutoContext : public ContextBase { MessageStrategy::kSyncOnOuterVertex); } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto& frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); for (auto v : inner_vertices) { os << frag.GetId(v) << " " << global_cluster_id.GetValue(v) << std::endl; diff --git a/examples/analytical_apps/wcc/wcc_context.h b/examples/analytical_apps/wcc/wcc_context.h index 363c2f2d..bd928913 100644 --- a/examples/analytical_apps/wcc/wcc_context.h +++ b/examples/analytical_apps/wcc/wcc_context.h @@ -25,21 +25,24 @@ namespace grape { * @tparam FRAG_T */ template -class WCCContext : public ContextBase { +class WCCContext : public VertexDataContext { public: using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; - void Init(const FRAG_T& frag, ParallelMessageManager& messages) { - auto vertices = frag.Vertices(); + explicit WCCContext(const FRAG_T& fragment) + : VertexDataContext(fragment, true), + comp_id(this->data()) {} - comp_id.Init(vertices); + void Init(ParallelMessageManager& messages) { + auto &frag = this->fragment(); curr_modified.init(frag.GetVerticesNum()); next_modified.init(frag.GetVerticesNum()); } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto& frag = this->fragment(); auto inner_vertices = frag.InnerVertices(); for (auto v : inner_vertices) { os << frag.GetId(v) << " " << comp_id[v] << std::endl; @@ -51,7 +54,7 @@ class WCCContext : public ContextBase { #endif } - typename FRAG_T::template vertex_array_t comp_id; + typename FRAG_T::template vertex_array_t& comp_id; Bitset curr_modified, next_modified; diff --git a/examples/gnn_sampler/append_only_edgecut_fragment.h b/examples/gnn_sampler/append_only_edgecut_fragment.h index 7b5b4360..83e068cf 100644 --- a/examples/gnn_sampler/append_only_edgecut_fragment.h +++ b/examples/gnn_sampler/append_only_edgecut_fragment.h @@ -285,6 +285,8 @@ class AppendOnlyEdgecutFragment using oid_t = OID_T; using vdata_t = VDATA_T; using edata_t = EDATA_T; + template + using vertex_array_t = VertexArray; using nbr_space_iter_impl = NbrSpaceIterImpl; using nbr_mapspace_iter_impl = NbrMapSpaceIterImpl; diff --git a/examples/gnn_sampler/sampler.h b/examples/gnn_sampler/sampler.h index d3d7569b..0b57abf5 100644 --- a/examples/gnn_sampler/sampler.h +++ b/examples/gnn_sampler/sampler.h @@ -139,6 +139,19 @@ class Sampler : public ParallelAppBase>, if (cur_hop <= ctx.nums_of_hop.size()) { ctx.random_cache = for_caches; messages.ForceContinue(); + } else { + auto& random_result = ctx.random_result; + auto& ctx_data = ctx.data(); + vertex_t v; + + for (auto& it : random_result) { + CHECK(frag.Gid2Vertex(it.first, v)); + auto& oids = ctx_data[v]; + + for (auto gid : it.second) { + oids.push_back(frag.Gid2Oid(gid)); + } + } } #ifdef PROFILING ctx.time_inceval_gen_send_msg += GetCurrentTime(); diff --git a/examples/gnn_sampler/sampler_context.h b/examples/gnn_sampler/sampler_context.h index 2c2c6909..86698178 100644 --- a/examples/gnn_sampler/sampler_context.h +++ b/examples/gnn_sampler/sampler_context.h @@ -29,14 +29,21 @@ limitations under the License. namespace grape { template -class SamplerContext : public grape::ContextBase { +class SamplerContext + : public VertexDataContext> { using oid_t = typename FRAG_T::oid_t; using vid_t = typename FRAG_T::vid_t; + using vertex_t = typename FRAG_T::vertex_t; public: - void Init(const FRAG_T& frag, ParallelMessageManager& messages, - const std::string& strategy, const std::string& sampler_hop_and_num, + explicit SamplerContext(const FRAG_T& fragment) + : VertexDataContext>( + fragment) {} + + void Init(ParallelMessageManager& messages, const std::string& strategy, + const std::string& sampler_hop_and_num, const std::vector& queries) { + auto& frag = this->fragment(); #ifdef PROFILING time_init -= GetCurrentTime(); #endif @@ -74,8 +81,11 @@ class SamplerContext : public grape::ContextBase { #endif } - void Output(const FRAG_T& frag, std::ostream& os) { + void Output(std::ostream& os) override { + auto& frag = this->fragment(); auto t_begin = grape::GetCurrentTime(); + vertex_t v; + for (auto& it : random_result) { std::stringstream ss; ss << frag.Gid2Oid(it.first); diff --git a/grape/app/context_base.h b/grape/app/context_base.h index ff210e6c..a4359c82 100644 --- a/grape/app/context_base.h +++ b/grape/app/context_base.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef GRAPE_APP_CONTEXT_BASE_H_ #define GRAPE_APP_CONTEXT_BASE_H_ -#include #include +#include + +#include "grape/types.h" namespace grape { @@ -27,11 +29,12 @@ namespace grape { * during supersteps. * */ -template class ContextBase { public: - ContextBase() {} - virtual ~ContextBase() {} + ContextBase() = default; + virtual ~ContextBase() = default; + + virtual const char* context_type() const = 0; /** * @brief Output function to implement for result output. @@ -42,7 +45,7 @@ class ContextBase { * @param frag * @param os */ - virtual void Output(const FRAG_T& frag, std::ostream& os) {} + virtual void Output(std::ostream& os) {} }; } // namespace grape diff --git a/grape/app/vertex_data_context.h b/grape/app/vertex_data_context.h new file mode 100644 index 00000000..4a818886 --- /dev/null +++ b/grape/app/vertex_data_context.h @@ -0,0 +1,58 @@ +/** Copyright 2020 Alibaba Group Holding Limited. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef GRAPE_APP_VERTEX_DATA_CONTEXT_H_ +#define GRAPE_APP_VERTEX_DATA_CONTEXT_H_ + +#include "grape/app/context_base.h" +#include "grape/utils/vertex_array.h" + +#define CONTEXT_TYPE_VERTEX_DATA "vertex_data" + +namespace grape { + +template +class VertexDataContext : public ContextBase { + using fragment_t = FRAG_T; + using vertex_t = typename fragment_t::vertex_t; + using vertex_array_t = typename fragment_t::template vertex_array_t; + + public: + using data_t = DATA_T; + + explicit VertexDataContext(const fragment_t& fragment, + bool including_outer = false) + : fragment_(fragment) { + if (including_outer) { + data_.Init(fragment.Vertices()); + } else { + data_.Init(fragment.InnerVertices()); + } + } + + const fragment_t& fragment() { return fragment_; } + + const char* context_type() const override { return CONTEXT_TYPE_VERTEX_DATA; } + + inline vertex_array_t& data() { return data_; } + + private: + const fragment_t& fragment_; + vertex_array_t data_; +}; + +} // namespace grape + +#endif // GRAPE_APP_VERTEX_DATA_CONTEXT_H_ diff --git a/grape/grape.h b/grape/grape.h index da9fb6df..4ef8ec0a 100644 --- a/grape/grape.h +++ b/grape/grape.h @@ -20,6 +20,7 @@ limitations under the License. #include "grape/app/batch_shuffle_app_base.h" #include "grape/app/context_base.h" #include "grape/app/parallel_app_base.h" +#include "grape/app/vertex_data_context.h" #include "grape/parallel/auto_parallel_message_manager.h" #include "grape/parallel/batch_shuffle_message_manager.h" #include "grape/parallel/default_message_manager.h" diff --git a/grape/parallel/sync_buffer.h b/grape/parallel/sync_buffer.h index 5a141a08..af847f7b 100644 --- a/grape/parallel/sync_buffer.h +++ b/grape/parallel/sync_buffer.h @@ -29,7 +29,7 @@ namespace grape { */ class ISyncBuffer { public: - virtual ~ISyncBuffer() {} + virtual ~ISyncBuffer() = default; virtual void* data() = 0; @@ -53,9 +53,10 @@ class ISyncBuffer { template class SyncBuffer : public ISyncBuffer { public: - SyncBuffer() {} - explicit SyncBuffer(VertexRange range) - : data_(range), updated_(range, false), range_(range) {} + SyncBuffer() : data_(internal_data_) {} + + explicit SyncBuffer(VertexArray& data) + : data_(data) {} bool updated(size_t begin, size_t length) const override { auto iter = updated_.begin() + begin; @@ -67,6 +68,7 @@ class SyncBuffer : public ISyncBuffer { } return false; } + void* data() override { return reinterpret_cast(&data_[range_.begin()]); } @@ -118,7 +120,8 @@ class SyncBuffer : public ISyncBuffer { } private: - VertexArray data_; + VertexArray internal_data_; + VertexArray& data_; VertexArray updated_; VertexRange range_; diff --git a/grape/utils/default_allocator.h b/grape/utils/default_allocator.h index e79b2d92..b1e4033d 100644 --- a/grape/utils/default_allocator.h +++ b/grape/utils/default_allocator.h @@ -17,6 +17,7 @@ limitations under the License. #define GRAPE_UTILS_DEFAULT_ALLOCATOR_H_ #include +#define ALLOC_ALIGNMENT 64 namespace grape { @@ -46,7 +47,10 @@ class DefaultAllocator { #ifdef __APPLE__ return static_cast(malloc(__n * sizeof(_Tp))); #else - return static_cast(aligned_alloc(64, __n * sizeof(_Tp))); + return static_cast(aligned_alloc( + ALLOC_ALIGNMENT, (__n * sizeof(_Tp) / ALLOC_ALIGNMENT + + (__n * sizeof(_Tp) % ALLOC_ALIGNMENT == 0 ? 0 : 1)) * + ALLOC_ALIGNMENT)); #endif } diff --git a/grape/utils/vertex_array.h b/grape/utils/vertex_array.h index a60220a8..82f2135c 100644 --- a/grape/utils/vertex_array.h +++ b/grape/utils/vertex_array.h @@ -169,7 +169,7 @@ class VertexArray : public Array> { fake_start_ = Base::data() - range_.begin().GetValue(); } - ~VertexArray() {} + ~VertexArray() = default; void Init(const VertexRange& range) { Base::clear(); diff --git a/grape/worker/auto_worker.h b/grape/worker/auto_worker.h index fa833f29..46ddb0f4 100644 --- a/grape/worker/auto_worker.h +++ b/grape/worker/auto_worker.h @@ -58,13 +58,15 @@ class AutoWorker { "The loaded graph is not valid for application"); AutoWorker(std::shared_ptr app, std::shared_ptr graph) - : app_(app), graph_(graph) {} - ~AutoWorker() {} + : app_(app), context_(std::make_shared(*graph)) {} + + ~AutoWorker() = default; void Init(const CommSpec& comm_spec, const ParallelEngineSpec& pe_spec = DefaultParallelEngineSpec()) { + auto& graph = const_cast(context_->fragment()); // prepare for the query - graph_->PrepareToRunApp(APP_T::message_strategy, APP_T::need_split_edges); + graph.PrepareToRunApp(APP_T::message_strategy, APP_T::need_split_edges); comm_spec_ = comm_spec; MPI_Barrier(comm_spec_.comm()); @@ -79,10 +81,11 @@ class AutoWorker { template void Query(Args&&... args) { + auto& graph = context_->fragment(); + MPI_Barrier(comm_spec_.comm()); - context_ = std::make_shared(); - context_->Init(*graph_, messages_, std::forward(args)...); + context_->Init(messages_, std::forward(args)...); int round = 0; @@ -90,7 +93,7 @@ class AutoWorker { messages_.StartARound(); - app_->PEval(*graph_, *context_); + app_->PEval(graph, *context_); messages_.FinishARound(); @@ -104,7 +107,7 @@ class AutoWorker { round++; messages_.StartARound(); - app_->IncEval(*graph_, *context_); + app_->IncEval(graph, *context_); messages_.FinishARound(); @@ -113,17 +116,21 @@ class AutoWorker { } ++step; } - MPI_Barrier(comm_spec_.comm()); messages_.Finalize(); } - void Output(std::ostream& os) { context_->Output(*graph_, os); } + std::shared_ptr GetContext() { + return context_; + } + + void Output(std::ostream& os) { + context_->Output(os); + } private: std::shared_ptr app_; - std::shared_ptr graph_; std::shared_ptr context_; message_manager_t messages_; diff --git a/grape/worker/batch_shuffle_worker.h b/grape/worker/batch_shuffle_worker.h index f4765a42..72966cb1 100644 --- a/grape/worker/batch_shuffle_worker.h +++ b/grape/worker/batch_shuffle_worker.h @@ -58,14 +58,15 @@ class BatchShuffleWorker { BatchShuffleWorker(std::shared_ptr app, std::shared_ptr graph) - : app_(app), graph_(graph) {} + : app_(app), context_(std::make_shared(*graph)) {} - virtual ~BatchShuffleWorker() {} + ~BatchShuffleWorker() = default; void Init(const CommSpec& comm_spec, const ParallelEngineSpec& pe_spec = DefaultParallelEngineSpec()) { + auto& graph = const_cast(context_->fragment()); // prepare for the query - graph_->PrepareToRunApp(APP_T::message_strategy, APP_T::need_split_edges); + graph.PrepareToRunApp(APP_T::message_strategy, APP_T::need_split_edges); comm_spec_ = comm_spec; MPI_Barrier(comm_spec_.comm()); @@ -80,10 +81,11 @@ class BatchShuffleWorker { template void Query(Args&&... args) { + auto& graph = context_->fragment(); + MPI_Barrier(comm_spec_.comm()); - context_ = std::make_shared(); - context_->Init(*graph_, messages_, std::forward(args)...); + context_->Init(messages_, std::forward(args)...); int round = 0; @@ -91,7 +93,7 @@ class BatchShuffleWorker { messages_.StartARound(); - app_->PEval(*graph_, *context_, messages_); + app_->PEval(graph, *context_, messages_); messages_.FinishARound(); @@ -105,7 +107,7 @@ class BatchShuffleWorker { round++; messages_.StartARound(); - app_->IncEval(*graph_, *context_, messages_); + app_->IncEval(graph, *context_, messages_); messages_.FinishARound(); @@ -120,11 +122,16 @@ class BatchShuffleWorker { messages_.Finalize(); } - void Output(std::ostream& os) { context_->Output(*graph_, os); } + std::shared_ptr GetContext() { + return context_; + } + + void Output(std::ostream& os) { + context_->Output(os); + } private: std::shared_ptr app_; - std::shared_ptr graph_; std::shared_ptr context_; message_manager_t messages_; diff --git a/grape/worker/parallel_worker.h b/grape/worker/parallel_worker.h index d88a73d3..190653fe 100644 --- a/grape/worker/parallel_worker.h +++ b/grape/worker/parallel_worker.h @@ -56,14 +56,15 @@ class ParallelWorker { "The loaded graph is not valid for application"); ParallelWorker(std::shared_ptr app, std::shared_ptr graph) - : app_(app), graph_(graph) {} + : app_(app), context_(std::make_shared(*graph)) {} - virtual ~ParallelWorker() {} + ~ParallelWorker() = default; void Init(const CommSpec& comm_spec, const ParallelEngineSpec& pe_spec = DefaultParallelEngineSpec()) { + auto& graph = const_cast(context_->fragment()); // prepare for the query - graph_->PrepareToRunApp(APP_T::message_strategy, APP_T::need_split_edges); + graph.PrepareToRunApp(APP_T::message_strategy, APP_T::need_split_edges); comm_spec_ = comm_spec; @@ -77,10 +78,11 @@ class ParallelWorker { template void Query(Args&&... args) { + auto& graph = context_->fragment(); + MPI_Barrier(comm_spec_.comm()); - context_ = std::make_shared(); - context_->Init(*graph_, messages_, std::forward(args)...); + context_->Init(messages_, std::forward(args)...); if (comm_spec_.worker_id() == kCoordinatorRank) { VLOG(1) << "[Coordinator]: Finished Init"; } @@ -91,7 +93,7 @@ class ParallelWorker { messages_.StartARound(); - app_->PEval(*graph_, *context_, messages_); + app_->PEval(graph, *context_, messages_); messages_.FinishARound(); @@ -105,7 +107,7 @@ class ParallelWorker { round++; messages_.StartARound(); - app_->IncEval(*graph_, *context_, messages_); + app_->IncEval(graph, *context_, messages_); messages_.FinishARound(); @@ -114,15 +116,21 @@ class ParallelWorker { } ++step; } + MPI_Barrier(comm_spec_.comm()); messages_.Finalize(); } - void Output(std::ostream& os) { context_->Output(*graph_, os); } + std::shared_ptr GetContext() { + return context_; + } + + void Output(std::ostream& os) { + context_->Output(os); + } private: std::shared_ptr app_; - std::shared_ptr graph_; std::shared_ptr context_; message_manager_t messages_;