Skip to content

Commit

Permalink
Discard timing outliers for faster, more accurate benchmarks
Browse files Browse the repository at this point in the history
Instead of attempting to use JIT stats to restart measurement runs (which never seems to stabilise anyway), and then only sampling for a couple of seconds, discard outliers (more than three standard deviations above the mean), and continue to sample until the 99% confidence interval of the estimated mean is within 1% of it.

Outliers can be discarded retroactively if more than 3 standard deviations above the mean of the next 20 samples, as the worst outliers tend to come at the start of the run, when there are no statistics to go on. Outliers are discarded continuously to ensure the sample error is not biased upwards by them.

Mean and standard deviation are calculated off of exponentially weighted moving averages to increase the ability to spot larger initial outliers without setting the threshold too low and discarding accurate samples. The mean and s.d. used for sample error estimation and display still weights all (non-outlier) samples equally, however.

Benchmarks now typically take around 5s rather than 12s, and include many more samples.
  • Loading branch information
alicederyn committed Apr 2, 2018
1 parent 04f0fd9 commit 595a96e
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 65 deletions.
188 changes: 123 additions & 65 deletions src/test/java/org/alicep/collect/benchmark/BenchmarkRunner.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package org.alicep.collect.benchmark;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.lang.Math.sqrt;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static org.alicep.collect.benchmark.Bytes.bytes;
Expand Down Expand Up @@ -189,13 +191,13 @@ private String title() {

private static class Flavour extends Runner {

private static final double CONFIDENCE_INTERVAL_99_PERCENT = 2.58;
private static Duration MIN_HOT_LOOP_TIME = Duration.ofMillis(50);
private static int MIN_WARMUP_ITERATIONS = 5;
private static Duration MIN_WARMUP_TIME = Duration.ofSeconds(1);
private static Duration MAX_WARMUP_TIME = Duration.ofSeconds(10);
private static int MIN_MEASUREMENT_ITERATIONS = 5;
private static Duration MIN_MEASUREMENT_TIME = Duration.ofSeconds(1);
private static final double TARGET_ERROR = 0.01;

private static final double OUTLIER_EWMAV_WEIGHT = 0.1;
private static final int OUTLIER_WINDOW = 20;
private static final double CONFIDENCE_INTERVAL_99_PERCENT = 2.58;

private final Description description;
private final Supplier<LongUnaryOperator> hotLoopFactory;
Expand Down Expand Up @@ -242,25 +244,18 @@ public void run(RunNotifier notifier) {
System.out.print(config() + ": ");
System.out.flush();

// Start in warmup phase, then move to timing once things have stabilised enough to trust the data.
boolean timing = false;

// Number of times to run the hot loop for
long hotLoopIterations = 1;
// Elapsed time (total time / iterations) for each timed iteration
double[] elapsedTime = new double[(int) (MIN_MEASUREMENT_TIME.toNanos() / MIN_HOT_LOOP_TIME.toNanos()) + 1];

// The time we started warming up
long warmupStartTime = System.nanoTime();

// The time we last saw JIT activity
long noJitStartTime = Long.MAX_VALUE;

// The time we started timing
long timingStartTime = Long.MAX_VALUE;
// Elapsed time (total time / iterations) for each timed iteration
double[] timings = new double[50];

// The time the last hot loop finished
long endTime;
// Elapsed time statistic sources
double tS = 0.0;
double tSS = 0.0;
double id = 0.0;
double ewma = 0.0;
double ewmas = 0.0;

// Memory usage across all timed iterations
long usageBeforeRun = 0;
Expand All @@ -272,65 +267,90 @@ public void run(RunNotifier notifier) {
// The hot loop we are timing
LongUnaryOperator hotLoop = hotLoopFactory.get();

// How many iterations since the last restart
int iterations = 0;
// How many timing samples we've taken
int timingSamples = 0;

// How many memory samples we've taken
int memorySamples = 0;

do {
if (iterations == 0) {
if (memorySamples == 0) {
memoryAllocationMonitor.prepareForBenchmark();
usageBeforeRun = memoryAllocationMonitor.memoryUsed();
monitor.start();
}

if (iterations == 0) {
if (timing) {
timingStartTime = System.nanoTime();
} else {
noJitStartTime = System.nanoTime();
}
if (timingSamples == 0) {
tS = 0.0;
tSS = 0.0;
}
long elapsed = hotLoop.applyAsLong(hotLoopIterations);
endTime = System.nanoTime();

long elapsed = hotLoop.applyAsLong(hotLoopIterations);
if (elapsed < MIN_HOT_LOOP_TIME.toNanos()) {
// Restart if the hot loop did not take enough time running
hotLoopIterations = hotLoopIterations + (hotLoopIterations >> 1) + 1;
iterations = -1;
} else if (timing) {
timingSamples = -1;
memorySamples = -1;
} else {
// Record elapsed time if we're in the timing loop
elapsedTime[iterations] = (double) elapsed / hotLoopIterations;
if (timings.length == timingSamples) {
timings = Arrays.copyOf(timings, timings.length * 2);
}
double iterationTime = (double) elapsed / hotLoopIterations;

// Break out of the loop if we've run enough iterations over enough time
boolean runMinIterations = iterations >= MIN_MEASUREMENT_ITERATIONS;
boolean runMinTime = (endTime - timingStartTime) > MIN_MEASUREMENT_TIME.toNanos();
if (runMinIterations && runMinTime) {
monitor.stop();
usageAfterRun = memoryAllocationMonitor.memoryUsed();
iterations++;
break;
if (timingSamples >= OUTLIER_WINDOW) {
if (isOutlier(iterationTime, ewma / id, ewmas / id, timingSamples)) {
continue;
}
}
} else {
boolean runMaxTime = (endTime - warmupStartTime) > MAX_WARMUP_TIME.toNanos();

// Restart if we saw the JIT trigger, and we're within the maximum warmup time
if (!runMaxTime && monitor.jitMetricChanged()) {
iterations = -1;
timings[timingSamples] = iterationTime;
tS += iterationTime;
tSS += iterationTime * iterationTime;
id = updateEwmav(id, 1.0);
ewma = updateEwmav(ewma, iterationTime);
ewmas = updateEwmav(ewmas, iterationTime * iterationTime);

// Remove old outliers
// We do this as we run so that the sample error calculations do not include erroneous data
if (timingSamples >= OUTLIER_WINDOW) {
int firstIndex = timingSamples - OUTLIER_WINDOW;
for (int index = timingSamples - OUTLIER_WINDOW; index >= firstIndex; index--) {
if (isOutlier(timings, index, timingSamples - index)) {
double value = timings[index];
tS -= value;
tSS -= value * value;
for (int i = index + 1; i < timingSamples; ++i) {
timings[i - 1] = timings[i];
}
timings[timingSamples] = 0.0;
firstIndex = Math.max(index - OUTLIER_WINDOW, 0);
timingSamples--;
}
}
}

// Start timing if we've run enough iterations over enough time since the last JIT.
boolean runMinIterations = iterations >= MIN_WARMUP_ITERATIONS;
boolean runMinTime = (endTime - noJitStartTime) > MIN_WARMUP_TIME.toNanos();
if (runMinIterations && runMinTime) {
timing = true;
iterations = -1;
// Calculate ongoing sample error
double sampleError = sqrt((tSS - tS*tS/(timingSamples + 1)) / ((timingSamples + 1) * timingSamples));
double confidenceInterval = sampleError * CONFIDENCE_INTERVAL_99_PERCENT;
boolean lowSampleError = confidenceInterval * timingSamples < tS * TARGET_ERROR;

// Break out of the loop if we're confident our error is low
boolean enoughSamples = timingSamples >= MIN_MEASUREMENT_ITERATIONS;
if (enoughSamples && lowSampleError) {
monitor.stop();
usageAfterRun = memoryAllocationMonitor.memoryUsed();
timingSamples++;
break;
}
}

iterations++;
timingSamples++;
memorySamples++;
} while (true);

double usagePerLoop = (double) (usageAfterRun - usageBeforeRun) / iterations / hotLoopIterations;
summarize(elapsedTime, iterations, usagePerLoop, monitor);
double memoryUsage = (double)
(usageAfterRun - usageBeforeRun - memoryAllocationMonitor.approximateBaselineError()) / hotLoopIterations;
summarize(tS, tSS, timingSamples, memoryUsage, memorySamples, monitor);
notifier.fireTestFinished(description);
} catch (Throwable t) {
System.out.print(t.getClass().getSimpleName());
Expand All @@ -343,22 +363,60 @@ public void run(RunNotifier notifier) {
}
}

private static void summarize(double[] elapsedTime, int iterations, double memoryUsage, ManagementMonitor monitor) {
String timeSummary = summarizeTime(elapsedTime, iterations);
String memorySummary = bytes((long) memoryUsage).toString();
private static boolean isOutlier(double[] timings, int index, int samples) {
double id = 0.0;
double ewma = 0.0;
double ewmas = 0.0;
for (int i = 0; i < samples; ++i) {
double value = timings[index + i + 1];
checkState(value > 0.0);
id = updateEwmav(id, 1.0);
ewma = updateEwmav(ewma, value);
ewmas = updateEwmav(ewmas, value * value);
}
return isOutlier(timings[index], ewma / id, ewmas / id, samples);
}

/**
* Discard samples more than 3 standard deviations above the mean of the subsequent readings.
*
* <p>About 99.7% of points lie within this range, so this should not be biasing results too
* significantly downwards.
*/
private static boolean isOutlier(double value, double ewma, double ewmas, int samples) {
double ewmasd = sqrt((ewmas - ewma*ewma) * samples/(samples - 1));
return value > ewma + ewmasd * 3;
}

private static double updateEwmav(double mav, double value) {
return OUTLIER_EWMAV_WEIGHT * value + (1 - OUTLIER_EWMAV_WEIGHT) * mav;
}

private static void summarize(
double tS,
double tSS,
int iterations,
double memoryUsage,
int memorySamples,
ManagementMonitor monitor) {
String timeSummary = summarizeTime(tS, tSS, iterations);
String memorySummary = summarizeMemory(memoryUsage, memorySamples);
System.out.println(timeSummary + ", " + memorySummary);
monitor.printIfChanged(System.out);
}

private static String summarizeTime(double[] elapsedTime, int iterations) {
double total = Arrays.stream(elapsedTime).limit(iterations).sum();
private static String summarizeTime(double tS, double tSS, int iterations) {
double total = tS;
double mean = total / iterations;
double totalVariance = Arrays.stream(elapsedTime).limit(iterations).map(a -> a*a - mean*mean).sum();
double sampleError = Math.sqrt(totalVariance / (iterations - 1)) * CONFIDENCE_INTERVAL_99_PERCENT;
String timeSummary = formatNanos(mean) + " (±" + formatNanos(sampleError) + ")";
double sd = sqrt((tSS - tS*tS/iterations) / (iterations - 1));
String timeSummary = formatNanos(mean) + " (±" + formatNanos(sd * CONFIDENCE_INTERVAL_99_PERCENT) + ")";
return timeSummary;
}

private static String summarizeMemory(double memoryUsage, int memorySamples) {
return bytes((long) memoryUsage/memorySamples).toString();
}

public Object config() {
return configuration;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.lang.management.GarbageCollectorMXBean;
import java.lang.management.MemoryPoolMXBean;
import java.lang.management.MemoryUsage;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
Expand Down Expand Up @@ -57,6 +58,8 @@ public static MemoryAllocationMonitor get() {

public abstract long memoryUsed();

public abstract long approximateBaselineError();

private static MemoryAllocationMonitor create() {
List<MemoryPoolMXBean> pools = parallelSweepPools();
if (pools.isEmpty()) {
Expand Down Expand Up @@ -176,6 +179,7 @@ private static long psUsage(Map<String, MemoryUsage> usage) {
@Nullable private final MemoryPoolMXBean survivorSpace;
private final SweepCount sweeps = new SweepCount();
private volatile long reclaimed;
private long baselineError = Long.MAX_VALUE;

public ParallelSweepMemoryAllocationMonitor() {
survivorSpace = getMemoryPoolBean("PS Survivor Space").orElse(null);
Expand Down Expand Up @@ -203,6 +207,33 @@ public long memoryUsed() {
return reclaimed;
}

/**
* Determine the common baseline error in this monitor. Subtracting this from any given sample
* will remove some of the bias from the result.
*
* <p>Sometimes the error is higher; frequently it is lower. Sample until the first quartile
* and the median agree.
*/
@Override
public long approximateBaselineError() {
if (baselineError == Long.MAX_VALUE) {
prepareForBenchmark();
long[] samples = new long[7];
int start = 0;
do {
long lastMemoryUsed = memoryUsed();
for (int i = start; i < samples.length; ++i) {
long memoryUsed = memoryUsed();
samples[i] = memoryUsed - lastMemoryUsed;
lastMemoryUsed = memoryUsed;
}
Arrays.sort(samples);
} while (samples[samples.length / 4] != samples[samples.length / 2]);
baselineError = samples[samples.length / 2];
}
return baselineError;
}

private void sweepAndAwait() {
long sweepsBeforeGc = sweeps.currentSweeps();
System.gc();
Expand Down Expand Up @@ -241,5 +272,10 @@ public void prepareForBenchmark() {
public long memoryUsed() {
return -1;
}

@Override
public long approximateBaselineError() {
return -1;
}
}
}

0 comments on commit 595a96e

Please sign in to comment.