Skip to content

Commit

Permalink
[Android] Add UnmanagedCallersOnly attribute to SafeDeleteSslContext.…
Browse files Browse the repository at this point in the history
…ReadFromConnection/WriteToConnection methods (#69507)
  • Loading branch information
simonrozsival committed May 31, 2022
1 parent 2fbc50f commit 6010dc0
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ internal static partial class AndroidCrypto
{
private const int UNSUPPORTED_API_LEVEL = 2;

internal unsafe delegate PAL_SSLStreamStatus SSLReadCallback(byte* data, int* length);
internal unsafe delegate void SSLWriteCallback(byte* data, int length);

internal enum PAL_SSLStreamStatus
{
OK = 0,
Expand Down Expand Up @@ -52,20 +49,22 @@ ref MemoryMarshal.GetReference(pkcs8PrivateKey),
}

[LibraryImport(Interop.Libraries.AndroidCryptoNative, EntryPoint = "AndroidCryptoNative_SSLStreamInitialize")]
private static partial int SSLStreamInitializeImpl(
private static unsafe partial int SSLStreamInitializeImpl(
SafeSslHandle sslHandle,
[MarshalAs(UnmanagedType.U1)] bool isServer,
SSLReadCallback streamRead,
SSLWriteCallback streamWrite,
IntPtr managedContextHandle,
delegate* unmanaged<IntPtr, byte*, int*, PAL_SSLStreamStatus> streamRead,
delegate* unmanaged<IntPtr, byte*, int, void> streamWrite,
int appBufferSize);
internal static void SSLStreamInitialize(
internal static unsafe void SSLStreamInitialize(
SafeSslHandle sslHandle,
bool isServer,
SSLReadCallback streamRead,
SSLWriteCallback streamWrite,
IntPtr managedContextHandle,
delegate* unmanaged<IntPtr, byte*, int*, PAL_SSLStreamStatus> streamRead,
delegate* unmanaged<IntPtr, byte*, int, void> streamWrite,
int appBufferSize)
{
int ret = SSLStreamInitializeImpl(sslHandle, isServer, streamRead, streamWrite, appBufferSize);
int ret = SSLStreamInitializeImpl(sslHandle, isServer, managedContextHandle, streamRead, streamWrite, appBufferSize);
if (ret != SUCCESS)
throw new SslException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Net.Security;
using System.Runtime.InteropServices;
using System.Security.Authentication;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
Expand Down Expand Up @@ -31,8 +32,6 @@ internal sealed class SafeDeleteSslContext : SafeDeleteContext
private static readonly Lazy<SslProtocols> s_supportedSslProtocols = new Lazy<SslProtocols>(Interop.AndroidCrypto.SSLGetSupportedProtocols);

private readonly SafeSslHandle _sslContext;
private readonly Interop.AndroidCrypto.SSLReadCallback _readCallback;
private readonly Interop.AndroidCrypto.SSLWriteCallback _writeCallback;

private ArrayBuffer _inputBuffer = new ArrayBuffer(InitialBufferSize);
private ArrayBuffer _outputBuffer = new ArrayBuffer(InitialBufferSize);
Expand All @@ -46,14 +45,8 @@ public SafeDeleteSslContext(SafeFreeSslCredentials credential, SslAuthentication

try
{
unsafe
{
_readCallback = ReadFromConnection;
_writeCallback = WriteToConnection;
}

_sslContext = CreateSslContext(credential);
InitializeSslContext(_sslContext, _readCallback, _writeCallback, credential, authOptions);
InitializeSslContext(_sslContext, credential, authOptions);
}
catch (Exception ex)
{
Expand Down Expand Up @@ -81,31 +74,39 @@ protected override void Dispose(bool disposing)
base.Dispose(disposing);
}

private unsafe void WriteToConnection(byte* data, int dataLength)
[UnmanagedCallersOnly]
private static unsafe void WriteToConnection(IntPtr connection, byte* data, int dataLength)
{
SafeDeleteSslContext? context = (SafeDeleteSslContext?)GCHandle.FromIntPtr(connection).Target;
Debug.Assert(context != null);

var inputBuffer = new ReadOnlySpan<byte>(data, dataLength);

_outputBuffer.EnsureAvailableSpace(dataLength);
inputBuffer.CopyTo(_outputBuffer.AvailableSpan);
_outputBuffer.Commit(dataLength);
context._outputBuffer.EnsureAvailableSpace(dataLength);
inputBuffer.CopyTo(context._outputBuffer.AvailableSpan);
context._outputBuffer.Commit(dataLength);
}

private unsafe PAL_SSLStreamStatus ReadFromConnection(byte* data, int* dataLength)
[UnmanagedCallersOnly]
private static unsafe PAL_SSLStreamStatus ReadFromConnection(IntPtr connection, byte* data, int* dataLength)
{
SafeDeleteSslContext? context = (SafeDeleteSslContext?)GCHandle.FromIntPtr(connection).Target;
Debug.Assert(context != null);

int toRead = *dataLength;
if (toRead == 0)
return PAL_SSLStreamStatus.OK;

if (_inputBuffer.ActiveLength == 0)
if (context._inputBuffer.ActiveLength == 0)
{
*dataLength = 0;
return PAL_SSLStreamStatus.NeedData;
}

toRead = Math.Min(toRead, _inputBuffer.ActiveLength);
toRead = Math.Min(toRead, context._inputBuffer.ActiveLength);

_inputBuffer.ActiveSpan.Slice(0, toRead).CopyTo(new Span<byte>(data, toRead));
_inputBuffer.Discard(toRead);
context._inputBuffer.ActiveSpan.Slice(0, toRead).CopyTo(new Span<byte>(data, toRead));
context._inputBuffer.Discard(toRead);

*dataLength = toRead;
return PAL_SSLStreamStatus.OK;
Expand Down Expand Up @@ -198,10 +199,8 @@ private static AsymmetricAlgorithm GetPrivateKeyAlgorithm(X509Certificate2 cert,
throw new NotSupportedException(SR.net_ssl_io_no_server_cert);
}

private static void InitializeSslContext(
private unsafe void InitializeSslContext(
SafeSslHandle handle,
Interop.AndroidCrypto.SSLReadCallback readCallback,
Interop.AndroidCrypto.SSLWriteCallback writeCallback,
SafeFreeSslCredentials credential,
SslAuthenticationOptions authOptions)
{
Expand All @@ -224,7 +223,10 @@ private static void InitializeSslContext(
throw new NotImplementedException(nameof(SafeDeleteSslContext));
}

Interop.AndroidCrypto.SSLStreamInitialize(handle, isServer, readCallback, writeCallback, InitialBufferSize);
// Make sure the class instance is associated to the session and is provided
// in the Read/Write callback connection parameter
IntPtr managedContextHandle = GCHandle.ToIntPtr(GCHandle.Alloc(this, GCHandleType.Weak));
Interop.AndroidCrypto.SSLStreamInitialize(handle, isServer, managedContextHandle, &ReadFromConnection, &WriteToConnection, InitialBufferSize);

if (credential.Protocols != SslProtocols.None)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus Flush(JNIEnv* env, SSLStream* sslSt

uint8_t* dataPtr = (uint8_t*)xmalloc((size_t)bufferLimit);
(*env)->GetByteArrayRegion(env, data, 0, bufferLimit, (jbyte*)dataPtr);
sslStream->streamWriter(dataPtr, bufferLimit);
sslStream->streamWriter(sslStream->managedContextHandle, dataPtr, bufferLimit);
free(dataPtr);

IGNORE_RETURN((*env)->CallObjectMethod(env, sslStream->netOutBuffer, g_ByteBufferCompact));
Expand Down Expand Up @@ -177,7 +177,8 @@ ARGS_NON_NULL_ALL static PAL_SSLStreamStatus DoUnwrap(JNIEnv* env, SSLStream* ss
jbyteArray tmp = make_java_byte_array(env, netInBufferLimit);
uint8_t* tmpNative = (uint8_t*)xmalloc((size_t)netInBufferLimit);
int count = netInBufferLimit;
PAL_SSLStreamStatus status = sslStream->streamReader(tmpNative, &count);
// todo assert streamReader != 0 ?
PAL_SSLStreamStatus status = sslStream->streamReader(sslStream->managedContextHandle, tmpNative, &count);
if (status != SSLStreamStatus_OK)
{
(*env)->DeleteLocalRef(env, tmp);
Expand Down Expand Up @@ -424,7 +425,7 @@ SSLStream* AndroidCryptoNative_SSLStreamCreateWithCertificates(uint8_t* pkcs8Pri
}

int32_t AndroidCryptoNative_SSLStreamInitialize(
SSLStream* sslStream, bool isServer, STREAM_READER streamReader, STREAM_WRITER streamWriter, int32_t appBufferSize)
SSLStream* sslStream, bool isServer, ManagedContextHandle managedContextHandle, STREAM_READER streamReader, STREAM_WRITER streamWriter, int32_t appBufferSize)
{
abort_if_invalid_pointer_argument (sslStream);
abort_unless(sslStream->sslContext != NULL, "sslContext is NULL in SSL stream");
Expand Down Expand Up @@ -465,6 +466,7 @@ int32_t AndroidCryptoNative_SSLStreamInitialize(
sslStream->netInBuffer =
ToGRef(env, (*env)->CallStaticObjectMethod(env, g_ByteBuffer, g_ByteBufferAllocate, packetBufferSize));

sslStream->managedContextHandle = managedContextHandle;
sslStream->streamReader = streamReader;
sslStream->streamWriter = streamWriter;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

#include <pal_ssl_types.h>

typedef void (*STREAM_WRITER)(uint8_t*, int32_t);
typedef int32_t (*STREAM_READER)(uint8_t*, int32_t*);
typedef intptr_t ManagedContextHandle;
typedef void (*STREAM_WRITER)(ManagedContextHandle, uint8_t*, int32_t);
typedef int32_t (*STREAM_READER)(ManagedContextHandle, uint8_t*, int32_t*);

typedef struct SSLStream
{
Expand All @@ -20,6 +21,7 @@ typedef struct SSLStream
jobject netOutBuffer;
jobject appInBuffer;
jobject netInBuffer;
ManagedContextHandle managedContextHandle;
STREAM_READER streamReader;
STREAM_WRITER streamWriter;
} SSLStream;
Expand Down Expand Up @@ -65,7 +67,7 @@ Initialize an SSL context
Returns 1 on success, 0 otherwise
*/
PALEXPORT int32_t AndroidCryptoNative_SSLStreamInitialize(
SSLStream* sslStream, bool isServer, STREAM_READER streamReader, STREAM_WRITER streamWriter, int32_t appBufferSize);
SSLStream* sslStream, bool isServer, ManagedContextHandle managedContextHandle, STREAM_READER streamReader, STREAM_WRITER streamWriter, int32_t appBufferSize);

/*
Set target host
Expand Down

0 comments on commit 6010dc0

Please sign in to comment.