diff --git a/commands.go b/commands.go index 5b880315..f06c429a 100644 --- a/commands.go +++ b/commands.go @@ -260,7 +260,7 @@ func fnSyncSpace(ce *WrappedCommandEvent) { if portal.IsPrivateChat() { continue } - if ce.Bridge.StateStore.IsInRoom(portal.MXID, ce.User.MXID) && portal.addToPersonalSpace(ctx, ce.User) { + if ce.Bridge.StateStore.IsInRoom(ctx, portal.MXID, ce.User.MXID) && portal.addToPersonalSpace(ctx, ce.User) { count++ } } diff --git a/custompuppet.go b/custompuppet.go index 31cc1194..5ef78429 100644 --- a/custompuppet.go +++ b/custompuppet.go @@ -58,7 +58,7 @@ func (puppet *Puppet) ClearCustomMXID() { } func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error { - newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(puppet.CustomMXID, puppet.AccessToken, reloginOnFail) + newIntent, newAccessToken, err := puppet.bridge.DoublePuppet.Setup(context.TODO(), puppet.CustomMXID, puppet.AccessToken, reloginOnFail) if err != nil { puppet.ClearCustomMXID() return err diff --git a/database/portal.go b/database/portal.go index 14e40bc5..31132c67 100644 --- a/database/portal.go +++ b/database/portal.go @@ -130,11 +130,11 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) { } func (pq *PortalQuery) FindPrivateChatsNotInSpace(ctx context.Context, receiver uuid.UUID) ([]PortalKey, error) { - rows, err := pq.GetDB().QueryContext(ctx, getChatsNotInSpaceQuery, receiver) + rows, err := pq.GetDB().Query(ctx, getChatsNotInSpaceQuery, receiver) if err != nil { return nil, err } - return dbutil.NewRowIter(rows, func(rows dbutil.Rows) (key PortalKey, err error) { + return dbutil.NewRowIter(rows, func(rows dbutil.Scannable) (key PortalKey, err error) { err = rows.Scan(&key.ChatID) key.Receiver = receiver return diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index c99efe8c..895be628 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -17,6 +17,7 @@ package upgrades import ( + "context" "embed" "errors" @@ -29,10 +30,10 @@ var Table dbutil.UpgradeTable var rawUpgrades embed.FS func init() { - Table.Register(-1, 12, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error { + Table.Register(-1, 12, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error { return errors.New("please upgrade to mautrix-signal v0.4.3 before upgrading to a newer version") }) - Table.Register(1, 13, 0, "Jump to version 13", false, func(tx dbutil.Execable, database *dbutil.Database) error { + Table.Register(1, 13, 0, "Jump to version 13", false, func(ctx context.Context, database *dbutil.Database) error { return nil }) Table.RegisterFS(rawUpgrades) diff --git a/database/userportal.go b/database/userportal.go index 3799f59f..c1ba828e 100644 --- a/database/userportal.go +++ b/database/userportal.go @@ -45,7 +45,7 @@ func (u *User) GetLastReadTS(ctx context.Context, portal PortalKey) uint64 { return cached } var ts int64 - err := u.qh.GetDB().QueryRowContext(ctx, getLastReadTSQuery, u.MXID, portal.ChatID, portal.Receiver).Scan(&ts) + err := u.qh.GetDB().QueryRow(ctx, getLastReadTSQuery, u.MXID, portal.ChatID, portal.Receiver).Scan(&ts) if err != nil && !errors.Is(err, sql.ErrNoRows) { zerolog.Ctx(ctx).Err(err). Str("user_id", u.MXID.String()). @@ -83,7 +83,7 @@ func (u *User) IsInSpace(ctx context.Context, portal PortalKey) bool { return cached } var inSpace bool - err := u.qh.GetDB().QueryRowContext(ctx, getIsInSpaceQuery, u.MXID, portal.ChatID, portal.Receiver).Scan(&inSpace) + err := u.qh.GetDB().QueryRow(ctx, getIsInSpaceQuery, u.MXID, portal.ChatID, portal.Receiver).Scan(&inSpace) if err != nil && !errors.Is(err, sql.ErrNoRows) { zerolog.Ctx(ctx).Err(err). Str("user_id", u.MXID.String()). diff --git a/go.mod b/go.mod index 71167af6..33de9f39 100644 --- a/go.mod +++ b/go.mod @@ -14,13 +14,13 @@ require ( github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.8.4 github.com/tidwall/gjson v1.17.0 - go.mau.fi/util v0.2.2-0.20240107131103-852f29430a02 + go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 golang.org/x/crypto v0.17.0 golang.org/x/exp v0.0.0-20231226003508-02704c960a9b golang.org/x/net v0.19.0 google.golang.org/protobuf v1.32.0 maunium.net/go/maulogger/v2 v2.4.1 - maunium.net/go/mautrix v0.16.3-0.20240104125737-88631708a41b + maunium.net/go/mautrix v0.16.3-0.20240107204502-25bc36bc7ae7 nhooyr.io/websocket v1.8.10 ) diff --git a/go.sum b/go.sum index 17e51cfc..ab06e6ab 100644 --- a/go.sum +++ b/go.sum @@ -65,8 +65,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.2.2-0.20240107131103-852f29430a02 h1:jREUBe6TF4a2HCGowTLzcvOFg44QDZ0xgoo+YJK3ugc= -go.mau.fi/util v0.2.2-0.20240107131103-852f29430a02/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= +go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 h1:CuR5LDSxBQLETorfwJ9vRtySeLHjMvJ7//lnCMw7Dy8= +go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= @@ -92,7 +92,7 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= -maunium.net/go/mautrix v0.16.3-0.20240104125737-88631708a41b h1:WWCD0vaAztVrrTRWcTXeOHq9U7HRcP2a1hs+0+guPPg= -maunium.net/go/mautrix v0.16.3-0.20240104125737-88631708a41b/go.mod h1:lI43hRW+/92FCqHLD5bINSPqsWrviZ5MpLl7J3hjvW4= +maunium.net/go/mautrix v0.16.3-0.20240107204502-25bc36bc7ae7 h1:Yo1S3mSazHoT/MHNheRMuRPH74rU6/ZyVaJqTEsmaN0= +maunium.net/go/mautrix v0.16.3-0.20240107204502-25bc36bc7ae7/go.mod h1:eRQu5ED1ODsP+xq1K9l1AOD+O9FMkAhodd/RVc3Bkqg= nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= diff --git a/main.go b/main.go index e8646330..2044fb11 100644 --- a/main.go +++ b/main.go @@ -172,7 +172,7 @@ func (br *SignalBridge) logLostPortals(ctx context.Context) { func (br *SignalBridge) Start() { go br.logLostPortals(context.TODO()) - err := br.MeowStore.Upgrade() + err := br.MeowStore.Upgrade(context.TODO()) if err != nil { br.Log.Fatalln("Failed to upgrade signalmeow database: %v", err) os.Exit(15) @@ -298,9 +298,9 @@ func (br *SignalBridge) createPrivatePortalFromInvite(ctx context.Context, roomI log.Err(err).Msg("Failed to enable e2be") } } - br.AS.StateStore.SetMembership(roomID, inviter.MXID, event.MembershipJoin) - br.AS.StateStore.SetMembership(roomID, puppet.MXID, event.MembershipJoin) - br.AS.StateStore.SetMembership(roomID, br.Bot.UserID, event.MembershipJoin) + br.AS.StateStore.SetMembership(ctx, roomID, inviter.MXID, event.MembershipJoin) + br.AS.StateStore.SetMembership(ctx, roomID, puppet.MXID, event.MembershipJoin) + br.AS.StateStore.SetMembership(ctx, roomID, br.Bot.UserID, event.MembershipJoin) portal.Encrypted = true } portal.UpdateDMInfo(ctx, true) diff --git a/metrics.go b/metrics.go index da188368..d8b5ace7 100644 --- a/metrics.go +++ b/metrics.go @@ -229,7 +229,7 @@ func (mh *MetricsHandler) TrackConnectionState(signalID string, connected bool) func (mh *MetricsHandler) updateStats() { start := time.Now() var puppetCount int - err := mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount) + err := mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM puppet").Scan(&puppetCount) if err != nil { mh.log.Warnln("Failed to scan number of puppets:", err) } else { @@ -237,7 +237,7 @@ func (mh *MetricsHandler) updateStats() { } var userCount int - err = mh.db.QueryRowContext(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount) + err = mh.db.QueryRow(mh.ctx, `SELECT COUNT(*) FROM "user"`).Scan(&userCount) if err != nil { mh.log.Warnln("Failed to scan number of users:", err) } else { @@ -245,7 +245,7 @@ func (mh *MetricsHandler) updateStats() { } var messageCount int - err = mh.db.QueryRowContext(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount) + err = mh.db.QueryRow(mh.ctx, "SELECT COUNT(*) FROM message").Scan(&messageCount) if err != nil { mh.log.Warnln("Failed to scan number of messages:", err) } else { @@ -255,7 +255,7 @@ func (mh *MetricsHandler) updateStats() { var encryptedGroupCount, encryptedPrivateCount, unencryptedGroupCount, unencryptedPrivateCount int // TODO Use a more precise way to check if a chat_id is a UUID. // It should also be compatible with both SQLite & Postgres. - err = mh.db.QueryRowContext(mh.ctx, ` + err = mh.db.QueryRow(mh.ctx, ` SELECT COUNT(CASE WHEN chat_id NOT LIKE '%-%-%-%-%' AND encrypted THEN 1 END) AS encrypted_group_portals, COUNT(CASE WHEN chat_id LIKE '%-%-%-%-%' AND encrypted THEN 1 END) AS encrypted_private_portals, diff --git a/pkg/signalmeow/store/contact_store.go b/pkg/signalmeow/store/contact_store.go index 1f465130..5718577b 100644 --- a/pkg/signalmeow/store/contact_store.go +++ b/pkg/signalmeow/store/contact_store.go @@ -111,25 +111,23 @@ func scanContact(row dbutil.Scannable) (*types.Contact, error) { } func (s *SQLStore) LoadContact(ctx context.Context, theirUUID uuid.UUID) (*types.Contact, error) { - return scanContact(s.db.Conn(ctx).QueryRowContext(ctx, getContactByUUIDQuery, s.ACI, theirUUID)) + return scanContact(s.db.QueryRow(ctx, getContactByUUIDQuery, s.ACI, theirUUID)) } func (s *SQLStore) LoadContactByE164(ctx context.Context, e164 string) (*types.Contact, error) { - return scanContact(s.db.Conn(ctx).QueryRowContext(ctx, getContactByPhoneQuery, s.ACI, e164)) + return scanContact(s.db.QueryRow(ctx, getContactByPhoneQuery, s.ACI, e164)) } func (s *SQLStore) AllContacts(ctx context.Context) ([]*types.Contact, error) { - rows, err := s.db.Conn(ctx).QueryContext(ctx, getAllContactsOfUserQuery, s.ACI) + rows, err := s.db.Query(ctx, getAllContactsOfUserQuery, s.ACI) if err != nil { return nil, err } - return dbutil.NewRowIter(rows, func(rows dbutil.Rows) (*types.Contact, error) { - return scanContact(rows) - }).AsList() + return dbutil.NewRowIter(rows, scanContact).AsList() } func (s *SQLStore) StoreContact(ctx context.Context, contact types.Contact) error { - _, err := s.db.Conn(ctx).ExecContext( + _, err := s.db.Exec( ctx, upsertContactQuery, s.ACI, diff --git a/pkg/signalmeow/store/container.go b/pkg/signalmeow/store/container.go index 1ef0fd5f..a173597b 100644 --- a/pkg/signalmeow/store/container.go +++ b/pkg/signalmeow/store/container.go @@ -40,8 +40,8 @@ FROM signalmeow_device const getDeviceQuery = getAllDevicesQuery + " WHERE aci_uuid=$1" -func (c *StoreContainer) Upgrade() error { - return c.db.Upgrade() +func (c *StoreContainer) Upgrade(ctx context.Context) error { + return c.db.Upgrade(ctx) } func (c *StoreContainer) scanDevice(row dbutil.Scannable) (*Device, error) { @@ -85,7 +85,7 @@ func (c *StoreContainer) scanDevice(row dbutil.Scannable) (*Device, error) { // GetAllDevices finds all the devices in the database. func (c *StoreContainer) GetAllDevices(ctx context.Context) ([]*Device, error) { - rows, err := c.db.Conn(ctx).QueryContext(ctx, getAllDevicesQuery) + rows, err := c.db.Query(ctx, getAllDevicesQuery) if err != nil { return nil, fmt.Errorf("failed to query sessions: %w", err) } @@ -104,7 +104,7 @@ func (c *StoreContainer) GetAllDevices(ctx context.Context) ([]*Device, error) { // GetDevice finds the device with the specified ACI UUID in the database. // If the device is not found, nil is returned instead. func (c *StoreContainer) DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, error) { - sess, err := c.scanDevice(c.db.Conn(ctx).QueryRowContext(ctx, getDeviceQuery, aci)) + sess, err := c.scanDevice(c.db.QueryRow(ctx, getDeviceQuery, aci)) if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -150,7 +150,7 @@ func (c *StoreContainer) PutDevice(ctx context.Context, device *DeviceData) erro zerolog.Ctx(ctx).Err(err).Msg("failed to serialize pni identity key pair") return err } - _, err = c.db.Conn(ctx).ExecContext(ctx, insertDeviceQuery, + _, err = c.db.Exec(ctx, insertDeviceQuery, device.ACI, aciIdentityKeyPair, device.RegistrationID, device.PNI, pniIdentityKeyPair, device.PNIRegistrationID, device.DeviceID, device.Number, device.Password, @@ -166,6 +166,6 @@ func (c *StoreContainer) DeleteDevice(ctx context.Context, device *DeviceData) e if device.ACI == uuid.Nil { return ErrDeviceIDMustBeSet } - _, err := c.db.Conn(ctx).ExecContext(ctx, deleteDeviceQuery, device.ACI) + _, err := c.db.Exec(ctx, deleteDeviceQuery, device.ACI) return err } diff --git a/pkg/signalmeow/store/group_store.go b/pkg/signalmeow/store/group_store.go index b584296e..5ee356ea 100644 --- a/pkg/signalmeow/store/group_store.go +++ b/pkg/signalmeow/store/group_store.go @@ -61,7 +61,7 @@ func scanGroup(row dbutil.Scannable) (*dbGroup, error) { } func (s *SQLStore) MasterKeyFromGroupIdentifier(ctx context.Context, groupID types.GroupIdentifier) (types.SerializedGroupMasterKey, error) { - g, err := scanGroup(s.db.Conn(ctx).QueryRowContext(ctx, getGroupByIDQuery, s.ACI, groupID)) + g, err := scanGroup(s.db.QueryRow(ctx, getGroupByIDQuery, s.ACI, groupID)) if g == nil { return "", err } else { @@ -70,6 +70,6 @@ func (s *SQLStore) MasterKeyFromGroupIdentifier(ctx context.Context, groupID typ } func (s *SQLStore) StoreMasterKey(ctx context.Context, groupID types.GroupIdentifier, key types.SerializedGroupMasterKey) error { - _, err := s.db.Conn(ctx).ExecContext(ctx, upsertGroupMasterKeyQuery, s.ACI, groupID, key) + _, err := s.db.Exec(ctx, upsertGroupMasterKeyQuery, s.ACI, groupID, key) return err } diff --git a/pkg/signalmeow/store/identity_store.go b/pkg/signalmeow/store/identity_store.go index f0176816..14276a17 100644 --- a/pkg/signalmeow/store/identity_store.go +++ b/pkg/signalmeow/store/identity_store.go @@ -71,12 +71,12 @@ func scanIdentityKey(row dbutil.Scannable) (*libsignalgo.IdentityKey, error) { } func (s *SQLStore) GetIdentityKeyPair(ctx context.Context) (*libsignalgo.IdentityKeyPair, error) { - return scanIdentityKeyPair(s.db.Conn(ctx).QueryRowContext(ctx, getIdentityKeyPairQuery, s.ACI)) + return scanIdentityKeyPair(s.db.QueryRow(ctx, getIdentityKeyPairQuery, s.ACI)) } func (s *SQLStore) GetLocalRegistrationID(ctx context.Context) (uint32, error) { var regID sql.NullInt64 - err := s.db.Conn(ctx).QueryRowContext(ctx, getRegistrationLocalIDQuery, s.ACI).Scan(®ID) + err := s.db.QueryRow(ctx, getRegistrationLocalIDQuery, s.ACI).Scan(®ID) if err != nil { return 0, fmt.Errorf("failed to get local registration ID: %w", err) } @@ -97,7 +97,7 @@ func (s *SQLStore) SaveIdentityKey(ctx context.Context, address *libsignalgo.Add if err != nil { return false, fmt.Errorf("failed to get device ID: %w", err) } - oldKey, err := scanIdentityKey(s.db.Conn(ctx).QueryRowContext(ctx, getIdentityKeyQuery, s.ACI, theirUUID, deviceID)) + oldKey, err := scanIdentityKey(s.db.QueryRow(ctx, getIdentityKeyQuery, s.ACI, theirUUID, deviceID)) if err != nil { return false, fmt.Errorf("failed to get old identity key: %w", err) } @@ -110,7 +110,7 @@ func (s *SQLStore) SaveIdentityKey(ctx context.Context, address *libsignalgo.Add // We are replacing the old key if the old key exists, and it is not equal to the new key replacing = !equal } - _, err = s.db.Conn(ctx).ExecContext(ctx, insertIdentityKeyQuery, s.ACI, theirUUID, deviceID, serialized, trustLevel) + _, err = s.db.Exec(ctx, insertIdentityKeyQuery, s.ACI, theirUUID, deviceID, serialized, trustLevel) if err != nil { return replacing, fmt.Errorf("failed to insert new identity key: %w", err) } @@ -128,7 +128,7 @@ func (s *SQLStore) IsTrustedIdentity(ctx context.Context, address *libsignalgo.A return false, fmt.Errorf("failed to get device ID: %w", err) } var trustLevel string - err = s.db.Conn(ctx).QueryRowContext(ctx, getIdentityKeyTrustLevelQuery, s.ACI, theirUUID, deviceID).Scan(&trustLevel) + err = s.db.QueryRow(ctx, getIdentityKeyTrustLevelQuery, s.ACI, theirUUID, deviceID).Scan(&trustLevel) if errors.Is(err, sql.ErrNoRows) { // If no rows, they are a new identity, so trust by default return true, nil @@ -148,7 +148,7 @@ func (s *SQLStore) GetIdentityKey(ctx context.Context, address *libsignalgo.Addr if err != nil { return nil, fmt.Errorf("failed to get device ID: %w", err) } - key, err := scanIdentityKey(s.db.Conn(ctx).QueryRowContext(ctx, getIdentityKeyQuery, s.ACI, theirUUID, deviceID)) + key, err := scanIdentityKey(s.db.QueryRow(ctx, getIdentityKeyQuery, s.ACI, theirUUID, deviceID)) if err != nil { return nil, fmt.Errorf("failed to get identity key from database: %w", err) } diff --git a/pkg/signalmeow/store/prekey_store.go b/pkg/signalmeow/store/prekey_store.go index 5f4f5723..5b4ee7b4 100644 --- a/pkg/signalmeow/store/prekey_store.go +++ b/pkg/signalmeow/store/prekey_store.go @@ -108,7 +108,7 @@ const ( func (s *SQLStore) KyberPreKey(ctx context.Context, uuidKind types.UUIDKind, preKeyID int) (*libsignalgo.KyberPreKeyRecord, error) { var record []byte var isLastResort bool - err := s.db.Conn(ctx).QueryRowContext(ctx, getKyberPreKeyQuery, s.ACI, preKeyID, uuidKind).Scan(&record, &isLastResort) + err := s.db.QueryRow(ctx, getKyberPreKeyQuery, s.ACI, preKeyID, uuidKind).Scan(&record, &isLastResort) if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -127,18 +127,18 @@ func (s *SQLStore) SaveKyberPreKey(ctx context.Context, uuidKind types.UUIDKind, if err != nil { return fmt.Errorf("failed to serialize kyber prekey record: %w", err) } - _, err = s.db.Conn(ctx).ExecContext(ctx, insertKyberPreKeyQuery, s.ACI, id, uuidKind, serialized, lastResort) + _, err = s.db.Exec(ctx, insertKyberPreKeyQuery, s.ACI, id, uuidKind, serialized, lastResort) return err } func (s *SQLStore) DeleteKyberPreKey(ctx context.Context, uuidKind types.UUIDKind, preKeyID int) error { - _, err := s.db.Conn(ctx).ExecContext(ctx, deleteKyberPreKeyQuery, s.ACI, preKeyID, uuidKind) + _, err := s.db.Exec(ctx, deleteKyberPreKeyQuery, s.ACI, preKeyID, uuidKind) return err } func (s *SQLStore) GetNextKyberPreKeyID(ctx context.Context, uuidKind types.UUIDKind) (uint, error) { var lastKeyID sql.NullInt64 - err := s.db.Conn(ctx).QueryRowContext(ctx, getLastKyberPreKeyIDQuery, s.ACI, uuidKind).Scan(&lastKeyID) + err := s.db.QueryRow(ctx, getLastKyberPreKeyIDQuery, s.ACI, uuidKind).Scan(&lastKeyID) if err != nil { return 0, fmt.Errorf("failed to query next kyber prekey ID: %w", err) } @@ -147,7 +147,7 @@ func (s *SQLStore) GetNextKyberPreKeyID(ctx context.Context, uuidKind types.UUID func (s *SQLStore) IsKyberPreKeyLastResort(ctx context.Context, uuidKind types.UUIDKind, preKeyID int) (bool, error) { var isLastResort bool - err := s.db.Conn(ctx).QueryRowContext(ctx, isLastResortQuery, s.ACI, preKeyID, uuidKind).Scan(&isLastResort) + err := s.db.QueryRow(ctx, isLastResortQuery, s.ACI, preKeyID, uuidKind).Scan(&isLastResort) if err != nil { return false, err } @@ -189,11 +189,11 @@ func scanSignedPreKey(row dbutil.Scannable) (*libsignalgo.SignedPreKeyRecord, er } func (s *SQLStore) PreKey(ctx context.Context, uuidKind types.UUIDKind, preKeyID int) (*libsignalgo.PreKeyRecord, error) { - return scanPreKey(s.db.Conn(ctx).QueryRowContext(ctx, getPreKeyQuery, s.ACI, preKeyID, uuidKind, false)) + return scanPreKey(s.db.QueryRow(ctx, getPreKeyQuery, s.ACI, preKeyID, uuidKind, false)) } func (s *SQLStore) SignedPreKey(ctx context.Context, uuidKind types.UUIDKind, preKeyID int) (*libsignalgo.SignedPreKeyRecord, error) { - return scanSignedPreKey(s.db.Conn(ctx).QueryRowContext(ctx, getPreKeyQuery, s.ACI, preKeyID, uuidKind, true)) + return scanSignedPreKey(s.db.QueryRow(ctx, getPreKeyQuery, s.ACI, preKeyID, uuidKind, true)) } func (s *SQLStore) SavePreKey(ctx context.Context, uuidKind types.UUIDKind, preKey *libsignalgo.PreKeyRecord, markUploaded bool) error { @@ -205,7 +205,7 @@ func (s *SQLStore) SavePreKey(ctx context.Context, uuidKind types.UUIDKind, preK if err != nil { return fmt.Errorf("failed to serialize prekey: %w", err) } - _, err = s.db.Conn(ctx).ExecContext(ctx, insertPreKeyQuery, s.ACI, id, uuidKind, false, serialized, markUploaded) + _, err = s.db.Exec(ctx, insertPreKeyQuery, s.ACI, id, uuidKind, false, serialized, markUploaded) return err } @@ -218,23 +218,23 @@ func (s *SQLStore) SaveSignedPreKey(ctx context.Context, uuidKind types.UUIDKind if err != nil { return fmt.Errorf("failed to serialize signed prekey: %w", err) } - _, err = s.db.Conn(ctx).ExecContext(ctx, insertPreKeyQuery, s.ACI, id, uuidKind, true, serialized, markUploaded) + _, err = s.db.Exec(ctx, insertPreKeyQuery, s.ACI, id, uuidKind, true, serialized, markUploaded) return err } func (s *SQLStore) DeletePreKey(ctx context.Context, uuidKind types.UUIDKind, preKeyID int) error { - _, err := s.db.Conn(ctx).ExecContext(ctx, deletePreKeyQuery, s.ACI, preKeyID, uuidKind, false) + _, err := s.db.Exec(ctx, deletePreKeyQuery, s.ACI, preKeyID, uuidKind, false) return err } func (s *SQLStore) DeleteSignedPreKey(ctx context.Context, uuidKind types.UUIDKind, preKeyID int) error { - _, err := s.db.Conn(ctx).ExecContext(ctx, deletePreKeyQuery, s.ACI, preKeyID, uuidKind, true) + _, err := s.db.Exec(ctx, deletePreKeyQuery, s.ACI, preKeyID, uuidKind, true) return err } func (s *SQLStore) GetNextPreKeyID(ctx context.Context, uuidKind types.UUIDKind) (uint, error) { var lastKeyID sql.NullInt64 - err := s.db.Conn(ctx).QueryRowContext(ctx, getLastPreKeyIDQuery, s.ACI, uuidKind, false).Scan(&lastKeyID) + err := s.db.QueryRow(ctx, getLastPreKeyIDQuery, s.ACI, uuidKind, false).Scan(&lastKeyID) if err != nil { return 0, fmt.Errorf("failed to query next prekey ID: %w", err) } @@ -243,7 +243,7 @@ func (s *SQLStore) GetNextPreKeyID(ctx context.Context, uuidKind types.UUIDKind) func (s *SQLStore) GetSignedNextPreKeyID(ctx context.Context, uuidKind types.UUIDKind) (uint, error) { var lastKeyID sql.NullInt64 - err := s.db.Conn(ctx).QueryRowContext(ctx, getLastPreKeyIDQuery, s.ACI, uuidKind, true).Scan(&lastKeyID) + err := s.db.QueryRow(ctx, getLastPreKeyIDQuery, s.ACI, uuidKind, true).Scan(&lastKeyID) if err != nil { return 0, fmt.Errorf("failed to query next signed prekey ID: %w", err) } @@ -251,22 +251,22 @@ func (s *SQLStore) GetSignedNextPreKeyID(ctx context.Context, uuidKind types.UUI } func (s *SQLStore) MarkPreKeysAsUploaded(ctx context.Context, uuidKind types.UUIDKind, upToID uint) error { - _, err := s.db.Conn(ctx).ExecContext(ctx, markPreKeysAsUploadedQuery, s.ACI, uuidKind, false, upToID) + _, err := s.db.Exec(ctx, markPreKeysAsUploadedQuery, s.ACI, uuidKind, false, upToID) return err } func (s *SQLStore) MarkSignedPreKeysAsUploaded(ctx context.Context, uuidKind types.UUIDKind, upToID uint) error { - _, err := s.db.Conn(ctx).ExecContext(ctx, markPreKeysAsUploadedQuery, s.ACI, uuidKind, true, upToID) + _, err := s.db.Exec(ctx, markPreKeysAsUploadedQuery, s.ACI, uuidKind, true, upToID) return err } func (s *SQLStore) DeleteAllPreKeys(ctx context.Context) error { return s.db.DoTxn(ctx, nil, func(ctx context.Context) error { - _, err := s.db.Conn(ctx).ExecContext(ctx, "DELETE FROM signalmeow_pre_keys WHERE aci_uuid=$1", s.ACI) + _, err := s.db.Exec(ctx, "DELETE FROM signalmeow_pre_keys WHERE aci_uuid=$1", s.ACI) if err != nil { return err } - _, err = s.db.Conn(ctx).ExecContext(ctx, "DELETE FROM signalmeow_kyber_pre_keys WHERE aci_uuid=$1", s.ACI) + _, err = s.db.Exec(ctx, "DELETE FROM signalmeow_kyber_pre_keys WHERE aci_uuid=$1", s.ACI) return err }) } diff --git a/pkg/signalmeow/store/profile_key_store.go b/pkg/signalmeow/store/profile_key_store.go index ba80ebe8..c384f4f5 100644 --- a/pkg/signalmeow/store/profile_key_store.go +++ b/pkg/signalmeow/store/profile_key_store.go @@ -55,14 +55,14 @@ func scanProfileKey(row dbutil.Scannable) (*libsignalgo.ProfileKey, error) { } func (s *SQLStore) LoadProfileKey(ctx context.Context, theirACI uuid.UUID) (*libsignalgo.ProfileKey, error) { - return scanProfileKey(s.db.Conn(ctx).QueryRowContext(ctx, loadProfileKeyQuery, s.ACI, theirACI)) + return scanProfileKey(s.db.QueryRow(ctx, loadProfileKeyQuery, s.ACI, theirACI)) } func (s *SQLStore) MyProfileKey(ctx context.Context) (*libsignalgo.ProfileKey, error) { - return scanProfileKey(s.db.Conn(ctx).QueryRowContext(ctx, loadProfileKeyQuery, s.ACI, s.ACI)) + return scanProfileKey(s.db.QueryRow(ctx, loadProfileKeyQuery, s.ACI, s.ACI)) } func (s *SQLStore) StoreProfileKey(ctx context.Context, theirACI uuid.UUID, key libsignalgo.ProfileKey) error { - _, err := s.db.Conn(ctx).ExecContext(ctx, storeProfileKeyQuery, s.ACI, theirACI, key.Slice()) + _, err := s.db.Exec(ctx, storeProfileKeyQuery, s.ACI, theirACI, key.Slice()) return err } diff --git a/pkg/signalmeow/store/sender_key_store.go b/pkg/signalmeow/store/sender_key_store.go index 9bfa5c28..e9ffea2b 100644 --- a/pkg/signalmeow/store/sender_key_store.go +++ b/pkg/signalmeow/store/sender_key_store.go @@ -55,7 +55,7 @@ func (s *SQLStore) LoadSenderKey(ctx context.Context, sender *libsignalgo.Addres if err != nil { return nil, fmt.Errorf("failed to get sender device ID: %w", err) } - return scanSenderKey(s.db.Conn(ctx).QueryRowContext(ctx, loadSenderKeyQuery, s.ACI, senderUUID, deviceID, distributionID)) + return scanSenderKey(s.db.QueryRow(ctx, loadSenderKeyQuery, s.ACI, senderUUID, deviceID, distributionID)) } func (s *SQLStore) StoreSenderKey(ctx context.Context, sender *libsignalgo.Address, distributionID uuid.UUID, record *libsignalgo.SenderKeyRecord) error { @@ -71,6 +71,6 @@ func (s *SQLStore) StoreSenderKey(ctx context.Context, sender *libsignalgo.Addre if err != nil { return fmt.Errorf("failed to serialize sender key: %w", err) } - _, err = s.db.Conn(ctx).ExecContext(ctx, storeSenderKeyQuery, s.ACI, senderUUID, deviceID, distributionID, serialized) + _, err = s.db.Exec(ctx, storeSenderKeyQuery, s.ACI, senderUUID, deviceID, distributionID, serialized) return err } diff --git a/pkg/signalmeow/store/session_store.go b/pkg/signalmeow/store/session_store.go index ca669263..5c190da0 100644 --- a/pkg/signalmeow/store/session_store.go +++ b/pkg/signalmeow/store/session_store.go @@ -69,12 +69,12 @@ func (s *SQLStore) RemoveSession(ctx context.Context, address *libsignalgo.Addre if err != nil { return fmt.Errorf("failed to get their device ID: %w", err) } - _, err = s.db.Conn(ctx).ExecContext(ctx, removeSessionQuery, s.ACI, theirUUID, deviceID) + _, err = s.db.Exec(ctx, removeSessionQuery, s.ACI, theirUUID, deviceID) return err } func (s *SQLStore) AllSessionsForUUID(ctx context.Context, theirUUID uuid.UUID) ([]*libsignalgo.Address, []*libsignalgo.SessionRecord, error) { - rows, err := s.db.Conn(ctx).QueryContext(ctx, allSessionsQuery, s.ACI, theirUUID) + rows, err := s.db.Query(ctx, allSessionsQuery, s.ACI, theirUUID) if err != nil { return nil, nil, err } @@ -105,7 +105,7 @@ func (s *SQLStore) LoadSession(ctx context.Context, address *libsignalgo.Address if err != nil { return nil, fmt.Errorf("failed to get their device ID: %w", err) } - _, record, err := scanRecord(s.db.Conn(ctx).QueryRowContext(ctx, loadSessionQuery, s.ACI, theirUUID, deviceID)) + _, record, err := scanRecord(s.db.QueryRow(ctx, loadSessionQuery, s.ACI, theirUUID, deviceID)) return record, err } @@ -122,11 +122,11 @@ func (s *SQLStore) StoreSession(ctx context.Context, address *libsignalgo.Addres if err != nil { return fmt.Errorf("failed to serialize session record: %w", err) } - _, err = s.db.Conn(ctx).ExecContext(ctx, storeSessionQuery, s.ACI, theirUUID, deviceID, serialized) + _, err = s.db.Exec(ctx, storeSessionQuery, s.ACI, theirUUID, deviceID, serialized) return err } func (s *SQLStore) RemoveAllSessions(ctx context.Context) error { - _, err := s.db.Conn(ctx).ExecContext(ctx, "DELETE FROM signalmeow_sessions WHERE our_aci_uuid=$1", s.ACI) + _, err := s.db.Exec(ctx, "DELETE FROM signalmeow_sessions WHERE our_aci_uuid=$1", s.ACI) return err } diff --git a/portal.go b/portal.go index b3d3035b..8b016b27 100644 --- a/portal.go +++ b/portal.go @@ -1371,7 +1371,7 @@ func (portal *Portal) encrypt(ctx context.Context, intent *appservice.IntentAPI, // TODO maybe the locking should be inside mautrix-go? portal.encryptLock.Lock() defer portal.encryptLock.Unlock() - err := portal.bridge.Crypto.Encrypt(portal.MXID, eventType, content) + err := portal.bridge.Crypto.Encrypt(ctx, portal.MXID, eventType, content) if err != nil { return eventType, fmt.Errorf("failed to encrypt event: %w", err) } @@ -1521,13 +1521,6 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, user *User, groupRev } portal.log.Info().Msg("Created matrix room for portal") - inviteMembership := event.MembershipInvite - if autoJoinInvites { - inviteMembership = event.MembershipJoin - } - for _, userID := range invite { - portal.bridge.StateStore.SetMembership(portal.MXID, userID, inviteMembership) - } if !autoJoinInvites { if !portal.IsPrivateChat() { portal.SyncParticipants(ctx, user, groupInfo) diff --git a/user.go b/user.go index 1e349b8a..cad7a166 100644 --- a/user.go +++ b/user.go @@ -229,7 +229,7 @@ func (user *User) GetIGhost() bridge.Ghost { func (user *User) ensureInvited(ctx context.Context, intent *appservice.IntentAPI, roomID id.RoomID, isDirect bool) (ok bool) { log := user.log.With().Str("action", "ensure_invited").Stringer("room_id", roomID).Logger() - if user.bridge.StateStore.GetMembership(roomID, user.MXID) == event.MembershipJoin { + if user.bridge.StateStore.IsMembership(ctx, roomID, user.MXID, event.MembershipJoin) { ok = true return } @@ -247,7 +247,10 @@ func (user *User) ensureInvited(ctx context.Context, intent *appservice.IntentAP _, err := intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{UserID: user.MXID}, extraContent) var httpErr mautrix.HTTPError if err != nil && errors.As(err, &httpErr) && httpErr.RespError != nil && strings.Contains(httpErr.RespError.Err, "is already in the room") { - user.bridge.StateStore.SetMembership(roomID, user.MXID, event.MembershipJoin) + err = user.bridge.StateStore.SetMembership(ctx, roomID, user.MXID, event.MembershipJoin) + if err != nil { + log.Warn().Err(err).Msg("Failed to update membership in state store") + } ok = true return } else if err != nil {