Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing the memoryview issues #2926

Merged
merged 8 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions nvflare/fuel/f3/streaming/blob_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,17 @@ def _read_stream(blob_task: BlobTask):
length = len(buf)
try:
if blob_task.pre_allocated:
blob_task.buffer[buf_size : buf_size + length] = buf
remaining = len(blob_task.buffer) - buf_size
if length > remaining:
log.error(f"Buffer overrun: {remaining=} {length=} {buf_size=}")
if remaining > 0:
blob_task.buffer[buf_size : buf_size + remaining] = buf[0:remaining]
nvidianz marked this conversation as resolved.
Show resolved Hide resolved
else:
blob_task.buffer[buf_size : buf_size + length] = buf
else:
blob_task.buffer.append(buf)
except Exception as ex:
log.error(
f"memory view error: {ex} "
f"Debug info: {length=} {buf_size=} {len(blob_task.pre_allocated)=} {type(buf)=}"
)
log.error(f"memory view error: {ex} Debug info: {length=} {buf_size=} {type(buf)=}")
raise ex

buf_size += length
Expand Down
55 changes: 39 additions & 16 deletions nvflare/fuel/f3/streaming/byte_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import threading
from collections import deque
from typing import Callable, Dict, Tuple
from typing import Callable, Dict, Optional, Tuple

from nvflare.fuel.f3.cellnet.core_cell import CoreCell
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey
Expand All @@ -41,6 +41,9 @@
ACK_INTERVAL = 1024 * 1024 * 4
READ_TIMEOUT = 300
COUNTER_NAME_RECEIVED = "received"
RESULT_DATA = 0
RESULT_WAIT = 1
RESULT_EOS = 2


class RxTask:
Expand Down Expand Up @@ -78,30 +81,44 @@ def __init__(self, byte_receiver: "ByteReceiver", task: RxTask):
super().__init__(task.size, task.headers)
self.byte_receiver = byte_receiver
self.task = task
self.timeout = CommConfigurator().get_streaming_read_timeout(READ_TIMEOUT)
self.ack_interval = CommConfigurator().get_streaming_ack_interval(ACK_INTERVAL)

def read(self, chunk_size: int) -> bytes:
if self.closed:
raise StreamError("Read from closed stream")

if (not self.task.buffers) and self.task.eos:
return EOS

# Block if buffers are empty
count = 0
while not self.task.buffers:
while True:
result_code, result = self._read_chunk(chunk_size)
if result_code == RESULT_EOS:
return EOS
elif result_code == RESULT_DATA:
return result

# Block if buffers are empty
if count > 0:
log.debug(f"Read block is unblocked multiple times: {count}")
log.warning(f"Read block is unblocked multiple times: {count}")

self.task.waiter.clear()
timeout = CommConfigurator().get_streaming_read_timeout(READ_TIMEOUT)
if not self.task.waiter.wait(timeout):
error = StreamError(f"{self.task} read timed out after {timeout} seconds")

if not self.task.waiter.wait(self.timeout):
error = StreamError(f"{self.task} read timed out after {self.timeout} seconds")
self.byte_receiver.stop_task(self.task, error)
raise error

count += 1

def _read_chunk(self, chunk_size: int) -> Tuple[int, Optional[BytesAlike]]:

with self.task.task_lock:

if not self.task.buffers:
if self.task.eos:
return RESULT_EOS, None
else:
return RESULT_WAIT, None

last_chunk, buf = self.task.buffers.popleft()
if buf is None:
buf = bytes(0)
Expand All @@ -117,8 +134,7 @@ def read(self, chunk_size: int) -> bytes:

self.task.offset += len(result)

ack_interval = CommConfigurator().get_streaming_ack_interval(ACK_INTERVAL)
if not self.task.last_chunk_received and (self.task.offset - self.task.offset_ack > ack_interval):
if not self.task.last_chunk_received and (self.task.offset - self.task.offset_ack > self.ack_interval):
# Send ACK
message = Message()
message.add_headers(
Expand All @@ -133,7 +149,7 @@ def read(self, chunk_size: int) -> bytes:

self.task.stream_future.set_progress(self.task.offset)

return result
return RESULT_DATA, result

def close(self):
if not self.task.stream_future.done():
Expand All @@ -148,6 +164,7 @@ def __init__(self, cell: CoreCell):
self.registry = Registry()
self.rx_task_map = {}
self.map_lock = threading.Lock()
self.max_out_seq = CommConfigurator().get_streaming_max_out_seq_chunks(MAX_OUT_SEQ_CHUNKS)

self.received_stream_counter_pool = StatsPoolManager.add_counter_pool(
name="Received_Stream_Counters",
Expand Down Expand Up @@ -254,6 +271,10 @@ def _data_handler(self, message: Message):
if last_chunk:
task.last_chunk_received = True

if seq < task.next_seq:
log.warning(f"{task} Duplicate chunk ignored {seq=}")
return

if seq == task.next_seq:
self._append(task, (last_chunk, payload))
task.next_seq += 1
Expand All @@ -266,8 +287,7 @@ def _data_handler(self, message: Message):

else:
# Out-of-seq chunk reassembly
max_out_seq = CommConfigurator().get_streaming_max_out_seq_chunks(MAX_OUT_SEQ_CHUNKS)
if len(task.out_seq_buffers) >= max_out_seq:
if len(task.out_seq_buffers) >= self.max_out_seq:
self.stop_task(task, StreamError(f"Too many out-of-sequence chunks: {len(task.out_seq_buffers)}"))
return
else:
Expand All @@ -294,7 +314,10 @@ def _append(task: RxTask, buf: Tuple[bool, BytesAlike]):
if not buf:
return

task.buffers.append(buf)
if task.eos:
log.error(f"{task} Data after EOS is ignored")
else:
task.buffers.append(buf)

# Wake up blocking read()
if not task.waiter.is_set():
Expand Down
Loading