Skip to content

Commit

Permalink
Update JsonWebToken to enable extensibility (#2582)
Browse files Browse the repository at this point in the history
* Extract token reading code into a separate ReadPropertyValue.

* Rename. Use IDictionary.

* Add a test. Add JsonWebTokens InternalsVisibleTo for tests.
  • Loading branch information
pmaytak authored May 15, 2024
1 parent ec0ae43 commit 909669f
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 71 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

[assembly: System.Runtime.CompilerServices.InternalsVisibleTo("Microsoft.IdentityModel.JsonWebTokens.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100b5fc90e7027f67871e773a8fde8938c81dd402ba65b9201d60593e96c492651e889cc13f1415ebb53fac1131ae0bd333c5ee6021672d9718ea31a8aebd0da0072f25d87dba6fc90ffd598ed4da35e44c398c454307e8e33b8426143daec9f596836f97c8f74750e5975c64e2189f45def46b2a2b1247adc3652bf5c308055da9")]
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal JsonClaimSet CreatePayloadClaimSet(byte[] bytes, int length)
}

internal JsonClaimSet CreatePayloadClaimSet(ReadOnlySpan<byte> byteSpan)
{
{
Utf8JsonReader reader = new(byteSpan);
if (!JsonSerializerPrimitives.IsReaderAtTokenType(ref reader, JsonTokenType.StartObject, true))
throw LogHelper.LogExceptionMessage(
Expand All @@ -37,71 +37,7 @@ internal JsonClaimSet CreatePayloadClaimSet(ReadOnlySpan<byte> byteSpan)
{
if (reader.TokenType == JsonTokenType.PropertyName)
{
if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Aud))
{
_audiences = [];
reader.Read();
if (reader.TokenType == JsonTokenType.StartArray)
{
JsonSerializerPrimitives.ReadStringsSkipNulls(ref reader, _audiences, JwtRegisteredClaimNames.Aud, ClassName);
claims[JwtRegisteredClaimNames.Aud] = _audiences;
}
else
{
if (reader.TokenType != JsonTokenType.Null)
{
_audiences.Add(JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Aud, ClassName));
claims[JwtRegisteredClaimNames.Aud] = _audiences[0];
}
else
{
claims[JwtRegisteredClaimNames.Aud] = _audiences;
}
}
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Azp))
{
_azp = JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Azp, ClassName, true);
claims[JwtRegisteredClaimNames.Azp] = _azp;
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Exp))
{
_exp = JsonSerializerPrimitives.ReadLong(ref reader, JwtRegisteredClaimNames.Exp, ClassName, true);
_expDateTime = EpochTime.DateTime(_exp.Value);
claims[JwtRegisteredClaimNames.Exp] = _exp;
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Iat))
{
_iat = JsonSerializerPrimitives.ReadLong(ref reader, JwtRegisteredClaimNames.Iat, ClassName, true);
_iatDateTime = EpochTime.DateTime(_iat.Value);
claims[JwtRegisteredClaimNames.Iat] = _iat;
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Iss))
{
_iss = JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Iss, ClassName, true);
claims[JwtRegisteredClaimNames.Iss] = _iss;
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Jti))
{
_jti = JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Jti, ClassName, true);
claims[JwtRegisteredClaimNames.Jti] = _jti;
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Nbf))
{
_nbf = JsonSerializerPrimitives.ReadLong(ref reader, JwtRegisteredClaimNames.Nbf, ClassName, true);
_nbfDateTime = EpochTime.DateTime(_nbf.Value);
claims[JwtRegisteredClaimNames.Nbf] = _nbf;
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Sub))
{
_sub = JsonSerializerPrimitives.ReadStringOrNumberAsString(ref reader, JwtRegisteredClaimNames.Sub, ClassName, true);
claims[JwtRegisteredClaimNames.Sub] = _sub;
}
else
{
string propertyName = reader.GetString();
claims[propertyName] = JsonSerializerPrimitives.ReadPropertyValueAsObject(ref reader, propertyName, JsonClaimSet.ClassName, true);
}
ReadPayloadValue(ref reader, claims);
}
// We read a JsonTokenType.StartObject above, exiting and positioning reader at next token.
else if (JsonSerializerPrimitives.IsReaderAtTokenType(ref reader, JsonTokenType.EndObject, false))
Expand All @@ -112,5 +48,74 @@ internal JsonClaimSet CreatePayloadClaimSet(ReadOnlySpan<byte> byteSpan)

return new JsonClaimSet(claims);
}

private protected virtual void ReadPayloadValue(ref Utf8JsonReader reader, IDictionary<string, object> claims)
{
if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Aud))
{
_audiences = [];
reader.Read();
if (reader.TokenType == JsonTokenType.StartArray)
{
JsonSerializerPrimitives.ReadStringsSkipNulls(ref reader, _audiences, JwtRegisteredClaimNames.Aud, ClassName);
claims[JwtRegisteredClaimNames.Aud] = _audiences;
}
else
{
if (reader.TokenType != JsonTokenType.Null)
{
_audiences.Add(JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Aud, ClassName));
claims[JwtRegisteredClaimNames.Aud] = _audiences[0];
}
else
{
claims[JwtRegisteredClaimNames.Aud] = _audiences;
}
}
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Azp))
{
_azp = JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Azp, ClassName, true);
claims[JwtRegisteredClaimNames.Azp] = _azp;
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Exp))
{
_exp = JsonSerializerPrimitives.ReadLong(ref reader, JwtRegisteredClaimNames.Exp, ClassName, true);
_expDateTime = EpochTime.DateTime(_exp.Value);
claims[JwtRegisteredClaimNames.Exp] = _exp;
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Iat))
{
_iat = JsonSerializerPrimitives.ReadLong(ref reader, JwtRegisteredClaimNames.Iat, ClassName, true);
_iatDateTime = EpochTime.DateTime(_iat.Value);
claims[JwtRegisteredClaimNames.Iat] = _iat;
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Iss))
{
_iss = JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Iss, ClassName, true);
claims[JwtRegisteredClaimNames.Iss] = _iss;
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Jti))
{
_jti = JsonSerializerPrimitives.ReadString(ref reader, JwtRegisteredClaimNames.Jti, ClassName, true);
claims[JwtRegisteredClaimNames.Jti] = _jti;
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Nbf))
{
_nbf = JsonSerializerPrimitives.ReadLong(ref reader, JwtRegisteredClaimNames.Nbf, ClassName, true);
_nbfDateTime = EpochTime.DateTime(_nbf.Value);
claims[JwtRegisteredClaimNames.Nbf] = _nbf;
}
else if (reader.ValueTextEquals(JwtPayloadUtf8Bytes.Sub))
{
_sub = JsonSerializerPrimitives.ReadStringOrNumberAsString(ref reader, JwtRegisteredClaimNames.Sub, ClassName, true);
claims[JwtRegisteredClaimNames.Sub] = _sub;
}
else
{
string propertyName = reader.GetString();
claims[propertyName] = JsonSerializerPrimitives.ReadPropertyValueAsObject(ref reader, propertyName, JsonClaimSet.ClassName, true);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using System;
using System.Collections.Generic;
using System.Text.Json;
using Microsoft.IdentityModel.Tokens.Json;

namespace Microsoft.IdentityModel.JsonWebTokens.Tests
{
public class CustomJsonWebToken : JsonWebToken
{
private const string CustomClaimName = "CustomClaim";

public CustomJsonWebToken(string jwtEncodedString) : base(jwtEncodedString) { }

public CustomJsonWebToken(ReadOnlyMemory<char> encodedTokenMemory) : base(encodedTokenMemory) { }

public CustomJsonWebToken(string header, string payload) : base(header, payload) { }

private protected override void ReadPayloadValue(ref Utf8JsonReader reader, IDictionary<string, object> claims)
{
if (reader.ValueTextEquals(CustomClaimName))
{
_customClaim = JsonSerializerPrimitives.ReadString(ref reader, CustomClaimName, ClassName, true);
claims[CustomClaimName] = _customClaim;
}
else
{
base.ReadPayloadValue(ref reader, claims);
}
}

private string _customClaim;

public string CustomClaim
{
get
{
_customClaim ??= Payload.GetStringValue(CustomClaimName);
return _customClaim;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,14 @@ public static TheoryData<GetPayloadValueTheoryData> CheckAudienceValuesTheoryDat
PropertyName = "aud",
PropertyValue = new List<string>(),
PropertyType = typeof(List<string>),
Json = JsonUtilities.CreateUnsignedToken("aud", new List<string>{ null, null })
Json = JsonUtilities.CreateUnsignedToken("aud", new List<string> { null, null })
});

theoryData.Add(new GetPayloadValueTheoryData("singleNonNull")
{
ClaimValue = new List<string> { "audience" },
PropertyName = "aud",
PropertyValue = new List<string> { "audience"},
PropertyValue = new List<string> { "audience" },
PropertyType = typeof(List<string>),
Json = JsonUtilities.CreateUnsignedToken("aud", "audience")
});
Expand All @@ -379,7 +379,7 @@ public static TheoryData<GetPayloadValueTheoryData> CheckAudienceValuesTheoryDat
{
ClaimValue = new List<string> { "audience1" },
PropertyName = "aud",
PropertyValue = new List<string> { "audience1"},
PropertyValue = new List<string> { "audience1" },
PropertyType = typeof(List<string>),
Json = JsonUtilities.CreateUnsignedToken("aud", new List<string> { null, "audience1", null })
});
Expand Down Expand Up @@ -726,7 +726,7 @@ public static TheoryData<GetPayloadValueTheoryData> GetPayloadSubClaimValueTheor

return theoryData;
}

}

// This test ensures that accessing claims from the payload works as expected.
Expand Down Expand Up @@ -981,7 +981,7 @@ public static TheoryData<GetPayloadValueTheoryData> GetPayloadValueTheoryData
{
PropertyName = "dateTime",
PropertyType = typeof(string[]),
PropertyValue = new string[] {dateTime.ToString("o", CultureInfo.InvariantCulture)},
PropertyValue = new string[] { dateTime.ToString("o", CultureInfo.InvariantCulture) },
Json = JsonUtilities.CreateUnsignedToken("dateTime", dateTime)
});

Expand Down Expand Up @@ -1718,6 +1718,25 @@ public void StringAndMemoryConstructors_CreateEquivalentTokens(JwtTheoryData the
}
TestUtilities.AssertFailIfErrors(context);
}

[Fact]
public void DerivedJsonWebToken_IsCreatedCorrectly()
{
var expectedCustomClaim = "customclaim";
var tokenStr = new JsonWebTokenHandler().CreateToken(new SecurityTokenDescriptor
{
Issuer = Default.Issuer,
Claims = new Dictionary<string, object>
{
{ "CustomClaim", expectedCustomClaim },
}
});

var derivedToken = new CustomJsonWebToken(tokenStr);

Assert.Equal(expectedCustomClaim, derivedToken.CustomClaim);
Assert.Equal(Default.Issuer, derivedToken.Issuer);
}
}

public class ParseTimeValuesTheoryData : TheoryDataBase
Expand Down

0 comments on commit 909669f

Please sign in to comment.