Skip to content

Commit

Permalink
refactor(ccc): allow go side to construct a rust trace from json (#807)
Browse files Browse the repository at this point in the history
  • Loading branch information
omerfirmak committed Jun 7, 2024
1 parent dc57ae2 commit 8bd2a16
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 20 deletions.
2 changes: 1 addition & 1 deletion params/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
const (
VersionMajor = 5 // Major version component of the current release
VersionMinor = 3 // Minor version component of the current release
VersionPatch = 39 // Patch version component of the current release
VersionPatch = 40 // Patch version component of the current release
VersionMeta = "mainnet" // Version metadata to append to the version string
)

Expand Down
27 changes: 24 additions & 3 deletions rollup/circuitcapacitychecker/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@ import (
"encoding/json"
"fmt"
"sync"
"time"
"unsafe"

"github.com/scroll-tech/go-ethereum/core/types"
"github.com/scroll-tech/go-ethereum/log"
"github.com/scroll-tech/go-ethereum/metrics"
)

// mutex for concurrent CircuitCapacityChecker creations
var creationMu sync.Mutex
var (
creationMu sync.Mutex
encodeTimer = metrics.NewRegisteredTimer("ccc/encode", nil)
)

func init() {
C.init()
Expand Down Expand Up @@ -67,6 +72,7 @@ func (ccc *CircuitCapacityChecker) ApplyTransaction(traces *types.BlockTrace) (*
return nil, ErrUnknown
}

encodeStart := time.Now()
ccc.jsonBuffer.Reset()
err := json.NewEncoder(&ccc.jsonBuffer).Encode(traces)
if err != nil {
Expand All @@ -79,8 +85,15 @@ func (ccc *CircuitCapacityChecker) ApplyTransaction(traces *types.BlockTrace) (*
C.free(unsafe.Pointer(tracesStr))
}()

rustTrace := C.parse_json_to_rust_trace(tracesStr)
if rustTrace == nil {
log.Error("fail to parse json in to rust trace", "id", ccc.ID, "TxHash", traces.Transactions[0].TxHash)
return nil, ErrUnknown
}
encodeTimer.UpdateSince(encodeStart)

log.Debug("start to check circuit capacity for tx", "id", ccc.ID, "TxHash", traces.Transactions[0].TxHash)
rawResult := C.apply_tx(C.uint64_t(ccc.ID), tracesStr)
rawResult := C.apply_tx(C.uint64_t(ccc.ID), rustTrace)
defer func() {
C.free_c_chars(rawResult)
}()
Expand Down Expand Up @@ -114,6 +127,7 @@ func (ccc *CircuitCapacityChecker) ApplyBlock(traces *types.BlockTrace) (*types.
ccc.Lock()
defer ccc.Unlock()

encodeStart := time.Now()
ccc.jsonBuffer.Reset()
err := json.NewEncoder(&ccc.jsonBuffer).Encode(traces)
if err != nil {
Expand All @@ -126,8 +140,15 @@ func (ccc *CircuitCapacityChecker) ApplyBlock(traces *types.BlockTrace) (*types.
C.free(unsafe.Pointer(tracesStr))
}()

rustTrace := C.parse_json_to_rust_trace(tracesStr)
if rustTrace == nil {
log.Error("fail to parse json in to rust trace", "id", ccc.ID, "TxHash", traces.Transactions[0].TxHash)
return nil, ErrUnknown
}
encodeTimer.UpdateSince(encodeStart)

log.Debug("start to check circuit capacity for block", "id", ccc.ID, "blockNumber", traces.Header.Number, "blockHash", traces.Header.Hash())
rawResult := C.apply_block(C.uint64_t(ccc.ID), tracesStr)
rawResult := C.apply_block(C.uint64_t(ccc.ID), rustTrace)
defer func() {
C.free_c_chars(rawResult)
}()
Expand Down
5 changes: 3 additions & 2 deletions rollup/circuitcapacitychecker/libzkp/libzkp.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
void init();
uint64_t new_circuit_capacity_checker();
void reset_circuit_capacity_checker(uint64_t id);
char* apply_tx(uint64_t id, char *tx_traces);
char* apply_block(uint64_t id, char *block_trace);
char* apply_tx(uint64_t id, void* tx_traces);
char* apply_block(uint64_t id, void* block_trace);
char* get_tx_num(uint64_t id);
char* set_light_mode(uint64_t id, bool light_mode);
void free_c_chars(char* ptr);
void* parse_json_to_rust_trace(char* trace_json_ptr);
38 changes: 24 additions & 14 deletions rollup/circuitcapacitychecker/libzkp/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
pub mod checker {
use crate::utils::{c_char_to_str, c_char_to_vec, vec_to_c_char};
use crate::utils::vec_to_c_char;
use anyhow::{anyhow, bail, Error};
use libc::c_char;
use prover::{
zkevm::{CircuitCapacityChecker, RowUsage},
BlockTrace,
};
use serde_derive::{Deserialize, Serialize};
use std::cell::OnceCell;
use std::{cell::OnceCell, ptr::null_mut};
use std::collections::HashMap;
use std::panic;
use std::ptr::null;
use std::ffi::CStr;

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CommonResult {
Expand Down Expand Up @@ -43,6 +44,17 @@ pub mod checker {
.expect("circuit capacity checker initialized twice");
}

/// # Safety
#[no_mangle]
pub unsafe extern "C" fn parse_json_to_rust_trace(trace_json_ptr: *const c_char) -> *mut BlockTrace {
let trace_json_cstr = unsafe { CStr::from_ptr(trace_json_ptr) };
let trace = serde_json::from_slice::<BlockTrace>(trace_json_cstr.to_bytes());
match trace {
Err(_) => return null_mut(),
Ok(t) => return Box::into_raw(Box::new(t))
}
}

/// # Safety
#[no_mangle]
pub unsafe extern "C" fn new_circuit_capacity_checker() -> u64 {
Expand All @@ -68,8 +80,9 @@ pub mod checker {

/// # Safety
#[no_mangle]
pub unsafe extern "C" fn apply_tx(id: u64, tx_traces: *const c_char) -> *const c_char {
let result = apply_tx_inner(id, tx_traces);
pub unsafe extern "C" fn apply_tx(id: u64, trace_ptr: *mut BlockTrace) -> *const c_char {
let trace = Box::from_raw(trace_ptr);
let result = apply_tx_inner(id, *trace);
let r = match result {
Ok(acc_row_usage) => {
log::debug!(
Expand All @@ -90,14 +103,12 @@ pub mod checker {
serde_json::to_vec(&r).map_or(null(), vec_to_c_char)
}

unsafe fn apply_tx_inner(id: u64, tx_traces: *const c_char) -> Result<RowUsage, Error> {
unsafe fn apply_tx_inner(id: u64, traces: BlockTrace) -> Result<RowUsage, Error> {
log::debug!(
"ccc apply_tx raw input, id: {:?}, tx_traces: {:?}",
id,
c_char_to_str(tx_traces)?
traces
);
let tx_traces_vec = c_char_to_vec(tx_traces);
let traces = serde_json::from_slice::<BlockTrace>(&tx_traces_vec)?;

if traces.transactions.len() != 1 {
bail!("traces.transactions.len() != 1");
Expand Down Expand Up @@ -131,8 +142,9 @@ pub mod checker {

/// # Safety
#[no_mangle]
pub unsafe extern "C" fn apply_block(id: u64, block_trace: *const c_char) -> *const c_char {
let result = apply_block_inner(id, block_trace);
pub unsafe extern "C" fn apply_block(id: u64, trace_ptr: *mut BlockTrace) -> *const c_char {
let trace = Box::from_raw(trace_ptr);
let result = apply_block_inner(id, *trace);
let r = match result {
Ok(acc_row_usage) => {
log::debug!(
Expand All @@ -153,14 +165,12 @@ pub mod checker {
serde_json::to_vec(&r).map_or(null(), vec_to_c_char)
}

unsafe fn apply_block_inner(id: u64, block_trace: *const c_char) -> Result<RowUsage, Error> {
unsafe fn apply_block_inner(id: u64, traces: BlockTrace) -> Result<RowUsage, Error> {
log::debug!(
"ccc apply_block raw input, id: {:?}, block_trace: {:?}",
id,
c_char_to_str(block_trace)?
traces
);
let block_trace = c_char_to_vec(block_trace);
let traces = serde_json::from_slice::<BlockTrace>(&block_trace)?;

let r = panic::catch_unwind(|| {
CHECKERS
Expand Down

0 comments on commit 8bd2a16

Please sign in to comment.