Skip to content

Commit

Permalink
Add Scan and Value to pgtype.Multirange[T]
Browse files Browse the repository at this point in the history
  • Loading branch information
jackc committed May 19, 2024
1 parent a07f175 commit 9316da5
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
47 changes: 47 additions & 0 deletions pgtype/multirange.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,50 @@ func (r Multirange[T]) ScanIndex(i int) any {
func (r Multirange[T]) ScanIndexType() any {
return new(T)
}

// Scan implements the database/sql Scanner interface.
//
// Range needs a *Map to decode the range elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Scanner interface does not have an argument that can be used to pass extra data to the Scan
// method. To work around this limitation, Scan uses a default *Map. This means Scan may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (r *Multirange[T]) Scan(v any) error {
if v == nil {
*r = nil
return nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

switch v := v.(type) {
case string:
return m.Scan(0, 0, []byte(v), r)
case []byte:
return m.Scan(0, 0, v, r)
default:
return fmt.Errorf("unsupported type: %T", v)
}
}

// Value implements the database/sql/driver Valuer interface.
//
// Range needs a *Map to encode the range elements. Unfortunately, the the Go type system does not support attaching
// data to a type and the Valuer interface does not have an argument that can be used to pass extra data to the Value
// method. To work around this limitation, Value uses a default *Map. This means Value may only work properly with types
// that are supported by a default *Map. from a Map or Codec.
func (src Multirange[T]) Value() (driver.Value, error) {
if src == nil {
return nil, nil
}

m := databaseSQLMapPool.Get().(*Map)
defer databaseSQLMapPool.Put(m)

buf, err := m.Encode(0, TextFormatCode, src, nil)
if err != nil {
return nil, err
}

return string(buf), nil
}
38 changes: 38 additions & 0 deletions stdlib/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,44 @@ func TestConnQueryPGTypeRange(t *testing.T) {
})
}

func TestConnQueryPGTypeMultirange(t *testing.T) {
testWithAllQueryExecModes(t, func(t *testing.T, db *sql.DB) {
skipCockroachDB(t, db, "Server does not support int4range")

var r pgtype.Multirange[pgtype.Range[pgtype.Int4]]
err := db.QueryRow("select int4multirange(int4range(1, 5), int4range(7,9))").Scan(&r)
require.NoError(t, err)
assert.Equal(
t,
pgtype.Multirange[pgtype.Range[pgtype.Int4]]{
{
Lower: pgtype.Int4{Int32: 1, Valid: true},
Upper: pgtype.Int4{Int32: 5, Valid: true},
LowerType: pgtype.Inclusive,
UpperType: pgtype.Exclusive,
Valid: true,
},
{
Lower: pgtype.Int4{Int32: 7, Valid: true},
Upper: pgtype.Int4{Int32: 9, Valid: true},
LowerType: pgtype.Inclusive,
UpperType: pgtype.Exclusive,
Valid: true,
},
},
r)

var equal bool
err = db.QueryRow("select int4multirange(int4range(1, 5), int4range(7,9)) = $1::int4multirange", r).Scan(&equal)
require.NoError(t, err)
require.Equal(t, true, equal)

err = db.QueryRow("select null::int4multirange").Scan(&r)
require.NoError(t, err)
require.Nil(t, r)
})
}

// Test type that pgx would handle natively in binary, but since it is not a
// database/sql native type should be passed through as a string
func TestConnQueryRowPgxBinary(t *testing.T) {
Expand Down

0 comments on commit 9316da5

Please sign in to comment.