From 9acc93617ebf13375e5541ec4351e9848261ba15 Mon Sep 17 00:00:00 2001 From: Elliott Jin Date: Tue, 5 Apr 2022 16:57:45 -0400 Subject: [PATCH] fixup: Fix type checking issues Replaced lots of `is_infinite(P)` with `P is not None` and added lots of extra `is not None` checks where that's implied by other checks (e.g. 0 < k < n, P = k * G) because the type-checking system can understand `is not None`. --- doc/musig-reference.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/doc/musig-reference.py b/doc/musig-reference.py index 758574d5b..710cfee9d 100644 --- a/doc/musig-reference.py +++ b/doc/musig-reference.py @@ -108,24 +108,27 @@ def cbytes(P: Point) -> bytes: a = b'\x02' if has_even_y(P) else b'\x03' return a + bytes_from_point(P) -def point_negate(P: Point) -> Point: - if is_infinite(P): +def point_negate(P: Optional[Point]) -> Optional[Point]: + if P is None: return P return (x(P), p - y(P)) def pointc(x: bytes) -> Point: P = lift_x(x[1:33]) + assert P is not None if x[0] == 2: return P elif x[0] == 3: - return point_negate(P) + P = point_negate(P) + assert P is not None + return P assert False def key_agg(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]) -> bytes: Q, _, _ = key_agg_internal(pubkeys, tweaks, is_xonly) return bytes_from_point(Q) -def key_agg_internal(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]) -> Point: +def key_agg_internal(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]) -> Tuple[Point, int, int]: pk2 = get_second_key(pubkeys) u = len(pubkeys) Q = infinity @@ -133,7 +136,7 @@ def key_agg_internal(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[b P_i = lift_x(pubkeys[i]) a_i = key_agg_coeff_internal(pubkeys, pubkeys[i], pk2) Q = point_add(Q, point_mul(P_i, a_i)) - assert not is_infinite(Q) + assert Q is not None gacc = 1 tacc = 0 v = len(tweaks) @@ -169,7 +172,7 @@ def apply_tweak(Q: Point, gacc: int, tacc: int, tweak_i: bytes, is_xonly_i: bool t_i = int_from_bytes(tweak_i) assert t_i < n Q_i = point_add(point_mul(Q, g), point_mul(G, t_i)) - assert not is_infinite(Q_i) + assert Q_i is not None gacc_i = g * gacc % n tacc_i = (t_i + g * tacc) % n return Q_i, gacc_i, tacc_i @@ -177,7 +180,7 @@ def apply_tweak(Q: Point, gacc: int, tacc: int, tweak_i: bytes, is_xonly_i: bool def bytes_xor(a: bytes, b: bytes) -> bytes: return bytes(x ^ y for x, y in zip(a, b)) -def nonce_hash(rand: bytes, aggpk: bytes, i: int, msg: bytes, extra_in: bytes) -> bytes: +def nonce_hash(rand: bytes, aggpk: bytes, i: int, msg: bytes, extra_in: bytes) -> int: buf = b'' buf += rand buf += len(aggpk).to_bytes(1, 'big') @@ -204,6 +207,8 @@ def nonce_gen(sk: bytes, aggpk: bytes, msg: bytes, extra_in: bytes) -> Tuple[byt assert k_2 != 0 R_1_ = point_mul(G, k_1) R_2_ = point_mul(G, k_2) + assert R_1_ is not None + assert R_2_ is not None pubnonce = cbytes(R_1_) + cbytes(R_2_) secnonce = bytes_from_int(k_1) + bytes_from_int(k_2) return secnonce, pubnonce @@ -216,19 +221,20 @@ def nonce_agg(pubnonces: List[bytes]) -> bytes: for j in range(u): R_i_ = point_add(R_i_, pointc(pubnonces[j][(i-1)*33:i*33])) R_i = R_i_ if not is_infinite(R_i_) else G + assert R_i is not None aggnonce += cbytes(R_i) return aggnonce SessionContext = namedtuple('SessionContext', ['aggnonce', 'pubkeys', 'tweaks', 'is_xonly', 'msg']) -def get_session_values(session_ctx: SessionContext) -> tuple[bytes, List[bytes], bytes]: +def get_session_values(session_ctx: SessionContext) -> tuple[Point, int, int, int, Point, int]: (aggnonce, pubkeys, tweaks, is_xonly, msg) = session_ctx Q, gacc_v, tacc_v = key_agg_internal(pubkeys, tweaks, is_xonly) b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n R_1 = pointc(aggnonce[0:33]) R_2 = pointc(aggnonce[33:66]) R = point_add(R_1, point_mul(R_2, b)) - assert not is_infinite(R) + assert R is not None e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n return (Q, gacc_v, tacc_v, b, R, e) @@ -248,13 +254,18 @@ def sign(secnonce: bytes, sk: bytes, session_ctx: SessionContext) -> bytes: d_ = int_from_bytes(sk) assert 0 < d_ < n P = point_mul(G, d_) + assert P is not None a = get_session_key_agg_coeff(session_ctx, P) gp = 1 if has_even_y(P) else n - 1 g_v = 1 if has_even_y(Q) else n - 1 d = g_v * gacc_v * gp * d_ % n s = (k_1 + b * k_2 + e * a * d) % n psig = bytes_from_int(s) - pubnonce = cbytes(point_mul(G, k_1_)) + cbytes(point_mul(G, k_2_)) + R_1_ = point_mul(G, k_1_) + R_2_ = point_mul(G, k_2_) + assert R_1_ is not None + assert R_2_ is not None + pubnonce = cbytes(R_1_) + cbytes(R_2_) assert partial_sig_verify_internal(psig, pubnonce, bytes_from_point(P), session_ctx) return psig @@ -274,6 +285,7 @@ def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk_: bytes, sessio g_v = 1 if has_even_y(Q) else n - 1 g_ = g_v * gacc_v % n P = point_mul(lift_x(pk_), g_) + assert P is not None a = get_session_key_agg_coeff(session_ctx, P) return point_mul(G, s) == point_add(R_, point_mul(P, e * a % n))