Skip to content

Commit

Permalink
Base 64 decoder, reject input when unused bits are not 0 (#105262)
Browse files Browse the repository at this point in the history
* Base 64 decoder, reject input when unused bits are not 0

* Update invalid json test value, fix issue in Base64Url validation

* Apply feedbacks

* Update comments

* Apply suggestions from code review

Co-authored-by: Stephen Toub <stoub@microsoft.com>

---------

Co-authored-by: Stephen Toub <stoub@microsoft.com>
  • Loading branch information
buyaa-n and stephentoub committed Jul 26, 2024
1 parent 99c9f5b commit eb765b7
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 72 deletions.
37 changes: 27 additions & 10 deletions src/libraries/System.Memory/tests/Base64/Base64DecoderUnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ public void DecodingOutputTooSmall()
{
Span<byte> source = new byte[12];
Base64TestHelper.InitializeDecodableBytes(source);
source[9] = 65; // make sure unused bits set to 0
source[10] = Base64TestHelper.EncodingPad;
source[11] = Base64TestHelper.EncodingPad;

Expand All @@ -193,6 +194,7 @@ public void DecodingOutputTooSmall()
{
Span<byte> source = new byte[12];
Base64TestHelper.InitializeDecodableBytes(source);
source[10] = 77; // make sure unused bits set to 0
source[11] = Base64TestHelper.EncodingPad;

Span<byte> decodedBytes = new byte[7];
Expand Down Expand Up @@ -270,6 +272,23 @@ public void BasicDecodingWithFinalBlockTrueKnownInputDone(string inputString, in
Assert.True(Base64TestHelper.VerifyDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes));
}

[Theory]
[InlineData("AR==")]
[InlineData("AQJ=")]
[InlineData("AQIDBB==")]
[InlineData("AQIDBAV=")]
[InlineData("AQIDBAUHCAkKCwwNDz==")]
[InlineData("AQIDBAUHCAkKCwwNDxD=")]
public void BasicDecodingWithNonZeroUnusedBits(string inputString)
{
Span<byte> source = Encoding.ASCII.GetBytes(inputString);
Span<byte> decodedBytes = new byte[Base64.GetMaxDecodedFromUtf8Length(source.Length)];

Assert.False(Base64.IsValid(inputString));
Assert.Equal(OperationStatus.InvalidData, Base64.DecodeFromUtf8(source, decodedBytes, out int _, out int _));
Assert.Equal(OperationStatus.InvalidData, Base64.DecodeFromUtf8InPlace(source, out int _));
}

[Theory]
[InlineData("A", 0, 0)]
[InlineData("A===", 0, 0)]
Expand Down Expand Up @@ -468,10 +487,9 @@ public void DecodingInvalidBytesPadding(bool isFinalBlock)

// The last byte or the last 2 bytes being the padding character is valid, if isFinalBlock = true
{
Span<byte> source = new byte[] { 50, 50, 50, 50, 80, 80, 80, 80 };
Span<byte> source = new byte[] { 50, 50, 50, 50, 80, 65,
Base64TestHelper.EncodingPad, Base64TestHelper.EncodingPad }; // valid input - "2222PA=="
Span<byte> decodedBytes = new byte[Base64.GetMaxDecodedFromUtf8Length(source.Length)];
source[6] = Base64TestHelper.EncodingPad;
source[7] = Base64TestHelper.EncodingPad; // valid input - "2222PP=="

OperationStatus expectedStatus = isFinalBlock ? OperationStatus.Done : OperationStatus.InvalidData;
int expectedConsumed = isFinalBlock ? source.Length : 4;
Expand All @@ -482,9 +500,9 @@ public void DecodingInvalidBytesPadding(bool isFinalBlock)
Assert.Equal(expectedWritten, decodedByteCount);
Assert.True(Base64TestHelper.VerifyDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes));

source = new byte[] { 50, 50, 50, 50, 80, 80, 80, 80 };
source = new byte[] { 50, 50, 50, 50, 80, 80, 77, 80 };
decodedBytes = new byte[Base64.GetMaxDecodedFromUtf8Length(source.Length)];
source[7] = Base64TestHelper.EncodingPad; // valid input - "2222PPP="
source[7] = Base64TestHelper.EncodingPad; // valid input - "2222PPM="

expectedConsumed = isFinalBlock ? source.Length : 4;
expectedWritten = isFinalBlock ? 5 : 3;
Expand Down Expand Up @@ -661,9 +679,8 @@ public void DecodeInPlaceInvalidBytesPadding()

// The last byte or the last 2 bytes being the padding character is valid
{
Span<byte> buffer = new byte[] { 50, 50, 50, 50, 80, 80, 80, 80 };
buffer[6] = Base64TestHelper.EncodingPad;
buffer[7] = Base64TestHelper.EncodingPad; // valid input - "2222PP=="
Span<byte> buffer = new byte[] { 50, 50, 50, 50, 80, 65,
Base64TestHelper.EncodingPad, Base64TestHelper.EncodingPad }; // valid input - "2222PA=="
string sourceString = Encoding.ASCII.GetString(buffer.ToArray());
Assert.Equal(OperationStatus.Done, Base64.DecodeFromUtf8InPlace(buffer, out int bytesWritten));
Assert.Equal(4, bytesWritten);
Expand All @@ -672,8 +689,8 @@ public void DecodeInPlaceInvalidBytesPadding()
}

{
Span<byte> buffer = new byte[] { 50, 50, 50, 50, 80, 80, 80, 80 };
buffer[7] = Base64TestHelper.EncodingPad; // valid input - "2222PPP="
Span<byte> buffer = new byte[] { 50, 50, 50, 50, 80, 80, 77, 80 };
buffer[7] = Base64TestHelper.EncodingPad; // valid input - "2222PPM="
string sourceString = Encoding.ASCII.GetString(buffer.ToArray());
Assert.Equal(OperationStatus.Done, Base64.DecodeFromUtf8InPlace(buffer, out int bytesWritten));
Assert.Equal(5, bytesWritten);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public void BasicDecoding()

Span<byte> source = new byte[numBytes];
Base64TestHelper.InitializeUrlDecodableBytes(source, numBytes);
source[numBytes - 1] = 65; // make sure unused bits set 0

Span<byte> decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)];
Assert.Equal(OperationStatus.Done, Base64Url.DecodeFromUtf8(source, decodedBytes, out int consumed, out int decodedByteCount));
Expand All @@ -46,6 +47,7 @@ public void BasicDecodingByteArrayReturnOverload()

Span<byte> source = new byte[numBytes];
Base64TestHelper.InitializeUrlDecodableBytes(source, numBytes);
source[numBytes - 1] = 65; // make sure unused bits set 0

Span<byte> decodedBytes = Base64Url.DecodeFromUtf8(source);
Assert.Equal(decodedBytes.Length, Base64Url.GetMaxDecodedLength(source.Length));
Expand Down Expand Up @@ -197,6 +199,7 @@ public void DecodingOutputTooSmall()
{
Span<byte> source = new byte[12];
Base64TestHelper.InitializeUrlDecodableBytes(source);
source[9] = 65; // make sure unused bits set 0
source[10] = Base64TestHelper.EncodingPad;
source[11] = Base64TestHelper.EncodingPad;

Expand All @@ -211,6 +214,7 @@ public void DecodingOutputTooSmall()
{
Span<byte> source = new byte[12];
Base64TestHelper.InitializeUrlDecodableBytes(source);
source[10] = 77; // make sure unused bits set 0
source[11] = Base64TestHelper.EncodingPad;

Span<byte> decodedBytes = new byte[7];
Expand Down Expand Up @@ -287,6 +291,23 @@ public void BasicDecodingWithFinalBlockTrueKnownInputDone(string inputString, in
Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(inputString.Length, expectedWritten, source, decodedBytes));
}

[Theory]
[InlineData("AR")]
[InlineData("AQJ")]
[InlineData("AQIDBB%%")]
[InlineData("AQIDBAV%")]
[InlineData("AQIDBAUHCAkKCwwNDz")]
[InlineData("AQIDBAUHCAkKCwwNDxD")]
public void BasicDecodingWithNonZeroUnusedBits(string inputString)
{
byte[] source = Encoding.ASCII.GetBytes(inputString);
Span<byte> decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)];

Assert.False(Base64Url.IsValid(inputString.AsSpan()));
Assert.Equal(OperationStatus.InvalidData, Base64Url.DecodeFromUtf8(source, decodedBytes, out int _, out int _));
Assert.Throws<FormatException>(() => Base64Url.DecodeFromUtf8InPlace(source));
}

[Theory]
[InlineData("A", 0, 0, OperationStatus.InvalidData)]
[InlineData("A===", 0, 0, OperationStatus.InvalidData)]
Expand Down Expand Up @@ -460,7 +481,7 @@ public void DecodingInvalidBytes(bool isFinalBlock)
// When isFinalBlock = true input that is not a multiple of 4 is invalid for Base64, but valid for Base64Url
if (isFinalBlock)
{
Span<byte> source = "2222PPP"u8.ToArray(); // incomplete input
Span<byte> source = "2222PPM"u8.ToArray(); // incomplete input
Span<byte> decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)];
Assert.Equal(5, Base64Url.DecodeFromUtf8(source, decodedBytes));
Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(7, 5, source, decodedBytes));
Expand Down Expand Up @@ -517,10 +538,9 @@ public void DecodingInvalidBytesPadding(bool isFinalBlock)

// The last byte or the last 2 bytes being the padding character is valid, if isFinalBlock = true
{
Span<byte> source = new byte[] { 50, 50, 50, 50, 80, 80, 80, 80 };
Span<byte> source = new byte[] { 50, 50, 50, 50, 80, 65,
Base64TestHelper.EncodingPad, Base64TestHelper.EncodingPad }; // valid input - "2222PA=="
Span<byte> decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)];
source[6] = Base64TestHelper.EncodingPad;
source[7] = Base64TestHelper.EncodingPad; // valid input - "2222PP=="

OperationStatus expectedStatus = isFinalBlock ? OperationStatus.Done : OperationStatus.InvalidData;
int expectedConsumed = isFinalBlock ? source.Length : 4;
Expand All @@ -531,9 +551,9 @@ public void DecodingInvalidBytesPadding(bool isFinalBlock)
Assert.Equal(expectedWritten, decodedByteCount);
Assert.True(Base64TestHelper.VerifyUrlDecodingCorrectness(expectedConsumed, expectedWritten, source, decodedBytes));

source = new byte[] { 50, 50, 50, 50, 80, 80, 80, 80 };
source = new byte[] { 50, 50, 50, 50, 80, 80, 77, 80 };
decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)];
source[7] = Base64TestHelper.UrlEncodingPad; // valid input - "2222PPP="
source[7] = Base64TestHelper.UrlEncodingPad; // valid input - "2222PPM="

expectedConsumed = isFinalBlock ? source.Length : 4;
expectedWritten = isFinalBlock ? 5 : 3;
Expand Down Expand Up @@ -685,9 +705,8 @@ public void DecodeInPlaceInvalidBytesPaddingThrowsFormatException()

// The last byte or the last 2 bytes being the padding character is valid
{
Span<byte> buffer = new byte[] { 50, 50, 50, 50, 80, 80, 80, 80 };
buffer[6] = Base64TestHelper.UrlEncodingPad;
buffer[7] = Base64TestHelper.EncodingPad; // valid input - "2222PP=="
Span<byte> buffer = new byte[] { 50, 50, 50, 50, 80, 65,
Base64TestHelper.UrlEncodingPad, Base64TestHelper.EncodingPad }; // valid input - "2222PA=="
string sourceString = Encoding.ASCII.GetString(buffer.ToArray());
int bytesWritten = Base64Url.DecodeFromUtf8InPlace(buffer);

Expand All @@ -696,8 +715,7 @@ public void DecodeInPlaceInvalidBytesPaddingThrowsFormatException()
}

{
Span<byte> buffer = new byte[] { 50, 50, 50, 50, 80, 80, 80, 80 };
buffer[7] = Base64TestHelper.EncodingPad; // valid input - "2222PPP="
Span<byte> buffer = new byte[] { 50, 50, 50, 50, 80, 80, 77, Base64TestHelper.EncodingPad }; // valid input - "2222PPM="
string sourceString = Encoding.ASCII.GetString(buffer.ToArray());
int bytesWritten = Base64Url.DecodeFromUtf8InPlace(buffer);

Expand All @@ -707,7 +725,7 @@ public void DecodeInPlaceInvalidBytesPaddingThrowsFormatException()

// The last byte or the last 2 bytes being the padding character is valid
{
Span<byte> buffer = new byte[] { 50, 50, 50, 50, 80, 80 }; // valid input without padding "2222PP"
Span<byte> buffer = new byte[] { 50, 50, 50, 50, 80, 65 }; // valid input without padding "2222PA"

string sourceString = Encoding.ASCII.GetString(buffer.ToArray());
int bytesWritten = Base64Url.DecodeFromUtf8InPlace(buffer);
Expand Down Expand Up @@ -775,12 +793,12 @@ public void DecodingInPlaceWithOnlyCharsToBeIgnored(string utf8WithCharsToBeIgno
}

[Theory]
[InlineData(new byte[] { 0xa, 0xa, 0x2d, 0x2d }, 251)]
[InlineData(new byte[] { 0xa, 0x5f, 0xa, 0x2d }, 255)]
[InlineData(new byte[] { 0x5f, 0x5f, 0xa, 0xa }, 255)]
[InlineData(new byte[] { 0x70, 0xa, 0x61, 0xa }, 165)]
[InlineData(new byte[] { 0xa, 0x70, 0xa, 0x61, 0xa }, 165)]
[InlineData(new byte[] { 0x70, 0xa, 0x61, 0xa, 0x3d, 0x3d }, 165)]
[InlineData(new byte[] { 0xa, 0xa, 0x2d, 0x77 }, 251)]
[InlineData(new byte[] { 0xa, 0x5f, 0xa, 0x77 }, 255)]
[InlineData(new byte[] { 0x5f, 0x77, 0xa, 0xa }, 255)]
[InlineData(new byte[] { 0x70, 0xa, 0x51, 0xa }, 165)]
[InlineData(new byte[] { 0xa, 0x70, 0xa, 0x51, 0xa }, 165)]
[InlineData(new byte[] { 0x70, 0xa, 0x51, 0xa, 0x3d, 0x3d }, 165)]
public void DecodingLessThan4BytesWithWhiteSpaces(byte[] utf8Bytes, byte decoded)
{
Assert.True(Base64Url.IsValid(utf8Bytes, out int decodedLength));
Expand All @@ -802,12 +820,12 @@ public void DecodingLessThan4BytesWithWhiteSpaces(byte[] utf8Bytes, byte decoded
}

[Theory]
[InlineData(new char[] { '\r', '\r', '-', '-' }, 251)]
[InlineData(new char[] { '\r', '_', '\r', '-' }, 255)]
[InlineData(new char[] { '_', '_', '\r', '\r' }, 255)]
[InlineData(new char[] { 'p', '\r', 'a', '\r' }, 165)]
[InlineData(new char[] { '\r', 'p', '\r', 'a', '\r' }, 165)]
[InlineData(new char[] { 'p', '\r', 'a', '\r', '=', '=' }, 165)]
[InlineData(new char[] { '\r', '\r', '-', 'w' }, 251)]
[InlineData(new char[] { '\r', '_', '\r', 'w' }, 255)]
[InlineData(new char[] { '_', 'w', '\r', '\r' }, 255)]
[InlineData(new char[] { 'p', '\r', 'Q', '\r' }, 165)]
[InlineData(new char[] { '\r', 'p', '\r', 'Q', '\r' }, 165)]
[InlineData(new char[] { 'p', '\r', 'Q', '\r', '=', '=' }, 165)]
public void DecodingLessThan4CharsWithWhiteSpaces(char[] utf8Bytes, byte decoded)
{
Assert.True(Base64Url.IsValid(utf8Bytes, out int decodedLength));
Expand All @@ -825,8 +843,8 @@ public void DecodingLessThan4CharsWithWhiteSpaces(char[] utf8Bytes, byte decoded
}

[Theory]
[InlineData(new byte[] { 0x4a, 0x74, 0xa, 0x4a, 0x4a, 0x74, 0xa, 0x4a }, new byte[] { 38, 210, 73, 180 })]
[InlineData(new byte[] { 0xa, 0x2d, 0x56, 0xa, 0xa, 0xa, 0x2d, 0x4a, 0x4a, 0x4a, }, new byte[] { 249, 95, 137, 36 })]
[InlineData(new byte[] { 0x4a, 0x74, 0xa, 0x4a, 0x4a, 0x74, 0xa, 0x41 }, new byte[] { 38, 210, 73, 180 })]
[InlineData(new byte[] { 0xa, 0x2d, 0x56, 0xa, 0xa, 0xa, 0x2d, 0x4a, 0x4a, 0x41, }, new byte[] { 249, 95, 137, 36 })]
public void DecodingNotMultipleOf4WithWhiteSpace(byte[] utf8Bytes, byte[] decoded)
{
Assert.True(Base64Url.IsValid(utf8Bytes, out int decodedLength));
Expand All @@ -847,8 +865,8 @@ public void DecodingNotMultipleOf4WithWhiteSpace(byte[] utf8Bytes, byte[] decode
}

[Theory]
[InlineData(new char[] { 'J', 't', '\r', 'J', 'J', 't', '\r', 'J' }, new byte[] { 38, 210, 73, 180 })]
[InlineData(new char[] { '\r', '-', 'V', '\r', '\r', '\r', '-', 'J', 'J', 'J', }, new byte[] { 249, 95, 137, 36 })]
[InlineData(new char[] { 'J', 't', '\r', 'J', 'J', 't', '\r', 'A' }, new byte[] { 38, 210, 73, 180 })]
[InlineData(new char[] { '\r', '-', 'V', '\r', '\r', '\r', '-', 'J', 'J', 'A', }, new byte[] { 249, 95, 137, 36 })]
public void DecodingNotMultipleOf4CharsWithWhiteSpace(char[] utf8Bytes, byte[] decoded)
{
Assert.True(Base64Url.IsValid(utf8Bytes, out int decodedLength));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public void DecodeWithLargeSpan()

Span<char> source = new char[numBytes];
Base64TestHelper.InitializeUrlDecodableChars(source, numBytes);
source[numBytes - 1] = 'A'; // make sure unused bits set 0

Span<byte> decodedBytes = new byte[Base64Url.GetMaxDecodedLength(source.Length)];
Assert.Equal(OperationStatus.Done, Base64Url.DecodeFromChars(source, decodedBytes, out int consumed, out int decodedByteCount));
Expand Down Expand Up @@ -262,45 +263,46 @@ public static void Roundtrip()
}

[Fact]
public static void PartialRoundtripWithoutPadding()
public static void RoundtripWithoutPadding()
{
string input = "ab";
string input = "ag";
Verify(input, result =>
{
Assert.Equal(1, result.Length);
string roundtrippedString = Base64Url.EncodeToString(result);
Assert.NotEqual(input, roundtrippedString);
Assert.Equal(input[0], roundtrippedString[0]);
Assert.Equal(input, roundtrippedString);
});
}

[Fact]
public static void PartialRoundtripWithPadding2()
public static void RoundtripWithPadding2()
{
string input = "ab==";
string input = "ag==";
Verify(input, result =>
{
Assert.Equal(1, result.Length);
string roundtrippedString = Base64Url.EncodeToString(result);
Assert.NotEqual(input, roundtrippedString);
Assert.NotEqual(input, roundtrippedString); // Padding character omitted
Assert.Equal(input[0], roundtrippedString[0]);
Assert.Equal(input[1], roundtrippedString[1]);
});
}

[Fact]
public static void PartialRoundtripWithPadding1()
public static void RoundtripWithPadding1()
{
string input = "789=";
string input = "788=";
Verify(input, result =>
{
Assert.Equal(2, result.Length);
string roundtrippedString = Base64Url.EncodeToString(result);
Assert.NotEqual(input, roundtrippedString);
Assert.NotEqual(input, roundtrippedString); // Padding character omitted
Assert.Equal(input[0], roundtrippedString[0]);
Assert.Equal(input[1], roundtrippedString[1]);
Assert.Equal(input[2], roundtrippedString[2]);
});
}

Expand Down
Loading

0 comments on commit eb765b7

Please sign in to comment.