Skip to content

Commit

Permalink
Add context parameter to all client and bridge API functions (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
recht authored Dec 15, 2023
1 parent 7f78c32 commit 753cdb2
Show file tree
Hide file tree
Showing 30 changed files with 669 additions and 663 deletions.
3 changes: 2 additions & 1 deletion appservice/appservice_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package appservice

import (
"context"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -35,7 +36,7 @@ func TestClient_UnixSocket(t *testing.T) {
err = as.SetHomeserverURL(fmt.Sprintf("unix://%s", socket))
assert.NoError(t, err)
client := as.Client("user1")
resp, err := client.Whoami()
resp, err := client.Whoami(context.Background())
assert.NoError(t, err)
assert.Equal(t, "@joe:example.org", string(resp.UserID))
}
171 changes: 86 additions & 85 deletions appservice/intent.go

Large diffs are not rendered by default.

34 changes: 18 additions & 16 deletions bridge/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ type Crypto interface {
Decrypt(*event.Event) (*event.Event, error)
Encrypt(id.RoomID, event.Type, *event.Content) error
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
ResetSession(id.RoomID)
Init() error
Start()
Expand Down Expand Up @@ -287,9 +287,9 @@ func (br *Bridge) InitVersion(tag, commit, buildTime string) {

var MinSpecVersion = mautrix.SpecV11

func (br *Bridge) ensureConnection() {
func (br *Bridge) ensureConnection(ctx context.Context) {
for {
versions, err := br.Bot.Versions()
versions, err := br.Bot.Versions(ctx)
if err != nil {
br.ZLog.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...")
time.Sleep(10 * time.Second)
Expand All @@ -315,7 +315,7 @@ func (br *Bridge) ensureConnection() {
}
}

resp, err := br.Bot.Whoami()
resp, err := br.Bot.Whoami(ctx)
if err != nil {
if errors.Is(err, mautrix.MUnknownToken) {
br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?")
Expand Down Expand Up @@ -346,7 +346,7 @@ func (br *Bridge) ensureConnection() {
const maxRetries = 6
for {
txnID = br.Bot.TxnID()
pingResp, err = br.Bot.AppservicePing(br.Config.AppService.ID, txnID)
pingResp, err = br.Bot.AppservicePing(ctx, br.Config.AppService.ID, txnID)
if err == nil {
break
}
Expand Down Expand Up @@ -385,42 +385,42 @@ func (br *Bridge) ensureConnection() {
Msg("Homeserver -> bridge connection works")
}

func (br *Bridge) fetchMediaConfig() {
cfg, err := br.Bot.GetMediaConfig()
func (br *Bridge) fetchMediaConfig(ctx context.Context) {
cfg, err := br.Bot.GetMediaConfig(ctx)
if err != nil {
br.ZLog.Warn().Err(err).Msg("Failed to fetch media config")
} else {
br.MediaConfig = *cfg
}
}

func (br *Bridge) UpdateBotProfile() {
func (br *Bridge) UpdateBotProfile(ctx context.Context) {
br.ZLog.Debug().Msg("Updating bot profile")
botConfig := &br.Config.AppService.Bot

var err error
var mxc id.ContentURI
if botConfig.Avatar == "remove" {
err = br.Bot.SetAvatarURL(mxc)
err = br.Bot.SetAvatarURL(ctx, mxc)
} else if !botConfig.ParsedAvatar.IsEmpty() {
err = br.Bot.SetAvatarURL(botConfig.ParsedAvatar)
err = br.Bot.SetAvatarURL(ctx, botConfig.ParsedAvatar)
}
if err != nil {
br.ZLog.Warn().Err(err).Msg("Failed to update bot avatar")
}

if botConfig.Displayname == "remove" {
err = br.Bot.SetDisplayName("")
err = br.Bot.SetDisplayName(ctx, "")
} else if len(botConfig.Displayname) > 0 {
err = br.Bot.SetDisplayName(botConfig.Displayname)
err = br.Bot.SetDisplayName(ctx, botConfig.Displayname)
}
if err != nil {
br.ZLog.Warn().Err(err).Msg("Failed to update bot displayname")
}

if br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) && br.BeeperNetworkName != "" {
br.ZLog.Debug().Msg("Setting contact info on the appservice bot")
br.Bot.BeeperUpdateProfile(map[string]any{
br.Bot.BeeperUpdateProfile(ctx, map[string]any{
"com.beeper.bridge.service": br.BeeperServiceName,
"com.beeper.bridge.network": br.BeeperNetworkName,
"com.beeper.bridge.is_bridge_bot": true,
Expand Down Expand Up @@ -633,8 +633,10 @@ func (br *Bridge) start() {
os.Exit(23)
}
br.ZLog.Debug().Msg("Checking connection to homeserver")
br.ensureConnection()
go br.fetchMediaConfig()

ctx := context.Background()
br.ensureConnection(ctx)
go br.fetchMediaConfig(ctx)

if br.Crypto != nil {
err = br.Crypto.Init()
Expand All @@ -647,7 +649,7 @@ func (br *Bridge) start() {
br.ZLog.Debug().Msg("Starting event processor")
br.EventProcessor.Start()

go br.UpdateBotProfile()
go br.UpdateBotProfile(ctx)
if br.Crypto != nil {
go br.Crypto.Start()
}
Expand Down
3 changes: 2 additions & 1 deletion bridge/commands/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package commands

import (
"context"
"strconv"

"maunium.net/go/mautrix/id"
Expand Down Expand Up @@ -57,7 +58,7 @@ func fnSetPowerLevel(ce *Event) {
ce.Reply("**Usage:** `set-pl [user] <level>`")
return
}
_, err = ce.Portal.MainIntent().SetPowerLevel(ce.RoomID, userID, level)
_, err = ce.Portal.MainIntent().SetPowerLevel(context.Background(), ce.RoomID, userID, level)
if err != nil {
ce.Reply("Failed to set power levels: %v", err)
}
Expand Down
4 changes: 3 additions & 1 deletion bridge/commands/doublepuppet.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

package commands

import "context"

var CommandLoginMatrix = &FullHandler{
Func: fnLoginMatrix,
Name: "login-matrix",
Expand Down Expand Up @@ -54,7 +56,7 @@ func fnPingMatrix(ce *Event) {
ce.Reply("You are not logged in with your Matrix account.")
return
}
resp, err := puppet.CustomIntent().Whoami()
resp, err := puppet.CustomIntent().Whoami(context.Background())
if err != nil {
ce.Reply("Failed to validate Matrix login: %v", err)
} else {
Expand Down
8 changes: 4 additions & 4 deletions bridge/commands/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,31 @@ func (ce *Event) Reply(msg string, args ...interface{}) {
func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) {
content := format.RenderMarkdown(msg, allowMarkdown, allowHTML)
content.MsgType = event.MsgNotice
_, err := ce.MainIntent().SendMessageEvent(ce.RoomID, event.EventMessage, content)
_, err := ce.MainIntent().SendMessageEvent(context.Background(), ce.RoomID, event.EventMessage, content)
if err != nil {
ce.ZLog.Error().Err(err).Msgf("Failed to reply to command")
}
}

// React sends a reaction to the command.
func (ce *Event) React(key string) {
_, err := ce.MainIntent().SendReaction(ce.RoomID, ce.EventID, key)
_, err := ce.MainIntent().SendReaction(context.Background(), ce.RoomID, ce.EventID, key)
if err != nil {
ce.ZLog.Error().Err(err).Msgf("Failed to react to command")
}
}

// Redact redacts the command.
func (ce *Event) Redact(req ...mautrix.ReqRedact) {
_, err := ce.MainIntent().RedactEvent(ce.RoomID, ce.EventID, req...)
_, err := ce.MainIntent().RedactEvent(context.Background(), ce.RoomID, ce.EventID, req...)
if err != nil {
ce.ZLog.Error().Err(err).Msgf("Failed to redact command")
}
}

// MarkRead marks the command event as read.
func (ce *Event) MarkRead() {
err := ce.MainIntent().SendReceipt(ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil)
err := ce.MainIntent().SendReceipt(context.Background(), ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil)
if err != nil {
ce.ZLog.Error().Err(err).Msgf("Failed to mark command as read")
}
Expand Down
4 changes: 3 additions & 1 deletion bridge/commands/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
package commands

import (
"context"

"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/bridge/bridgeconfig"
"maunium.net/go/mautrix/event"
Expand Down Expand Up @@ -76,7 +78,7 @@ func (fh *FullHandler) ShowInHelp(ce *Event) bool {
}

func (fh *FullHandler) userHasRoomPermission(ce *Event) bool {
levels, err := ce.MainIntent().PowerLevels(ce.RoomID)
levels, err := ce.MainIntent().PowerLevels(context.Background(), ce.RoomID)
if err != nil {
ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels")
ce.Reply("Failed to get room power levels to see if you're allowed to use that command")
Expand Down
23 changes: 13 additions & 10 deletions bridge/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ func (helper *CryptoHelper) Init() error {
}

func (helper *CryptoHelper) resyncEncryptionInfo() {
ctx := context.Background()
log := helper.log.With().Str("action", "resync encryption event").Logger()
rows, err := helper.bridge.DB.Query(`SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
rows, err := helper.bridge.DB.QueryContext(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`)
if err != nil {
log.Err(err).Msg("Failed to query rooms for resync")
return
Expand All @@ -158,10 +159,10 @@ func (helper *CryptoHelper) resyncEncryptionInfo() {
log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms")
for _, roomID := range roomIDs {
var evt event.EncryptionEventContent
err = helper.client.StateEvent(roomID, event.StateEncryption, "", &evt)
err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt)
if err != nil {
log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event")
_, err = helper.bridge.DB.Exec(`
_, err = helper.bridge.DB.ExecContext(ctx, `
UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}'
`, roomID)
if err != nil {
Expand All @@ -182,7 +183,7 @@ func (helper *CryptoHelper) resyncEncryptionInfo() {
Int("max_messages", maxMessages).
Interface("content", &evt).
Msg("Resynced encryption event")
_, err = helper.bridge.DB.Exec(`
_, err = helper.bridge.DB.ExecContext(ctx, `
UPDATE crypto_megolm_inbound_session
SET max_age=$1, max_messages=$2
WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL
Expand Down Expand Up @@ -223,20 +224,21 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device
}

func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) {
ctx := context.Background()
deviceID := helper.store.FindDeviceID()
if len(deviceID) > 0 {
helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database")
}
// Create a new client instance with the default AS settings (including as_token),
// the Login call will then override the access token in the client.
client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID())
flows, err := client.GetLoginFlows()
flows, err := client.GetLoginFlows(ctx)
if err != nil {
return nil, deviceID != "", fmt.Errorf("failed to get supported login flows: %w", err)
} else if !flows.HasFlow(mautrix.AuthTypeAppservice) {
return nil, deviceID != "", fmt.Errorf("homeserver does not support appservice login")
}
resp, err := client.Login(&mautrix.ReqLogin{
resp, err := client.Login(ctx, &mautrix.ReqLogin{
Type: mautrix.AuthTypeAppservice,
Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser,
Expand All @@ -255,8 +257,9 @@ func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) {
}

func (helper *CryptoHelper) verifyKeysAreOnServer() {
ctx := context.Background()
helper.log.Debug().Msg("Making sure keys are still on server")
resp, err := helper.client.QueryKeys(&mautrix.ReqQueryKeys{
resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{
DeviceKeys: map[id.UserID]mautrix.DeviceIDList{
helper.client.UserID: {helper.client.DeviceID},
},
Expand Down Expand Up @@ -333,7 +336,7 @@ func (helper *CryptoHelper) Reset(startAfterReset bool) {
helper.log.Debug().Msg("Crypto syncer stopped, clearing database")
helper.clearDatabase()
helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions")
_, err := helper.client.LogoutAll()
_, err := helper.client.LogoutAll(context.Background())
if err != nil {
helper.log.Warn().Err(err).Msg("Failed to log out all devices")
}
Expand Down Expand Up @@ -395,13 +398,13 @@ func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.Sender
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
}

func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
helper.lock.RLock()
defer helper.lock.RUnlock()
if deviceID == "" {
deviceID = "*"
}
err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}})
err := helper.mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}})
if err != nil {
helper.log.Warn().Err(err).
Str("user_id", userID.String()).
Expand Down
Loading

0 comments on commit 753cdb2

Please sign in to comment.