Skip to content

Commit

Permalink
Update mautrix-go
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Jan 7, 2024
1 parent fee5cf2 commit 959eb7e
Show file tree
Hide file tree
Showing 19 changed files with 75 additions and 80 deletions.
2 changes: 1 addition & 1 deletion commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ func fnSyncSpace(ce *WrappedCommandEvent) {
if portal.IsPrivateChat() {
continue
}
if ce.Bridge.StateStore.IsInRoom(portal.MXID, ce.User.MXID) && portal.addToPersonalSpace(ctx, ce.User) {
if ce.Bridge.StateStore.IsInRoom(ctx, portal.MXID, ce.User.MXID) && portal.addToPersonalSpace(ctx, ce.User) {
count++
}
}
Expand Down
2 changes: 1 addition & 1 deletion custompuppet.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (puppet *Puppet) ClearCustomMXID() {
}

func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error {
newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(puppet.CustomMXID, puppet.AccessToken, reloginOnFail)
newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(context.TODO(), puppet.CustomMXID, puppet.AccessToken, reloginOnFail)
if err != nil {
puppet.ClearCustomMXID()
return err
Expand Down
4 changes: 2 additions & 2 deletions database/portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) {
}

func (pq *PortalQuery) FindPrivateChatsNotInSpace(ctx context.Context, receiver uuid.UUID) ([]PortalKey, error) {
rows, err := pq.GetDB().QueryContext(ctx, getChatsNotInSpaceQuery, receiver)
rows, err := pq.GetDB().Query(ctx, getChatsNotInSpaceQuery, receiver)
if err != nil {
return nil, err
}
return dbutil.NewRowIter(rows, func(rows dbutil.Rows) (key PortalKey, err error) {
return dbutil.NewRowIter(rows, func(rows dbutil.Scannable) (key PortalKey, err error) {
err = rows.Scan(&key.ChatID)
key.Receiver = receiver
return
Expand Down
5 changes: 3 additions & 2 deletions database/upgrades/upgrades.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package upgrades

import (
"context"
"embed"
"errors"

Expand All @@ -29,10 +30,10 @@ var Table dbutil.UpgradeTable
var rawUpgrades embed.FS

func init() {
Table.Register(-1, 12, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
Table.Register(-1, 12, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error {
return errors.New("please upgrade to mautrix-signal v0.4.3 before upgrading to a newer version")
})
Table.Register(1, 13, 0, "Jump to version 13", false, func(tx dbutil.Execable, database *dbutil.Database) error {
Table.Register(1, 13, 0, "Jump to version 13", false, func(ctx context.Context, database *dbutil.Database) error {
return nil
})
Table.RegisterFS(rawUpgrades)
Expand Down
4 changes: 2 additions & 2 deletions database/userportal.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (u *User) GetLastReadTS(ctx context.Context, portal PortalKey) uint64 {
return cached
}
var ts int64
err := u.qh.GetDB().QueryRowContext(ctx, getLastReadTSQuery, u.MXID, portal.ChatID, portal.Receiver).Scan(&ts)
err := u.qh.GetDB().QueryRow(ctx, getLastReadTSQuery, u.MXID, portal.ChatID, portal.Receiver).Scan(&ts)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
zerolog.Ctx(ctx).Err(err).
Str("user_id", u.MXID.String()).
Expand Down Expand Up @@ -83,7 +83,7 @@ func (u *User) IsInSpace(ctx context.Context, portal PortalKey) bool {
return cached
}
var inSpace bool
err := u.qh.GetDB().QueryRowContext(ctx, getIsInSpaceQuery, u.MXID, portal.ChatID, portal.Receiver).Scan(&inSpace)
err := u.qh.GetDB().QueryRow(ctx, getIsInSpaceQuery, u.MXID, portal.ChatID, portal.Receiver).Scan(&inSpace)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
zerolog.Ctx(ctx).Err(err).
Str("user_id", u.MXID.String()).
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ require (
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.17.0
go.mau.fi/util v0.2.2-0.20240107131103-852f29430a02
go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894
golang.org/x/crypto v0.17.0
golang.org/x/exp v0.0.0-20231226003508-02704c960a9b
golang.org/x/net v0.19.0
google.golang.org/protobuf v1.32.0
maunium.net/go/maulogger/v2 v2.4.1
maunium.net/go/mautrix v0.16.3-0.20240104125737-88631708a41b
maunium.net/go/mautrix v0.16.3-0.20240107204502-25bc36bc7ae7
nhooyr.io/websocket v1.8.10
)

Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68=
github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.mau.fi/util v0.2.2-0.20240107131103-852f29430a02 h1:jREUBe6TF4a2HCGowTLzcvOFg44QDZ0xgoo+YJK3ugc=
go.mau.fi/util v0.2.2-0.20240107131103-852f29430a02/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 h1:CuR5LDSxBQLETorfwJ9vRtySeLHjMvJ7//lnCMw7Dy8=
go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto=
go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
Expand All @@ -92,7 +92,7 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
maunium.net/go/mautrix v0.16.3-0.20240104125737-88631708a41b h1:WWCD0vaAztVrrTRWcTXeOHq9U7HRcP2a1hs+0+guPPg=
maunium.net/go/mautrix v0.16.3-0.20240104125737-88631708a41b/go.mod h1:lI43hRW+/92FCqHLD5bINSPqsWrviZ5MpLl7J3hjvW4=
maunium.net/go/mautrix v0.16.3-0.20240107204502-25bc36bc7ae7 h1:Yo1S3mSazHoT/MHNheRMuRPH74rU6/ZyVaJqTEsmaN0=
maunium.net/go/mautrix v0.16.3-0.20240107204502-25bc36bc7ae7/go.mod h1:eRQu5ED1ODsP+xq1K9l1AOD+O9FMkAhodd/RVc3Bkqg=
nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q=
nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
8 changes: 4 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (br *SignalBridge) logLostPortals(ctx context.Context) {

func (br *SignalBridge) Start() {
go br.logLostPortals(context.TODO())
err := br.MeowStore.Upgrade()
err := br.MeowStore.Upgrade(context.TODO())
if err != nil {
br.Log.Fatalln("Failed to upgrade signalmeow database: %v", err)
os.Exit(15)
Expand Down Expand Up @@ -298,9 +298,9 @@ func (br *SignalBridge) createPrivatePortalFromInvite(ctx context.Context, roomI
log.Err(err).Msg("Failed to enable e2be")
}
}
br.AS.StateStore.SetMembership(roomID, inviter.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(roomID, puppet.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(roomID, br.Bot.UserID, event.MembershipJoin)
br.AS.StateStore.SetMembership(ctx, roomID, inviter.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(ctx, roomID, puppet.MXID, event.MembershipJoin)
br.AS.StateStore.SetMembership(ctx, roomID, br.Bot.UserID, event.MembershipJoin)
portal.Encrypted = true
}
portal.UpdateDMInfo(ctx, true)
Expand Down
8 changes: 4 additions & 4 deletions metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,23 +229,23 @@ func (mh *MetricsHandler) TrackConnectionState(signalID string, connected bool)
func (mh *MetricsHandler) updateStats() {
start := time.Now()
var puppetCount int
err := mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount)
err := mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount)
if err != nil {
mh.log.Warnln("Failed to scan number of puppets:", err)
} else {
mh.puppetCount.Set(float64(puppetCount))
}

var userCount int
err = mh.db.QueryRowContext(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount)
err = mh.db.QueryRow(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount)
if err != nil {
mh.log.Warnln("Failed to scan number of users:", err)
} else {
mh.userCount.Set(float64(userCount))
}

var messageCount int
err = mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount)
err = mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount)
if err != nil {
mh.log.Warnln("Failed to scan number of messages:", err)
} else {
Expand All @@ -255,7 +255,7 @@ func (mh *MetricsHandler) updateStats() {
var encryptedGroupCount, encryptedPrivateCount, unencryptedGroupCount, unencryptedPrivateCount int
// TODO Use a more precise way to check if a chat_id is a UUID.
// It should also be compatible with both SQLite & Postgres.
err = mh.db.QueryRowContext(mh.ctx, `
err = mh.db.QueryRow(mh.ctx, `
SELECT
COUNT(CASE WHEN chat_id NOT LIKE '%-%-%-%-%' AND encrypted THEN 1 END) AS encrypted_group_portals,
COUNT(CASE WHEN chat_id LIKE '%-%-%-%-%' AND encrypted THEN 1 END) AS encrypted_private_portals,
Expand Down
12 changes: 5 additions & 7 deletions pkg/signalmeow/store/contact_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,23 @@ func scanContact(row dbutil.Scannable) (*types.Contact, error) {
}

func (s *SQLStore) LoadContact(ctx context.Context, theirUUID uuid.UUID) (*types.Contact, error) {
return scanContact(s.db.Conn(ctx).QueryRowContext(ctx, getContactByUUIDQuery, s.ACI, theirUUID))
return scanContact(s.db.QueryRow(ctx, getContactByUUIDQuery, s.ACI, theirUUID))
}

func (s *SQLStore) LoadContactByE164(ctx context.Context, e164 string) (*types.Contact, error) {
return scanContact(s.db.Conn(ctx).QueryRowContext(ctx, getContactByPhoneQuery, s.ACI, e164))
return scanContact(s.db.QueryRow(ctx, getContactByPhoneQuery, s.ACI, e164))
}

func (s *SQLStore) AllContacts(ctx context.Context) ([]*types.Contact, error) {
rows, err := s.db.Conn(ctx).QueryContext(ctx, getAllContactsOfUserQuery, s.ACI)
rows, err := s.db.Query(ctx, getAllContactsOfUserQuery, s.ACI)
if err != nil {
return nil, err
}
return dbutil.NewRowIter(rows, func(rows dbutil.Rows) (*types.Contact, error) {
return scanContact(rows)
}).AsList()
return dbutil.NewRowIter(rows, scanContact).AsList()
}

func (s *SQLStore) StoreContact(ctx context.Context, contact types.Contact) error {
_, err := s.db.Conn(ctx).ExecContext(
_, err := s.db.Exec(
ctx,
upsertContactQuery,
s.ACI,
Expand Down
12 changes: 6 additions & 6 deletions pkg/signalmeow/store/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ FROM signalmeow_device

const getDeviceQuery = getAllDevicesQuery + " WHERE aci_uuid=$1"

func (c *StoreContainer) Upgrade() error {
return c.db.Upgrade()
func (c *StoreContainer) Upgrade(ctx context.Context) error {
return c.db.Upgrade(ctx)
}

func (c *StoreContainer) scanDevice(row dbutil.Scannable) (*Device, error) {
Expand Down Expand Up @@ -85,7 +85,7 @@ func (c *StoreContainer) scanDevice(row dbutil.Scannable) (*Device, error) {

// GetAllDevices finds all the devices in the database.
func (c *StoreContainer) GetAllDevices(ctx context.Context) ([]*Device, error) {
rows, err := c.db.Conn(ctx).QueryContext(ctx, getAllDevicesQuery)
rows, err := c.db.Query(ctx, getAllDevicesQuery)
if err != nil {
return nil, fmt.Errorf("failed to query sessions: %w", err)
}
Expand All @@ -104,7 +104,7 @@ func (c *StoreContainer) GetAllDevices(ctx context.Context) ([]*Device, error) {
// GetDevice finds the device with the specified ACI UUID in the database.
// If the device is not found, nil is returned instead.
func (c *StoreContainer) DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, error) {
sess, err := c.scanDevice(c.db.Conn(ctx).QueryRowContext(ctx, getDeviceQuery, aci))
sess, err := c.scanDevice(c.db.QueryRow(ctx, getDeviceQuery, aci))
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
Expand Down Expand Up @@ -150,7 +150,7 @@ func (c *StoreContainer) PutDevice(ctx context.Context, device *DeviceData) erro
zerolog.Ctx(ctx).Err(err).Msg("failed to serialize pni identity key pair")
return err
}
_, err = c.db.Conn(ctx).ExecContext(ctx, insertDeviceQuery,
_, err = c.db.Exec(ctx, insertDeviceQuery,
device.ACI, aciIdentityKeyPair, device.RegistrationID,
device.PNI, pniIdentityKeyPair, device.PNIRegistrationID,
device.DeviceID, device.Number, device.Password,
Expand All @@ -166,6 +166,6 @@ func (c *StoreContainer) DeleteDevice(ctx context.Context, device *DeviceData) e
if device.ACI == uuid.Nil {
return ErrDeviceIDMustBeSet
}
_, err := c.db.Conn(ctx).ExecContext(ctx, deleteDeviceQuery, device.ACI)
_, err := c.db.Exec(ctx, deleteDeviceQuery, device.ACI)
return err
}
4 changes: 2 additions & 2 deletions pkg/signalmeow/store/group_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func scanGroup(row dbutil.Scannable) (*dbGroup, error) {
}

func (s *SQLStore) MasterKeyFromGroupIdentifier(ctx context.Context, groupID types.GroupIdentifier) (types.SerializedGroupMasterKey, error) {
g, err := scanGroup(s.db.Conn(ctx).QueryRowContext(ctx, getGroupByIDQuery, s.ACI, groupID))
g, err := scanGroup(s.db.QueryRow(ctx, getGroupByIDQuery, s.ACI, groupID))
if g == nil {
return "", err
} else {
Expand All @@ -70,6 +70,6 @@ func (s *SQLStore) MasterKeyFromGroupIdentifier(ctx context.Context, groupID typ
}

func (s *SQLStore) StoreMasterKey(ctx context.Context, groupID types.GroupIdentifier, key types.SerializedGroupMasterKey) error {
_, err := s.db.Conn(ctx).ExecContext(ctx, upsertGroupMasterKeyQuery, s.ACI, groupID, key)
_, err := s.db.Exec(ctx, upsertGroupMasterKeyQuery, s.ACI, groupID, key)
return err
}
12 changes: 6 additions & 6 deletions pkg/signalmeow/store/identity_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ func scanIdentityKey(row dbutil.Scannable) (*libsignalgo.IdentityKey, error) {
}

func (s *SQLStore) GetIdentityKeyPair(ctx context.Context) (*libsignalgo.IdentityKeyPair, error) {
return scanIdentityKeyPair(s.db.Conn(ctx).QueryRowContext(ctx, getIdentityKeyPairQuery, s.ACI))
return scanIdentityKeyPair(s.db.QueryRow(ctx, getIdentityKeyPairQuery, s.ACI))
}

func (s *SQLStore) GetLocalRegistrationID(ctx context.Context) (uint32, error) {
var regID sql.NullInt64
err := s.db.Conn(ctx).QueryRowContext(ctx, getRegistrationLocalIDQuery, s.ACI).Scan(&regID)
err := s.db.QueryRow(ctx, getRegistrationLocalIDQuery, s.ACI).Scan(&regID)
if err != nil {
return 0, fmt.Errorf("failed to get local registration ID: %w", err)
}
Expand All @@ -97,7 +97,7 @@ func (s *SQLStore) SaveIdentityKey(ctx context.Context, address *libsignalgo.Add
if err != nil {
return false, fmt.Errorf("failed to get device ID: %w", err)
}
oldKey, err := scanIdentityKey(s.db.Conn(ctx).QueryRowContext(ctx, getIdentityKeyQuery, s.ACI, theirUUID, deviceID))
oldKey, err := scanIdentityKey(s.db.QueryRow(ctx, getIdentityKeyQuery, s.ACI, theirUUID, deviceID))
if err != nil {
return false, fmt.Errorf("failed to get old identity key: %w", err)
}
Expand All @@ -110,7 +110,7 @@ func (s *SQLStore) SaveIdentityKey(ctx context.Context, address *libsignalgo.Add
// We are replacing the old key if the old key exists, and it is not equal to the new key
replacing = !equal
}
_, err = s.db.Conn(ctx).ExecContext(ctx, insertIdentityKeyQuery, s.ACI, theirUUID, deviceID, serialized, trustLevel)
_, err = s.db.Exec(ctx, insertIdentityKeyQuery, s.ACI, theirUUID, deviceID, serialized, trustLevel)
if err != nil {
return replacing, fmt.Errorf("failed to insert new identity key: %w", err)
}
Expand All @@ -128,7 +128,7 @@ func (s *SQLStore) IsTrustedIdentity(ctx context.Context, address *libsignalgo.A
return false, fmt.Errorf("failed to get device ID: %w", err)
}
var trustLevel string
err = s.db.Conn(ctx).QueryRowContext(ctx, getIdentityKeyTrustLevelQuery, s.ACI, theirUUID, deviceID).Scan(&trustLevel)
err = s.db.QueryRow(ctx, getIdentityKeyTrustLevelQuery, s.ACI, theirUUID, deviceID).Scan(&trustLevel)
if errors.Is(err, sql.ErrNoRows) {
// If no rows, they are a new identity, so trust by default
return true, nil
Expand All @@ -148,7 +148,7 @@ func (s *SQLStore) GetIdentityKey(ctx context.Context, address *libsignalgo.Addr
if err != nil {
return nil, fmt.Errorf("failed to get device ID: %w", err)
}
key, err := scanIdentityKey(s.db.Conn(ctx).QueryRowContext(ctx, getIdentityKeyQuery, s.ACI, theirUUID, deviceID))
key, err := scanIdentityKey(s.db.QueryRow(ctx, getIdentityKeyQuery, s.ACI, theirUUID, deviceID))
if err != nil {
return nil, fmt.Errorf("failed to get identity key from database: %w", err)
}
Expand Down
Loading

0 comments on commit 959eb7e

Please sign in to comment.