Skip to content

Commit

Permalink
Init TIR schedule and script
Browse files Browse the repository at this point in the history
rebased

[TIR][Schedule] fix reorder/buffer_flatten & finish CPU demo (apache#59)

[CPU DEMO] Update cpu gemm demo and fix bug (apache#58)

* [TIR][Schedule] introduce parallel and fix bugs for cpu demo

* [TIR][Schedule] update cpu demo

* [TIR][Schedule] fix lint

* [TIR][Schedule] fix

rebased

[TIR][Schedule] introduce reduction block and CPU demo (apache#53)

* [TIR] reduction : split_reduction

* [TIR] reduction : split_reduction

* [TIR] reduction : fuse_reduction

* [TIR] reduction : cpu demo

* [TIR] reduction : fix

* [TIR] reduction : pattern detect remains

* [TIR] reduction : pattern detect remains

* [TIR] reduction : pattern match done

* [TIR] reduction : fix lint

* [TIR] reduction : fix

* [TIR] reduction : fix

* [TIR] reduction : fix

* [TIR] reduction : fix

* [TIR] reduction : rebased

* [TIR] reduction : rebased

[TIR][Schedule] introduce cache_read cache_write (apache#54)

* [TIR][Schedule] introduce cache_read cache_write

* [TIR][Schedule] add more comments

* [TIR][Schedule] fix problem and add comments

* [TIR][Schedule] address comments

[TIR] schedule: introduce vectorize, unroll, loop validation (apache#47)

* [TIR] vectorize : basically complete

* [TIR] vectorize&unroll : update comments&unroll

* [TIR] vectorize&unroll : rebased

* [TIR] vectorize, unroll, cpu_demo: done

* [TIR] vectorize, unroll, cpu_demo: simplify

* [TIR] vectorize, unroll, cpu_demo: fix

* [TIR] reduction : rebased

* [TIR] reduction : fix

[TIR][Schedule] fix sref and scopes problem during replace and compute_at (apache#50)

* [TIR][Schedule] fix sref and scopes problem during replace and compute_at

* [TIR][Schedule] fix

* [TIR][Schedule] fix

[TIR][Refactor] move function to ScheduleNode

[TIR] Schedule: introduce primitive compute_at (apache#36)

* [TIR] Schedule: introduce primitive compute_at

* [TIR] Schedule: address comments

* [TIR] Schedule: address comments

* [TIR] Schedule: address comments

* [TIR] Schedule: add check to compute_at

* [TIR] Schedule: address comments

* [TIR] Schedule: address comments

[TIR] Schedule: introduce primitive reorder (apache#37)

* [Schedule] debug

* [TIR] Schedule: reorder, loop type detect remains

* [TIR] reorder complete

* [TIR] reorder complete

* [TIR] fix

* [TIR] reorder : rebased complete

* [TIR] reorder : fix container.h

* [TIR] reorder : fix

* [TIR] reorder : fix

* [TIR] reorder : fix

* [TIR] reorder : simplify

* [TIR] reorder : simplify

* [TIR] reorder : simplify

* [TIR] reorder : fix

* [TIR] reorder : fix

* [TIR] reorder : rebased

* [TIR] reorder : rebased

rebase

[TIR] Schedule: introduce BlockRealize and Block SRef reuse(apache#39)

* [TIR] BlockRealize: schedule refactor

* [TIR] BlockRealize: debug

* [TIR] BlockRealize finish

* [TIR] BlockRealize finish

* [TIR] BlockRealize fix

* [TIR] BlockRealize update test

* [TIR] BlockRealize: add loop var reuse

* [TIR] BlockRealize: add loop var reuse

* [TIR] BlockRealize: fix

* [TIR] BlockRealize: fix

* [TIR] BlockRealize: fix

* [TIR] BlockRealize: fix

* [TIR] BlockRealize: fix

* [TIR] BlockRealize: fix

* [TIR] BlockRealize: fix

* [TIR] BlockRealize: fix

* [TIR] BlockRealize: fix

* [TIR] BlockRealize: fix

[TIR] compare for module (apache#38)

* [TIR] compare for module

* [TIR] fix

* [TIR] fix

* [TIR] fix

* [TIR] fix

* [TIR] fix

* [TIR] fix

[Hybrid] Module init

[Hybrid] Module print

[Hybrid] Module print with meta

[Hybrid] adjust

[Hybrid] finished but without lint and comment check

[Hybrid] fix lint

[Hybrid] comments

[Hybrid] fix script decoration API

[Hybrid] using IRModule

[Hybrid] fix

[Hybrid] adjust API

[Hybrid] fix

[Hybrid] fix

[Hybrid] fix

[Hybrid] fix symbol table, adjust API, introduce meta_mutator and resolve import issue

[Hybrid] fix lint

[TIR] introduce pass BufferFlatten (apache#32)

* [TIR] introduce pass BufferFlatten

* [Tir] add comments & remove old TeLower

* [TIR] split GatherRegion and BufferFlatten to two Visitor/Mutator

* [TIR] address comments: Only consider stmt scope

* [TIR] BufferFlatten: address comments

* [TIR] BufferFlatten: fold BlockFlattener into BufferFlattener

* [TIR] BufferFlatten: add asserts

* [TIR] BufferFlatten: use Equal in testcase

* [TIR] Equal Pass: Enhanced the pass

* [TIR] Equal Pass: add comments

[Hybrid] refactor using Doc, introduce annotation, enhance parser (apache#28)

* [Hybrid] refactor printer, enhance parser

* [Hybrid] refactor

* [Hybrid] fix

* [Hybrid] fix

* [Hybrid] fix namespace issue

* [Hybrid] compare using Equal

[TIR] rebased

[TE] fix replace again and add primitive fuse and split (apache#27)

* [TE] add: schedule primitive fuse

* [TE] add: schedule primitive split

* [TE] address comments: add IRSubstitueInScope and other minor fix

* [TE] address comments: Enhance Equal api and fix split by nparts

* [TE] address comments

[Hybrid] introduce printer (apache#25)

* [Hybrid] substitute Block with SeqStmt, change block() syntax

* [Hybrid] add printer, type declare intrin

* [Hybrid] refactor

* [Hybrid] meta

* [Hybrid] refactor

* [Hybrid] macro

[TE] fix replace (apache#23)

* [TE] fix replace

* [TE] fix replace: add more tests

* [TE] fix replace: add more tests

[TE] rebased

[Hybrid] python syntax parser (apache#20)

* [Hybrid] python syntax parser

* [Hybrid] add a testcase

* [Hybrid] improve comments and fix bugs

* [Hybrid] improve comments, refactor __internal_assert, add new testcases

* [Hybrid] improve error report message, refactor intrin

* [Hybrid] separate ScopeEmitter from parser

* [Hybrid] refactor type check

* [Hybrid] refactor intrin

* [Hybrid] refactor intrin, allow register external functions with argument type checking, add a testcase

* [Hybrid] address comments, fix a bug in te/ir.h

* [Hybrid] remove type check

* [Hybrid] python syntax parser

* [Hybrid] add a testcase

* [Hybrid] improve comments and fix bugs

* [Hybrid] improve comments, refactor __internal_assert, add new testcases

* [Hybrid] improve error report message, refactor intrin

* [Hybrid] separate ScopeEmitter from parser

* [Hybrid] refactor type check

* [Hybrid] refactor intrin

* [Hybrid] refactor intrin, allow register external functions with argument type checking, add a testcase

* [Hybrid] address comments, fix a bug in te/ir.h

* [Hybrid] remove type check

* [Hybrid] refactor intrin, scope_handler, special_stmt

* [Hybrid] address comments

* [Hybrid] clean code, improve error reporting & testcase

* [Hybrid] clean code

* [Hybrid] clean code

[IR] introduce dependency graph and write map

[TE] refactor and clean codebase

[TE] refactor IR

[TE] introduce schedule, dependency graph and support fuse and split (apache#17)

* fix lint

* introduce dependency graph

* enable create schedule

* support get axes

* fix lint

* revert Set

* add schedule primitive fuse

* address comment

* support split

[IR] Introduce SeqStmt

add TeLower pass and enable to run Te IR (apache#15)

* add function data structure
add TeLower pass to transform Te to current IR
enable to run Te IR

* address comments

* unify terminology

TensorIR data structure init (apache#14)

* init te data structure

* finish printer and enhanced ir_builder

* address the comments

Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
  • Loading branch information
Hzfengsy and spectrometerHBH committed Jul 4, 2021
1 parent d17f753 commit 74df784
Show file tree
Hide file tree
Showing 53 changed files with 8,161 additions and 31 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ file(GLOB_RECURSE COMPILER_SRCS
src/driver/*.cc
src/parser/*.cc
src/printer/*.cc
src/api/*.cc
src/hybrid_te/*.cc
src/support/*.cc
)

Expand Down
4 changes: 3 additions & 1 deletion cmake/modules/contrib/HybridDump.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,7 @@
# under the License.

message(STATUS "Build with contrib.hybriddump")
file(GLOB HYBRID_CONTRIB_SRC src/contrib/hybrid/*.cc)
file(GLOB HYBRID_CONTRIB_SRC
src/contrib/hybrid/*.cc
src/contrib/hybrid_tir/*.cc)
list(APPEND COMPILER_SRCS ${HYBRID_CONTRIB_SRC})
1 change: 1 addition & 0 deletions include/tvm/node/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
namespace tvm {

using runtime::Downcast;
using runtime::DowncastPtr;
using runtime::GetRef;
using runtime::make_object;
using runtime::Object;
Expand Down
27 changes: 22 additions & 5 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,18 @@ class TVM_DLL Object {
* \note We use stl style naming to be consistent with known API in shared_ptr.
*/
inline bool unique() const;
/*!
* \return The usage count of the cell.
* \note We use stl style naming to be consistent with known API in shared_ptr.
*/
inline int use_count() const;

/*!
* \return Weather the cell has only one reference
* \note We use stl style naming to be consistent with known API in shared_ptr.
*/
inline bool unique() const;

/*!
* \brief Get the type key of the corresponding index from runtime.
* \param tindex The type index.
Expand Down Expand Up @@ -301,11 +313,6 @@ class TVM_DLL Object {
inline void DecRef();

private:
/*!
* \return The usage count of the cell.
* \note We use stl style naming to be consistent with known API in shared_ptr.
*/
inline int use_count() const;
/*!
* \brief Check of this object is derived from the parent.
* \param parent_tindex The parent type index.
Expand Down Expand Up @@ -802,6 +809,10 @@ inline void Object::DecRef() {

inline int Object::use_count() const { return ref_counter_.load(std::memory_order_relaxed); }

inline bool Object::unique() const {
return use_count() == 1;
}

#else

inline void Object::IncRef() { ++ref_counter_; }
Expand Down Expand Up @@ -893,6 +904,12 @@ inline SubRef Downcast(BaseRef ref) {
return SubRef(std::move(ref.data_));
}

template<typename SubType, typename BaseType>
const SubType* DowncastPtr(BaseType* node) {
if (node->template IsInstance<SubType>()) return static_cast<const SubType*>(node);
return nullptr;
}

} // namespace runtime
} // namespace tvm

Expand Down
265 changes: 265 additions & 0 deletions include/tvm/tir/schedule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_TIR_SCHEDULE_H_
#define TVM_TIR_SCHEDULE_H_
#include <tvm/ir/attrs.h>
#include <tvm/te/tensor.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/scope.h>
#include <tvm/tir/stmt_sref.h>
#include <tvm/tir/function.h>

#include <string>
#include <unordered_map>
#include <utility>
#include <vector>


namespace tvm {
namespace tir {

class Schedule;
class ScheduleNode : public Object {
public:
/*! \brief The function to be scheduled */
PrimFunc func;
/*! \brief The root of schedulable reference tree */
StmtSRef root;
/*!
* \brief The mapping from stmt to its schedulable reference node
* \note This is a hint to improve mutation efficiency
* */
std::unordered_map<const StmtNode*, StmtSRef> stmt2ref;
/*! \brief The block scopes of each block */
std::unordered_map<StmtSRef, Scope, ObjectHash, ObjectEqual> scopes_;

void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("root", &root);
}

/*!
* \brief Create a new schedule
* \param function The function to be scheduled
* \return The schedule
*/
static Schedule Create(PrimFunc function);

/*!
* \brief replace part of AST with new stmt
* \param ref The schedulable reference of the old stmt
* \param target The new stmt
* \param block_sref_map The Sref remapping of blocks
*/
void Replace(StmtSRef ref, Stmt target,
Map<Block, Block> block_sref_map = NullValue<Map<Block, Block> >());

/*!
* \brief Get block from its tag
* \param scope The block scope
* \param tag The query tag
* \return the block schedulable reference list
*/
Array<StmtSRef> GetBlock(const std::string& tag, StmtSRef scope = StmtSRef()) const;

/*!
* \brief Get block from its output tensor
* \param scope The block scope
* \param buffer The query buffer
* \return the block schedulable reference list
*/
Array<StmtSRef> GetBlock(const Buffer& buffer, StmtSRef scope = StmtSRef()) const;

/*!
* \brief Get all blocks in the scope
* \param scope The block scope
* \return the block schedulable reference list
*/
Array<StmtSRef> Blocks(StmtSRef scope) const;

/*!
* \brief Get loops of the block
* \param block The query block
* \return the loop sref list
*/
Array<StmtSRef> GetLoopsInScope(const StmtSRef& block) const;

/*!
* \brief Get the scope of the schedulable reference
* \param node The queried node
* \return the block scope reference
*/
StmtSRef GetScope(StmtSRef node) const;

/*!
* \brief fuse two consecutive loops of one computation.
* \param outer The outer loop
* \param inner The inner loop
* \return the fused loop
*/
StmtSRef fuse(const StmtSRef& outer, const StmtSRef& inner);

/*!
* \brief split a specified loop into two loops by factor.
* \param node The loop to be split
* \param factor The split factor
* \return the loops after splitting
*/
Array<StmtSRef> split(const StmtSRef& node, const PrimExpr& nparts, const PrimExpr& factor);

/*!
* \brief Move the block under the loop and regenerate the
* loops to cover the producing region.
* \param block_sref The block to be moved
* \param loop_sref The target loop
* \return the regenerated loops
* */
void compute_at(const StmtSRef& block_sref, const StmtSRef& loop_sref);

/*!
* \brief vectorize a loop
* \param node the loop to be vectorized
*/
void vectorize(const StmtSRef& node);

/*!
* \brief parallel a loop
* \param node the loop to be paralleled
*/
void parallel(const StmtSRef& node);

/*!
* \brief unroll a loop
* \param node the loop to be unrolled
*/
void unroll(const StmtSRef& node);

/*!
* \brief reorder a list of loops
* \param order the order of loops
*/
void reorder(const Array<StmtSRef>& order);

/*!
* \brief Decompose reduction block_sref into init&update blocks
* \param block_sref the reduction block_sref
* \param loop_sref the position where init block_sref will be
* \return the sref of init block
*/
StmtSRef decompose_reduction(const StmtSRef& block_sref, const StmtSRef& loop_sref);

/*!
* \brief Merge init and reduction block into reduction block
* \param init_sref the init block
* \param update_sref the update block
*/
void merge_reduction(const StmtSRef& init_sref, const StmtSRef& update_sref);

/*!
* \brief Create a cache read of original tensor for readers.
* \param buffer The buffer
* \param storage_scope The storage scope
*/
StmtSRef cache_read(const Buffer& buffer, const std::string& storage_scope);

/*!
* \brief Create a cache write of original tensor, before storing into tensor.
* \param buffer The buffer
* \param storage_scope The storage scope
*/
StmtSRef cache_write(const Buffer& buffer, const std::string& storage_scope);

/*!
* \brief Register a reducer pattern
* \param comm_reducer the reducer pattern to be registered
*/
void register_reducer(const CommReducer& comm_reducer);

/*!
* \brief validate sref tree and scope information
*/
bool ValidateSRef() const;

static constexpr const char* _type_key = "tir.Schedule";
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);

private:
/*! \brief The reducer list for reduction pattern matching */
std::vector<CommReducer> reducers_;

/*!
* \brief Update the sref to make it point to new Block/Loop
* \param sref The outdated sref
* \param stmt The new stmt
*/
void UpdateSRef(StmtSRefNode* sref, const Stmt& stmt);
/*!
* \brief Check the region cover for the single consumer block
*/
bool CheckRegionCover(const StmtSRef& consumer) const;
/*!
* \brief Check whether a sub_tree satisfies the one-way fine-grained data flow check
* \details Suppose a loop tree has several blocks on the leaves.
* We can sort them by DFS order as B1, B2, ...., Bn.
* The subtree satisfies compact data flow if
* - All the blocks are complete
* - Bi doesn't read the buffers that Bi+1, Bi+2, ... Bn will write
* - Suppose Bi reads Bj's output buffer(j < i) and Loop k is the LCA of Bi and
* Bj, Bj's output region covers Bi's input under Loop k
* \note Condition 2 and 3 are global condition of a schedulable IR,
* so it is omitted in the check.
*/
bool IsCompactDataFlow(const StmtSRef& sub_tree) const;
/*!
* \brief Validate Tir, now the ValidateLoops pass contains the following checks
* 1) loop binding validation: a set of binding expressions is valid if and only if
* 1. vi=i, vj=j, vk=k ... (one loop_var binds exactly one block_var)
* 2. if f is a legal binding and g is the binding after we applying `split` on f,
* then g is legal
* 3. if f is a legal binding and g is the binding after we applying `fuse` on f,
* then g is legal
* 2) region cover check: Suppose B is a RAW predecessor of C, Loop k is the LCA of B and
* C, then B's output region covers C's input region under Loop k
* \param func the TirFunction to be validated
*/
void ValidateLoops(PrimFunc function);

/*!
* \brief Help function for checking and mutating loops to do parallel computation
* For now it is only used for vectorize, bind and parallel
* \param node the loop to be annotated
* \param annotation the annotation
*/
void ParallelCompute(const StmtSRef& node, const Annotation& annotation);
};

class Schedule : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Schedule, ObjectRef, ScheduleNode);

ScheduleNode* operator->() {
return static_cast<ScheduleNode*>(ObjectRef::get_mutable());
}
};

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_SCHEDULE_H_
Loading

0 comments on commit 74df784

Please sign in to comment.