From f5221f4d1fae6f3da1bd4d1647e45f40e2b055c4 Mon Sep 17 00:00:00 2001 From: mbaxter Date: Wed, 17 Jul 2024 15:02:26 -0400 Subject: [PATCH] cannon: Add basic types for MTCannon (#11109) * cannon: Rename StepWitness.MemProofs to ProofData * cannon: Add MTState type (in progress) * cannon: Tweak MtState tests to cover more ground * cannon: Add test for MTState.UpdateCurrentThread() * cannon: Use constants for byte size vars, set byte slice capacities * cannon: Add StepsSinceLastContextSwitch field * cannon: Rename witness offset constants * cannon: Rename ThreadContext to ThreadState * cannon: Panic on unimplemented method calls * cannon: Compute thread stack roots lazily * cannon: Push initial thread to left stack --- cannon/cmd/load_elf.go | 2 +- cannon/cmd/run.go | 2 +- cannon/mipsevm/evm_test.go | 10 +- cannon/mipsevm/instrumented.go | 4 +- cannon/mipsevm/patch.go | 32 ++-- cannon/mipsevm/state.go | 44 ++++-- cannon/mipsevm/state_mt.go | 271 ++++++++++++++++++++++++++++++++ cannon/mipsevm/state_mt_test.go | 126 +++++++++++++++ cannon/mipsevm/state_test.go | 6 +- cannon/mipsevm/witness.go | 10 +- 10 files changed, 462 insertions(+), 45 deletions(-) create mode 100644 cannon/mipsevm/state_mt.go create mode 100644 cannon/mipsevm/state_mt_test.go diff --git a/cannon/cmd/load_elf.go b/cannon/cmd/load_elf.go index 7c41fc3c0ef6..f80d48984067 100644 --- a/cannon/cmd/load_elf.go +++ b/cannon/cmd/load_elf.go @@ -46,7 +46,7 @@ func LoadELF(ctx *cli.Context) error { if elfProgram.Machine != elf.EM_MIPS { return fmt.Errorf("ELF is not big-endian MIPS R3000, but got %q", elfProgram.Machine.String()) } - state, err := mipsevm.LoadELF(elfProgram) + state, err := mipsevm.LoadELF(elfProgram, mipsevm.CreateInitialState) if err != nil { return fmt.Errorf("failed to load ELF data into VM state: %w", err) } diff --git a/cannon/cmd/run.go b/cannon/cmd/run.go index df1870d73554..4b88b279e8fb 100644 --- a/cannon/cmd/run.go +++ b/cannon/cmd/run.go @@ -446,7 +446,7 @@ func Run(ctx *cli.Context) error { Pre: witness.StateHash, Post: postStateHash, StateData: witness.State, - ProofData: witness.MemProof, + ProofData: witness.ProofData, } if witness.HasPreimage() { proof.OracleKey = witness.PreimageKey[:] diff --git a/cannon/mipsevm/evm_test.go b/cannon/mipsevm/evm_test.go index cb6bfdac39da..22f21d0ce137 100644 --- a/cannon/mipsevm/evm_test.go +++ b/cannon/mipsevm/evm_test.go @@ -113,7 +113,7 @@ func (m *MIPSEVM) Step(t *testing.T, stepWitness *StepWitness, step uint64) []by } func encodeStepInput(t *testing.T, wit *StepWitness, localContext LocalContext, mips *foundry.Artifact) []byte { - input, err := mips.ABI.Pack("step", wit.State, wit.MemProof, localContext) + input, err := mips.ABI.Pack("step", wit.State, wit.ProofData, localContext) require.NoError(t, err) return input } @@ -485,8 +485,8 @@ func TestEVMFault(t *testing.T) { insnProof := initialState.Memory.MerkleProof(0) encodedWitness, _ := initialState.EncodeWitness() stepWitness := &StepWitness{ - State: encodedWitness, - MemProof: insnProof[:], + State: encodedWitness, + ProofData: insnProof[:], } input := encodeStepInput(t, stepWitness, LocalContext{}, contracts.MIPS) startingGas := uint64(30_000_000) @@ -509,7 +509,7 @@ func TestHelloEVM(t *testing.T) { elfProgram, err := elf.Open("../example/bin/hello.elf") require.NoError(t, err, "open ELF file") - state, err := LoadELF(elfProgram) + state, err := LoadELF(elfProgram, CreateInitialState) require.NoError(t, err, "load ELF into state") err = PatchGo(elfProgram, state) @@ -560,7 +560,7 @@ func TestClaimEVM(t *testing.T) { elfProgram, err := elf.Open("../example/bin/claim.elf") require.NoError(t, err, "open ELF file") - state, err := LoadELF(elfProgram) + state, err := LoadELF(elfProgram, CreateInitialState) require.NoError(t, err, "load ELF into state") err = PatchGo(elfProgram, state) diff --git a/cannon/mipsevm/instrumented.go b/cannon/mipsevm/instrumented.go index fad3e541c1a1..cb6520e719d6 100644 --- a/cannon/mipsevm/instrumented.go +++ b/cannon/mipsevm/instrumented.go @@ -83,7 +83,7 @@ func (m *InstrumentedState) Step(proof bool) (wit *StepWitness, err error) { wit = &StepWitness{ State: encodedWitness, StateHash: stateHash, - MemProof: insnProof[:], + ProofData: insnProof[:], } } err = m.mipsStep() @@ -92,7 +92,7 @@ func (m *InstrumentedState) Step(proof bool) (wit *StepWitness, err error) { } if proof { - wit.MemProof = append(wit.MemProof, m.memProof[:]...) + wit.ProofData = append(wit.ProofData, m.memProof[:]...) if m.lastPreimageOffset != ^uint32(0) { wit.PreimageOffset = m.lastPreimageOffset wit.PreimageKey = m.lastPreimageKey diff --git a/cannon/mipsevm/patch.go b/cannon/mipsevm/patch.go index 47abb41e0915..26ce8337d86a 100644 --- a/cannon/mipsevm/patch.go +++ b/cannon/mipsevm/patch.go @@ -10,21 +10,11 @@ import ( const HEAP_START = 0x05000000 -func LoadELF(f *elf.File) (*State, error) { - s := &State{ - Cpu: CpuScalars{ - PC: uint32(f.Entry), - NextPC: uint32(f.Entry + 4), - LO: 0, - HI: 0, - }, - Heap: HEAP_START, - Registers: [32]uint32{}, - Memory: NewMemory(), - ExitCode: 0, - Exited: false, - Step: 0, - } +type CreateFPVMState[T FPVMState] func(pc, heapStart uint32) T + +func LoadELF[T FPVMState](f *elf.File, initState CreateFPVMState[T]) (T, error) { + var empty T + s := initState(uint32(f.Entry), HEAP_START) for i, prog := range f.Progs { if prog.Type == 0x70000003 { // MIPS_ABIFLAGS @@ -37,21 +27,21 @@ func LoadELF(f *elf.File) (*State, error) { if prog.Filesz < prog.Memsz { r = io.MultiReader(r, bytes.NewReader(make([]byte, prog.Memsz-prog.Filesz))) } else { - return nil, fmt.Errorf("invalid PT_LOAD program segment %d, file size (%d) > mem size (%d)", i, prog.Filesz, prog.Memsz) + return empty, fmt.Errorf("invalid PT_LOAD program segment %d, file size (%d) > mem size (%d)", i, prog.Filesz, prog.Memsz) } } else { - return nil, fmt.Errorf("program segment %d has different file size (%d) than mem size (%d): filling for non PT_LOAD segments is not supported", i, prog.Filesz, prog.Memsz) + return empty, fmt.Errorf("program segment %d has different file size (%d) than mem size (%d): filling for non PT_LOAD segments is not supported", i, prog.Filesz, prog.Memsz) } } if prog.Vaddr+prog.Memsz >= uint64(1<<32) { - return nil, fmt.Errorf("program %d out of 32-bit mem range: %x - %x (size: %x)", i, prog.Vaddr, prog.Vaddr+prog.Memsz, prog.Memsz) + return empty, fmt.Errorf("program %d out of 32-bit mem range: %x - %x (size: %x)", i, prog.Vaddr, prog.Vaddr+prog.Memsz, prog.Memsz) } if prog.Vaddr+prog.Memsz >= HEAP_START { - return nil, fmt.Errorf("program %d overlaps with heap: %x - %x (size: %x). The heap start offset must be reconfigured", i, prog.Vaddr, prog.Vaddr+prog.Memsz, prog.Memsz) + return empty, fmt.Errorf("program %d overlaps with heap: %x - %x (size: %x). The heap start offset must be reconfigured", i, prog.Vaddr, prog.Vaddr+prog.Memsz, prog.Memsz) } - if err := s.Memory.SetMemoryRange(uint32(prog.Vaddr), r); err != nil { - return nil, fmt.Errorf("failed to read program segment %d: %w", i, err) + if err := s.GetMemory().SetMemoryRange(uint32(prog.Vaddr), r); err != nil { + return empty, fmt.Errorf("failed to read program segment %d: %w", i, err) } } diff --git a/cannon/mipsevm/state.go b/cannon/mipsevm/state.go index 474f80021969..0679e4d056d5 100644 --- a/cannon/mipsevm/state.go +++ b/cannon/mipsevm/state.go @@ -10,8 +10,8 @@ import ( "github.com/ethereum/go-ethereum/crypto" ) -// StateWitnessSize is the size of the state witness encoding in bytes. -var StateWitnessSize = 226 +// STATE_WITNESS_SIZE is the size of the state witness encoding in bytes. +const STATE_WITNESS_SIZE = 226 type CpuScalars struct { PC uint32 `json:"pc"` @@ -48,6 +48,32 @@ type State struct { LastHint hexutil.Bytes `json:"lastHint,omitempty"` } +func CreateEmptyState() *State { + return &State{ + Cpu: CpuScalars{ + PC: 0, + NextPC: 0, + LO: 0, + HI: 0, + }, + Heap: 0, + Registers: [32]uint32{}, + Memory: NewMemory(), + ExitCode: 0, + Exited: false, + Step: 0, + } +} + +func CreateInitialState(pc, heapStart uint32) *State { + state := CreateEmptyState() + state.Cpu.PC = pc + state.Cpu.NextPC = pc + 4 + state.Heap = heapStart + + return state +} + type stateMarshaling struct { Memory *Memory `json:"memory"` PreimageKey common.Hash `json:"preimageKey"` @@ -121,7 +147,7 @@ func (s *State) GetMemory() *Memory { } func (s *State) EncodeWitness() ([]byte, common.Hash) { - out := make([]byte, 0) + out := make([]byte, 0, STATE_WITNESS_SIZE) memRoot := s.Memory.MerkleRoot() out = append(out, memRoot[:]...) out = append(out, s.PreimageKey[:]...) @@ -132,11 +158,7 @@ func (s *State) EncodeWitness() ([]byte, common.Hash) { out = binary.BigEndian.AppendUint32(out, s.Cpu.HI) out = binary.BigEndian.AppendUint32(out, s.Heap) out = append(out, s.ExitCode) - if s.Exited { - out = append(out, 1) - } else { - out = append(out, 0) - } + out = AppendBoolToWitness(out, s.Exited) out = binary.BigEndian.AppendUint64(out, s.Step) for _, r := range s.Registers { out = binary.BigEndian.AppendUint32(out, r) @@ -154,14 +176,14 @@ const ( ) func (sw StateWitness) StateHash() (common.Hash, error) { - if len(sw) != 226 { - return common.Hash{}, fmt.Errorf("Invalid witness length. Got %d, expected 226", len(sw)) + if len(sw) != STATE_WITNESS_SIZE { + return common.Hash{}, fmt.Errorf("Invalid witness length. Got %d, expected %d", len(sw), STATE_WITNESS_SIZE) } return stateHashFromWitness(sw), nil } func stateHashFromWitness(sw []byte) common.Hash { - if len(sw) != 226 { + if len(sw) != STATE_WITNESS_SIZE { panic("Invalid witness length") } hash := crypto.Keccak256Hash(sw) diff --git a/cannon/mipsevm/state_mt.go b/cannon/mipsevm/state_mt.go new file mode 100644 index 000000000000..ed02f3261038 --- /dev/null +++ b/cannon/mipsevm/state_mt.go @@ -0,0 +1,271 @@ +package mipsevm + +import ( + "encoding/binary" + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/crypto" +) + +// SERIALIZED_THREAD_SIZE is the size of a serialized ThreadState object +const SERIALIZED_THREAD_SIZE = 166 + +// THREAD_WITNESS_SIZE is the size of a thread witness encoded in bytes. +// +// It consists of the active thread serialized and concatenated with the +// 32 byte hash onion of the active thread stack without the active thread +const THREAD_WITNESS_SIZE = SERIALIZED_THREAD_SIZE + 32 + +// The empty thread root - keccak256(bytes32(0) ++ bytes32(0)) +var EmptyThreadsRoot common.Hash = common.HexToHash("0xad3228b676f7d3cd4284a5443f17f1962b36e491b30a40b2405849e597ba5fb5") + +type ThreadState struct { + ThreadId uint32 `json:"threadId"` + ExitCode uint8 `json:"exit"` + Exited bool `json:"exited"` + FutexAddr uint32 `json:"futexAddr"` + FutexVal uint32 `json:"futexVal"` + FutexTimeoutStep uint64 `json:"futexTimeoutStep"` + Cpu CpuScalars `json:"cpu"` + Registers [32]uint32 `json:"registers"` +} + +func (t *ThreadState) serializeThread() []byte { + out := make([]byte, 0, SERIALIZED_THREAD_SIZE) + + out = binary.BigEndian.AppendUint32(out, t.ThreadId) + out = append(out, t.ExitCode) + out = AppendBoolToWitness(out, t.Exited) + out = binary.BigEndian.AppendUint32(out, t.FutexAddr) + out = binary.BigEndian.AppendUint32(out, t.FutexVal) + out = binary.BigEndian.AppendUint64(out, t.FutexTimeoutStep) + + out = binary.BigEndian.AppendUint32(out, t.Cpu.PC) + out = binary.BigEndian.AppendUint32(out, t.Cpu.NextPC) + out = binary.BigEndian.AppendUint32(out, t.Cpu.LO) + out = binary.BigEndian.AppendUint32(out, t.Cpu.HI) + + for _, r := range t.Registers { + out = binary.BigEndian.AppendUint32(out, r) + } + + return out +} + +func computeThreadRoot(prevStackRoot common.Hash, threadToPush *ThreadState) common.Hash { + hashedThread := crypto.Keccak256Hash(threadToPush.serializeThread()) + + var hashData []byte + hashData = append(hashData, prevStackRoot[:]...) + hashData = append(hashData, hashedThread[:]...) + + return crypto.Keccak256Hash(hashData) +} + +// MT_STATE_WITNESS_SIZE is the size of the state witness encoding in bytes. +const MT_STATE_WITNESS_SIZE = 163 +const ( + MEMROOT_MT_WITNESS_OFFSET = 0 + PREIMAGE_KEY_MT_WITNESS_OFFSET = MEMROOT_MT_WITNESS_OFFSET + 32 + PREIMAGE_OFFSET_MT_WITNESS_OFFSET = PREIMAGE_KEY_MT_WITNESS_OFFSET + 32 + HEAP_MT_WITNESS_OFFSET = PREIMAGE_OFFSET_MT_WITNESS_OFFSET + 4 + EXITCODE_MT_WITNESS_OFFSET = HEAP_MT_WITNESS_OFFSET + 4 + EXITED_MT_WITNESS_OFFSET = EXITCODE_MT_WITNESS_OFFSET + 1 + STEP_MT_WITNESS_OFFSET = EXITED_MT_WITNESS_OFFSET + 1 + STEPS_SINCE_CONTEXT_SWITCH_MT_WITNESS_OFFSET = STEP_MT_WITNESS_OFFSET + 8 + WAKEUP_MT_WITNESS_OFFSET = STEPS_SINCE_CONTEXT_SWITCH_MT_WITNESS_OFFSET + 8 + TRAVERSE_RIGHT_MT_WITNESS_OFFSET = WAKEUP_MT_WITNESS_OFFSET + 4 + LEFT_THREADS_ROOT_MT_WITNESS_OFFSET = TRAVERSE_RIGHT_MT_WITNESS_OFFSET + 1 + RIGHT_THREADS_ROOT_MT_WITNESS_OFFSET = LEFT_THREADS_ROOT_MT_WITNESS_OFFSET + 32 + THREAD_ID_MT_WITNESS_OFFSET = RIGHT_THREADS_ROOT_MT_WITNESS_OFFSET + 32 +) + +type MTState struct { + Memory *Memory `json:"memory"` + + PreimageKey common.Hash `json:"preimageKey"` + PreimageOffset uint32 `json:"preimageOffset"` // note that the offset includes the 8-byte length prefix + + Heap uint32 `json:"heap"` // to handle mmap growth + + ExitCode uint8 `json:"exit"` + Exited bool `json:"exited"` + + Step uint64 `json:"step"` + StepsSinceLastContextSwitch uint64 `json:"stepsSinceLastContextSwitch"` + Wakeup uint32 `json:"wakeup"` + + TraverseRight bool `json:"traverseRight"` + LeftThreadStack []ThreadState `json:"leftThreadStack"` + RightThreadStack []ThreadState `json:"rightThreadStack"` + NextThreadId uint32 `json:"nextThreadId"` + + // LastHint is optional metadata, and not part of the VM state itself. + // It is used to remember the last pre-image hint, + // so a VM can start from any state without fetching prior pre-images, + // and instead just repeat the last hint on setup, + // to make sure pre-image requests can be served. + // The first 4 bytes are a uin32 length prefix. + // Warning: the hint MAY NOT BE COMPLETE. I.e. this is buffered, + // and should only be read when len(LastHint) > 4 && uint32(LastHint[:4]) <= len(LastHint[4:]) + LastHint hexutil.Bytes `json:"lastHint,omitempty"` +} + +func CreateEmptyMTState() *MTState { + initThreadId := uint32(0) + initThread := ThreadState{ + ThreadId: initThreadId, + ExitCode: 0, + Exited: false, + Cpu: CpuScalars{ + PC: 0, + NextPC: 0, + LO: 0, + HI: 0, + }, + FutexAddr: ^uint32(0), + FutexVal: 0, + FutexTimeoutStep: 0, + Registers: [32]uint32{}, + } + + return &MTState{ + Memory: NewMemory(), + Heap: 0, + ExitCode: 0, + Exited: false, + Step: 0, + Wakeup: ^uint32(0), + TraverseRight: false, + LeftThreadStack: []ThreadState{initThread}, + RightThreadStack: []ThreadState{}, + NextThreadId: initThreadId + 1, + } +} + +func CreateInitialMTState(pc, heapStart uint32) *MTState { + state := CreateEmptyMTState() + currentThread := state.getCurrentThread() + currentThread.Cpu.PC = pc + currentThread.Cpu.NextPC = pc + 4 + state.Heap = heapStart + + return state +} + +func (s *MTState) getCurrentThread() *ThreadState { + activeStack := s.getActiveThreadStack() + + activeStackSize := len(activeStack) + if activeStackSize == 0 { + panic("Active thread stack is empty") + } + + return &activeStack[activeStackSize-1] +} + +type ThreadMutator func(thread *ThreadState) + +func (s *MTState) getActiveThreadStack() []ThreadState { + var activeStack []ThreadState + if s.TraverseRight { + activeStack = s.RightThreadStack + } else { + activeStack = s.LeftThreadStack + } + + return activeStack +} + +func (s *MTState) getRightThreadStackRoot() common.Hash { + return s.calculateThreadStackRoot(s.RightThreadStack) +} + +func (s *MTState) getLeftThreadStackRoot() common.Hash { + return s.calculateThreadStackRoot(s.LeftThreadStack) +} + +func (s *MTState) calculateThreadStackRoot(stack []ThreadState) common.Hash { + curRoot := EmptyThreadsRoot + for _, thread := range stack { + curRoot = computeThreadRoot(curRoot, &thread) + } + + return curRoot +} + +func (s *MTState) PreemptThread() { + // TODO(CP-903) + panic("Not Implemented") +} + +func (s *MTState) PushThread(thread *ThreadState) { + // TODO(CP-903) + panic("Not Implemented") +} + +func (s *MTState) GetPC() uint32 { + activeThread := s.getCurrentThread() + return activeThread.Cpu.PC +} + +func (s *MTState) GetExitCode() uint8 { return s.ExitCode } + +func (s *MTState) GetExited() bool { return s.Exited } + +func (s *MTState) GetStep() uint64 { return s.Step } + +func (s *MTState) VMStatus() uint8 { + return vmStatus(s.Exited, s.ExitCode) +} + +func (s *MTState) GetMemory() *Memory { + return s.Memory +} + +func (s *MTState) EncodeWitness() ([]byte, common.Hash) { + out := make([]byte, 0, MT_STATE_WITNESS_SIZE) + memRoot := s.Memory.MerkleRoot() + out = append(out, memRoot[:]...) + out = append(out, s.PreimageKey[:]...) + out = binary.BigEndian.AppendUint32(out, s.PreimageOffset) + out = binary.BigEndian.AppendUint32(out, s.Heap) + out = append(out, s.ExitCode) + out = AppendBoolToWitness(out, s.Exited) + + out = binary.BigEndian.AppendUint64(out, s.Step) + out = binary.BigEndian.AppendUint64(out, s.StepsSinceLastContextSwitch) + out = binary.BigEndian.AppendUint32(out, s.Wakeup) + + leftStackRoot := s.getLeftThreadStackRoot() + rightStackRoot := s.getRightThreadStackRoot() + out = AppendBoolToWitness(out, s.TraverseRight) + out = append(out, (leftStackRoot)[:]...) + out = append(out, (rightStackRoot)[:]...) + out = binary.BigEndian.AppendUint32(out, s.NextThreadId) + + return out, mtStateHashFromWitness(out) +} + +type MTStateWitness []byte + +func (sw MTStateWitness) StateHash() (common.Hash, error) { + if len(sw) != MT_STATE_WITNESS_SIZE { + return common.Hash{}, fmt.Errorf("Invalid witness length. Got %d, expected %d", len(sw), MT_STATE_WITNESS_SIZE) + } + return mtStateHashFromWitness(sw), nil +} + +func mtStateHashFromWitness(sw []byte) common.Hash { + if len(sw) != MT_STATE_WITNESS_SIZE { + panic("Invalid witness length") + } + hash := crypto.Keccak256Hash(sw) + exitCode := sw[EXITCODE_MT_WITNESS_OFFSET] + exited := sw[EXITED_MT_WITNESS_OFFSET] + status := vmStatus(exited == 1, exitCode) + hash[0] = status + return hash +} diff --git a/cannon/mipsevm/state_mt_test.go b/cannon/mipsevm/state_mt_test.go new file mode 100644 index 000000000000..a89af66300a3 --- /dev/null +++ b/cannon/mipsevm/state_mt_test.go @@ -0,0 +1,126 @@ +package mipsevm + +import ( + "debug/elf" + "encoding/json" + "testing" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/stretchr/testify/require" +) + +func setWitnessField(witness MTStateWitness, fieldOffset int, fieldData []byte) { + start := fieldOffset + end := fieldOffset + len(fieldData) + copy(witness[start:end], fieldData) +} + +// Run through all permutations of `exited` / `exitCode` and ensure that the +// correct witness, state hash, and VM Status is produced. +func TestMTState_EncodeWitness(t *testing.T) { + cases := []struct { + exited bool + exitCode uint8 + }{ + {exited: false, exitCode: 0}, + {exited: false, exitCode: 1}, + {exited: false, exitCode: 2}, + {exited: false, exitCode: 3}, + {exited: true, exitCode: 0}, + {exited: true, exitCode: 1}, + {exited: true, exitCode: 2}, + {exited: true, exitCode: 3}, + } + + heap := uint32(12) + preimageKey := crypto.Keccak256Hash([]byte{1, 2, 3, 4}) + preimageOffset := uint32(24) + step := uint64(33) + stepsSinceContextSwitch := uint64(123) + for _, c := range cases { + state := CreateEmptyMTState() + state.Exited = c.exited + state.ExitCode = c.exitCode + state.PreimageKey = preimageKey + state.PreimageOffset = preimageOffset + state.Heap = heap + state.Step = step + state.StepsSinceLastContextSwitch = stepsSinceContextSwitch + + memRoot := state.Memory.MerkleRoot() + leftStackRoot := state.calculateThreadStackRoot(state.LeftThreadStack) + rightStackRoot := EmptyThreadsRoot + + // Set up expected witness + expectedWitness := make(MTStateWitness, MT_STATE_WITNESS_SIZE) + setWitnessField(expectedWitness, MEMROOT_MT_WITNESS_OFFSET, memRoot[:]) + setWitnessField(expectedWitness, PREIMAGE_KEY_MT_WITNESS_OFFSET, preimageKey[:]) + setWitnessField(expectedWitness, PREIMAGE_OFFSET_MT_WITNESS_OFFSET, []byte{0, 0, 0, byte(preimageOffset)}) + setWitnessField(expectedWitness, HEAP_MT_WITNESS_OFFSET, []byte{0, 0, 0, byte(heap)}) + setWitnessField(expectedWitness, EXITCODE_MT_WITNESS_OFFSET, []byte{c.exitCode}) + if c.exited { + setWitnessField(expectedWitness, EXITED_MT_WITNESS_OFFSET, []byte{1}) + } + setWitnessField(expectedWitness, STEP_MT_WITNESS_OFFSET, []byte{0, 0, 0, 0, 0, 0, 0, byte(step)}) + setWitnessField(expectedWitness, STEPS_SINCE_CONTEXT_SWITCH_MT_WITNESS_OFFSET, []byte{0, 0, 0, 0, 0, 0, 0, byte(stepsSinceContextSwitch)}) + setWitnessField(expectedWitness, WAKEUP_MT_WITNESS_OFFSET, []byte{0xFF, 0xFF, 0xFF, 0xFF}) + setWitnessField(expectedWitness, TRAVERSE_RIGHT_MT_WITNESS_OFFSET, []byte{0}) + setWitnessField(expectedWitness, LEFT_THREADS_ROOT_MT_WITNESS_OFFSET, leftStackRoot[:]) + setWitnessField(expectedWitness, RIGHT_THREADS_ROOT_MT_WITNESS_OFFSET, rightStackRoot[:]) + setWitnessField(expectedWitness, THREAD_ID_MT_WITNESS_OFFSET, []byte{0, 0, 0, 1}) + + // Validate witness + actualWitness, actualStateHash := state.EncodeWitness() + require.Equal(t, len(actualWitness), MT_STATE_WITNESS_SIZE, "Incorrect witness size") + require.EqualValues(t, expectedWitness[:], actualWitness[:], "Incorrect witness") + // Validate witness hash + expectedStateHash := crypto.Keccak256Hash(actualWitness) + expectedStateHash[0] = vmStatus(c.exited, c.exitCode) + require.Equal(t, expectedStateHash, actualStateHash, "Incorrect state hash") + } +} + +func TestMTState_JSONCodec(t *testing.T) { + elfProgram, err := elf.Open("../example/bin/hello.elf") + require.NoError(t, err, "open ELF file") + state, err := LoadELF(elfProgram, CreateInitialMTState) + require.NoError(t, err, "load ELF into state") + // Set a few additional fields + state.PreimageKey = crypto.Keccak256Hash([]byte{1, 2, 3, 4}) + state.PreimageOffset = 4 + state.Heap = 555 + state.Step = 99_999 + state.StepsSinceLastContextSwitch = 123 + state.Exited = true + state.ExitCode = 2 + state.LastHint = []byte{11, 12, 13} + + stateJSON, err := json.Marshal(state) + require.NoError(t, err) + + var newState *MTState + err = json.Unmarshal(stateJSON, &newState) + require.NoError(t, err) + + require.Equal(t, state.PreimageKey, newState.PreimageKey) + require.Equal(t, state.PreimageOffset, newState.PreimageOffset) + require.Equal(t, state.Heap, newState.Heap) + require.Equal(t, state.ExitCode, newState.ExitCode) + require.Equal(t, state.Exited, newState.Exited) + require.Equal(t, state.Memory.MerkleRoot(), newState.Memory.MerkleRoot()) + require.Equal(t, state.Step, newState.Step) + require.Equal(t, state.StepsSinceLastContextSwitch, newState.StepsSinceLastContextSwitch) + require.Equal(t, state.Wakeup, newState.Wakeup) + require.Equal(t, state.TraverseRight, newState.TraverseRight) + require.Equal(t, state.LeftThreadStack, newState.LeftThreadStack) + require.Equal(t, state.RightThreadStack, newState.RightThreadStack) + require.Equal(t, state.NextThreadId, newState.NextThreadId) + require.Equal(t, state.LastHint, newState.LastHint) +} + +func TestMTState_EmptyThreadsRoot(t *testing.T) { + data := [64]byte{} + expectedEmptyRoot := crypto.Keccak256Hash(data[:]) + + require.Equal(t, expectedEmptyRoot, EmptyThreadsRoot) +} diff --git a/cannon/mipsevm/state_test.go b/cannon/mipsevm/state_test.go index fe36267926ca..d84226ebd201 100644 --- a/cannon/mipsevm/state_test.go +++ b/cannon/mipsevm/state_test.go @@ -105,7 +105,7 @@ func TestStateHash(t *testing.T) { } actualWitness, actualStateHash := state.EncodeWitness() - require.Equal(t, len(actualWitness), StateWitnessSize, "Incorrect witness size") + require.Equal(t, len(actualWitness), STATE_WITNESS_SIZE, "Incorrect witness size") expectedWitness := make(StateWitness, 226) memRoot := state.Memory.MerkleRoot() @@ -266,7 +266,7 @@ func loadELFProgram(t *testing.T, name string) *State { elfProgram, err := elf.Open(name) require.NoError(t, err, "open ELF file") - state, err := LoadELF(elfProgram) + state, err := LoadELF(elfProgram, CreateInitialState) require.NoError(t, err, "load ELF into state") err = PatchGo(elfProgram, state) @@ -337,7 +337,7 @@ func selectOracleFixture(t *testing.T, programName string) PreimageOracle { func TestStateJSONCodec(t *testing.T) { elfProgram, err := elf.Open("../example/bin/hello.elf") require.NoError(t, err, "open ELF file") - state, err := LoadELF(elfProgram) + state, err := LoadELF(elfProgram, CreateInitialState) require.NoError(t, err, "load ELF into state") stateJSON, err := state.MarshalJSON() diff --git a/cannon/mipsevm/witness.go b/cannon/mipsevm/witness.go index ef75db69c65c..bbe1241fc003 100644 --- a/cannon/mipsevm/witness.go +++ b/cannon/mipsevm/witness.go @@ -9,7 +9,7 @@ type StepWitness struct { State []byte StateHash common.Hash - MemProof []byte + ProofData []byte PreimageKey [32]byte // zeroed when no pre-image is accessed PreimageValue []byte // including the 8-byte length prefix @@ -19,3 +19,11 @@ type StepWitness struct { func (wit *StepWitness) HasPreimage() bool { return wit.PreimageKey != ([32]byte{}) } + +func AppendBoolToWitness(witnessData []byte, boolVal bool) []byte { + if boolVal { + return append(witnessData, 1) + } else { + return append(witnessData, 0) + } +}