diff --git a/src/mono/browser/runtime/jiterpreter-enums.ts b/src/mono/browser/runtime/jiterpreter-enums.ts index b65315ec4cf9b..3caa16bb61349 100644 --- a/src/mono/browser/runtime/jiterpreter-enums.ts +++ b/src/mono/browser/runtime/jiterpreter-enums.ts @@ -21,6 +21,8 @@ export const enum JiterpCounter { BackBranchesNotEmitted, ElapsedGenerationMs, ElapsedCompilationMs, + SwitchTargetsOk, + SwitchTargetsFailed, } // keep in sync with jiterpreter.c, see mono_jiterp_get_member_offset @@ -127,7 +129,8 @@ export const enum BailoutReason { Icall, UnexpectedRetIp, LeaveCheck, - Switch, + SwitchSize, + SwitchTarget, } export const BailoutReasonNames = [ @@ -158,7 +161,8 @@ export const BailoutReasonNames = [ "Icall", "UnexpectedRetIp", "LeaveCheck", - "Switch", + "SwitchSize", + "SwitchTarget", ]; export const enum JitQueue { diff --git a/src/mono/browser/runtime/jiterpreter-support.ts b/src/mono/browser/runtime/jiterpreter-support.ts index 3ba0a3c6db506..6fe9453662fff 100644 --- a/src/mono/browser/runtime/jiterpreter-support.ts +++ b/src/mono/browser/runtime/jiterpreter-support.ts @@ -1150,7 +1150,14 @@ type CfgBranch = { branchType: CfgBranchType; } -type CfgSegment = CfgBlob | CfgBranchBlockHeader | CfgBranch; +type CfgJumpTable = { + type: "jump-table"; + from: MintOpcodePtr; + targets: MintOpcodePtr[]; + fallthrough: MintOpcodePtr; +} + +type CfgSegment = CfgBlob | CfgBranchBlockHeader | CfgBranch | CfgJumpTable; export const enum CfgBranchType { Unconditional, @@ -1278,6 +1285,23 @@ class Cfg { } } + // It's the caller's responsibility to wrap this in a block and follow it with a bailout! + jumpTable (targets: MintOpcodePtr[], fallthrough: MintOpcodePtr) { + this.appendBlob(); + this.segments.push({ + type: "jump-table", + from: this.ip, + targets, + fallthrough, + }); + // opcode, length, fallthrough (approximate) + this.overheadBytes += 4; + // length of branch depths (approximate) + this.overheadBytes += targets.length; + // bailout for missing targets (approximate) + this.overheadBytes += 24; + } + emitBlob (segment: CfgBlob, source: Uint8Array) { // mono_log_info(`segment @${(segment.ip).toString(16)} ${segment.start}-${segment.start + segment.length}`); const view = source.subarray(segment.start, segment.start + segment.length); @@ -1415,6 +1439,38 @@ class Cfg { this.blockStack.shift(); break; } + case "jump-table": { + // Our caller wrapped us in a block and put a missing target bailout after us + const offset = 1; + // The selector was already loaded onto the wasm stack before cfg.jumpTable was called, + // so we just need to generate a br_table + this.builder.appendU8(WasmOpcode.br_table); + this.builder.appendULeb(segment.targets.length); + for (const target of segment.targets) { + const indexInStack = this.blockStack.indexOf(target); + if (indexInStack >= 0) { + modifyCounter(JiterpCounter.SwitchTargetsOk, 1); + this.builder.appendULeb(indexInStack + offset); + } else { + modifyCounter(JiterpCounter.SwitchTargetsFailed, 1); + if (this.trace > 0) + mono_log_info(`Switch target ${target} not found in block stack ${this.blockStack}`); + this.builder.appendULeb(0); + } + } + const fallthroughIndex = this.blockStack.indexOf(segment.fallthrough); + if (fallthroughIndex >= 0) { + modifyCounter(JiterpCounter.SwitchTargetsOk, 1); + this.builder.appendULeb(fallthroughIndex + offset); + } else { + modifyCounter(JiterpCounter.SwitchTargetsFailed, 1); + if (this.trace > 0) + mono_log_info(`Switch fallthrough ${segment.fallthrough} not found in block stack ${this.blockStack}`); + this.builder.appendULeb(0); + } + this.builder.appendU8(WasmOpcode.unreachable); + break; + } case "branch": { const lookupTarget = segment.isBackward ? dispatchIp : segment.target; let indexInStack = this.blockStack.indexOf(lookupTarget), @@ -1965,6 +2021,7 @@ export type JiterpreterOptions = { tableSize: number; aotTableSize: number; maxModuleSize: number; + maxSwitchSize: number; } const optionNames: { [jsName: string]: string } = { @@ -2002,6 +2059,7 @@ const optionNames: { [jsName: string]: string } = { "tableSize": "jiterpreter-table-size", "aotTableSize": "jiterpreter-aot-table-size", "maxModuleSize": "jiterpreter-max-module-size", + "maxSwitchSize": "jiterpreter-max-switch-size", }; let optionsVersion = -1; diff --git a/src/mono/browser/runtime/jiterpreter-trace-generator.ts b/src/mono/browser/runtime/jiterpreter-trace-generator.ts index da85d7ab02163..b994927d308d8 100644 --- a/src/mono/browser/runtime/jiterpreter-trace-generator.ts +++ b/src/mono/browser/runtime/jiterpreter-trace-generator.ts @@ -184,6 +184,34 @@ function getOpcodeLengthU16 (ip: MintOpcodePtr, opcode: MintOpcode) { } } +function decodeSwitch (ip: MintOpcodePtr) : MintOpcodePtr[] { + mono_assert(getU16(ip) === MintOpcode.MINT_SWITCH, "decodeSwitch called on a non-switch"); + const n = getArgU32(ip, 2); + const result = []; + /* + guint32 val = LOCAL_VAR (ip [1], guint32); + guint32 n = READ32 (ip + 2); + ip += 4; + if (val < n) { + ip += 2 * val; + int offset = READ32 (ip); + ip += offset; + } else { + ip += 2 * n; + } + */ + // mono_log_info(`switch[${n}] @${ip}`); + for (let i = 0; i < n; i++) { + const base = ip + 8 + (4 * i), + offset = getU32_unaligned(base), + target = base + (offset * 2); + // mono_log_info(` ${i} -> ${target}`); + result.push(target); + } + + return result; +} + // Perform a quick scan through the opcodes potentially in this trace to build a table of // backwards branch targets, compatible with the layout of the old one that was generated in C. // We do this here to match the exact way that the jiterp calculates branch targets, since @@ -205,47 +233,60 @@ export function generateBackwardBranchTable ( const opcode = getU16(ip); const opLengthU16 = getOpcodeLengthU16(ip, opcode); - // Any opcode with a branch argtype will have a decoded displacement, even if we don't - // implement the opcode. Everything else will return undefined here and be skipped - const displacement = getBranchDisplacement(ip, opcode); - if (typeof (displacement) !== "number") { - ip += (opLengthU16 * 2); - continue; - } - - // These checks shouldn't fail unless memory is corrupted or something is wrong with the decoder. - // We don't want to cause decoder bugs to make the application exit, though - graceful degradation. - if (displacement === 0) { - mono_log_info(`opcode @${ip} branch target is self. aborting backbranch table generation`); - break; - } + if (opcode === MintOpcode.MINT_SWITCH) { + // FIXME: Once the cfg supports back-branches in jump tables, uncomment this to + // insert the back-branch targets into the table so they'll actually work + /* + const switchTable = decodeSwitch(ip); + for (const target of switchTable) { + const rtarget16 = (target - startOfBody) / 2; + if (target < ip) + table.push(rtarget16); + } + */ + } else { + // Any opcode with a branch argtype will have a decoded displacement, even if we don't + // implement the opcode. Everything else will return undefined here and be skipped + const displacement = getBranchDisplacement(ip, opcode); + if (typeof (displacement) !== "number") { + ip += (opLengthU16 * 2); + continue; + } - // Only record *backward* branches - // We will filter this down further in the Cfg because it takes note of which branches it sees, - // but it is also beneficial to have a null table (further down) due to seeing no potential - // back branch targets at all, as it allows the Cfg to skip additional code generation entirely - // if it knows there will never be any backwards branches in a given trace - if (displacement < 0) { - const rtarget16 = rip16 + (displacement); - if (rtarget16 < 0) { - mono_log_info(`opcode @${ip}'s displacement of ${displacement} goes before body: ${rtarget16}. aborting backbranch table generation`); + // These checks shouldn't fail unless memory is corrupted or something is wrong with the decoder. + // We don't want to cause decoder bugs to make the application exit, though - graceful degradation. + if (displacement === 0) { + mono_log_info(`opcode @${ip} branch target is self. aborting backbranch table generation`); break; } - // If the relative target is before the start of the trace, don't record it. - // The trace will be unable to successfully branch to it so it would just make the table bigger. - if (rtarget16 >= rbase16) - table.push(rtarget16); - } + // Only record *backward* branches + // We will filter this down further in the Cfg because it takes note of which branches it sees, + // but it is also beneficial to have a null table (further down) due to seeing no potential + // back branch targets at all, as it allows the Cfg to skip additional code generation entirely + // if it knows there will never be any backwards branches in a given trace + if (displacement < 0) { + const rtarget16 = rip16 + (displacement); + if (rtarget16 < 0) { + mono_log_info(`opcode @${ip}'s displacement of ${displacement} goes before body: ${rtarget16}. aborting backbranch table generation`); + break; + } - switch (opcode) { - case MintOpcode.MINT_CALL_HANDLER: - case MintOpcode.MINT_CALL_HANDLER_S: - // While this formally isn't a backward branch target, we want to record - // the offset of its following instruction so that the jiterpreter knows - // to generate the necessary dispatch code to enable branching back to it. - table.push(rip16 + opLengthU16); - break; + // If the relative target is before the start of the trace, don't record it. + // The trace will be unable to successfully branch to it so it would just make the table bigger. + if (rtarget16 >= rbase16) + table.push(rtarget16); + } + + switch (opcode) { + case MintOpcode.MINT_CALL_HANDLER: + case MintOpcode.MINT_CALL_HANDLER_S: + // While this formally isn't a backward branch target, we want to record + // the offset of its following instruction so that the jiterpreter knows + // to generate the necessary dispatch code to enable branching back to it. + table.push(rip16 + opLengthU16); + break; + } } ip += (opLengthU16 * 2); @@ -399,7 +440,7 @@ export function generateWasmBody ( switch (opcode) { case MintOpcode.MINT_SWITCH: { - if (!emit_switch(builder, ip)) + if (!emit_switch(builder, ip, exitOpcodeCounter)) ip = abort; break; } @@ -4036,7 +4077,39 @@ function emit_atomics ( return false; } -function emit_switch (builder: WasmBuilder, ip: MintOpcodePtr) : boolean { - append_bailout(builder, ip, BailoutReason.Switch); +function emit_switch (builder: WasmBuilder, ip: MintOpcodePtr, exitOpcodeCounter: number) : boolean { + const lengthU16 = getOpcodeLengthU16(ip, MintOpcode.MINT_SWITCH), + table = decodeSwitch(ip); + let failed = false; + + if (table.length > builder.options.maxSwitchSize) { + failed = true; + } else { + // Record all the switch's forward branch targets. + // If it contains any back branches they will bailout at runtime. + for (const target of table) { + if (target > ip) + builder.branchTargets.add(target); + } + } + + if (failed) { + modifyCounter(JiterpCounter.SwitchTargetsFailed, table.length); + append_bailout(builder, ip, BailoutReason.SwitchSize); + return true; + } + + const fallthrough = ip + (lengthU16 * 2); + builder.branchTargets.add(fallthrough); + + // Jump table needs a block so it can `br 0` for missing targets + builder.block(); + // Load selector + append_ldloc(builder, getArgU16(ip, 1), WasmOpcode.i32_load); + // Dispatch + builder.cfg.jumpTable(table, fallthrough); + // Missing target + builder.endBlock(); + append_exit(builder, ip, exitOpcodeCounter, BailoutReason.SwitchTarget); return true; } diff --git a/src/mono/browser/runtime/jiterpreter.ts b/src/mono/browser/runtime/jiterpreter.ts index 68130b681358a..f665d0cccc7dc 100644 --- a/src/mono/browser/runtime/jiterpreter.ts +++ b/src/mono/browser/runtime/jiterpreter.ts @@ -79,6 +79,7 @@ export const callTargetCounts: { [method: number]: number } = {}; export let mostRecentTrace: InstrumentedTraceState | undefined; export let mostRecentOptions: JiterpreterOptions | undefined = undefined; +export let traceTableIsFull = false; // You can disable an opcode for debugging purposes by adding it to this list, // instead of aborting the trace it will insert a bailout instead. This means that you will @@ -861,6 +862,11 @@ function generate_wasm ( idx = presetFunctionPointer; } else { idx = addWasmFunctionPointer(JiterpreterTable.Trace, fn); + if (idx === 0) { + // Failed to add function pointer because trace table is full. Disable future + // trace generation to reduce CPU usage. + traceTableIsFull = true; + } } if (trace >= 2) mono_log_info(`${traceName} -> fn index ${idx}`); @@ -984,6 +990,8 @@ export function mono_interp_tier_prepare_jiterpreter ( return JITERPRETER_NOT_JITTED; else if (mostRecentOptions.wasmBytesLimit <= getCounter(JiterpCounter.BytesGenerated)) return JITERPRETER_NOT_JITTED; + else if (traceTableIsFull) + return JITERPRETER_NOT_JITTED; let info = traceInfo[index]; @@ -1078,7 +1086,9 @@ export function jiterpreter_dump_stats (concise?: boolean): void { traceCandidates = getCounter(JiterpCounter.TraceCandidates), bytesGenerated = getCounter(JiterpCounter.BytesGenerated), elapsedGenerationMs = getCounter(JiterpCounter.ElapsedGenerationMs), - elapsedCompilationMs = getCounter(JiterpCounter.ElapsedCompilationMs); + elapsedCompilationMs = getCounter(JiterpCounter.ElapsedCompilationMs), + switchTargetsOk = getCounter(JiterpCounter.SwitchTargetsOk), + switchTargetsFailed = getCounter(JiterpCounter.SwitchTargetsFailed); const backBranchHitRate = (backBranchesEmitted / (backBranchesEmitted + backBranchesNotEmitted)) * 100, tracesRejected = cwraps.mono_jiterp_get_rejected_trace_count(), @@ -1089,8 +1099,8 @@ export function jiterpreter_dump_stats (concise?: boolean): void { mostRecentOptions.directJitCalls ? `direct jit calls: ${directJitCallsCompiled} (${(directJitCallsCompiled / jitCallsCompiled * 100).toFixed(1)}%)` : "direct jit calls: off" ) : ""; - mono_log_info(`// jitted ${bytesGenerated} bytes; ${tracesCompiled} traces (${(tracesCompiled / traceCandidates * 100).toFixed(1)}%) (${tracesRejected} rejected); ${jitCallsCompiled} jit_calls; ${entryWrappersCompiled} interp_entries`); - mono_log_info(`// cknulls eliminated: ${nullChecksEliminatedText}, fused: ${nullChecksFusedText}; back-branches ${backBranchesEmittedText}; ${directJitCallsText}`); + mono_log_info(`// jitted ${bytesGenerated}b; ${tracesCompiled} traces (${(tracesCompiled / traceCandidates * 100).toFixed(1)}%) (${tracesRejected} rejected); ${jitCallsCompiled} jit_calls; ${entryWrappersCompiled} interp_entries`); + mono_log_info(`// cknulls pruned: ${nullChecksEliminatedText}, fused: ${nullChecksFusedText}; back-brs ${backBranchesEmittedText}; switch tgts ${switchTargetsOk}/${switchTargetsFailed + switchTargetsOk}; ${directJitCallsText}`); mono_log_info(`// time: ${elapsedGenerationMs | 0}ms generating, ${elapsedCompilationMs | 0}ms compiling wasm.`); if (concise) return; diff --git a/src/mono/mono/mini/interp/jiterpreter.c b/src/mono/mono/mini/interp/jiterpreter.c index fa75f766cec02..362f01a2390b2 100644 --- a/src/mono/mono/mini/interp/jiterpreter.c +++ b/src/mono/mono/mini/interp/jiterpreter.c @@ -1257,7 +1257,9 @@ enum { JITERP_COUNTER_BACK_BRANCHES_NOT_EMITTED, JITERP_COUNTER_ELAPSED_GENERATION, JITERP_COUNTER_ELAPSED_COMPILATION, - JITERP_COUNTER_MAX = JITERP_COUNTER_ELAPSED_COMPILATION + JITERP_COUNTER_SWITCH_TARGETS_OK, + JITERP_COUNTER_SWITCH_TARGETS_FAILED, + JITERP_COUNTER_MAX = JITERP_COUNTER_SWITCH_TARGETS_FAILED }; #define JITERP_COUNTER_UNIT 100 diff --git a/src/mono/mono/utils/options-def.h b/src/mono/mono/utils/options-def.h index 4d3bdddaf7073..ddfc314b6bc7b 100644 --- a/src/mono/mono/utils/options-def.h +++ b/src/mono/mono/utils/options-def.h @@ -177,6 +177,7 @@ DEFINE_INT(jiterpreter_table_size, "jiterpreter-table-size", 6 * 1024, "Size of // FIXME: In the future if we find a way to reduce the number of unique tables we can raise this constant DEFINE_INT(jiterpreter_aot_table_size, "jiterpreter-aot-table-size", 3 * 1024, "Size of the jiterpreter AOT trampoline function tables") DEFINE_INT(jiterpreter_max_module_size, "jiterpreter-max-module-size", 4080, "Size limit for jiterpreter generated WASM modules") +DEFINE_INT(jiterpreter_max_switch_size, "jiterpreter-max-switch-size", 24, "Size limit for jiterpreter switch opcodes (0 to disable)") #endif // HOST_BROWSER #if defined(TARGET_WASM) || defined(TARGET_IOS) || defined(TARGET_TVOS) || defined (TARGET_MACCAT)