Skip to content

Commit

Permalink
bindings/go: parallelize MSM for N<32.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Jul 24, 2024
1 parent e4be9e6 commit 1a89355
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 3 deletions.
134 changes: 132 additions & 2 deletions bindings/go/blst.go
Original file line number Diff line number Diff line change
Expand Up @@ -2112,7 +2112,7 @@ func P1AffinesMult(pointsIf interface{}, scalarsIf interface{}, nbits int) *P1 {

numThreads := numThreads(0)

if numThreads < 2 || npoints < 32 {
if numThreads < 2 {
sz := int(C.blst_p1s_mult_pippenger_scratch_sizeof(C.size_t(npoints))) / 8
scratch := make([]uint64, sz)

Expand Down Expand Up @@ -2161,6 +2161,71 @@ func P1AffinesMult(pointsIf interface{}, scalarsIf interface{}, nbits int) *P1 {
return &ret
}

if npoints < 32 {
if numThreads > npoints {
numThreads = npoints
}

curItem := uint32(0)
msgs := make(chan P1, numThreads)

for tid := 0; tid < numThreads; tid++ {
go func() {
var acc P1

for {
workItem := int(atomic.AddUint32(&curItem, 1) - 1)
if workItem >= npoints {
break
}

var point *P1Affine
switch val := pointsIf.(type) {
case []*P1Affine:
point = val[workItem]
case []P1Affine:
point = &val[workItem]
case P1Affines:
point = &val[workItem]
}

var scalar *C.byte
switch val := scalarsIf.(type) {
case []byte:
scalar = (*C.byte)(&val[workItem*nbytes])
case [][]byte:
scalar = scalars[workItem]
case []Scalar:
if nbits > 248 {
scalar = &val[workItem].b[0]
} else {
scalar = scalars[workItem]
}
case []*Scalar:
scalar = scalars[workItem]
}

C.go_p1_mult_n_acc(&acc, &point.x, true,
scalar, C.size_t(nbits))
}

msgs <- acc
}()
}

ret := <-msgs
for tid := 1; tid < numThreads; tid++ {
point := <-msgs
C.blst_p1_add_or_double(&ret, &ret, &point)
}

for i := range scalars {
scalars[i] = nil
}

return &ret
}

// this is sizeof(scratch[0])
sz := int(C.blst_p1s_mult_pippenger_scratch_sizeof(0)) / 8

Expand Down Expand Up @@ -2852,7 +2917,7 @@ func P2AffinesMult(pointsIf interface{}, scalarsIf interface{}, nbits int) *P2 {

numThreads := numThreads(0)

if numThreads < 2 || npoints < 32 {
if numThreads < 2 {
sz := int(C.blst_p2s_mult_pippenger_scratch_sizeof(C.size_t(npoints))) / 8
scratch := make([]uint64, sz)

Expand Down Expand Up @@ -2901,6 +2966,71 @@ func P2AffinesMult(pointsIf interface{}, scalarsIf interface{}, nbits int) *P2 {
return &ret
}

if npoints < 32 {
if numThreads > npoints {
numThreads = npoints
}

curItem := uint32(0)
msgs := make(chan P2, numThreads)

for tid := 0; tid < numThreads; tid++ {
go func() {
var acc P2

for {
workItem := int(atomic.AddUint32(&curItem, 1) - 1)
if workItem >= npoints {
break
}

var point *P2Affine
switch val := pointsIf.(type) {
case []*P2Affine:
point = val[workItem]
case []P2Affine:
point = &val[workItem]
case P2Affines:
point = &val[workItem]
}

var scalar *C.byte
switch val := scalarsIf.(type) {
case []byte:
scalar = (*C.byte)(&val[workItem*nbytes])
case [][]byte:
scalar = scalars[workItem]
case []Scalar:
if nbits > 248 {
scalar = &val[workItem].b[0]
} else {
scalar = scalars[workItem]
}
case []*Scalar:
scalar = scalars[workItem]
}

C.go_p2_mult_n_acc(&acc, &point.x, true,
scalar, C.size_t(nbits))
}

msgs <- acc
}()
}

ret := <-msgs
for tid := 1; tid < numThreads; tid++ {
point := <-msgs
C.blst_p2_add_or_double(&ret, &ret, &point)
}

for i := range scalars {
scalars[i] = nil
}

return &ret
}

// this is sizeof(scratch[0])
sz := int(C.blst_p2s_mult_pippenger_scratch_sizeof(0)) / 8

Expand Down
7 changes: 7 additions & 0 deletions bindings/go/blst_minpk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,13 @@ func TestMultiScalarP1(t *testing.T) {
for i := range points {
points[i] = *generator.Mult(scalars[i*4:(i+1)*4])
refs[i] = *points[i].Mult(scalars[i*16:(i+1)*16], 128)
if i < 27 {
ref := P1s(refs[:i+1]).Add()
ret := P1s(points[:i+1]).Mult(scalars, 128)
if !ref.Equals(ret) {
t.Errorf("failed self-consistency multi-scalar test")
}
}
}
ref := P1s(refs).Add()
ret := P1s(points).Mult(scalars, 128)
Expand Down
7 changes: 7 additions & 0 deletions bindings/go/blst_minsig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,13 @@ func TestMultiScalarP2(t *testing.T) {
for i := range points {
points[i] = *generator.Mult(scalars[i*4:(i+1)*4])
refs[i] = *points[i].Mult(scalars[i*16:(i+1)*16], 128)
if i < 27 {
ref := P2s(refs[:i+1]).Add()
ret := P2s(points[:i+1]).Mult(scalars, 128)
if !ref.Equals(ret) {
t.Errorf("failed self-consistency multi-scalar test")
}
}
}
ref := P2s(refs).Add()
ret := P2s(points).Mult(scalars, 128)
Expand Down
67 changes: 66 additions & 1 deletion bindings/go/blst_px.tgo
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ func P1AffinesMult(pointsIf interface{}, scalarsIf interface{}, nbits int) *P1 {

numThreads := numThreads(0)

if numThreads < 2 || npoints < 32 {
if numThreads < 2 {
sz := int(C.blst_p1s_mult_pippenger_scratch_sizeof(C.size_t(npoints)))/8
scratch := make([]uint64, sz)

Expand Down Expand Up @@ -508,6 +508,71 @@ func P1AffinesMult(pointsIf interface{}, scalarsIf interface{}, nbits int) *P1 {
return &ret
}

if npoints < 32 {
if numThreads > npoints {
numThreads = npoints
}

curItem := uint32(0)
msgs := make(chan P1, numThreads)

for tid := 0; tid < numThreads; tid++ {
go func() {
var acc P1

for {
workItem := int(atomic.AddUint32(&curItem, 1) - 1)
if workItem >= npoints {
break
}

var point *P1Affine
switch val := pointsIf.(type) {
case []*P1Affine:
point = val[workItem]
case []P1Affine:
point = &val[workItem]
case P1Affines:
point = &val[workItem]
}

var scalar *C.byte
switch val := scalarsIf.(type) {
case []byte:
scalar = (*C.byte)(&val[workItem*nbytes])
case [][]byte:
scalar = scalars[workItem]
case []Scalar:
if nbits > 248 {
scalar = &val[workItem].b[0]
} else {
scalar = scalars[workItem]
}
case []*Scalar:
scalar = scalars[workItem]
}

C.go_p1_mult_n_acc(&acc, &point.x, true,
scalar, C.size_t(nbits))
}

msgs <- acc
}()
}

ret := <-msgs
for tid := 1; tid < numThreads; tid++ {
point := <- msgs
C.blst_p1_add_or_double(&ret, &ret, &point);
}

for i := range(scalars) {
scalars[i] = nil
}

return &ret
}

// this is sizeof(scratch[0])
sz := int(C.blst_p1s_mult_pippenger_scratch_sizeof(0))/8

Expand Down

0 comments on commit 1a89355

Please sign in to comment.