Skip to content

Commit

Permalink
Merge pull request #70 from 0xPolygonHermez/optimize-sm-pils/arith
Browse files Browse the repository at this point in the history
Optimize sm pils/arith
  • Loading branch information
zkronos73 committed Nov 14, 2022
2 parents 70e9e54 + 8967ffb commit afc48f8
Show file tree
Hide file tree
Showing 9 changed files with 2,141 additions and 134 deletions.
47 changes: 20 additions & 27 deletions pil/arith.pil
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ namespace Arith(%N);

pol constant BYTE2_BIT19;
pol constant SEL_BYTE2_BIT19;
pol constant GL_SIGNED_4BITS_C0;
pol constant GL_SIGNED_4BITS_C1;
pol constant GL_SIGNED_4BITS_C2;
pol constant GL_SIGNED_18BITS;
pol constant GL_SIGNED_22BITS;
pol constant CLK[32]; // 1 if CLK==0 and 0 if CLK!=0

pol commit x1[16];
Expand All @@ -34,6 +31,13 @@ namespace Arith(%N);
pol commit q1[16];
pol commit q2[16];

pol commit resultEq0;
pol commit resultEq1;
pol commit resultEq2;
resultEq0 * (1 - resultEq0) = 0;
resultEq1 * (1 - resultEq1) = 0;
resultEq2 * (1 - resultEq2) = 0;

/****
*
* LATCH POLS: x1,y1,x2,y2,x3,y3,s,q0,q1,q2
Expand Down Expand Up @@ -1903,29 +1907,18 @@ namespace Arith(%N);
selEq[2] * (1-selEq[2]) = 0;
selEq[3] * (1-selEq[3]) = 0;

pol commit carry[3];

pol commit carryL[3];
pol commit carryH[3];

carryL[0] * CLK[0] = 0;
carryL[1] * CLK[0] = 0;
carryL[2] * CLK[0] = 0;
carryH[0] * CLK[0] = 0;
carryH[1] * CLK[0] = 0;
carryH[2] * CLK[0] = 0;

carryL[0] in GL_SIGNED_18BITS;
carryL[1] in GL_SIGNED_18BITS;
carryL[2] in GL_SIGNED_18BITS;

{carryH[0], carryH[1], carryH[2]} in {GL_SIGNED_4BITS_C0, GL_SIGNED_4BITS_C1, GL_SIGNED_4BITS_C2}; // 3 * (4+1) = 15 bits
carry[0] * CLK[0] = 0;
carry[1] * CLK[0] = 0;
carry[2] * CLK[0] = 0;

// eq + carry = carry' * 2**16
// carry = cl + ch * 2**18
// eq + cl + ch * 2**18 = cl 2**16 + ch * 2**34
carry[0] in GL_SIGNED_22BITS;
carry[1] in GL_SIGNED_22BITS;
carry[2] in GL_SIGNED_22BITS;

selEq[0] * (eq0 + carryL[0] + 2**18 * carryH[0]) = selEq[0] * (carryL[0]' * 2**16 + carryH[0]' * 2**34);
selEq[1] * (eq1 + carryL[0] + 2**18 * carryH[0]) = selEq[1] * (carryL[0]' * 2**16 + carryH[0]' * 2**34);
selEq[2] * (eq2 + carryL[0] + 2**18 * carryH[0]) = selEq[2] * (carryL[0]' * 2**16 + carryH[0]' * 2**34);
selEq[3] * (eq3 + carryL[1] + 2**18 * carryH[1]) = selEq[3] * (carryL[1]' * 2**16 + carryH[1]' * 2**34);
selEq[3] * (eq4 + carryL[2] + 2**18 * carryH[2]) = selEq[3] * (carryL[2]' * 2**16 + carryH[2]' * 2**34);
selEq[0] * (eq0 + carry[0]) = selEq[0] * (carry[0]' * 2**16);
selEq[1] * (eq1 + carry[0]) = selEq[1] * (carry[0]' * 2**16);
selEq[2] * (eq2 + carry[0]) = selEq[2] * (carry[0]' * 2**16);
selEq[3] * (eq3 + carry[1]) = selEq[3] * (carry[1]' * 2**16);
selEq[3] * (eq4 + carry[2]) = selEq[3] * (carry[2]' * 2**16);
181 changes: 101 additions & 80 deletions pil/main.pil
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ namespace Main(%N);
// operations
pol commit mOp, mWR;
pol commit sWR, sRD;
pol commit arith;
pol commit arithEq0, arithEq1, arithEq2, arithEq3;
pol commit arithEq0, arithEq1, arithEq2;
pol commit memAlign, memAlignWR, memAlignWR8;
pol commit hashK, hashKLen, hashKDigest;
pol commit hashP, hashPLen, hashPDigest;
Expand Down Expand Up @@ -343,15 +342,15 @@ namespace Main(%N);
setA, setB, setC, setD, setE, setSR, setCTX, setSP, setPC, setGAS, setMAXMEM, setRR, setHASHPOS,
JMP, JMPN, JMPC,
isStack, isCode, isMem, incStack, incCode, useCTX, ind, indRR,
mOp, mWR, sWR, sRD, arith, arithEq0, arithEq1, arithEq2, arithEq3, hashK, hashKLen, hashKDigest, hashP, hashPLen, hashPDigest,
mOp, mWR, sWR, sRD, arithEq0, arithEq1, arithEq2, hashK, hashKLen, hashKDigest, hashP, hashPLen, hashPDigest,
bin, binOpcode, assert, memAlign, memAlignWR, memAlignWR8,
zkPC} in
{Rom.CONST0, Rom.CONST1, Rom.CONST2, Rom.CONST3, Rom.CONST4, Rom.CONST5, Rom.CONST6, Rom.CONST7, Rom.offset,
Rom.inA, Rom.inB, Rom.inC, Rom.inROTL_C, Rom.inD, Rom.inE, Rom.inSR, Rom.inCTX, Rom.inSP, Rom.inPC, Rom.inGAS, Rom.inMAXMEM, Rom.inSTEP, Rom.inFREE, Rom.inRR, Rom.inHASHPOS,
Rom.setA, Rom.setB, Rom.setC, Rom.setD, Rom.setE, Rom.setSR, Rom.setCTX, Rom.setSP, Rom.setPC, Rom.setGAS, Rom.setMAXMEM, Rom.setRR, Rom.setHASHPOS,
Rom.JMP, Rom.JMPN, Rom.JMPC,
Rom.isStack, Rom.isCode, Rom.isMem, Rom.incStack, Rom.incCode, Rom.useCTX, Rom.ind, Rom.indRR,
Rom.mOp, Rom.mWR, Rom.sWR, Rom.sRD, Rom.arith, Rom.arithEq0, Rom.arithEq1, Rom.arithEq2, Rom.arithEq3, Rom.hashK, Rom.hashKLen, Rom.hashKDigest, Rom.hashP, Rom.hashPLen, Rom.hashPDigest,
Rom.mOp, Rom.mWR, Rom.sWR, Rom.sRD, Rom.arithEq0, Rom.arithEq1, Rom.arithEq2, Rom.hashK, Rom.hashKLen, Rom.hashKDigest, Rom.hashP, Rom.hashPLen, Rom.hashPDigest,
Rom.bin, Rom.binOpcode, Rom.assert, Rom.memAlign, Rom.memAlignWR, Rom.memAlignWR8,
Rom.line};

Expand All @@ -361,89 +360,111 @@ namespace Main(%N);
/////////
// Arithmetic Plookpups
/////////
arith { arithEq0, arithEq1, arithEq2, arithEq3,

pol ax1_0 = Arith.x1[0] + Arith.x1[1]*2**16;
pol ax1_1 = Arith.x1[2] + Arith.x1[3]*2**16;
pol ax1_2 = Arith.x1[4] + Arith.x1[5]*2**16;
pol ax1_3 = Arith.x1[6] + Arith.x1[7]*2**16;
pol ax1_4 = Arith.x1[8] + Arith.x1[9]*2**16;
pol ax1_5 = Arith.x1[10] + Arith.x1[11]*2**16;
pol ax1_6 = Arith.x1[12] + Arith.x1[13]*2**16;
pol ax1_7 = Arith.x1[14] + Arith.x1[15]*2**16;

pol ay1_0 = Arith.y1[0] + Arith.y1[1]*2**16;
pol ay1_1 = Arith.y1[2] + Arith.y1[3]*2**16;
pol ay1_2 = Arith.y1[4] + Arith.y1[5]*2**16;
pol ay1_3 = Arith.y1[6] + Arith.y1[7]*2**16;
pol ay1_4 = Arith.y1[8] + Arith.y1[9]*2**16;
pol ay1_5 = Arith.y1[10] + Arith.y1[11]*2**16;
pol ay1_6 = Arith.y1[12] + Arith.y1[13]*2**16;
pol ay1_7 = Arith.y1[14] + Arith.y1[15]*2**16;

pol ax2_0 = Arith.x2[0] + Arith.x2[1]*2**16;
pol ax2_1 = Arith.x2[2] + Arith.x2[3]*2**16;
pol ax2_2 = Arith.x2[4] + Arith.x2[5]*2**16;
pol ax2_3 = Arith.x2[6] + Arith.x2[7]*2**16;
pol ax2_4 = Arith.x2[8] + Arith.x2[9]*2**16;
pol ax2_5 = Arith.x2[10] + Arith.x2[11]*2**16;
pol ax2_6 = Arith.x2[12] + Arith.x2[13]*2**16;
pol ax2_7 = Arith.x2[14] + Arith.x2[15]*2**16;

pol ay2_0 = Arith.y2[0] + Arith.y2[1]*2**16;
pol ay2_1 = Arith.y2[2] + Arith.y2[3]*2**16;
pol ay2_2 = Arith.y2[4] + Arith.y2[5]*2**16;
pol ay2_3 = Arith.y2[6] + Arith.y2[7]*2**16;
pol ay2_4 = Arith.y2[8] + Arith.y2[9]*2**16;
pol ay2_5 = Arith.y2[10] + Arith.y2[11]*2**16;
pol ay2_6 = Arith.y2[12] + Arith.y2[13]*2**16;
pol ay2_7 = Arith.y2[14] + Arith.y2[15]*2**16;

pol ax3_0 = Arith.x3[0] + Arith.x3[1]*2**16;
pol ax3_1 = Arith.x3[2] + Arith.x3[3]*2**16;
pol ax3_2 = Arith.x3[4] + Arith.x3[5]*2**16;
pol ax3_3 = Arith.x3[6] + Arith.x3[7]*2**16;
pol ax3_4 = Arith.x3[8] + Arith.x3[9]*2**16;
pol ax3_5 = Arith.x3[10] + Arith.x3[11]*2**16;
pol ax3_6 = Arith.x3[12] + Arith.x3[13]*2**16;
pol ax3_7 = Arith.x3[14] + Arith.x3[15]*2**16;

pol ay3_0 = Arith.y3[0] + Arith.y3[1]*2**16;
pol ay3_1 = Arith.y3[2] + Arith.y3[3]*2**16;
pol ay3_2 = Arith.y3[4] + Arith.y3[5]*2**16;
pol ay3_3 = Arith.y3[6] + Arith.y3[7]*2**16;
pol ay3_4 = Arith.y3[8] + Arith.y3[9]*2**16;
pol ay3_5 = Arith.y3[10] + Arith.y3[11]*2**16;
pol ay3_6 = Arith.y3[12] + Arith.y3[13]*2**16;
pol ay3_7 = Arith.y3[14] + Arith.y3[15]*2**16;

arithEq0 { 1, 0, 0, 0,
A0, A1, A2, A3, A4, A5, A6, A7,
B0, B1, B2, B3, B4, B5, B6, B7,
C0, C1, C2, C3, C4, C5, C6, C7,
D0, D1, D2, D3, D4, D5, D6, D7,
op0, op1, op2, op3, op4, op5, op6, op7 } is
Arith.resultEq0 {
Arith.selEq[0], Arith.selEq[1], Arith.selEq[2], Arith.selEq[3],
ax1_0, ax1_1, ax1_2, ax1_3, ax1_4, ax1_5, ax1_6, ax1_7,
ay1_0, ay1_1, ay1_2, ay1_3, ay1_4, ay1_5, ay1_6, ay1_7,
ax2_0, ax2_1, ax2_2, ax2_3, ax2_4, ax2_5, ax2_6, ax2_7,
ay2_0, ay2_1, ay2_2, ay2_3, ay2_4, ay2_5, ay2_6, ay2_7,
ay3_0, ay3_1, ay3_2, ay3_3, ay3_4, ay3_5, ay3_6, ay3_7
};

arithEq0*C0 + arithEq1*C0 + arithEq2*A0,
arithEq0*C1 + arithEq1*C1 + arithEq2*A1,
arithEq0*C2 + arithEq1*C2 + arithEq2*A2,
arithEq0*C3 + arithEq1*C3 + arithEq2*A3,
arithEq0*C4 + arithEq1*C4 + arithEq2*A4,
arithEq0*C5 + arithEq1*C5 + arithEq2*A5,
arithEq0*C6 + arithEq1*C6 + arithEq2*A6,
arithEq0*C7 + arithEq1*C7 + arithEq2*A7,

arithEq0*D0 + arithEq1*D0 + arithEq2*B0,
arithEq0*D1 + arithEq1*D1 + arithEq2*B1,
arithEq0*D2 + arithEq1*D2 + arithEq2*B2,
arithEq0*D3 + arithEq1*D3 + arithEq2*B3,
arithEq0*D4 + arithEq1*D4 + arithEq2*B4,
arithEq0*D5 + arithEq1*D5 + arithEq2*B5,
arithEq0*D6 + arithEq1*D6 + arithEq2*B6,
arithEq0*D7 + arithEq1*D7 + arithEq2*B7,

arithEq3 * E0, arithEq3 * E1, arithEq3 * E2, arithEq3 * E3, arithEq3 * E4, arithEq3 * E5, arithEq3 * E6, arithEq3 * E7,
op0, op1, op2, op3, op4, op5, op6, op7 } in
{
arithEq1 { 0, 1, 0, 1,
A0, A1, A2, A3, A4, A5, A6, A7,
B0, B1, B2, B3, B4, B5, B6, B7,
C0, C1, C2, C3, C4, C5, C6, C7,
D0, D1, D2, D3, D4, D5, D6, D7,
E0, E1, E2, E3, E4, E5, E6, E7,
op0, op1, op2, op3, op4, op5, op6, op7 } is
Arith.resultEq1 {
Arith.selEq[0], Arith.selEq[1], Arith.selEq[2], Arith.selEq[3],
ax1_0, ax1_1, ax1_2, ax1_3, ax1_4, ax1_5, ax1_6, ax1_7,
ay1_0, ay1_1, ay1_2, ay1_3, ay1_4, ay1_5, ay1_6, ay1_7,
ax2_0, ax2_1, ax2_2, ax2_3, ax2_4, ax2_5, ax2_6, ax2_7,
ay2_0, ay2_1, ay2_2, ay2_3, ay2_4, ay2_5, ay2_6, ay2_7,
ax3_0, ax3_1, ax3_2, ax3_3, ax3_4, ax3_5, ax3_6, ax3_7,
ay3_0, ay3_1, ay3_2, ay3_3, ay3_4, ay3_5, ay3_6, ay3_7
};

Arith.x1[0] + Arith.x1[1]*2**16,
Arith.x1[2] + Arith.x1[3]*2**16,
Arith.x1[4] + Arith.x1[5]*2**16,
Arith.x1[6] + Arith.x1[7]*2**16,
Arith.x1[8] + Arith.x1[9]*2**16,
Arith.x1[10] + Arith.x1[11]*2**16,
Arith.x1[12] + Arith.x1[13]*2**16,
Arith.x1[14] + Arith.x1[15]*2**16,

Arith.y1[0] + Arith.y1[1]*2**16,
Arith.y1[2] + Arith.y1[3]*2**16,
Arith.y1[4] + Arith.y1[5]*2**16,
Arith.y1[6] + Arith.y1[7]*2**16,
Arith.y1[8] + Arith.y1[9]*2**16,
Arith.y1[10] + Arith.y1[11]*2**16,
Arith.y1[12] + Arith.y1[13]*2**16,
Arith.y1[14] + Arith.y1[15]*2**16,

Arith.x2[0] + Arith.x2[1]*2**16,
Arith.x2[2] + Arith.x2[3]*2**16,
Arith.x2[4] + Arith.x2[5]*2**16,
Arith.x2[6] + Arith.x2[7]*2**16,
Arith.x2[8] + Arith.x2[9]*2**16,
Arith.x2[10] + Arith.x2[11]*2**16,
Arith.x2[12] + Arith.x2[13]*2**16,
Arith.x2[14] + Arith.x2[15]*2**16,

Arith.y2[0] + Arith.y2[1]*2**16,
Arith.y2[2] + Arith.y2[3]*2**16,
Arith.y2[4] + Arith.y2[5]*2**16,
Arith.y2[6] + Arith.y2[7]*2**16,
Arith.y2[8] + Arith.y2[9]*2**16,
Arith.y2[10] + Arith.y2[11]*2**16,
Arith.y2[12] + Arith.y2[13]*2**16,
Arith.y2[14] + Arith.y2[15]*2**16,

Arith.x3[0] + Arith.x3[1]*2**16,
Arith.x3[2] + Arith.x3[3]*2**16,
Arith.x3[4] + Arith.x3[5]*2**16,
Arith.x3[6] + Arith.x3[7]*2**16,
Arith.x3[8] + Arith.x3[9]*2**16,
Arith.x3[10] + Arith.x3[11]*2**16,
Arith.x3[12] + Arith.x3[13]*2**16,
Arith.x3[14] + Arith.x3[15]*2**16,

Arith.y3[0] + Arith.y3[1]*2**16,
Arith.y3[2] + Arith.y3[3]*2**16,
Arith.y3[4] + Arith.y3[5]*2**16,
Arith.y3[6] + Arith.y3[7]*2**16,
Arith.y3[8] + Arith.y3[9]*2**16,
Arith.y3[10] + Arith.y3[11]*2**16,
Arith.y3[12] + Arith.y3[13]*2**16,
Arith.y3[14] + Arith.y3[15]*2**16
arithEq2 { 0, 0, 1, 1,
A0, A1, A2, A3, A4, A5, A6, A7,
B0, B1, B2, B3, B4, B5, B6, B7,
A0, A1, A2, A3, A4, A5, A6, A7,
B0, B1, B2, B3, B4, B5, B6, B7,
E0, E1, E2, E3, E4, E5, E6, E7,
op0, op1, op2, op3, op4, op5, op6, op7 } is
Arith.resultEq2 {
Arith.selEq[0], Arith.selEq[1], Arith.selEq[2], Arith.selEq[3],
ax1_0, ax1_1, ax1_2, ax1_3, ax1_4, ax1_5, ax1_6, ax1_7,
ay1_0, ay1_1, ay1_2, ay1_3, ay1_4, ay1_5, ay1_6, ay1_7,
ax2_0, ax2_1, ax2_2, ax2_3, ax2_4, ax2_5, ax2_6, ax2_7,
ay2_0, ay2_1, ay2_2, ay2_3, ay2_4, ay2_5, ay2_6, ay2_7,
ax3_0, ax3_1, ax3_2, ax3_3, ax3_4, ax3_5, ax3_6, ax3_7,
ay3_0, ay3_1, ay3_2, ay3_3, ay3_4, ay3_5, ay3_6, ay3_7
};

cntArith' = cntArith*(1-Global.L1) + arith;
cntArith' = cntArith*(1-Global.L1) + arithEq0 + arithEq1 + arithEq2;

/////////
// Binary Plookpups
Expand Down
2 changes: 0 additions & 2 deletions pil/rom.pil
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@ namespace Rom(%N);
pol constant useCTX;
pol constant mOp, mWR;
pol constant sWR, sRD;
pol constant arith;
pol constant arithEq0;
pol constant arithEq1;
pol constant arithEq2;
pol constant arithEq3;
pol constant memAlign, memAlignWR, memAlignWR8;
pol constant hashK, hashKLen, hashKDigest;
pol constant hashP, hashPLen, hashPDigest;
Expand Down
19 changes: 9 additions & 10 deletions src/sm/sm_arith/sm_arith.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ module.exports.buildConstants = async function (pols) {

buildClocks(pols, N, 32);
buildByte2Bits16(pols, N);
buildRange(pols, N, 'GL_SIGNED_4BITS_C0', -16n, 16n);
buildRange(pols, N, 'GL_SIGNED_4BITS_C1', -16n, 16n, 33);
buildRange(pols, N, 'GL_SIGNED_4BITS_C2', -16n, 16n, 33*33);
buildRange(pols, N, 'GL_SIGNED_18BITS', -(2n**18n), (2n**18n));
buildRange(pols, N, 'GL_SIGNED_22BITS', -(2n**22n), (2n**22n)-1n);
}

function buildByte2Bits16(pols, N) {
Expand Down Expand Up @@ -85,11 +82,12 @@ module.exports.execute = async function (pols, input) {
pols.q1[j][i] = 0n;
pols.q2[j][i] = 0n;
pols.s[j][i] = 0n;
if (j < pols.carryL.length) pols.carryL[j][i] = 0n;
if (j < pols.carryH.length) pols.carryH[j][i] = 0n;
if (j < pols.carry.length) pols.carry[j][i] = 0n;
if (j < pols.selEq.length) pols.selEq[j][i] = 0n;
}
pols.resultReady[i] = 0n;
pols.resultEq0[i] = 0n;
pols.resultEq1[i] = 0n;
pols.resultEq2[i] = 0n;
}
let s, q0, q1, q2;
for (let i = 0; i < input.length; i++) {
Expand Down Expand Up @@ -183,12 +181,13 @@ module.exports.execute = async function (pols, input) {
eqIndexes.forEach((eqIndex) => {
let carryIndex = eqIndexToCarryIndex[eqIndex];
eq[eqIndex] = eqCalculates[eqIndex](pols, step, offset);
pols.carryL[carryIndex][offset + step] = Fr.e((carry[carryIndex]) % (2n**18n));
pols.carryH[carryIndex][offset + step] = Fr.e((carry[carryIndex]) / (2n**18n));
pols.carry[carryIndex][offset + step] = Fr.e(carry[carryIndex]);
carry[carryIndex] = (eq[eqIndex] + carry[carryIndex]) / (2n ** 16n);
});
}
pols.resultReady[offset + 31] = 1n;
pols.resultEq0[offset + 31] = pols.selEq[0][offset] ? 1n : 0n;
pols.resultEq1[offset + 31] = pols.selEq[1][offset] ? 1n : 0n;
pols.resultEq2[offset + 31] = pols.selEq[2][offset] ? 1n : 0n;
}
}

Expand Down
Loading

0 comments on commit afc48f8

Please sign in to comment.