diff --git a/doc/musig-reference.py b/doc/musig-reference.py index 758574d5b..63f92bbea 100644 --- a/doc/musig-reference.py +++ b/doc/musig-reference.py @@ -16,7 +16,7 @@ # represented by the None keyword. G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) -Point = Tuple[int, int] +Point = Optional[Tuple[int, int]] # This implementation can be sped up by storing the midstate after hashing # tag_hash instead of rehashing it all the time. @@ -24,18 +24,18 @@ def tagged_hash(tag: str, msg: bytes) -> bytes: tag_hash = hashlib.sha256(tag.encode()).digest() return hashlib.sha256(tag_hash + tag_hash + msg).digest() -def is_infinite(P: Optional[Point]) -> bool: +def is_infinite(P: Point) -> bool: return P is None def x(P: Point) -> int: - assert not is_infinite(P) + assert P is not None return P[0] def y(P: Point) -> int: - assert not is_infinite(P) + assert P is not None return P[1] -def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: +def point_add(P1: Point, P2: Point) -> Point: if P1 is None: return P2 if P2 is None: @@ -49,7 +49,7 @@ def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: x3 = (lam * lam - x(P1) - x(P2)) % p return (x3, (lam * (x(P1) - x3) - y(P1)) % p) -def point_mul(P: Optional[Point], n: int) -> Optional[Point]: +def point_mul(P: Point, n: int) -> Point: R = None for i in range(256): if (n >> i) & 1: @@ -63,7 +63,7 @@ def bytes_from_int(x: int) -> bytes: def bytes_from_point(P: Point) -> bytes: return bytes_from_int(x(P)) -def lift_x(b: bytes) -> Optional[Point]: +def lift_x(b: bytes) -> Point: x = int_from_bytes(b) if x >= p: return None @@ -125,7 +125,7 @@ def key_agg(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]) -> Q, _, _ = key_agg_internal(pubkeys, tweaks, is_xonly) return bytes_from_point(Q) -def key_agg_internal(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]) -> Point: +def key_agg_internal(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]) -> Tuple[Point, int, int]: pk2 = get_second_key(pubkeys) u = len(pubkeys) Q = infinity @@ -177,7 +177,7 @@ def apply_tweak(Q: Point, gacc: int, tacc: int, tweak_i: bytes, is_xonly_i: bool def bytes_xor(a: bytes, b: bytes) -> bytes: return bytes(x ^ y for x, y in zip(a, b)) -def nonce_hash(rand: bytes, aggpk: bytes, i: int, msg: bytes, extra_in: bytes) -> bytes: +def nonce_hash(rand: bytes, aggpk: bytes, i: int, msg: bytes, extra_in: bytes) -> int: buf = b'' buf += rand buf += len(aggpk).to_bytes(1, 'big') @@ -221,7 +221,7 @@ def nonce_agg(pubnonces: List[bytes]) -> bytes: SessionContext = namedtuple('SessionContext', ['aggnonce', 'pubkeys', 'tweaks', 'is_xonly', 'msg']) -def get_session_values(session_ctx: SessionContext) -> tuple[bytes, List[bytes], bytes]: +def get_session_values(session_ctx: SessionContext) -> tuple[Point, int, int, int, Point, int]: (aggnonce, pubkeys, tweaks, is_xonly, msg) = session_ctx Q, gacc_v, tacc_v = key_agg_internal(pubkeys, tweaks, is_xonly) b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n