Skip to content

Commit

Permalink
wip reverse cache-read/write
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Sep 21, 2022
1 parent 69583bc commit 88b74d4
Show file tree
Hide file tree
Showing 9 changed files with 548 additions and 0 deletions.
25 changes: 25 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,31 @@ class ScheduleNode : public runtime::Object {
*/
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) = 0;
/*!
* \brief Create a block that reads a buffer region into a read cache. It requires:
* 1) There is at most one block who writes the buffer in the scope.
* 2) The scope block have stage-pipeline property.
* Compared to cache read, the index mapping was performed at producer rather than consumer.
* \param block_rv The consumer block of the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope.
* \return The cache stage block.
*/
virtual BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) = 0;
/*!
* \brief Create a block that writes a buffer region into a write cache. It requires:
* 1) There is only one block who writes the target buffer.
* 2) The scope block have stage-pipeline property.
* Compared to cache write, the index mapping was performed at producer rather than consumer.
* \param block_rv The producer of the buffer
* \param write_buffer_index The index of the buffer in block's write region
* \param storage_scope The target storage scope
* \return The cache stage block.
*/
virtual BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) = 0;

/******** Schedule: Compute location ********/
/*!
* \brief Move a producer block under the specific loop, and regenerate the
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,18 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member
self, block, write_buffer_index, storage_scope
)

@type_checked
def reverse_cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) -> BlockRV:
return _ffi_api.ScheduleReverseCacheRead( # type: ignore # pylint: disable=no-member
self, block, read_buffer_index, storage_scope
)

@type_checked
def reverse_cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: str) -> BlockRV:
return _ffi_api.ScheduleReverseCacheWrite( # type: ignore # pylint: disable=no-member
self, block, write_buffer_index, storage_scope
)

########## Schedule: Compute location ##########

Expand Down
21 changes: 21 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,27 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff
return CreateRV<BlockRV>(result);
}

BlockRV ConcreteScheduleNode::ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::ReverseCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope);
TVM_TIR_SCHEDULE_END("reverse-cache-read", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

BlockRV ConcreteScheduleNode::ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result =
tir::ReverseCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope);
TVM_TIR_SCHEDULE_END("reverse-cache-write", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

/******** Schedule: Compute location ********/

void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
Expand Down
4 changes: 4 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ class ConcreteScheduleNode : public ScheduleNode {
const String& storage_scope) override;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) override;
BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
Expand Down
26 changes: 26 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,32 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r
*/
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
const String& storage_scope);
/*!
* \brief Create a block that reads a buffer region into a read cache. It requires:
* 1) There is at most one block who writes the buffer in the scope.
* 2) The scope block have stage-pipeline property.
* Compared to cache read, the index mapping was performed at producer instead of consumer.
* \param self The state of the schedule
* \param block_sref The consumer block of the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope.
* \return The cache stage block.
*/
TVM_DLL StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref,
int read_buffer_index, const String& storage_scope);
/*!
* \brief Create a block that writes a buffer region into a write cache. It requires:
* 1) There is only one block that writes the target buffer.
* 2) The scope block have stage-pipeline property.
* Compared to cache write, the index mapping was performed at producer instead of consumer.
* \param self The state of the schedule
* \param block_sref The producer of the buffer
* \param write_buffer_index The index of the buffer in block's write region
* \param storage_scope The target storage scope
* \return The cache stage block.
*/
TVM_DLL StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref,
int write_buffer_index, const String& storage_scope);

/******** Schedule: Compute location ********/
/*!
Expand Down
105 changes: 105 additions & 0 deletions src/tir/schedule/primitive/reverse_cache_read_write.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "../utils.h"

namespace tvm {
namespace tir {

/******** Error Classes ********/

/******** Helper Functions/Classes ********/

/*! \brief Mutator for ReverseCacheRead */
class ReverseCacheReadRewriter : public StmtExprMutator {};

/*! \brief Mutator for ReverseCacheWrite */
class ReverseCacheWriteRewriter : public StmtExprMutator {};

/******** Implementation ********/

StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index,
const String& storage_scope) {}

StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
const String& storage_scope) {}

/******** Instruction Registration ********/

struct ReverseCacheReadTraits : public UnpackedInstTraits<ReverseCacheReadTraits> {
static constexpr const char* kName = "ReverseCacheRead";
static constexpr bool kIsPure = false;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer read_buffer_index,
String storage_scope) {
return sch->ReverseCacheRead(block, read_buffer_index->value, storage_scope);
}

static String UnpackedAsPython(Array<String> outputs, String block, Integer read_buffer_index,
String storage_scope) {
PythonAPICall py("reverse_cache_read");
py.Input("block", block);
py.Input("read_buffer_index", read_buffer_index->value);
py.Input("storage_scope", storage_scope);
py.SingleOutput(outputs);
return py.Str();
}

template <typename>
friend struct ::tvm::tir::UnpackedInstTraits;
};

struct ReverseCacheWriteTraits : public UnpackedInstTraits<ReverseCacheWriteTraits> {
static constexpr const char* kName = "ReverseCacheWrite";
static constexpr bool kIsPure = false;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer write_buffer_index,
String storage_scope) {
return sch->ReverseCacheWrite(block, write_buffer_index->value, storage_scope);
}

static String UnpackedAsPython(Array<String> outputs, String block, Integer write_buffer_index,
String storage_scope) {
PythonAPICall py("reverse_cache_write");
py.Input("block", block);
py.Input("write_buffer_index", write_buffer_index->value);
py.Input("storage_scope", storage_scope);
py.SingleOutput(outputs);
return py.Str();
}

template <typename>
friend struct ::tvm::tir::UnpackedInstTraits;
};

TVM_REGISTER_INST_KIND_TRAITS(ReverseCacheReadTraits);
TVM_REGISTER_INST_KIND_TRAITS(ReverseCacheWriteTraits);

} // namespace tir
} // namespace tvm
26 changes: 26 additions & 0 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,32 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer
return result;
}

BlockRV TracedScheduleNode::ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) {
BlockRV result =
ConcreteScheduleNode::ReverseCacheRead(block_rv, read_buffer_index, storage_scope);

static const InstructionKind& kind = InstructionKind::Get("ReverseCacheRead");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv},
/*attrs=*/{Integer(read_buffer_index), storage_scope},
/*outputs=*/{result}));
return result;
}

BlockRV TracedScheduleNode::ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) {
BlockRV result =
ConcreteScheduleNode::ReverseCacheWrite(block_rv, write_buffer_index, storage_scope);

static const InstructionKind& kind = InstructionKind::Get("ReverseCacheWrite");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv},
/*attrs=*/{Integer(write_buffer_index), storage_scope},
/*outputs=*/{result}));
return result;
}

/******** Schedule: Compute location ********/

void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
Expand Down
4 changes: 4 additions & 0 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ class TracedScheduleNode : public ConcreteScheduleNode {
const String& storage_scope) final;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) final;
BlockRV ReverseCacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) final;
BlockRV ReverseCacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) final;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
Expand Down
Loading

0 comments on commit 88b74d4

Please sign in to comment.