Skip to content

Commit

Permalink
[wasm] Implement MINT_SWITCH opcode in jiterpreter (#107423)
Browse files Browse the repository at this point in the history
* Implement MINT_SWITCH opcode (without support for backward jumps)
* Introduce runtime option for max switch size (set to 0 to disable switches)
* Disable trace generation once the trace table fills up, since there's no point to it
  • Loading branch information
kg committed Sep 7, 2024
1 parent a349912 commit 5c4686f
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 46 deletions.
8 changes: 6 additions & 2 deletions src/mono/browser/runtime/jiterpreter-enums.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ export const enum JiterpCounter {
BackBranchesNotEmitted,
ElapsedGenerationMs,
ElapsedCompilationMs,
SwitchTargetsOk,
SwitchTargetsFailed,
}

// keep in sync with jiterpreter.c, see mono_jiterp_get_member_offset
Expand Down Expand Up @@ -127,7 +129,8 @@ export const enum BailoutReason {
Icall,
UnexpectedRetIp,
LeaveCheck,
Switch,
SwitchSize,
SwitchTarget,
}

export const BailoutReasonNames = [
Expand Down Expand Up @@ -158,7 +161,8 @@ export const BailoutReasonNames = [
"Icall",
"UnexpectedRetIp",
"LeaveCheck",
"Switch",
"SwitchSize",
"SwitchTarget",
];

export const enum JitQueue {
Expand Down
60 changes: 59 additions & 1 deletion src/mono/browser/runtime/jiterpreter-support.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,14 @@ type CfgBranch = {
branchType: CfgBranchType;
}

type CfgSegment = CfgBlob | CfgBranchBlockHeader | CfgBranch;
type CfgJumpTable = {
type: "jump-table";
from: MintOpcodePtr;
targets: MintOpcodePtr[];
fallthrough: MintOpcodePtr;
}

type CfgSegment = CfgBlob | CfgBranchBlockHeader | CfgBranch | CfgJumpTable;

export const enum CfgBranchType {
Unconditional,
Expand Down Expand Up @@ -1278,6 +1285,23 @@ class Cfg {
}
}

// It's the caller's responsibility to wrap this in a block and follow it with a bailout!
jumpTable (targets: MintOpcodePtr[], fallthrough: MintOpcodePtr) {
this.appendBlob();
this.segments.push({
type: "jump-table",
from: this.ip,
targets,
fallthrough,
});
// opcode, length, fallthrough (approximate)
this.overheadBytes += 4;
// length of branch depths (approximate)
this.overheadBytes += targets.length;
// bailout for missing targets (approximate)
this.overheadBytes += 24;
}

emitBlob (segment: CfgBlob, source: Uint8Array) {
// mono_log_info(`segment @${(<any>segment.ip).toString(16)} ${segment.start}-${segment.start + segment.length}`);
const view = source.subarray(segment.start, segment.start + segment.length);
Expand Down Expand Up @@ -1415,6 +1439,38 @@ class Cfg {
this.blockStack.shift();
break;
}
case "jump-table": {
// Our caller wrapped us in a block and put a missing target bailout after us
const offset = 1;
// The selector was already loaded onto the wasm stack before cfg.jumpTable was called,
// so we just need to generate a br_table
this.builder.appendU8(WasmOpcode.br_table);
this.builder.appendULeb(segment.targets.length);
for (const target of segment.targets) {
const indexInStack = this.blockStack.indexOf(target);
if (indexInStack >= 0) {
modifyCounter(JiterpCounter.SwitchTargetsOk, 1);
this.builder.appendULeb(indexInStack + offset);
} else {
modifyCounter(JiterpCounter.SwitchTargetsFailed, 1);
if (this.trace > 0)
mono_log_info(`Switch target ${target} not found in block stack ${this.blockStack}`);
this.builder.appendULeb(0);
}
}
const fallthroughIndex = this.blockStack.indexOf(segment.fallthrough);
if (fallthroughIndex >= 0) {
modifyCounter(JiterpCounter.SwitchTargetsOk, 1);
this.builder.appendULeb(fallthroughIndex + offset);
} else {
modifyCounter(JiterpCounter.SwitchTargetsFailed, 1);
if (this.trace > 0)
mono_log_info(`Switch fallthrough ${segment.fallthrough} not found in block stack ${this.blockStack}`);
this.builder.appendULeb(0);
}
this.builder.appendU8(WasmOpcode.unreachable);
break;
}
case "branch": {
const lookupTarget = segment.isBackward ? dispatchIp : segment.target;
let indexInStack = this.blockStack.indexOf(lookupTarget),
Expand Down Expand Up @@ -1965,6 +2021,7 @@ export type JiterpreterOptions = {
tableSize: number;
aotTableSize: number;
maxModuleSize: number;
maxSwitchSize: number;
}

const optionNames: { [jsName: string]: string } = {
Expand Down Expand Up @@ -2002,6 +2059,7 @@ const optionNames: { [jsName: string]: string } = {
"tableSize": "jiterpreter-table-size",
"aotTableSize": "jiterpreter-aot-table-size",
"maxModuleSize": "jiterpreter-max-module-size",
"maxSwitchSize": "jiterpreter-max-switch-size",
};

let optionsVersion = -1;
Expand Down
151 changes: 112 additions & 39 deletions src/mono/browser/runtime/jiterpreter-trace-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,34 @@ function getOpcodeLengthU16 (ip: MintOpcodePtr, opcode: MintOpcode) {
}
}

function decodeSwitch (ip: MintOpcodePtr) : MintOpcodePtr[] {
mono_assert(getU16(ip) === MintOpcode.MINT_SWITCH, "decodeSwitch called on a non-switch");
const n = getArgU32(ip, 2);
const result = [];
/*
guint32 val = LOCAL_VAR (ip [1], guint32);
guint32 n = READ32 (ip + 2);
ip += 4;
if (val < n) {
ip += 2 * val;
int offset = READ32 (ip);
ip += offset;
} else {
ip += 2 * n;
}
*/
// mono_log_info(`switch[${n}] @${ip}`);
for (let i = 0; i < n; i++) {
const base = <any>ip + 8 + (4 * i),
offset = getU32_unaligned(base),
target = base + (offset * 2);
// mono_log_info(` ${i} -> ${target}`);
result.push(target);
}

return result;
}

// Perform a quick scan through the opcodes potentially in this trace to build a table of
// backwards branch targets, compatible with the layout of the old one that was generated in C.
// We do this here to match the exact way that the jiterp calculates branch targets, since
Expand All @@ -205,47 +233,60 @@ export function generateBackwardBranchTable (
const opcode = <MintOpcode>getU16(ip);
const opLengthU16 = getOpcodeLengthU16(ip, opcode);

// Any opcode with a branch argtype will have a decoded displacement, even if we don't
// implement the opcode. Everything else will return undefined here and be skipped
const displacement = getBranchDisplacement(ip, opcode);
if (typeof (displacement) !== "number") {
ip += <any>(opLengthU16 * 2);
continue;
}

// These checks shouldn't fail unless memory is corrupted or something is wrong with the decoder.
// We don't want to cause decoder bugs to make the application exit, though - graceful degradation.
if (displacement === 0) {
mono_log_info(`opcode @${ip} branch target is self. aborting backbranch table generation`);
break;
}
if (opcode === MintOpcode.MINT_SWITCH) {
// FIXME: Once the cfg supports back-branches in jump tables, uncomment this to
// insert the back-branch targets into the table so they'll actually work
/*
const switchTable = decodeSwitch(ip);
for (const target of switchTable) {
const rtarget16 = (<any>target - <any>startOfBody) / 2;
if (target < ip)
table.push(rtarget16);
}
*/
} else {
// Any opcode with a branch argtype will have a decoded displacement, even if we don't
// implement the opcode. Everything else will return undefined here and be skipped
const displacement = getBranchDisplacement(ip, opcode);
if (typeof (displacement) !== "number") {
ip += <any>(opLengthU16 * 2);
continue;
}

// Only record *backward* branches
// We will filter this down further in the Cfg because it takes note of which branches it sees,
// but it is also beneficial to have a null table (further down) due to seeing no potential
// back branch targets at all, as it allows the Cfg to skip additional code generation entirely
// if it knows there will never be any backwards branches in a given trace
if (displacement < 0) {
const rtarget16 = rip16 + (displacement);
if (rtarget16 < 0) {
mono_log_info(`opcode @${ip}'s displacement of ${displacement} goes before body: ${rtarget16}. aborting backbranch table generation`);
// These checks shouldn't fail unless memory is corrupted or something is wrong with the decoder.
// We don't want to cause decoder bugs to make the application exit, though - graceful degradation.
if (displacement === 0) {
mono_log_info(`opcode @${ip} branch target is self. aborting backbranch table generation`);
break;
}

// If the relative target is before the start of the trace, don't record it.
// The trace will be unable to successfully branch to it so it would just make the table bigger.
if (rtarget16 >= rbase16)
table.push(rtarget16);
}
// Only record *backward* branches
// We will filter this down further in the Cfg because it takes note of which branches it sees,
// but it is also beneficial to have a null table (further down) due to seeing no potential
// back branch targets at all, as it allows the Cfg to skip additional code generation entirely
// if it knows there will never be any backwards branches in a given trace
if (displacement < 0) {
const rtarget16 = rip16 + (displacement);
if (rtarget16 < 0) {
mono_log_info(`opcode @${ip}'s displacement of ${displacement} goes before body: ${rtarget16}. aborting backbranch table generation`);
break;
}

switch (opcode) {
case MintOpcode.MINT_CALL_HANDLER:
case MintOpcode.MINT_CALL_HANDLER_S:
// While this formally isn't a backward branch target, we want to record
// the offset of its following instruction so that the jiterpreter knows
// to generate the necessary dispatch code to enable branching back to it.
table.push(rip16 + opLengthU16);
break;
// If the relative target is before the start of the trace, don't record it.
// The trace will be unable to successfully branch to it so it would just make the table bigger.
if (rtarget16 >= rbase16)
table.push(rtarget16);
}

switch (opcode) {
case MintOpcode.MINT_CALL_HANDLER:
case MintOpcode.MINT_CALL_HANDLER_S:
// While this formally isn't a backward branch target, we want to record
// the offset of its following instruction so that the jiterpreter knows
// to generate the necessary dispatch code to enable branching back to it.
table.push(rip16 + opLengthU16);
break;
}
}

ip += <any>(opLengthU16 * 2);
Expand Down Expand Up @@ -399,7 +440,7 @@ export function generateWasmBody (

switch (opcode) {
case MintOpcode.MINT_SWITCH: {
if (!emit_switch(builder, ip))
if (!emit_switch(builder, ip, exitOpcodeCounter))
ip = abort;
break;
}
Expand Down Expand Up @@ -4036,7 +4077,39 @@ function emit_atomics (
return false;
}

function emit_switch (builder: WasmBuilder, ip: MintOpcodePtr) : boolean {
append_bailout(builder, ip, BailoutReason.Switch);
function emit_switch (builder: WasmBuilder, ip: MintOpcodePtr, exitOpcodeCounter: number) : boolean {
const lengthU16 = getOpcodeLengthU16(ip, MintOpcode.MINT_SWITCH),
table = decodeSwitch(ip);
let failed = false;

if (table.length > builder.options.maxSwitchSize) {
failed = true;
} else {
// Record all the switch's forward branch targets.
// If it contains any back branches they will bailout at runtime.
for (const target of table) {
if (target > ip)
builder.branchTargets.add(target);
}
}

if (failed) {
modifyCounter(JiterpCounter.SwitchTargetsFailed, table.length);
append_bailout(builder, ip, BailoutReason.SwitchSize);
return true;
}

const fallthrough = <any>ip + (lengthU16 * 2);
builder.branchTargets.add(fallthrough);

// Jump table needs a block so it can `br 0` for missing targets
builder.block();
// Load selector
append_ldloc(builder, getArgU16(ip, 1), WasmOpcode.i32_load);
// Dispatch
builder.cfg.jumpTable(table, fallthrough);
// Missing target
builder.endBlock();
append_exit(builder, ip, exitOpcodeCounter, BailoutReason.SwitchTarget);
return true;
}
16 changes: 13 additions & 3 deletions src/mono/browser/runtime/jiterpreter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export const callTargetCounts: { [method: number]: number } = {};

export let mostRecentTrace: InstrumentedTraceState | undefined;
export let mostRecentOptions: JiterpreterOptions | undefined = undefined;
export let traceTableIsFull = false;

// You can disable an opcode for debugging purposes by adding it to this list,
// instead of aborting the trace it will insert a bailout instead. This means that you will
Expand Down Expand Up @@ -861,6 +862,11 @@ function generate_wasm (
idx = presetFunctionPointer;
} else {
idx = addWasmFunctionPointer(JiterpreterTable.Trace, <any>fn);
if (idx === 0) {
// Failed to add function pointer because trace table is full. Disable future
// trace generation to reduce CPU usage.
traceTableIsFull = true;
}
}
if (trace >= 2)
mono_log_info(`${traceName} -> fn index ${idx}`);
Expand Down Expand Up @@ -984,6 +990,8 @@ export function mono_interp_tier_prepare_jiterpreter (
return JITERPRETER_NOT_JITTED;
else if (mostRecentOptions.wasmBytesLimit <= getCounter(JiterpCounter.BytesGenerated))
return JITERPRETER_NOT_JITTED;
else if (traceTableIsFull)
return JITERPRETER_NOT_JITTED;

let info = traceInfo[index];

Expand Down Expand Up @@ -1078,7 +1086,9 @@ export function jiterpreter_dump_stats (concise?: boolean): void {
traceCandidates = getCounter(JiterpCounter.TraceCandidates),
bytesGenerated = getCounter(JiterpCounter.BytesGenerated),
elapsedGenerationMs = getCounter(JiterpCounter.ElapsedGenerationMs),
elapsedCompilationMs = getCounter(JiterpCounter.ElapsedCompilationMs);
elapsedCompilationMs = getCounter(JiterpCounter.ElapsedCompilationMs),
switchTargetsOk = getCounter(JiterpCounter.SwitchTargetsOk),
switchTargetsFailed = getCounter(JiterpCounter.SwitchTargetsFailed);

const backBranchHitRate = (backBranchesEmitted / (backBranchesEmitted + backBranchesNotEmitted)) * 100,
tracesRejected = cwraps.mono_jiterp_get_rejected_trace_count(),
Expand All @@ -1089,8 +1099,8 @@ export function jiterpreter_dump_stats (concise?: boolean): void {
mostRecentOptions.directJitCalls ? `direct jit calls: ${directJitCallsCompiled} (${(directJitCallsCompiled / jitCallsCompiled * 100).toFixed(1)}%)` : "direct jit calls: off"
) : "";

mono_log_info(`// jitted ${bytesGenerated} bytes; ${tracesCompiled} traces (${(tracesCompiled / traceCandidates * 100).toFixed(1)}%) (${tracesRejected} rejected); ${jitCallsCompiled} jit_calls; ${entryWrappersCompiled} interp_entries`);
mono_log_info(`// cknulls eliminated: ${nullChecksEliminatedText}, fused: ${nullChecksFusedText}; back-branches ${backBranchesEmittedText}; ${directJitCallsText}`);
mono_log_info(`// jitted ${bytesGenerated}b; ${tracesCompiled} traces (${(tracesCompiled / traceCandidates * 100).toFixed(1)}%) (${tracesRejected} rejected); ${jitCallsCompiled} jit_calls; ${entryWrappersCompiled} interp_entries`);
mono_log_info(`// cknulls pruned: ${nullChecksEliminatedText}, fused: ${nullChecksFusedText}; back-brs ${backBranchesEmittedText}; switch tgts ${switchTargetsOk}/${switchTargetsFailed + switchTargetsOk}; ${directJitCallsText}`);
mono_log_info(`// time: ${elapsedGenerationMs | 0}ms generating, ${elapsedCompilationMs | 0}ms compiling wasm.`);
if (concise)
return;
Expand Down
4 changes: 3 additions & 1 deletion src/mono/mono/mini/interp/jiterpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,9 @@ enum {
JITERP_COUNTER_BACK_BRANCHES_NOT_EMITTED,
JITERP_COUNTER_ELAPSED_GENERATION,
JITERP_COUNTER_ELAPSED_COMPILATION,
JITERP_COUNTER_MAX = JITERP_COUNTER_ELAPSED_COMPILATION
JITERP_COUNTER_SWITCH_TARGETS_OK,
JITERP_COUNTER_SWITCH_TARGETS_FAILED,
JITERP_COUNTER_MAX = JITERP_COUNTER_SWITCH_TARGETS_FAILED
};

#define JITERP_COUNTER_UNIT 100
Expand Down
1 change: 1 addition & 0 deletions src/mono/mono/utils/options-def.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ DEFINE_INT(jiterpreter_table_size, "jiterpreter-table-size", 6 * 1024, "Size of
// FIXME: In the future if we find a way to reduce the number of unique tables we can raise this constant
DEFINE_INT(jiterpreter_aot_table_size, "jiterpreter-aot-table-size", 3 * 1024, "Size of the jiterpreter AOT trampoline function tables")
DEFINE_INT(jiterpreter_max_module_size, "jiterpreter-max-module-size", 4080, "Size limit for jiterpreter generated WASM modules")
DEFINE_INT(jiterpreter_max_switch_size, "jiterpreter-max-switch-size", 24, "Size limit for jiterpreter switch opcodes (0 to disable)")
#endif // HOST_BROWSER

#if defined(TARGET_WASM) || defined(TARGET_IOS) || defined(TARGET_TVOS) || defined (TARGET_MACCAT)
Expand Down

0 comments on commit 5c4686f

Please sign in to comment.