Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] [VIRTUALDEVICE] Change syntax for device planning and store parameter virtual devices in virtual_device_ field #10352

Merged
merged 32 commits into from
Feb 25, 2022

Conversation

electriclilies
Copy link
Contributor

@electriclilies electriclilies commented Feb 22, 2022

Scope of this PR

Below, I outline a design for representing virtual devices in tree. This PR only implements a small part of the design below-- it changes the representation of the virtual devices for function parameters and is relatively straightforward. However, I do change the syntax through which the function parameter virtual device is specified, and I designed the change to be consistent with what I want to do for let-bound variables in the future. I have provided this document because it provides important context for this PR.


Background

Currently, users can write virtual devices into the program by specifying the function_result_virtual_device, and the function_param_virtual_devices in the function attributes. Users can also introduce annotations using the on_device op, mostly on function arguments, but sometimes on let-bound variables.

Device planning solves the constraints created by the user-specified virtual devices— for each sub-expression, it determines a unique and consistent virtual device. We’ll call this the complete representation.

Motivation

Currently, function parameter virtual devices and the function result virtual device is represented through attributes on the function. Since there is already a text format for attributes on functions, we can represent the function parameter virtual devices and the function result virtual device in the text format by just putting them in the attributes— we get the text representation automatically.

However, once we move the function virtual devices out of the attributes, we no longer get the text representation for free. The most immediate challenge is that the unit tests for the device planner are written in RelayScript— so without rethinking the text format for virtual devices, we can’t run the unit tests.

A second motivation is that the current text representation is clunky— for the user to specify the virtual device of a subexpr, they need to wrap that expression in an on_device op. In the current implementation, let-bound variables and also arguments to functions must be wrapped in on_device . While I am rethinking the text format, I’d like to rethink this as well.

Goal

The goal for the text representation is a RelayScript program that preserves all virtual device information from device planning in a minimal representation. The minimal representation is the least amount of information (subexprs assigned virtual devices) we need to reconstruct the device planned program using simple lexical scoping rules.

Additionally, I want to be able to reconstruct the complete representation from the minimal representation without using device planning itself. This will let us express the expected result of running PlanDevices in RelayScript without running PlanDevices itself.

Proposed design

In general, I propose removing the on_device op from the RelayScript representation, and simplifying the way function virtual devices are represented in text. I will

  1. formalize the minimal representation

  2. introduce syntax for the critical virtual devices

  3. introduce a pass that uses simple, lexical scoping rules to expand the minimal representation into the complete representation

The minimal representation of device planning information

The current text format for representing the device planned program in RelayScript preserves the virtual devices for

  1. function parameters and the function result

  2. arguments to functions if the argument device is different from the function result device

  3. let-bound variables

  4. inputs to device_copy

For our minimal representation, we don’t need 4, since the virtual devices of the inputs of device_copy must agree with the virtual devices specified in the device_copy op itself.

In this example from the current text representation, the virtual device specified in %1 is the same as the source virtual device in %2. We can remove the on_device op and not lose any information.

def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32],
          param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][1], meta[VirtualDevice][0]],
          result_virtual_device=meta[VirtualDevice][0]) {
     %0 = add(%x, %y);
     %1 = on_device(%0, virtual_device=meta[VirtualDevice][1], constrain_result=True);
     %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]);
     subtract(%2, %z)
}

So, we’ll define our minimal representation to consist of:

  1. The virtual devices of variable bindings (namely, function parameters and let-bound variables)
  2. The virtual device of function results

We’ll call these the critical virtual devices.

Syntax for the critical virtual devices

Piggy-backing off the current text representation, we’d like to represent the virtual devices for let-bound variables and function parameters directly after the variable definition in a structure that looks like attributes.

Here is an example:

def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(3, 3, 4)], float32],
          virtual_device=meta[VirtualDevice][1] /* result virtual device*/) {
    %0 = split(%x, indices_or_sections=3);
    let %t {virtual_device=meta[VirtualDevice][0]} = %0;
    %2 = %t.1;
    %3 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
    %4 = device_copy(%2, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
	  subtract(%3, %4)
}

The function result virtual device is represented directly in the function’s attributes, just like in the current implementation. We are still using the meta syntax to represent the value of the virtual device function. In the parser, we’ll “promote” the function’s virtual_device attribute to first-class by setting the virtual_device_ field of the function. We won’t put the virtual device in the attributes of the function.

For variable definitions, the virtual device will be represented directly after the type annotation of the variable in text that looks like attributes. We won’t actually add attributes to variable definitions. However, by using the same syntax as attributes, we can reuse the utilities in the parser for parsing attributes to parse the virtual device. (If there are fields other than virtual_device in the fake attributes, the parser will fail). An advantage to this approach is that if we do want to add attributes to bound variables in the future, we don’t need to change our syntax at all.

Expansion of minimal representation / propagation of ‘critical’ virtual devices

We’ll introduce a new pass, called DPL (”device plan lite”), which propagates the ‘critical’ virtual devices and device_copy virtual devices so that every subexpr’s virtual_device_ field is populated. This pass will follow simple lexical propagation rules; if it finds a ‘critical’ virtual device that is not set, it will fail.

Note that in the current implementation, the minimal representation is not expanded into the complete representation at all— rather, during traversal in the DeviceAwareVisitExpr, the visitor keeps track of what the current virtual device is using simple lexical scoping rules. Every time you traverse the program, you must recompute the virtual devices of all the subexpressions. With DPL, we only have to do the propagation once, and we can get rid of DeviceAwareVisitExpr completely.

Implied virtual device rules

The DPL pass will rely on some implied rules to properly flow device planning information, most importantly

  1. The virtual device of an argument to a function is the same as the virtual device of a function parameter

  2. Call ops are either literals, globals and possibly let-bound variables

Note that for number 1, the user will need to insert a device copy op around function arguments whose virtual device is different than the corresponding function parameter’s virtual device. Also note that for number 2, if we have a call to something other than a literal, global or let-bound variable, we will need to re-run device planning completely since DPL won’t be able to reconstruct the complete representation.

Tests in device planning

The test cases will use the new Relay Script syntax and the DPL pass to test device planning.

Let input be the input program (in text format), containing on_device ops and device_copy ops, and expected be our expected output program (in text format), which has the virtual device information for every ‘critical’ virtual device.

Then, let complete_output = DP(parse(input)) , where complete_output is a fully device planned program with all the virtual_device_ fields propagated (the complete representation).

Now, let minimal_expected = parse(expected) be the minimal representation of the fully device planned program. minimal_expected doesn’t have all the virtual_device_ fields propagated, but it does contain enough information to completely reconstruct the device planned program through simple lexical rules. We then use DPL (”device plan lite”) to recreate the original program.

Finally, we can check that complete_output == DPL(minimal_expected). It is also true that parse(print(complete_output)) == minimal_expected.

Note that the DPL pass will also be useful for reconstructing any virtual device information that is removed or not propagated correctly by some other Relay pass. We expect that this may occasionally happen. As long as the ‘critical’ virtual devices are preserved, we can run DPL to get the complete representation.

@electriclilies electriclilies marked this pull request as draft February 22, 2022 20:46
@electriclilies electriclilies marked this pull request as ready for review February 23, 2022 18:10
@electriclilies electriclilies changed the title [DRAFT] Change syntax for device planning and store parameter virtual devices in virtual_device_ field [RELAY] [VIRTUALDEVICE] Change syntax for device planning and store parameter virtual devices in virtual_device_ field Feb 23, 2022
Copy link
Contributor

@mbs-octoml mbs-octoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Lily. A few nits, and two issues we'll need to go around for.

VirtualDevice virtual_device;
if (WhenMatch(TokenType::kLCurly)) {
Map<String, ObjectRef> fake_attrs = ParseAttrs();
VLOG(1) << "Fake attributes for function parameter: " << fake_attrs;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meganit: I've been using VLOG(9) for these super-duper verbose ones since I often just set TVM_LOG_DEBUG="DEFAULT=1" to get an overall debug trace that's still vaguely readable.

@@ -220,9 +220,13 @@ Doc RelayTextPrinter::AllocVar(const Var& var) {
}
Doc val = GetUniqueName("%" + name);
memo_[var] = val;
if (!var->virtual_device()->IsFullyUnconstrained()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any issue with this being used for both param- and let-bound vars even though we don't parse annots for let-bound vars?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess for now, let-bound variables don't have their virtual devices set so it theoretically won't be triggered.. I haven't seen any issues in CI related to this but I could split the function into two if you'd like
Eventually we will annotate let-bound variables and at that point we will have to parse the fake attrs for let bound variables

src/printer/relay_text_printer.cc Outdated Show resolved Hide resolved
VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(func.get(), i);
param_virtual_devices.push_back(param_virtual_device);
param_device_indexes.push_back(GetDeviceIndex(param_virtual_device));
param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we get some payoff at last!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah!!

src/relay/backend/vm/lambda_lift.cc Show resolved Hide resolved
src/relay/ir/expr_functor.cc Show resolved Hide resolved
Copy link
Contributor

@mbs-octoml mbs-octoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That LGTM except:

  • we lost binding new params in Bind
  • suggest slight change to API and name for your new substitute method.

src/relay/ir/expr_functor.cc Show resolved Hide resolved
@@ -528,6 +521,31 @@ TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret)
}
});

Expr SubstituteVars(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about 'SubstituteBoundVars' and make it only accept a Function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@@ -877,7 +876,7 @@ class DeviceDefaulter : public ExprVisitor {
};

/* =============== Phase 3 =============== */

// TODO(@electriclilies): rewrite this comment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just function attributes bullet needs to be updated, the rest still stands.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@masahi masahi merged commit 308d320 into apache:main Feb 25, 2022
@electriclilies electriclilies deleted the new-syntax-virtual-device branch February 25, 2022 22:43
pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
…arameter virtual devices in virtual_device_ field (apache#10352)

* parent 33082e0
author electriclilies <lilyorthsmith@gmail.com> 1643141097 -0800
committer Lily Orth-Smith <lilyorthsmith@gmail.com> 1645560059 -0800

Store function param virtual devices in virtual_device_ field

Fix test_annotation.py and change result_virtual_device to virtual_device

* Change plan devices tests to use the new syntax for function parameters

* Fix free var problem

* Fix attribute parsing if there is virtual device; most device planning tests passgit status

* fixed lambda lifting

* Debugging high order functions -- right now FunctionOnDevice and Bind are mutually recursive. This needs to not be the case.

* tests pass wootgit status

* Remove FunctionOnDevice from device planner

* Don't use MaybeFunctionOnDevice in VM compiler

* Remove MaybeFunctionOnDevice from lambda lifter

* Delete FunctionOnDevice and MaybeFunctionOnDevice!

* Reomve GetFunctionResultVirtualDevice

* Remove GetFunctionParamVirtualDevice

* lint

* lint

* Python formatting

* Remove FunctionOnDevice python test

* Fix bug in binds & debug output

* Fix text printer

* lint

* Remove function on device from fold constant tests

* Mark nits

* Revert behavior of bind

* clean up debug

* Make ExprBinder public interface and use instead of Bind

* Fix lambda lift

* This is broken but not sure how to fix

* passes all device planning tests yay!

* Add substitution helper and use in device planner

* Remove unnecessary check

* Respond to comments

* Update comment
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants