Skip to content

Commit

Permalink
Search for int16 accumulation. (apache#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored and hypercubestart committed Mar 12, 2021
1 parent 7834bbd commit 3578471
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 30 deletions.
4 changes: 3 additions & 1 deletion python/tvm/hago/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from .base import qconfig, current_qconfig
from .analysis import inspect_graph_statistic
from .hardware import Hardware, OpDesc, create_accelerator_description
from .search import generate_search_space, search_quantize_strategy, DefaultSetting, BatchedGreedySearchTuner, serialize, deserialize
from .search import generate_search_space, search_quantize_strategy
from .search import DefaultSetting, RandomSearchTuner, GreedySearchTuner, BatchedGreedySearchTuner
from .search import serialize, deserialize
from .quantize import CalibrationDataset, prerequisite_optimize, create_quantizer
from .record import serialize, deserialize, load_from_file, pick_best
11 changes: 11 additions & 0 deletions python/tvm/hago/_op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,17 @@ def realize_addition(node, in_types, out_types):
rhs = relay.cast(rhs, DataType(dtype))
return forward_op(node, [lhs, rhs])

@register_realize("nn.dense")
def realize_dense(node, in_types, out_types):
data, weight = node.args
fields = node.attrs.list_field_info()
attrs_dict = {}
for field in fields:
key = field.name
attrs_dict[str(key)] = getattr(node.attrs, key)
attrs_dict['out_dtype'] = DataType(out_types[0])
attrs = tvm.ir.make_node("relay.attrs.DenseAttrs", **attrs_dict)
return relay.Call(node.op, node.args, attrs, node.type_args)

@register_realize("nn.conv2d")
def realize_conv2d(node, in_types, out_types):
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/hago/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,6 @@ def visit_call(self, node):
def _realize_simulated_quantize(self, node):
data, in_scale, out_scale, clip_min, clip_max = node.args
attrs = node.attrs
# in_scale = to_scalar(in_scale)
# out_scale = to_scalar(out_scale)
in_dtype = attrs.in_dtype
out_dtype = attrs.out_dtype
axis = attrs.axis
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/hago/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def compare_key(m):
key = MeasureKind.enum_to_str(kind)
attr = getattr(m.result, key)
nbit = sum(m.strategy.bits)
return (attr, nbit)
return (attr, -nbit)
return max(measures, key=compare_key)


Expand Down
96 changes: 90 additions & 6 deletions python/tvm/hago/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def _write_to_file(self, fout, measures):
fout.write('\n')

def _update_best_measure(self, measures):
old_measure = self.best_measure
if self.best_measure is None:
self.best_measure = best_measure(measures, self.measure_kind)
else:
Expand All @@ -383,11 +384,12 @@ def _update_best_measure(self, measures):
print(m)
print('best_measure')
print(self.best_measure)
return self.best_measure
updated = (self.best_measure == old_measure)
return updated, self.best_measure

def _measure(self, bits_list):
# support single sample measure and batched measure
# [bits] -> Measure(strategy, MeasureResult)
# [bits] -> [Measure(strategy, MeasureResult)]
results = []
if isinstance(bits_list, list):
groups = _group_same_graph(self.graph, self.hardware, self.topology, bits_list)
Expand All @@ -407,17 +409,99 @@ def _measure(self, bits_list):


class DefaultSetting(Tuner):
def __init__(self, space, objective):
def __init__(self, space, objective, bits=None):
super(DefaultSetting, self).__init__(space, objective, max_trials=1)
if bits is None:
self.bits = [choices[0] for choices in self.space]
else:
self.bits = bits

def has_next(self):
return True

def next_trials(self):
return [[choices[0] for choices in self.space]]
return [self.bits]

def update(self, measures):
ms = self._update_best_measure(measures)
self._update_best_measure(measures)


class RandomSearchTuner(Tuner):
def __init__(self, space, objective, max_trials=None):
if max_trials is None:
max_trials = len(space)
super(RandomSearchTuner, self).__init__(space, objective, max_trials)

def has_next(self):
return True

def next_trials(self):
return [[random.choice(choices) for choices in self.space]]

def update(self, measures):
self._update_best_measure(measures)


class GreedySearchTuner(Tuner):
def __init__(self, space, objective, max_trials=None):
super(GreedySearchTuner, self).__init__(space, objective, max_trials)
self.dim_idx = 0
self.bit_idx = 0
self.decided = []
self.default = [choices[0] for choices in space]

def has_next(self):
return self.dim_idx < len(self.space)

def next_trials(self):
choice = self.space[self.dim_idx][self.bit_idx]
trials = [self.decided + [choice] + self.default[self.dim_idx+1:]]
return trials

def update(self, measures):
updated, best_measure = self._update_best_measure(measures)
self.bit_idx += 1
if measures[0].result.accuracy < best_measure.result.accuracy or \
self.bit_idx >= len(self.space[self.dim_idx]):
# move to next dimension
best_bit = best_measure.strategy.bits[self.dim_idx]
self.decided.append(best_bit)
self.dim_idx += 1
self.bit_idx = 0

def _measure(self, bits_list):
assert len(bits_list) == 1
bits = bits_list[0]
thresholds = threshold_estimate(self.graph, self.topology, self.stats, bits)
quantizer = qtz.Quantizer(self.graph, self.hardware, self.topology, bits, thresholds)
sgraph = quantizer.simulate()
qgraph = quantizer.quantize()
# print('original graph')
# print(self.graph)
# print('simulated graph')
# print(sgraph)
# print('quantized graph')
# print(qgraph)
# lowered_qgraph = relay.qnn.transform.CanonicalizeOps()(tvm.IRModule.from_expr(qgraph))
# print('lowered quantized graph')
# print(lowered_qgraph)
# raise ValueError

runtime = relay.create_executor("graph", ctx=self.ctx, target=self.target).evaluate(qgraph)
input_keys = [str(param.name_hint) for param in qgraph.params]
outputs = []
for batch_id, batch in enumerate(self.dataset):
inputs = {}
for key in input_keys:
assert key in batch
inputs[key] = batch[key]
out = runtime(**inputs)
outputs.append(out)
measure_result = self.measure_func(self.graph, self.dataset, outputs, self.ctx, self.target)
strategy = Strategy(self.model_hash, self.topology, bits, thresholds)
result = Measure(strategy, measure_result)
print(result)
return [result]


class BatchedGreedySearchTuner(Tuner):
Expand All @@ -436,7 +520,7 @@ def next_trials(self):
return trials

def update(self, measures):
ms = self._update_best_measure(measures)
updated, ms = self._update_best_measure(measures)
best_bit = ms.strategy.bits[self.dim_idx]
self.decided.append(best_bit)
self.dim_idx += 1
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/hago/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def threshold_estimate(graph, topology, stats, bits=None):
else:
raise ValueError

print('thresholds: {}'.format(thresholds))
if cfg.round_scale_to_pot:
thresholds = [_round2pot(x) for x in thresholds]

print('thresholds: {}'.format(thresholds))
return thresholds
2 changes: 2 additions & 0 deletions python/tvm/hago/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def set_cond(node, cond):

if isinstance(node, relay.Call):
# print(node.op.name)
# print(hardware.list_integer_descs(node))
# print(hardware.list_float_descs(node))
if not hardware.list_integer_descs(node):
# current op does not support integer computation
set_cond(node, False)
Expand Down
7 changes: 4 additions & 3 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}

const auto input_dtype = data->dtype;
ICHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) ||
input_dtype == DataType::Int(32))
<< "Input type should be one of the quantized types [unit8, int8, int32] but was "
CHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) ||
input_dtype == DataType::Int(16) || input_dtype == DataType::UInt(16) ||
input_dtype == DataType::Int(32))
<< "Input type should be one of the quantized types [unit8, int8, uint16, int16, int32] but was "
<< input_dtype;

const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
Expand Down
7 changes: 4 additions & 3 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

const Array<tvm::PrimExpr> oshape = data->shape;
const DataType out_dtype = quantize_attrs->out_dtype;
ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, unit8, int32] but was " << out_dtype;
CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(16) || out_dtype == DataType::UInt(16) ||
out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, unit8, int16, uint16, int32] but was " << out_dtype;
// assign output type
reporter->Assign(types[3], TensorType(oshape, out_dtype));
return true;
Expand Down
14 changes: 8 additions & 6 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,10 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
}
const auto in_dtype = data->dtype;
ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
in_dtype == DataType::Int(32))
<< "Input type should be one of [int8, uint8, int32] but was " << in_dtype;
CHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
in_dtype == DataType::Int(16) || in_dtype == DataType::UInt(16) ||
in_dtype == DataType::Int(32))
<< "Input type should be one of [int8, uint8, int16, uint16, int32] but was " << in_dtype;

const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
int axis = requantize_attrs->axis;
Expand All @@ -297,9 +298,10 @@ bool RequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const Array<tvm::PrimExpr> oshape = data->shape;
// assign output type
auto out_dtype = requantize_attrs->out_dtype;
ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, uint8, int32] but was " << out_dtype;
CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(16) || out_dtype == DataType::UInt(16) ||
out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, uint8, int16, uint16, int32] but was " << out_dtype;
reporter->Assign(types[5], TensorType(oshape, out_dtype));
return true;
}
Expand Down
32 changes: 25 additions & 7 deletions tests/python/nightly/quantization/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,29 +67,47 @@ def eval_acc(func, dataset, batch_fn, args, var_name, target='cuda', ctx=tvm.gpu
#################
# Quantize helper
#################
def quantize_hago(mod, params, calib_dataset, qconfig=None):
def quantize_hago(mod, params, calib_dataset,
qconfig=None, hardware=None, tuner=None,
target="llvm", ctx=tvm.cpu(), eval_only=False):
if qconfig is None:
qconfig = hago.qconfig(log_file='temp.log')
if hardware is None:
hardware = hago.create_accelerator_description()

with qconfig:
graph = hago.prerequisite_optimize(mod['main'], params=params)
logging.debug('current quantize config')
logging.debug(hago.current_qconfig())
hardware = hago.create_accelerator_description()
space = hago.generate_search_space(graph, hardware)
# tuner = hago.BatchedGreedySearchTuner(space, 'accuracy')
tuner = hago.DefaultSetting(space, 'accuracy')
ctx = tvm.cpu()
strategy, result = hago.search_quantize_strategy(graph, hardware, calib_dataset, tuner, ctx,
target='llvm')
if tuner is None:
tuner = hago.DefaultSetting(space, 'accuracy')
elif isinstance(tuner, list):
tuner = hago.DefaultSetting(space, 'accuracy', tuner)
elif tuner == 'greedy':
tuner = hago.GreedySearchTuner(space, "accuracy")
elif tuner == 'batched':
tuner = hago.BatchedGreedySearchTuner(space, "accuracy")

if eval_only:
record = hago.pick_best(qconfig.log_file, "accuracy")
print(record)
raise ValueError
else:
strategy, result = hago.search_quantize_strategy(graph, hardware, calib_dataset, tuner, ctx, target)
print('strategy')
print(strategy)

quantizer = hago.create_quantizer(graph, hardware, strategy)
simulated_graph = quantizer.simulate()
quantized_graph = quantizer.quantize()
lowered_quantized_graph = relay.qnn.transform.CanonicalizeOps()(tvm.IRModule.from_expr(quantized_graph))
logging.debug('simulated graph')
logging.debug(simulated_graph.astext(show_meta_data=False))
logging.debug('quantize graph')
logging.debug(quantized_graph.astext(show_meta_data=False))
logging.debug('lowered quantized graph')
logging.debug(lowered_quantized_graph.astext(show_meta_data=False))
# hago.inspect_graph_statistic(graph, hardware, strategy, dataset, ctx, target='llvm')
return tvm.IRModule.from_expr(quantized_graph)

Expand Down
Loading

0 comments on commit 3578471

Please sign in to comment.