Skip to content

Commit

Permalink
fixup: Fix type checking issues
Browse files Browse the repository at this point in the history
Replaced lots of `is_infinite(P)` with `P is not None` and
added lots of extra `is not None` checks where that's implied
by other checks (e.g. 0 < k < n, P = k * G) because the type-checking
system can understand `is not None`.
  • Loading branch information
robot-dreams committed Apr 5, 2022
1 parent a5b1aaa commit 9acc936
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions doc/musig-reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,32 +108,35 @@ def cbytes(P: Point) -> bytes:
a = b'\x02' if has_even_y(P) else b'\x03'
return a + bytes_from_point(P)

def point_negate(P: Point) -> Point:
if is_infinite(P):
def point_negate(P: Optional[Point]) -> Optional[Point]:
if P is None:
return P
return (x(P), p - y(P))

def pointc(x: bytes) -> Point:
P = lift_x(x[1:33])
assert P is not None
if x[0] == 2:
return P
elif x[0] == 3:
return point_negate(P)
P = point_negate(P)
assert P is not None
return P
assert False

def key_agg(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]) -> bytes:
Q, _, _ = key_agg_internal(pubkeys, tweaks, is_xonly)
return bytes_from_point(Q)

def key_agg_internal(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]) -> Point:
def key_agg_internal(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]) -> Tuple[Point, int, int]:
pk2 = get_second_key(pubkeys)
u = len(pubkeys)
Q = infinity
for i in range(u):
P_i = lift_x(pubkeys[i])
a_i = key_agg_coeff_internal(pubkeys, pubkeys[i], pk2)
Q = point_add(Q, point_mul(P_i, a_i))
assert not is_infinite(Q)
assert Q is not None
gacc = 1
tacc = 0
v = len(tweaks)
Expand Down Expand Up @@ -169,15 +172,15 @@ def apply_tweak(Q: Point, gacc: int, tacc: int, tweak_i: bytes, is_xonly_i: bool
t_i = int_from_bytes(tweak_i)
assert t_i < n
Q_i = point_add(point_mul(Q, g), point_mul(G, t_i))
assert not is_infinite(Q_i)
assert Q_i is not None
gacc_i = g * gacc % n
tacc_i = (t_i + g * tacc) % n
return Q_i, gacc_i, tacc_i

def bytes_xor(a: bytes, b: bytes) -> bytes:
return bytes(x ^ y for x, y in zip(a, b))

def nonce_hash(rand: bytes, aggpk: bytes, i: int, msg: bytes, extra_in: bytes) -> bytes:
def nonce_hash(rand: bytes, aggpk: bytes, i: int, msg: bytes, extra_in: bytes) -> int:
buf = b''
buf += rand
buf += len(aggpk).to_bytes(1, 'big')
Expand All @@ -204,6 +207,8 @@ def nonce_gen(sk: bytes, aggpk: bytes, msg: bytes, extra_in: bytes) -> Tuple[byt
assert k_2 != 0
R_1_ = point_mul(G, k_1)
R_2_ = point_mul(G, k_2)
assert R_1_ is not None
assert R_2_ is not None
pubnonce = cbytes(R_1_) + cbytes(R_2_)
secnonce = bytes_from_int(k_1) + bytes_from_int(k_2)
return secnonce, pubnonce
Expand All @@ -216,19 +221,20 @@ def nonce_agg(pubnonces: List[bytes]) -> bytes:
for j in range(u):
R_i_ = point_add(R_i_, pointc(pubnonces[j][(i-1)*33:i*33]))
R_i = R_i_ if not is_infinite(R_i_) else G
assert R_i is not None
aggnonce += cbytes(R_i)
return aggnonce

SessionContext = namedtuple('SessionContext', ['aggnonce', 'pubkeys', 'tweaks', 'is_xonly', 'msg'])

def get_session_values(session_ctx: SessionContext) -> tuple[bytes, List[bytes], bytes]:
def get_session_values(session_ctx: SessionContext) -> tuple[Point, int, int, int, Point, int]:
(aggnonce, pubkeys, tweaks, is_xonly, msg) = session_ctx
Q, gacc_v, tacc_v = key_agg_internal(pubkeys, tweaks, is_xonly)
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n
R_1 = pointc(aggnonce[0:33])
R_2 = pointc(aggnonce[33:66])
R = point_add(R_1, point_mul(R_2, b))
assert not is_infinite(R)
assert R is not None
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
return (Q, gacc_v, tacc_v, b, R, e)

Expand All @@ -248,13 +254,18 @@ def sign(secnonce: bytes, sk: bytes, session_ctx: SessionContext) -> bytes:
d_ = int_from_bytes(sk)
assert 0 < d_ < n
P = point_mul(G, d_)
assert P is not None
a = get_session_key_agg_coeff(session_ctx, P)
gp = 1 if has_even_y(P) else n - 1
g_v = 1 if has_even_y(Q) else n - 1
d = g_v * gacc_v * gp * d_ % n
s = (k_1 + b * k_2 + e * a * d) % n
psig = bytes_from_int(s)
pubnonce = cbytes(point_mul(G, k_1_)) + cbytes(point_mul(G, k_2_))
R_1_ = point_mul(G, k_1_)
R_2_ = point_mul(G, k_2_)
assert R_1_ is not None
assert R_2_ is not None
pubnonce = cbytes(R_1_) + cbytes(R_2_)
assert partial_sig_verify_internal(psig, pubnonce, bytes_from_point(P), session_ctx)
return psig

Expand All @@ -274,6 +285,7 @@ def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk_: bytes, sessio
g_v = 1 if has_even_y(Q) else n - 1
g_ = g_v * gacc_v % n
P = point_mul(lift_x(pk_), g_)
assert P is not None
a = get_session_key_agg_coeff(session_ctx, P)
return point_mul(G, s) == point_add(R_, point_mul(P, e * a % n))

Expand Down

0 comments on commit 9acc936

Please sign in to comment.