Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow split #1

Closed
wants to merge 15 commits into from
Closed

Conversation

jiuqi-yang
Copy link

Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.

jingbang.yjb added 2 commits July 22, 2020 17:03
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

Conflicts:
	src/auto_scheduler/compute_dag.cc
	src/auto_scheduler/transform_step.cc
	src/auto_scheduler/transform_step.h
	tests/python/unittest/test_auto_scheduler_loop_state.py
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Copy link
Owner

@jcf94 jcf94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a brief look, fix these comments first. 😃

python/tvm/auto_scheduler/loop_state.py Outdated Show resolved Hide resolved
python/tvm/auto_scheduler/loop_state.py Outdated Show resolved Hide resolved
src/auto_scheduler/loop_state.cc Outdated Show resolved Hide resolved
src/auto_scheduler/loop_state.h Outdated Show resolved Hide resolved
src/auto_scheduler/loop_state.h Show resolved Hide resolved
src/auto_scheduler/transform_step.cc Outdated Show resolved Hide resolved
jcf94 added a commit that referenced this pull request Jul 23, 2020
…generating (apache#5962)

* Code migration Start (#1)

* Init commit: Code migration Start

* Add loop_state.cc/h

* Add ComputeDAG basic test

* Split transform_step out & Update more UTs (apache#3)

* Split transform_step out

* Update GetProducers & GetConsumers

* Update UTs

* Add UT for CacheReadWrite & Some bug fix

* Add search_task, measure and serialization (apache#4)

* Add FollowSplit & FollowFusedSplit tests

* Update dag.InferBound & its UT

* Add search_task, measure and serialization

* Update Serialization UT

* Add MetaTileRewritePolicy (apache#5)

* Add feature

* Add cost_model, meta_tile_rewrite_policy

* Add MetaTileRewritePolicy basic UT

* Basic Python API for State (apache#6)

* Add Basic Python API for State

* Add UTs for State

* Add Python API: Measure & Task (apache#7)

* Update the return value of state operation

* Add task

* Copy measure.py & utils.py

* Fix LocalBuilder

* Fix LocalRunner

* Add ansor.auto_schedule() API; First AutoSchedule working version(apache#8)

* Add basic Python support for ansor.auto_schedule

* Update AutoSchedule API

* Bug fix for get the attach point of a fused iter

* Update UT after infer bug fix

* Bug fix & Add python serialization API (apache#10)

* Delete C++ UT hack since Python is ready

* Add ndarray.non_empty

* Update Serialization python API

* Improve code style, python wrapper and test cases (apache#11)

* Update c++ code style and unit test

* Update python State wrapper and test cases

* fix unit tests

* Add RPCRunner & OpenCL/CUDA test (apache#12)

* Add RPCRunner & OpenCL search test

* Add CUDA search test

* Add RPCRunner test

* rebase to upstream/master

* Add Ansor basic tutorial (apache#13)

* Add basic tutorial

* migrate feature extraction (apache#14)

* Add XGBModel & RPCRunnerWarpper (apache#15)

* Add XGBModel & RPCRunnerWarpper

* Revert "Add Parallel Granularity Mutation"

* Migrate workload_registry.py (apache#16)

* add workload registry

* update

* update

* add task scheduler (apache#17)

* Add conv2d cuda tutorial with workload registry (apache#18)

* add tune_test.py (the old tune_wkl.py) (apache#19)

* add tune_test.py (the old tune_wkl.py)

* update

* fix measure

* fix for gpu

* Code refine for tune_test.py & Add a pre load callback (apache#20)

* Bug fix for tutorials

* Add PreLoadMeasuredStates

* Add search_callback support for task tuner

* Code refine for tune_test.py

* Update

* Update

* Update

* Update

* Bug fix

* Add python custom sketch rule (apache#21)

* Add custom sketch rule

* Bug fix

* Ansor Relay Integration (without layout rewrite) (apache#22)

* relay integration

* Add tune_op_subgraph.py & Some code clean for tune_network.py (apache#23)

* Add single op tune scripts

* Add tune subgraph support

* Merge all op & all subgraph to one file

* Rename file

* add explicit_unroll_max_extent (apache#25)

* Add Index simplification & API update (apache#26)

* Add vectorized cooperative_fetching test

* Update math simplify for vectorized CF

* File rename

* Update tune_network

* API update

* Update PreLoadMeasuredStates & Some bug fix (apache#27)

* Add a threading wrapper to fix the test bug

* Set default TVM_USE_AUTO_SCHEDULER to false

* Update PreLoadMeasuredStates callback

* Add tensorize step for loop_state (apache#31)

* Add tensorize step

* State python api update (apache#33)

* Start to update api

* Add compute_dag to state

* API update

* kernel layout rewrite (apache#28)

* kernel layout rewrite

* remove some hacks

* add defuse_ops pass and move kernel_layout_rewrite pass after fuse_ops pass

* set TVM_RELAY_DISABLE_BUILD_CACHE for task extraction and prepare_layout_rewrite

* [cache flush] port cache flush to ansor (apache#32)

* Improve relay integration (apache#34)

* tmp checkpoint

* Improve relay integration

* Improve relay integration

* Fix xgb error & Simplify dispatcher (apache#35)

* Rename "MetaTileRewritePolicy" to "SketchPolicy". (apache#36)

* Rename "MetaTileRewritePolicy" to "SketchPolicy".

* Add a new class for auto_unroll_max_step, storage_offset in StageNode

* fix tune_op_subgraph.py

* rebase

* Migrate all node::make to noderef's construct function (apache#37)

* Start to move xxxnode::make to noderef()

* Update

* Update

* Finish transform_step

* Finish comute dag & auto schedule

* Update

* Update

* Update

* Update

* Update

* Code refine

* Code refine

* Code refine

* Update

* Update

* Some lint fix & Recover the double constructor of tvm::PrimExpr (apache#39)

* lint fix

* clang-format-fix

* pylint fix

* Update

* Recover the double constructor of tvm::PrimExpr

* Fix pylint

* pylint fix

* pylint fix

* Add MutateComputeLocation and MutateParallel in evolutionary search (apache#40)

* Add MutateComputeLocation and MutateParallel in evolutionary search

* fix lint

* Improve loop state python API (stage_tensors -> stage_ops) (apache#41)

* improve loop state python API (stage_tensors -> stage_ops)

* fix

* ComputeDAG bug fix & Add Custom TensorCore Matmul Example (apache#42)

* Bug Fix

* Sample example of Custom TensorCore Matmul

* Rever Commits, Start to build minimum Ansor system

* Code clean for minimum Ansor system

* Bug fix & Delete AccessAnalyzer

* Delete attachmap & Code clean

* Doc update

Update statenode::stages from vector to Array

* Headfile update & Python doc update

* clang-format fix

* pylint fix

* Update

* Doc update

* Update

* Bug fix after code merge to the new master

* clang-format fix

* Update

* Update

* Update std::vector to Array; Update verbosity setting; Some commemts
addressed

* std::vector->Array & std::string->String

* Add init_state to ComputeDAG

* Update

* Update some unordered_map to Map

* clang-format fix

* Comments addressed
Delete ReplayAndInferBound
Delete ReplaySteps & InferBoundCommon

* Lint fix

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Rename ansor namespace to auto_schedule

* Update

* Rename ThreadPool to ParallelFor

* Add parallel_for

* Remove ThreadPool

* Update python/tvm/auto_schedule/auto_schedule.py

* trigger CI

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: Minmin Sun (孙敏敏) <minmin.smm@alibaba-inc.com>
Co-authored-by: Zhao Wu <zhaowu@apache.org>
jingbang.yjb added 2 commits July 23, 2020 11:57
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
* (i.e. Follow another split step) */
class FollowSplitStepNode : public StepNode {
public:
int iter_id; // The id of the iter to split
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move these comments to /*! \brief ... */, also for other class members.


void WriteToRecord(dmlc::JSONWriter* writer) const final;

void ExtractSplitLengths(const Array<Step>& transform_steps,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add /*! \brief ... \param ... */ doc string for those member functions.


String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const;
static constexpr const char* record_prefix_str = "FSP";
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add blank lines in front and back of this line, keep the same as other class do.

static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
};

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for the next three classes.

@@ -239,6 +239,27 @@ Iterator State::vectorize(int stage_id, const Iterator& it) {
return step->ApplyToState(this);
}

Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put these two functions right behind the State::split().
And be care of the order in other positions, they should all be put right behind the split step. (e.g. TVM_REGISTER_GLOBAL in this file, follow_split/follow_fused_split in loop_state.h, classes in transform_steps.h)

Comment on lines 307 to 325
"""
Schedule primitive corresponds to te.follow_split.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to split.
src_step_id : Int
The index of the split step to follow in the history.
n_split : Int
The number of split level.

Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Schedule primitive corresponds to te.follow_split.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to split.
src_step_id : Int
The index of the split step to follow in the history.
n_split : Int
The number of split level.
Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""
""" Schedule primitive extends to split step.
This step is used to follow a former SplitStep, keeps their iterator structures to be same.
Example cases:
With subgraph: Dense -> Relu
Some tiling structures are used in Relu stage and we intend to compute the Dense
stage at Relu.
The follow_split is used here to keep their outer most few iterators the same for
applying compute at.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to split.
src_step_id : int
The index of the split step to follow in the history.
n_split : int
The number of split level.
Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""

Comment on lines 335 to 353
"""
Schedule primitive corresponds to te.follow_fused_split.
Parameters
----------
iterator : Iterator
The iterator to split.
src_step_ids : List[int]
The indices of the split steps to follow in the history.
level : Int
Use the length in this split level.
factor_or_nparts : Bool
True to use `factor` for split from inner to outer,
False to use `nparts` for split from outer to inner.

Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Schedule primitive corresponds to te.follow_fused_split.
Parameters
----------
iterator : Iterator
The iterator to split.
src_step_ids : List[int]
The indices of the split steps to follow in the history.
level : Int
Use the length in this split level.
factor_or_nparts : Bool
True to use `factor` for split from inner to outer,
False to use `nparts` for split from outer to inner.
Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""
""" Schedule primitive extends to split step.
This step is used to follow several former SplitSteps and FuseSteps.
Example cases:
With subgraph in GPU schedule: Input -> Dense
for i.0@j.0 = ... : Bind to blockIdx.x
for i.1@j.1 = ... : Bind to threadIdx.x
for i.2@j.2 = ...
Input_shared = Input ...
for k = ...
Dense = ...
We intend to apply cooperative fetching with the Input stage, while the threadIdx.x
axis is binded to a iterator generated by split & fuse step.
The follow_fused_step is used here to figure out the final extent of the threadIdx.x
binded iterator.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to split.
src_step_ids : List[int]
The indices of the split steps to follow in the history.
level : int
Use the length in this split level.
factor_or_nparts : bool
True to use `factor` for split from inner to outer,
False to use `nparts` for split from outer to inner.
Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""

jingbang.yjb added 11 commits July 23, 2020 17:34
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
1. delete a comment
2. add "fuse" between follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
@jcf94
Copy link
Owner

jcf94 commented Jul 29, 2020

Thanks.

@jcf94 jcf94 closed this Jul 29, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants