Skip to content

Commit

Permalink
[PROFILER] Theoretical roofline models
Browse files Browse the repository at this point in the history
`tvm.analysis.roofline_analysis` adds estimated roofline performance to a
profiling report. The roofline model measures how close an operator gets
to best possible memory bandwidth or FLOP/s depending on whether it is
memory or compute bound. This computation uses the runtime of the
operator along with two numbers extracted from the TIR code: bytes of
memory touched and number of floating point operations. Because these
numbers are extracted from TIR, they may not be 100% accurate. The best
possible memory bandwidth and FLOP/s are measured by running small
programs that are memory and compute bound respectively.

For now, this function only works with llvm cpu targets, but it should
be possible to extend to GPU targets.
  • Loading branch information
Tristan Konolige committed Apr 19, 2022
1 parent 8f664f5 commit 8a7c2d7
Show file tree
Hide file tree
Showing 7 changed files with 503 additions and 27 deletions.
15 changes: 15 additions & 0 deletions include/tvm/runtime/profiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,21 @@ class CountNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(CountNode, Object);
};

/* \brief A ratio of two things. */
class RatioNode : public Object {
public:
/* The ratio as a floating point number. */
double ratio;

/* \brief Construct a new ratio.
* \param a The ratio.
*/
explicit RatioNode(double a) : ratio(a) {}

static constexpr const char* _type_key = "runtime.profiling.Ratio";
TVM_DECLARE_FINAL_OBJECT_INFO(RatioNode, Object);
};

/*! \brief String representation of an array of NDArray shapes
* \param shapes Array of NDArrays to get the shapes of.
* \return A textual representation of the shapes. For example: `float32[2], int64[1, 2]`.
Expand Down
296 changes: 296 additions & 0 deletions python/tvm/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""High level analysis functions"""
import csv
import subprocess
from typing import Dict, Union, Optional
import numpy as np

from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
from ..target import Target
from ..runtime import profiler_vm, profiling, Device, num_threads
from ..script import tir as T


def _create_args(mod, dev, func_name="main"):
args = []
for arg in mod[func_name].params:
args.append(
nd.array(
np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
device=dev,
)
)
return args


def _estimated_features(mod, params, target):
comp = relay.vm.VMCompiler()
mod, params = comp.optimize(mod, params=params, target=target)
return {
prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
for name, prim in mod.functions.items()
if isinstance(prim, tir.PrimFunc)
}


def _vec_width_registers(target, vec_width, num_vector_registers):
if vec_width is None:
if target.device_name == "": # indicates x86
with target:
vec_width = topi.x86.utils.get_simd_32bit_lanes() # in number of float32s
else:
raise RuntimeError(f"Cannot determine vector width for target {target}")
if num_vector_registers is None:
if target.device_name == "": # indicates x86
with target:
num_vector_registers = (
16 # Assuming for all platforms, probably wrong on older ones
)
else:
raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
return vec_width, num_vector_registers


@T.prim_func
def peakflops_fma_tir(
a: T.handle,
vec_width: T.int32,
iters: T.int32,
num_vector_registers: T.int32,
threads: T.int32,
) -> None:
# pylint: disable=invalid-name, missing-function-docstring
N = T.var("int32")
A = T.match_buffer(a, [N], "float32")
assert (
N >= threads * num_vector_registers * vec_width
), "Input vectors must be >= num_vector_registers*vec_width"
for t in T.parallel(threads):
for _j in range(iters):
for l in T.unroll(num_vector_registers):
# We want to use as few registers as possible, so we perform
# all operations on the same element
for k in T.vectorized(vec_width):
A[t * vec_width * num_vector_registers + vec_width * l + k] = (
A[t * vec_width * num_vector_registers + vec_width * l + k]
* A[t * vec_width * num_vector_registers + vec_width * l + k]
+ A[t * vec_width * num_vector_registers + vec_width * l + k]
)


def estimate_peak_fma_flops(
target: Target,
dev: Device,
vec_width: Optional[int] = None,
num_vector_registers: Optional[int] = None,
) -> float:
"""
Estimate the maximum number of FLOP/s this target/device combo is capable
of reaching by running a test program. This assumes vectorized f32 FMA
(fused-multiply-add) instructions.
Parameters
----------
target : Target
Target to run on. This should be as specific to the actual hardware as
possible to make sure that LLVM generates the best vector code.
dev : Device
Device to run on.
vec_width : Optional[int]
Vector width of SIMD units on the underlying hardware. Will try to
infer if no value is provided.
num_vector_registers : Optional[int]
Number of vector registers on the underlying hardware. Will try to
infer if no value is provided.
Returns
-------
float
Approximate sustained FLOP/s of this target/device combo assuming
vectorized f32 FMA instructions.
"""
assert str(target.kind) == "llvm", "Only llvm targets are supported"
vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
iters = 100000
nthreads = num_threads()
specialized = peakflops_fma_tir.specialize(
{
peakflops_fma_tir.params[1]: vec_width,
peakflops_fma_tir.params[2]: iters,
peakflops_fma_tir.params[3]: num_vector_registers,
peakflops_fma_tir.params[4]: nthreads,
}
)
with transform.PassContext(opt_level=3):
f = build(specialized, target=target)
a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
flops = 2 * vec_width * num_vector_registers * nthreads * iters # fma is two flops
flop_s = flops / times.min
return flop_s


@T.prim_func
def peak_bandwidth_tir(a: T.handle, b: T.handle, vec_width: T.int32) -> None:
# pylint: disable=invalid-name, missing-function-docstring
N = T.var("int32")
A = T.match_buffer(a, [N], "float32")
B = T.match_buffer(b, [vec_width * 4], "float32")
# Parallelism is necessary to hit all cores/nodes
for i in T.parallel(N // (vec_width * 4)):
for l in T.unroll(4):
# vectorized load is necessary to hit peak bandwidth
for j in T.vectorized(vec_width):
B[l * vec_width + j] += A[i * vec_width * 4 + l * vec_width + j]


def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
"""Estimate peak memory bandwidth of a target/device combo.
Peak bandwidth is estimated by running a small experiment on the underlying
hardware. The peak bandwidth measurement assumes that vector instructions
are being used to load the data.
Parameters
----------
target : Target
Target to use for measurement. This target should be as specific to the
underlying hardware as possible.
dev : Device
Device to measure peak bandwidth on.
vec_width : Optional[int]
Vector unit width, determined from target if not supplied.
Returns
-------
float
Peak memory bandwidth in bytes/seconds.
"""
# Ideally we'd be able to use this code to measure peak bandwidth of the
# different cache levels. If we could just generate load commands, then we
# could use those in a tight loop. Instead we need some code that is
# limited on the cache bandwidth. With the L1 cache we need an operation
# that has a very low arithmetic intensity and we haven't come up with one
# yet.
vec_width, _ = _vec_width_registers(target, vec_width, 1)
specialized = peak_bandwidth_tir.specialize(
{
peak_bandwidth_tir.params[2]: vec_width,
}
)
with transform.PassContext(opt_level=3):
f = build(specialized, target=target)
size = 10**9 # 4 gigabytes, needs to be larger than last level of cache
a = nd.array(np.ones(size).astype("float32"), device=dev)
b = nd.array(np.ones(vec_width * 4).astype("float32"), device=dev)
times = f.time_evaluator(f.entry_name, dev, repeat=5, number=1)(a, b)
return size * 4 / times.min # 4 bytes per float32


def roofline_analysis(
mod: IRModule, params: Dict[str, nd.NDArray], target: Union[str, Target], dev: Device
) -> profiling.Report:
"""
Create a profiling report that contains roofline and other estimated
statistics from running a module on the VM.
These statistics are calculated by analyzing the lowered TIR of each
operator, so they are estimates of the true values. The statistics are:
- Bound: Is the operator memory or compute bound. This is computed by
assuming that the operator could perfectly cache all loads -- each byte
of memory is only loaded once.
- Percent of Theoretical Optimal: What percent of theoretical optimal for
the bound. i.e. percent of peak memory bandwidth if memory bound,
percent of peak FLOP/s if compute bound.
- Unique Loaded Bytes: estimation of the number of byte loaded not
counting multiple accesses to the same byte.
- Estimated Flops: estimated number of floating point operations.
- Arithmetic Intensity: ratio of FLOPs per byte of data.
- FLOP/s: floating point operations per second.
- Bandwidth: Number of bytes loaded per second.
Parameters
----------
mod : IRModule
Uncompiled input module>
params : Dict[str, nd.NDArray]
target : Union[str, Target]
Target to run on.
dev : Device
Device to run on.
Returns
-------
report : profiling.Report
Profiling report which includes the estimated statistics.
"""
if isinstance(target, str):
target = Target(target)
peak_bandwidth = estimate_peak_bandwidth(target, dev)
peak_flops = estimate_peak_fma_flops(target, dev)

ridge_point = peak_flops / peak_bandwidth

all_features = _estimated_features(mod, params, target)

lib = relay.vm.compile(mod, params=params, target=target)
vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)

args = _create_args(mod, dev)
report = vmexec.profile(*args)
new_calls = []
for call in report.calls:
if "Hash" in call.keys():
_, features = all_features[call["Hash"]]

flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"])
unique_loaded_bytes = 0.0
# assume no more than 100 buffers
for i in range(100):
# We could uses loaded bytes, but that accounts for for L1 cache.
# If we use unique_bytes, then we are looking at how close we come
# to the performance assuming all data is cached perfectly.
key = f"B{i}.unique_bytes"
if not key in features.keys():
break
unique_loaded_bytes += np.sum(features[key])
runtime = call["Duration (us)"].microseconds * 1e-6
arith_inten = flops / unique_loaded_bytes
call = dict(call)
call["Unique Loaded Bytes"] = profiling.Count(int(unique_loaded_bytes))
call["Estimated FLOPs"] = profiling.Count(int(flops))
call["Arithmetic Intensity"] = profiling.Ratio(arith_inten)
call["FLOP/s"] = profiling.Ratio(flops / runtime)
call["Bandwidth"] = profiling.Ratio(unique_loaded_bytes / runtime)
compute_bound = arith_inten > ridge_point
call["Bound"] = "compute" if compute_bound else "memory"
per_mem_bound = (unique_loaded_bytes / runtime) / peak_bandwidth * 100
per_compute_bound = flops / peak_flops * 100.0
# We use ratio here because the percentages should be averaged instead of summed.
call["Percent of Theoretical Optimal"] = profiling.Ratio(
per_compute_bound if compute_bound else per_mem_bound
)
new_calls.append(call)
else:
new_calls.append(call)
return profiling.Report(new_calls, report.device_metrics)
4 changes: 1 addition & 3 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,7 @@ def lower(self, mod, target=None, target_host=None):
target, target_host, target_is_dict_key=False
)

tophub_context = self._tophub_context(target)
with tophub_context:
self._lower(mod, target, target_host)
self._lower(mod, target, target_host)

def codegen(self):
"""Generate the kernel library."""
Expand Down
47 changes: 47 additions & 0 deletions python/tvm/runtime/profiling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ class Report(Object):
Per-device metrics collected over the entire run.
"""

def __init__(
self, calls: Sequence[Dict[str, Object]], device_metrics: Dict[str, Dict[str, Object]]
):
"""Construct a profiling report from a list of metrics and per-device metrics.
Parameters
----------
calls : Sequence[Dict[str, Object]]
Per function call metrics.
device_metrics : Dict[str, Dict[str, Object]]
Per device metrics.
"""
self.__init_handle_by_constructor__(_ffi_api.Report, calls, device_metrics)

def csv(self):
"""Convert this profiling report into CSV format.
Expand Down Expand Up @@ -150,6 +165,38 @@ def from_json(cls, s):
return _ffi_api.FromJSON(s)


@_ffi.register_object("runtime.profiling.Count")
class Count(Object):
"""A integer count of something"""

def __init__(self, count: int):
self.__init_handle_by_constructor__(_ffi_api.Count, count)


@_ffi.register_object("runtime.profiling.Duration")
class Duration(Object):
"""A duration of something"""

def __init__(self, duration: float):
self.__init_handle_by_constructor__(_ffi_api.Duration, duration)


@_ffi.register_object("runtime.profiling.Percent")
class Percent(Object):
"""A Percent of something"""

def __init__(self, percent: float):
self.__init_handle_by_constructor__(_ffi_api.Percent, percent)


@_ffi.register_object("runtime.profiling.Ratio")
class Ratio(Object):
"""A Ratio of two things"""

def __init__(self, ratio: float):
self.__init_handle_by_constructor__(_ffi_api.Ratio, ratio)


@_ffi.register_object("runtime.profiling.MetricCollector")
class MetricCollector(Object):
"""Interface for user defined profiling metric collection."""
Expand Down
Loading

0 comments on commit 8a7c2d7

Please sign in to comment.