Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Socket.Select: increase ref count while the handle is in use #41763

Merged
merged 4 commits into from
Oct 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/System.Net.Sockets/src/System/Net/Sockets/Socket.Windows.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ internal unsafe bool TransmitPackets(SafeSocketHandle socketHandle, IntPtr packe
return transmitPackets(socketHandle, packetArray, elementCount, sendSize, overlapped, flags);
}

internal static void SocketListToFileDescriptorSet(IList socketList, Span<IntPtr> fileDescriptorSet)
internal static void SocketListToFileDescriptorSet(IList socketList, Span<IntPtr> fileDescriptorSet, ref int refsAdded)
{
int count;
if (socketList == null || (count = socketList.Count) == 0)
Expand All @@ -166,18 +166,21 @@ internal static void SocketListToFileDescriptorSet(IList socketList, Span<IntPtr
fileDescriptorSet[0] = (IntPtr)count;
for (int current = 0; current < count; current++)
{
if (!(socketList[current] is Socket))
if (!(socketList[current] is Socket socket))
{
throw new ArgumentException(SR.Format(SR.net_sockets_select, socketList[current].GetType().FullName, typeof(System.Net.Sockets.Socket).FullName), nameof(socketList));
}

fileDescriptorSet[current + 1] = ((Socket)socketList[current])._handle.DangerousGetHandle();
bool success = false;
socket.InternalSafeHandle.DangerousAddRef(ref success);
fileDescriptorSet[current + 1] = socket.InternalSafeHandle.DangerousGetHandle();
refsAdded++;
}
}

// Transform the list socketList such that the only sockets left are those
// with a file descriptor contained in the array "fileDescriptorArray".
internal static void SelectFileDescriptor(IList socketList, Span<IntPtr> fileDescriptorSet)
internal static void SelectFileDescriptor(IList socketList, Span<IntPtr> fileDescriptorSet, ref int refsAdded)
{
// Walk the list in order.
//
Expand All @@ -195,6 +198,9 @@ internal static void SelectFileDescriptor(IList socketList, Span<IntPtr> fileDes
int returnedCount = (int)fileDescriptorSet[0];
if (returnedCount == 0)
{
// Unref safehandles.
SocketListDangerousReleaseRefs(socketList, ref refsAdded);

// No socket present, will never find any socket, remove them all.
socketList.Clear();
return;
Expand All @@ -219,6 +225,8 @@ internal static void SelectFileDescriptor(IList socketList, Span<IntPtr> fileDes
if (currentFileDescriptor == returnedCount)
{
// Descriptor not found: remove the current socket and start again.
socket.InternalSafeHandle.DangerousRelease();
refsAdded--;
socketList.RemoveAt(currentSocket--);
count--;
}
Expand Down
15 changes: 15 additions & 0 deletions src/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5070,6 +5070,21 @@ private void ThrowIfDisposed()

private bool IsConnectionOriented => _socketType == SocketType.Stream;

internal static void SocketListDangerousReleaseRefs(IList socketList, ref int refsAdded)
{
if (socketList == null)
{
return;
}

for (int i = 0; (i < socketList.Count) && (refsAdded > 0); i++)
{
Socket socket = (Socket)socketList[i];
socket.InternalSafeHandle.DangerousRelease();
refsAdded--;
}
tmds marked this conversation as resolved.
Show resolved Hide resolved
}

#endregion
}
}
77 changes: 52 additions & 25 deletions src/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1444,37 +1444,58 @@ private static unsafe SocketError SelectViaPoll(
// Add each of the list's contents to the events array
Debug.Assert(eventsLength == checkReadInitialCount + checkWriteInitialCount + checkErrorInitialCount, "Invalid eventsLength");
int offset = 0;
AddToPollArray(events, eventsLength, checkRead, ref offset, Interop.Sys.PollEvents.POLLIN | Interop.Sys.PollEvents.POLLHUP);
AddToPollArray(events, eventsLength, checkWrite, ref offset, Interop.Sys.PollEvents.POLLOUT);
AddToPollArray(events, eventsLength, checkError, ref offset, Interop.Sys.PollEvents.POLLPRI);
Debug.Assert(offset == eventsLength, $"Invalid adds. offset={offset}, eventsLength={eventsLength}.");

// Do the poll
uint triggered = 0;
int milliseconds = microseconds == -1 ? -1 : microseconds / 1000;
Interop.Error err = Interop.Sys.Poll(events, (uint)eventsLength, milliseconds, &triggered);
if (err != Interop.Error.SUCCESS)
int refsAdded = 0;
try
{
return GetSocketErrorForErrorCode(err);
}
// In case we can't increase the reference count for each Socket,
// we'll unref refAdded Sockets in the finally block ordered: [checkRead, checkWrite, checkError].
AddToPollArray(events, eventsLength, checkRead, ref offset, Interop.Sys.PollEvents.POLLIN | Interop.Sys.PollEvents.POLLHUP, ref refsAdded);
AddToPollArray(events, eventsLength, checkWrite, ref offset, Interop.Sys.PollEvents.POLLOUT, ref refsAdded);
AddToPollArray(events, eventsLength, checkError, ref offset, Interop.Sys.PollEvents.POLLPRI, ref refsAdded);
Debug.Assert(offset == eventsLength, $"Invalid adds. offset={offset}, eventsLength={eventsLength}.");
Debug.Assert(refsAdded == eventsLength, $"Invalid ref adds. refsAdded={refsAdded}, eventsLength={eventsLength}.");

// Do the poll
uint triggered = 0;
int milliseconds = microseconds == -1 ? -1 : microseconds / 1000;
Interop.Error err = Interop.Sys.Poll(events, (uint)eventsLength, milliseconds, &triggered);
if (err != Interop.Error.SUCCESS)
{
return GetSocketErrorForErrorCode(err);
}

// Remove from the lists any entries which weren't set
if (triggered == 0)
{
checkRead?.Clear();
checkWrite?.Clear();
checkError?.Clear();
// Remove from the lists any entries which weren't set
if (triggered == 0)
{
Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);

checkRead?.Clear();
checkWrite?.Clear();
checkError?.Clear();
}
else
{
FilterPollList(checkRead, events, checkReadInitialCount - 1, Interop.Sys.PollEvents.POLLIN | Interop.Sys.PollEvents.POLLHUP, ref refsAdded);
FilterPollList(checkWrite, events, checkWriteInitialCount + checkReadInitialCount - 1, Interop.Sys.PollEvents.POLLOUT, ref refsAdded);
FilterPollList(checkError, events, checkErrorInitialCount + checkWriteInitialCount + checkReadInitialCount - 1, Interop.Sys.PollEvents.POLLERR | Interop.Sys.PollEvents.POLLPRI, ref refsAdded);
}

return SocketError.Success;
}
else
finally
{
FilterPollList(checkRead, events, checkReadInitialCount - 1, Interop.Sys.PollEvents.POLLIN | Interop.Sys.PollEvents.POLLHUP);
FilterPollList(checkWrite, events, checkWriteInitialCount + checkReadInitialCount - 1, Interop.Sys.PollEvents.POLLOUT);
FilterPollList(checkError, events, checkErrorInitialCount + checkWriteInitialCount + checkReadInitialCount - 1, Interop.Sys.PollEvents.POLLERR | Interop.Sys.PollEvents.POLLPRI);
// This order matches with the AddToPollArray calls
// to release only the handles that were ref'd.
Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);
Debug.Assert(refsAdded == 0);
}
return SocketError.Success;
}

private static unsafe void AddToPollArray(Interop.Sys.PollEvent* arr, int arrLength, IList socketList, ref int arrOffset, Interop.Sys.PollEvents events)
private static unsafe void AddToPollArray(Interop.Sys.PollEvent* arr, int arrLength, IList socketList, ref int arrOffset, Interop.Sys.PollEvents events, ref int refsAdded)
{
if (socketList == null)
return;
Expand All @@ -1494,12 +1515,15 @@ private static unsafe void AddToPollArray(Interop.Sys.PollEvent* arr, int arrLen
throw new ArgumentException(SR.Format(SR.net_sockets_select, socket?.GetType().FullName ?? "null", typeof(Socket).FullName), nameof(socketList));
}

bool success = false;
socket.InternalSafeHandle.DangerousAddRef(ref success);
int fd = (int)socket.InternalSafeHandle.DangerousGetHandle();
arr[arrOffset++] = new Interop.Sys.PollEvent { Events = events, FileDescriptor = fd };
refsAdded++;
}
}

private static unsafe void FilterPollList(IList socketList, Interop.Sys.PollEvent* arr, int arrEndOffset, Interop.Sys.PollEvents desiredEvents)
private static unsafe void FilterPollList(IList socketList, Interop.Sys.PollEvent* arr, int arrEndOffset, Interop.Sys.PollEvents desiredEvents, ref int refsAdded)
{
if (socketList == null)
return;
Expand All @@ -1525,6 +1549,9 @@ private static unsafe void FilterPollList(IList socketList, Interop.Sys.PollEven

if ((arr[arrEndOffset].TriggeredEvents & desiredEvents) == 0)
{
Socket socket = (Socket)socketList[i];
socket.InternalSafeHandle.DangerousRelease();
refsAdded--;
socketList.RemoveAt(i);
}
}
Expand Down
23 changes: 17 additions & 6 deletions src/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs
Original file line number Diff line number Diff line change
Expand Up @@ -814,14 +814,17 @@ bool ShouldStackAlloc(IList list, ref IntPtr[] lease, out Span<IntPtr> span)
}

IntPtr[] leaseRead = null, leaseWrite = null, leaseError = null;
int refsAdded = 0;
try
{
// In case we can't increase the reference count for each Socket,
// we'll unref refAdded Sockets in the finally block ordered: [checkRead, checkWrite, checkError].
Span<IntPtr> readfileDescriptorSet = ShouldStackAlloc(checkRead, ref leaseRead, out var tmp) ? stackalloc IntPtr[StackThreshold] : tmp;
Socket.SocketListToFileDescriptorSet(checkRead, readfileDescriptorSet);
Socket.SocketListToFileDescriptorSet(checkRead, readfileDescriptorSet, ref refsAdded);
Span<IntPtr> writefileDescriptorSet = ShouldStackAlloc(checkWrite, ref leaseWrite, out tmp) ? stackalloc IntPtr[StackThreshold] : tmp;
Socket.SocketListToFileDescriptorSet(checkWrite, writefileDescriptorSet);
Socket.SocketListToFileDescriptorSet(checkWrite, writefileDescriptorSet, ref refsAdded);
Span<IntPtr> errfileDescriptorSet = ShouldStackAlloc(checkError, ref leaseError, out tmp) ? stackalloc IntPtr[StackThreshold] : tmp;
Socket.SocketListToFileDescriptorSet(checkError, errfileDescriptorSet);
Socket.SocketListToFileDescriptorSet(checkError, errfileDescriptorSet, ref refsAdded);

// This code used to erroneously pass a non-null timeval structure containing zeroes
// to select() when the caller specified (-1) for the microseconds parameter. That
Expand Down Expand Up @@ -872,9 +875,10 @@ bool ShouldStackAlloc(IList list, ref IntPtr[] lease, out Span<IntPtr> span)
return GetLastSocketError();
}

Socket.SelectFileDescriptor(checkRead, readfileDescriptorSet);
Socket.SelectFileDescriptor(checkWrite, writefileDescriptorSet);
Socket.SelectFileDescriptor(checkError, errfileDescriptorSet);
// Remove from the lists any entries which weren't set
Socket.SelectFileDescriptor(checkRead, readfileDescriptorSet, ref refsAdded);
Socket.SelectFileDescriptor(checkWrite, writefileDescriptorSet, ref refsAdded);
Socket.SelectFileDescriptor(checkError, errfileDescriptorSet, ref refsAdded);

return SocketError.Success;
}
Expand All @@ -883,6 +887,13 @@ bool ShouldStackAlloc(IList list, ref IntPtr[] lease, out Span<IntPtr> span)
if (leaseRead != null) ArrayPool<IntPtr>.Shared.Return(leaseRead);
if (leaseWrite != null) ArrayPool<IntPtr>.Shared.Return(leaseWrite);
if (leaseError != null) ArrayPool<IntPtr>.Shared.Return(leaseError);

// This order matches with the AddToPollArray calls
// to release only the handles that were ref'd.
Socket.SocketListDangerousReleaseRefs(checkRead, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkWrite, ref refsAdded);
Socket.SocketListDangerousReleaseRefs(checkError, ref refsAdded);
Debug.Assert(refsAdded == 0);
}
}

Expand Down