diff --git a/snowflake.go b/snowflake.go index 169e2be..0eef035 100644 --- a/snowflake.go +++ b/snowflake.go @@ -11,6 +11,8 @@ import ( // Epoch is the discord epoch in milliseconds. const Epoch = 1420070400000 +var AllowUnquoted = false + var ( nullBytes = []byte("null") zeroBytes = []byte("0") @@ -75,14 +77,17 @@ func (id *ID) UnmarshalJSON(data []byte) error { if bytes.Equal(data, nullBytes) || bytes.Equal(data, zeroBytes) { return nil } + snowflake, err := strconv.Unquote(string(data)) - if err != nil { + if err != nil && !AllowUnquoted { return fmt.Errorf("failed to unquote snowflake: %w", err) } + i, err := strconv.ParseUint(snowflake, 10, 64) if err != nil { return fmt.Errorf("failed to parse snowflake as uint64: %w", err) } + *id = ID(i) return nil } diff --git a/snowflake_test.go b/snowflake_test.go new file mode 100644 index 0000000..9674448 --- /dev/null +++ b/snowflake_test.go @@ -0,0 +1,74 @@ +package snowflake + +import ( + "encoding/json" + "errors" + "strconv" + "testing" +) + +func TestID_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + data []byte + expected ID + allowUnquoted bool + err error + }{ + { + name: "null", + data: []byte("null"), + expected: 0, + allowUnquoted: false, + err: nil, + }, + { + name: "valid id", + data: []byte(`"123456"`), + expected: 123456, + allowUnquoted: false, + err: nil, + }, + { + name: "invalid id", + data: []byte(`"id"`), + expected: 0, + allowUnquoted: false, + err: strconv.ErrSyntax, + }, + { + name: "unquoted 0", + data: []byte("0"), + expected: 0, + allowUnquoted: false, + err: nil, + }, + { + name: "quoted 0", + data: []byte(`"0"`), + expected: 0, + allowUnquoted: false, + err: nil, + }, + { + name: "unquoted 1", + data: []byte("1"), + expected: 0, + allowUnquoted: true, + err: strconv.ErrSyntax, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var id ID + err := json.Unmarshal(tt.data, &id) + + if !errors.Is(err, tt.err) { + t.Errorf("expected error %v, got %v", tt.err, err) + } + if id != tt.expected { + t.Errorf("expected %d, got %d", tt.expected, id) + } + }) + } +}