Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented PING fully-featured #409

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 80 additions & 10 deletions libp2p/host/ping.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import logging
import math
import secrets
import time
from typing import Union

import trio

from libp2p.exceptions import ValidationError
from libp2p.host.host_interface import IHost
from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset
from libp2p.network.stream.net_stream_interface import INetStream
from libp2p.peer.id import ID as PeerID
Expand All @@ -14,6 +20,21 @@
logger = logging.getLogger("libp2p.host.ping")


async def handle_ping(stream: INetStream) -> None:
"""``handle_ping`` responds to incoming ping requests until one side errors
or closes the ``stream``."""
peer_id = stream.muxed_conn.peer_id

while True:
try:
should_continue = await _handle_ping(stream, peer_id)
if not should_continue:
return
except Exception:
await stream.reset()
return
Comment on lines +23 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any particular reason to move this chunk of code up here?

just curious.... at first pass on this review i figured you had changed some of the logic but it seems to be the same(!)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I'm sorry, I think this happened because at first I moved it inside PingService as a method. But at the end I left it out, and I think I moved it accidentally



async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool:
"""Return a boolean indicating if we expect more pings from the peer at
``peer_id``."""
Expand Down Expand Up @@ -45,16 +66,65 @@ async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool:
return True


async def handle_ping(stream: INetStream) -> None:
"""``handle_ping`` responds to incoming ping requests until one side errors
or closes the ``stream``."""
peer_id = stream.muxed_conn.peer_id
class PingService:
"""PingService executes pings and returns RTT in miliseconds."""

while True:
def __init__(self, host: IHost):
self._host = host

async def ping(self, peer_id: PeerID) -> int:
stream = await self._host.new_stream(peer_id, (ID,))
try:
should_continue = await _handle_ping(stream, peer_id)
if not should_continue:
return
rtt = await _ping(stream)
await _close_stream(stream)
return rtt
except Exception:
await stream.reset()
return
await _close_stream(stream)
raise

async def ping_loop(
self, peer_id: PeerID, ping_amount: Union[int, float] = math.inf
) -> "PingIterator":
stream = await self._host.new_stream(peer_id, (ID,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be helpful to leave a docstring describing why we have ping and ping_loop.

it turns out that ping is just ping_loop specialized to ping_amount == 1. which also suggest we may be able to simplify the implementation here...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, it is a good idea to write a docstring!

The reason why I thought of separating into two methods is because ping_loop is a way to use it as an intuitive for loop. However, it is not so straightforward, you must generate the iterator with ping_loop and then construct a for loop. So I created the second method ping for when you just want a quick, simple and one-line ping (maybe for a rapid keep-alive, or something else).

If am open to discussion whether it really is worth it or useful or not :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the loop is useful in case you want to add some more functionality between each ping. (or maybe it is not a good feature?)

However, I have been thinking and maybe it is a good idea to add a second argument to the ping(peer_id), which specifies the amount of pings and return a list of RTTs. Something like rtts = ping_service.ping(peer_id, amount=5).

ping_iterator = PingIterator(stream, ping_amount)
return ping_iterator


class PingIterator:
def __init__(self, stream: INetStream, ping_amount: Union[int, float]):
self._stream = stream
self._ping_limit = ping_amount
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think ping_limit is much clearer than ping_amount.

want to just use ping_limit everywhere?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it! I'll change it to ping_limit

self._ping_counter = 0

def __aiter__(self) -> "PingIterator":
return self

async def __anext__(self) -> int:
if self._ping_counter > self._ping_limit:
await _close_stream(self._stream)
raise StopAsyncIteration

self._ping_counter += 1
try:
return await _ping(self._stream)
except trio.EndOfChannel:
await _close_stream(self._stream)
raise StopAsyncIteration


async def _ping(stream: INetStream) -> int:
ping_bytes = secrets.token_bytes(PING_LENGTH)
before = int(time.time() * 10 ** 6) # convert float of seconds to int miliseconds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about just keeping the native time.time() value as a float for as long as possible (possibly just letting the consumer of the rtt convert to int if they want...)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all thank you for your comments. I learn a lot from them!

I think it is a good idea to keep it as float. Do you think it is better to leave it also as native seconds? I think I prefer to return it as a miliseconds float

await stream.write(ping_bytes)
pong_bytes = await stream.read(PING_LENGTH)
rtt = int(time.time() * 10 ** 6) - before
if ping_bytes != pong_bytes:
raise ValidationError("Invalid PING response")
return rtt


async def _close_stream(stream: INetStream) -> None:
try:
await stream.close()
except Exception:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth at least logging this exception, if not letting bubble up to some caller

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right, it is better to let bubble up to the caller

2 changes: 1 addition & 1 deletion libp2p/pubsub/gossipsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ async def heartbeat(self) -> None:
await trio.sleep(self.heartbeat_interval)

def mesh_heartbeat(
self
self,
) -> Tuple[DefaultDict[ID, List[str]], DefaultDict[ID, List[str]]]:
peers_to_graft: DefaultDict[ID, List[str]] = defaultdict(list)
peers_to_prune: DefaultDict[ID, List[str]] = defaultdict(list)
Expand Down
10 changes: 5 additions & 5 deletions libp2p/tools/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def security_transport_factory(

@asynccontextmanager
async def raw_conn_factory(
nursery: trio.Nursery
nursery: trio.Nursery,
) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]:
conn_0 = None
conn_1 = None
Expand Down Expand Up @@ -351,7 +351,7 @@ async def swarm_pair_factory(

@asynccontextmanager
async def host_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
await connect(hosts[0], hosts[1])
Expand All @@ -370,7 +370,7 @@ async def swarm_conn_pair_factory(

@asynccontextmanager
async def mplex_conn_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
Expand All @@ -382,7 +382,7 @@ async def mplex_conn_pair_factory(

@asynccontextmanager
async def mplex_stream_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
Expand All @@ -398,7 +398,7 @@ async def mplex_stream_pair_factory(

@asynccontextmanager
async def net_stream_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[INetStream, INetStream]]:
protocol_id = TProtocol("/example/id/1")

Expand Down
2 changes: 1 addition & 1 deletion libp2p/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def connect(node1: IHost, node2: IHost) -> None:


def create_echo_stream_handler(
ack_prefix: str
ack_prefix: str,
) -> Callable[[INetStream], Awaitable[None]]:
async def echo_stream_handler(stream: INetStream) -> None:
while True:
Expand Down
31 changes: 30 additions & 1 deletion tests/host/test_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import trio

from libp2p.host.ping import ID, PING_LENGTH
from libp2p.host.ping import ID, PING_LENGTH, PingService
from libp2p.tools.factories import host_pair_factory


Expand Down Expand Up @@ -36,3 +36,32 @@ async def test_ping_several(is_host_secure):
# NOTE: this interval can be `0` for this test.
await trio.sleep(0)
await stream.close()


@pytest.mark.trio
async def test_ping_service_once(is_host_secure):
async with host_pair_factory(is_host_secure) as (host_a, host_b):
ping_service = PingService(host_b)
rtt = await ping_service.ping(host_a.get_id())
assert rtt < 10 ** 6 # rtt is in miliseconds


@pytest.mark.trio
async def test_ping_service_loop(is_host_secure):
async with host_pair_factory(is_host_secure) as (host_a, host_b):
ping_service = PingService(host_b)
ping_loop = await ping_service.ping_loop(
host_a.get_id(), ping_amount=SOME_PING_COUNT
)
async for rtt in ping_loop:
assert rtt < 10 ** 6


@pytest.mark.trio
async def test_ping_service_loop_infinite(is_host_secure):
async with host_pair_factory(is_host_secure) as (host_a, host_b):
ping_service = PingService(host_b)
ping_loop = await ping_service.ping_loop(host_a.get_id())
with trio.move_on_after(1): # breaking loop after one second
async for rtt in ping_loop:
assert rtt < 10 ** 6