Skip to content

Commit

Permalink
[GraphBolt] Reuse CachePolicy::Query partition result. (#7608)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jul 29, 2024
1 parent def2a1b commit f989f9d
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 63 deletions.
92 changes: 61 additions & 31 deletions graphbolt/src/partitioned_cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@ PartitionedCachePolicy::PartitionedCachePolicy(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
PartitionedCachePolicy::Partition(torch::Tensor keys) {
const int64_t num_parts = policies_.size();
torch::Tensor offsets = torch::zeros(
torch::Tensor offsets = torch::empty(
num_parts * num_parts + 1, keys.options().dtype(torch::kInt64));
auto offsets_ptr = offsets.data_ptr<int64_t>();
std::memset(offsets_ptr, 0, offsets.size(0) * offsets.element_size());
auto indices = torch::empty_like(keys, keys.options().dtype(torch::kInt64));
auto part_id = torch::empty_like(keys, keys.options().dtype(torch::kInt32));
const auto num_keys = keys.size(0);
auto part_id_ptr = part_id.data_ptr<int32_t>();
auto offsets_ptr = offsets.data_ptr<int64_t>();
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "PartitionedCachePolicy::partition", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
Expand Down Expand Up @@ -123,18 +124,26 @@ PartitionedCachePolicy::Partition(torch::Tensor keys) {
}

std::tuple<
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
PartitionedCachePolicy::Query(torch::Tensor keys) {
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
auto [positions, output_indices, missing_keys, found_pointers] =
policies_[0]->Query(keys);
auto found_offsets = torch::empty(2, found_pointers.options());
auto found_offsets_ptr = found_offsets.data_ptr<int64_t>();
found_offsets_ptr[0] = 0;
found_offsets_ptr[1] = found_pointers.size(0);
return {
positions, output_indices, missing_keys, found_pointers, found_offsets};
auto found_and_missing_offsets = torch::empty(4, found_pointers.options());
auto found_and_missing_offsets_ptr =
found_and_missing_offsets.data_ptr<int64_t>();
// Found offsets part.
found_and_missing_offsets_ptr[0] = 0;
found_and_missing_offsets_ptr[1] = found_pointers.size(0);
// Missing offsets part.
found_and_missing_offsets_ptr[2] = 0;
found_and_missing_offsets_ptr[3] = missing_keys.size(0);
auto found_offsets = found_and_missing_offsets.slice(0, 0, 2);
auto missing_offsets = found_and_missing_offsets.slice(0, 2);
return {positions, output_indices, missing_keys,
found_pointers, found_offsets, missing_offsets};
};
torch::Tensor offsets, indices, permuted_keys;
std::tie(offsets, indices, permuted_keys) = Partition(keys);
Expand Down Expand Up @@ -176,7 +185,11 @@ PartitionedCachePolicy::Query(torch::Tensor keys) {
torch::Tensor found_pointers = torch::empty(
positions.size(0),
std::get<3>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
auto missing_offsets =
torch::empty(policies_.size() + 1, result_offsets_tensor.options());
auto output_indices_ptr = output_indices.data_ptr<int64_t>();
auto missing_offsets_ptr = missing_offsets.data_ptr<int64_t>();
missing_offsets_ptr[0] = 0;
gb::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
const auto tid = begin;
Expand All @@ -200,6 +213,7 @@ PartitionedCachePolicy::Query(torch::Tensor keys) {
num_selected * found_pointers.element_size());
begin = result_offsets[policies_.size() + tid];
end = result_offsets[policies_.size() + tid + 1];
missing_offsets[tid + 1] = end - result_offsets[policies_.size()];
const auto num_missing = end - begin;
for (int64_t i = 0; i < num_missing; i++) {
output_indices_ptr[begin + i] =
Expand All @@ -211,35 +225,44 @@ PartitionedCachePolicy::Query(torch::Tensor keys) {
std::get<2>(results[tid]).data_ptr(),
num_missing * missing_keys.element_size());
});
auto found_offsets = result_offsets_tensor.slice(0, 0, policies_.size() + 1);
return std::make_tuple(
positions, output_indices, missing_keys, found_pointers,
result_offsets_tensor.slice(0, 0, policies_.size() + 1));
positions, output_indices, missing_keys, found_pointers, found_offsets,
missing_offsets);
}

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>
PartitionedCachePolicy::QueryAsync(torch::Tensor keys) {
return async([=] {
auto
[positions, output_indices, missing_keys, found_pointers,
found_offsets] = Query(keys);
return std::vector{
positions, output_indices, missing_keys, found_pointers, found_offsets};
[positions, output_indices, missing_keys, found_pointers, found_offsets,
missing_offsets] = Query(keys);
return std::vector{positions, output_indices, missing_keys,
found_pointers, found_offsets, missing_offsets};
});
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
PartitionedCachePolicy::Replace(torch::Tensor keys) {
PartitionedCachePolicy::Replace(
torch::Tensor keys, torch::optional<torch::Tensor> offsets) {
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
auto [positions, pointers] = policies_[0]->Replace(keys);
auto offsets = torch::empty(2, pointers.options());
auto offsets_ptr = offsets.data_ptr<int64_t>();
offsets_ptr[0] = 0;
offsets_ptr[1] = pointers.size(0);
return {positions, pointers, offsets};
if (!offsets.has_value()) {
offsets = torch::empty(2, pointers.options());
auto offsets_ptr = offsets->data_ptr<int64_t>();
offsets_ptr[0] = 0;
offsets_ptr[1] = pointers.size(0);
}
return {positions, pointers, *offsets};
}
const auto offsets_provided = offsets.has_value();
torch::Tensor indices, permuted_keys;
if (!offsets_provided) {
std::tie(offsets, indices, permuted_keys) = Partition(keys);
} else {
permuted_keys = keys;
}
torch::Tensor offsets, indices, permuted_keys;
std::tie(offsets, indices, permuted_keys) = Partition(keys);
auto output_positions = torch::empty_like(
keys, keys.options()
.dtype(torch::kInt64)
Expand All @@ -248,8 +271,8 @@ PartitionedCachePolicy::Replace(torch::Tensor keys) {
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto offsets_ptr = offsets.data_ptr<int64_t>();
auto indices_ptr = indices.data_ptr<int64_t>();
auto offsets_ptr = offsets->data_ptr<int64_t>();
auto indices_ptr = offsets_provided ? nullptr : indices.data_ptr<int64_t>();
auto output_positions_ptr = output_positions.data_ptr<int64_t>();
auto output_pointers_ptr = output_pointers.data_ptr<int64_t>();
namespace gb = graphbolt;
Expand All @@ -269,22 +292,29 @@ PartitionedCachePolicy::Replace(torch::Tensor keys) {
}
auto positions_ptr = positions.data_ptr<int64_t>();
const auto off = tid * capacity_ / policies_.size();
for (int64_t i = 0; i < positions.size(0); i++) {
output_positions_ptr[indices_ptr[begin + i]] = positions_ptr[i] + off;
if (indices_ptr) {
for (int64_t i = 0; i < positions.size(0); i++) {
output_positions_ptr[indices_ptr[begin + i]] = positions_ptr[i] + off;
}
} else {
std::transform(
positions_ptr, positions_ptr + positions.size(0),
output_positions_ptr + begin, [off](auto x) { return x + off; });
}
auto pointers_ptr = pointers.data_ptr<int64_t>();
std::copy(
pointers_ptr, pointers_ptr + pointers.size(0),
output_pointers_ptr + begin);
});
return {output_positions, output_pointers, offsets};
return {output_positions, output_pointers, *offsets};
}

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>
PartitionedCachePolicy::ReplaceAsync(torch::Tensor keys) {
PartitionedCachePolicy::ReplaceAsync(
torch::Tensor keys, torch::optional<torch::Tensor> offsets) {
return async([=] {
auto [positions, pointers, offsets] = Replace(keys);
return std::vector{positions, pointers, offsets};
auto [positions, pointers, offsets_out] = Replace(keys, offsets);
return std::vector{positions, pointers, offsets_out};
});
}

Expand Down
23 changes: 13 additions & 10 deletions graphbolt/src/partitioned_cache_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,18 @@ class PartitionedCachePolicy : public torch::CustomClassHolder {
* @brief The policy query function.
* @param keys The keys to query the cache.
*
* @return (positions, indices, missing_keys, found_ptrs, found_offsets),
* where positions has the locations of the keys which were found in the
* cache, missing_keys has the keys that were not found and indices is defined
* such that keys[indices[:positions.size(0)]] gives us the keys for the found
* pointers and keys[indices[positions.size(0):]] is identical to
* missing_keys. The found_offsets tensor holds the partition offsets for the
* found pointers.
* @return (positions, indices, missing_keys, found_ptrs, found_offsets,
* missing_offsets), where positions has the locations of the keys which were
* found in the cache, missing_keys has the keys that were not found and
* indices is defined such that keys[indices[:positions.size(0)]] gives us the
* keys for the found pointers and keys[indices[positions.size(0):]] is
* identical to missing_keys. The found_offsets tensor holds the partition
* offsets for the found pointers. The missing_offsets holds the partition
* offsets for the missing_keys.
*/
std::tuple<
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
Query(torch::Tensor keys);

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> QueryAsync(
Expand All @@ -75,16 +77,17 @@ class PartitionedCachePolicy : public torch::CustomClassHolder {
/**
* @brief The policy replace function.
* @param keys The keys to query the cache.
* @param offsets The partition offsets for the keys.
*
* @return (positions, pointers, offsets), where positions holds the locations
* of the replaced entries in the cache, pointers holds the CacheKey pointers
* for the inserted keys and offsets holds the partition offsets for pointers.
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> Replace(
torch::Tensor keys);
torch::Tensor keys, torch::optional<torch::Tensor> offsets);

c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> ReplaceAsync(
torch::Tensor keys);
torch::Tensor keys, torch::optional<torch::Tensor> offsets);

template <bool write>
void ReadingWritingCompletedImpl(
Expand Down
24 changes: 19 additions & 5 deletions python/dgl/graphbolt/impl/cpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,15 @@ def read(self, ids: torch.Tensor = None):
"""
if ids is None:
return self._fallback_feature.read()
values, missing_index, missing_keys = self._feature.query(ids)
(
values,
missing_index,
missing_keys,
missing_offsets,
) = self._feature.query(ids)
missing_values = self._fallback_feature.read(missing_keys)
values[missing_index] = missing_values
self._feature.replace(missing_keys, missing_values)
self._feature.replace(missing_keys, missing_values, missing_offsets)
return values

def read_async(self, ids: torch.Tensor):
Expand Down Expand Up @@ -133,6 +138,7 @@ def read_async(self, ids: torch.Tensor):
missing_keys,
found_pointers,
found_offsets,
missing_offsets,
) = policy_future.wait()
self._feature.total_queries += ids.shape[0]
self._feature.total_miss += missing_keys.shape[0]
Expand All @@ -144,7 +150,9 @@ def read_async(self, ids: torch.Tensor):
values_from_cpu_copy_event = torch.cuda.Event()
values_from_cpu_copy_event.record()

positions_future = policy.replace_async(missing_keys)
positions_future = policy.replace_async(
missing_keys, missing_offsets
)

fallback_reader = self._fallback_feature.read_async(missing_keys)
for _ in range(
Expand Down Expand Up @@ -239,12 +247,15 @@ def wait(self):
missing_keys,
found_pointers,
found_offsets,
missing_offsets,
) = policy_future.wait()
self._feature.total_queries += ids.shape[0]
self._feature.total_miss += missing_keys.shape[0]
values_future = cache.query_async(positions, index, ids.shape[0])

positions_future = policy.replace_async(missing_keys)
positions_future = policy.replace_async(
missing_keys, missing_offsets
)

fallback_reader = self._fallback_feature.read_async(missing_keys)
for _ in range(
Expand Down Expand Up @@ -310,12 +321,15 @@ def wait(self):
missing_keys,
found_pointers,
found_offsets,
missing_offsets,
) = policy_future.wait()
self._feature.total_queries += ids.shape[0]
self._feature.total_miss += missing_keys.shape[0]
values_future = cache.query_async(positions, index, ids.shape[0])

positions_future = policy.replace_async(missing_keys)
positions_future = policy.replace_async(
missing_keys, missing_offsets
)

fallback_reader = self._fallback_feature.read_async(missing_keys)
for _ in range(
Expand Down
21 changes: 13 additions & 8 deletions python/dgl/graphbolt/impl/feature_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ def query(self, keys):
Returns
-------
tuple(Tensor, Tensor, Tensor)
A tuple containing (values, missing_indices, missing_keys) where
tuple(Tensor, Tensor, Tensor, Tensor)
A tuple containing
(values, missing_indices, missing_keys, missing_offsets) where
values[missing_indices] corresponds to cache misses that should be
filled by quering another source with missing_keys. If keys is
pinned, then the returned values tensor is pinned as well.
pinned, then the returned values tensor is pinned as well. The
missing_offsets tensor has the partition offsets of missing_keys.
"""
self.total_queries += keys.shape[0]
(
Expand All @@ -82,25 +84,28 @@ def query(self, keys):
missing_keys,
found_pointers,
found_offsets,
missing_offsets,
) = self._policy.query(keys)
values = self._cache.query(positions, index, keys.shape[0])
self._policy.reading_completed(found_pointers, found_offsets)
self.total_miss += missing_keys.shape[0]
missing_index = index[positions.size(0) :]
return values, missing_index, missing_keys
return values, missing_index, missing_keys, missing_offsets

def replace(self, keys, values):
def replace(self, keys, values, offsets=None):
"""Inserts key-value pairs into the cache using the selected caching
policy algorithm to remove old key-value pairs if it is full.
Parameters
----------
keys: Tensor
keys : Tensor
The keys to insert to the cache.
values: Tensor
values : Tensor
The values to insert to the cache.
offsets : Tensor, optional
The partition offsets of the keys.
"""
positions, pointers, offsets = self._policy.replace(keys)
positions, pointers, offsets = self._policy.replace(keys, offsets)
self._cache.replace(positions, values)
self._policy.writing_completed(pointers, offsets)

Expand Down
Loading

0 comments on commit f989f9d

Please sign in to comment.