diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index 50cc6fc4..eace1668 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -1,6 +1,7 @@ package appservice import ( + "context" "fmt" "net" "net/http" @@ -35,7 +36,7 @@ func TestClient_UnixSocket(t *testing.T) { err = as.SetHomeserverURL(fmt.Sprintf("unix://%s", socket)) assert.NoError(t, err) client := as.Client("user1") - resp, err := client.Whoami() + resp, err := client.Whoami(context.Background()) assert.NoError(t, err) assert.Equal(t, "@joe:example.org", string(resp.UserID)) } diff --git a/appservice/intent.go b/appservice/intent.go index 7995f44b..348eee2a 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -7,6 +7,7 @@ package appservice import ( + "context" "errors" "fmt" "strings" @@ -46,8 +47,8 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI { } } -func (intent *IntentAPI) Register() error { - _, _, err := intent.Client.Register(&mautrix.ReqRegister{ +func (intent *IntentAPI) Register(ctx context.Context) error { + _, _, err := intent.Client.Register(ctx, &mautrix.ReqRegister{ Username: intent.Localpart, Type: mautrix.AuthTypeAppservice, InhibitLogin: true, @@ -55,14 +56,14 @@ func (intent *IntentAPI) Register() error { return err } -func (intent *IntentAPI) EnsureRegistered() error { +func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error { intent.registerLock.Lock() defer intent.registerLock.Unlock() if intent.IsCustomPuppet || intent.as.StateStore.IsRegistered(intent.UserID) { return nil } - err := intent.Register() + err := intent.Register(ctx) if err != nil && !errors.Is(err, mautrix.MUserInUse) { return fmt.Errorf("failed to ensure registered: %w", err) } @@ -75,7 +76,7 @@ type EnsureJoinedParams struct { BotOverride *mautrix.Client } -func (intent *IntentAPI) EnsureJoined(roomID id.RoomID, extra ...EnsureJoinedParams) error { +func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...EnsureJoinedParams) error { var params EnsureJoinedParams if len(extra) > 1 { panic("invalid number of extra parameters") @@ -86,11 +87,11 @@ func (intent *IntentAPI) EnsureJoined(roomID id.RoomID, extra ...EnsureJoinedPar return nil } - if err := intent.EnsureRegistered(); err != nil { + if err := intent.EnsureRegistered(ctx); err != nil { return fmt.Errorf("failed to ensure joined: %w", err) } - resp, err := intent.JoinRoomByID(roomID) + resp, err := intent.JoinRoomByID(ctx, roomID) if err != nil { bot := intent.bot if params.BotOverride != nil { @@ -99,13 +100,13 @@ func (intent *IntentAPI) EnsureJoined(roomID id.RoomID, extra ...EnsureJoinedPar if !errors.Is(err, mautrix.MForbidden) || bot == nil { return fmt.Errorf("failed to ensure joined: %w", err) } - _, inviteErr := bot.InviteUser(roomID, &mautrix.ReqInviteUser{ + _, inviteErr := bot.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ UserID: intent.UserID, }) if inviteErr != nil { return fmt.Errorf("failed to invite in ensure joined: %w", inviteErr) } - resp, err = intent.JoinRoomByID(roomID) + resp, err = intent.JoinRoomByID(ctx, roomID) if err != nil { return fmt.Errorf("failed to ensure joined after invite: %w", err) } @@ -151,55 +152,55 @@ func (intent *IntentAPI) AddDoublePuppetValue(into interface{}) interface{} { } } -func (intent *IntentAPI) SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMessageEvent(roomID, eventType, contentJSON) + return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON) } -func (intent *IntentAPI) SendMassagedMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMessageEvent(roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) + return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) } -func (intent *IntentAPI) SendStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) { if eventType != event.StateMember || stateKey != string(intent.UserID) { - if err := intent.EnsureJoined(roomID); err != nil { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendStateEvent(roomID, eventType, stateKey, contentJSON) + return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON) } -func (intent *IntentAPI) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMassagedStateEvent(roomID, eventType, stateKey, contentJSON, ts) + return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts) } -func (intent *IntentAPI) StateEvent(roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return err } - return intent.Client.StateEvent(roomID, eventType, stateKey, outContent) + return intent.Client.StateEvent(ctx, roomID, eventType, stateKey, outContent) } -func (intent *IntentAPI) State(roomID id.RoomID) (mautrix.RoomStateMap, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) State(ctx context.Context, roomID id.RoomID) (mautrix.RoomStateMap, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - return intent.Client.State(roomID) + return intent.Client.State(ctx, roomID) } -func (intent *IntentAPI) SendCustomMembershipEvent(roomID id.RoomID, target id.UserID, membership event.Membership, reason string, extraContent ...map[string]interface{}) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID id.RoomID, target id.UserID, membership event.Membership, reason string, extraContent ...map[string]interface{}) (*mautrix.RespSendEvent, error) { content := &event.MemberEventContent{ Membership: membership, Reason: reason, @@ -211,7 +212,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(roomID id.RoomID, target id.U ok = memberContent != nil } if !ok { - profile, err := intent.GetProfile(target) + profile, err := intent.GetProfile(ctx, target) if err != nil { intent.Log.Debug().Err(err). Str("target_user_id", target.String()). @@ -231,21 +232,21 @@ func (intent *IntentAPI) SendCustomMembershipEvent(roomID id.RoomID, target id.U if len(extraContent) > 0 { extra = extraContent[0] } - return intent.SendStateEvent(roomID, event.StateMember, target.String(), &event.Content{ + return intent.SendStateEvent(ctx, roomID, event.StateMember, target.String(), &event.Content{ Parsed: content, Raw: extra, }) } -func (intent *IntentAPI) JoinRoomByID(roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) { +func (intent *IntentAPI) JoinRoomByID(ctx context.Context, roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { - _, err = intent.SendCustomMembershipEvent(roomID, intent.UserID, event.MembershipJoin, "", extraContent...) + _, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipJoin, "", extraContent...) return &mautrix.RespJoinRoom{}, err } - return intent.Client.JoinRoomByID(roomID) + return intent.Client.JoinRoomByID(ctx, roomID) } -func (intent *IntentAPI) LeaveRoom(roomID id.RoomID, extra ...interface{}) (resp *mautrix.RespLeaveRoom, err error) { +func (intent *IntentAPI) LeaveRoom(ctx context.Context, roomID id.RoomID, extra ...interface{}) (resp *mautrix.RespLeaveRoom, err error) { var extraContent map[string]interface{} leaveReq := &mautrix.ReqLeave{} for _, item := range extra { @@ -257,94 +258,94 @@ func (intent *IntentAPI) LeaveRoom(roomID id.RoomID, extra ...interface{}) (resp } } if intent.IsCustomPuppet || extraContent != nil { - _, err = intent.SendCustomMembershipEvent(roomID, intent.UserID, event.MembershipLeave, leaveReq.Reason, extraContent) + _, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipLeave, leaveReq.Reason, extraContent) return &mautrix.RespLeaveRoom{}, err } - return intent.Client.LeaveRoom(roomID, leaveReq) + return intent.Client.LeaveRoom(ctx, roomID, leaveReq) } -func (intent *IntentAPI) InviteUser(roomID id.RoomID, req *mautrix.ReqInviteUser, extraContent ...map[string]interface{}) (resp *mautrix.RespInviteUser, err error) { +func (intent *IntentAPI) InviteUser(ctx context.Context, roomID id.RoomID, req *mautrix.ReqInviteUser, extraContent ...map[string]interface{}) (resp *mautrix.RespInviteUser, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { - _, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipInvite, req.Reason, extraContent...) + _, err = intent.SendCustomMembershipEvent(ctx, roomID, req.UserID, event.MembershipInvite, req.Reason, extraContent...) return &mautrix.RespInviteUser{}, err } - return intent.Client.InviteUser(roomID, req) + return intent.Client.InviteUser(ctx, roomID, req) } -func (intent *IntentAPI) KickUser(roomID id.RoomID, req *mautrix.ReqKickUser, extraContent ...map[string]interface{}) (resp *mautrix.RespKickUser, err error) { +func (intent *IntentAPI) KickUser(ctx context.Context, roomID id.RoomID, req *mautrix.ReqKickUser, extraContent ...map[string]interface{}) (resp *mautrix.RespKickUser, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { - _, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...) + _, err = intent.SendCustomMembershipEvent(ctx, roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...) return &mautrix.RespKickUser{}, err } - return intent.Client.KickUser(roomID, req) + return intent.Client.KickUser(ctx, roomID, req) } -func (intent *IntentAPI) BanUser(roomID id.RoomID, req *mautrix.ReqBanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespBanUser, err error) { +func (intent *IntentAPI) BanUser(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespBanUser, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { - _, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipBan, req.Reason, extraContent...) + _, err = intent.SendCustomMembershipEvent(ctx, roomID, req.UserID, event.MembershipBan, req.Reason, extraContent...) return &mautrix.RespBanUser{}, err } - return intent.Client.BanUser(roomID, req) + return intent.Client.BanUser(ctx, roomID, req) } -func (intent *IntentAPI) UnbanUser(roomID id.RoomID, req *mautrix.ReqUnbanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespUnbanUser, err error) { +func (intent *IntentAPI) UnbanUser(ctx context.Context, roomID id.RoomID, req *mautrix.ReqUnbanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespUnbanUser, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { - _, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...) + _, err = intent.SendCustomMembershipEvent(ctx, roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...) return &mautrix.RespUnbanUser{}, err } - return intent.Client.UnbanUser(roomID, req) + return intent.Client.UnbanUser(ctx, roomID, req) } -func (intent *IntentAPI) Member(roomID id.RoomID, userID id.UserID) *event.MemberEventContent { +func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id.UserID) *event.MemberEventContent { member, ok := intent.as.StateStore.TryGetMember(roomID, userID) if !ok { - _ = intent.StateEvent(roomID, event.StateMember, string(userID), &member) + _ = intent.StateEvent(ctx, roomID, event.StateMember, string(userID), &member) } return member } -func (intent *IntentAPI) PowerLevels(roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) { +func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) { pl = intent.as.StateStore.GetPowerLevels(roomID) if pl == nil { pl = &event.PowerLevelsEventContent{} - err = intent.StateEvent(roomID, event.StatePowerLevels, "", pl) + err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl) } return } -func (intent *IntentAPI) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) (resp *mautrix.RespSendEvent, err error) { - return intent.SendStateEvent(roomID, event.StatePowerLevels, "", &levels) +func (intent *IntentAPI) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) (resp *mautrix.RespSendEvent, err error) { + return intent.SendStateEvent(ctx, roomID, event.StatePowerLevels, "", &levels) } -func (intent *IntentAPI) SetPowerLevel(roomID id.RoomID, userID id.UserID, level int) (*mautrix.RespSendEvent, error) { - pl, err := intent.PowerLevels(roomID) +func (intent *IntentAPI) SetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, level int) (*mautrix.RespSendEvent, error) { + pl, err := intent.PowerLevels(ctx, roomID) if err != nil { return nil, err } if pl.GetUserLevel(userID) != level { pl.SetUserLevel(userID, level) - return intent.SendStateEvent(roomID, event.StatePowerLevels, "", &pl) + return intent.SendStateEvent(ctx, roomID, event.StatePowerLevels, "", &pl) } return nil, nil } -func (intent *IntentAPI) SendText(roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) SendText(ctx context.Context, roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - return intent.Client.SendText(roomID, text) + return intent.Client.SendText(ctx, roomID, text) } -func (intent *IntentAPI) SendNotice(roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) SendNotice(ctx context.Context, roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - return intent.Client.SendNotice(roomID, text) + return intent.Client.SendNotice(ctx, roomID, text) } -func (intent *IntentAPI) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...mautrix.ReqRedact) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID, extra ...mautrix.ReqRedact) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } var req mautrix.ReqRedact @@ -352,65 +353,65 @@ func (intent *IntentAPI) RedactEvent(roomID id.RoomID, eventID id.EventID, extra req = extra[0] } intent.AddDoublePuppetValue(&req.Extra) - return intent.Client.RedactEvent(roomID, eventID, req) + return intent.Client.RedactEvent(ctx, roomID, eventID, req) } -func (intent *IntentAPI) SetRoomName(roomID id.RoomID, roomName string) (*mautrix.RespSendEvent, error) { - return intent.SendStateEvent(roomID, event.StateRoomName, "", map[string]interface{}{ +func (intent *IntentAPI) SetRoomName(ctx context.Context, roomID id.RoomID, roomName string) (*mautrix.RespSendEvent, error) { + return intent.SendStateEvent(ctx, roomID, event.StateRoomName, "", map[string]interface{}{ "name": roomName, }) } -func (intent *IntentAPI) SetRoomAvatar(roomID id.RoomID, avatarURL id.ContentURI) (*mautrix.RespSendEvent, error) { - return intent.SendStateEvent(roomID, event.StateRoomAvatar, "", map[string]interface{}{ +func (intent *IntentAPI) SetRoomAvatar(ctx context.Context, roomID id.RoomID, avatarURL id.ContentURI) (*mautrix.RespSendEvent, error) { + return intent.SendStateEvent(ctx, roomID, event.StateRoomAvatar, "", map[string]interface{}{ "url": avatarURL.String(), }) } -func (intent *IntentAPI) SetRoomTopic(roomID id.RoomID, topic string) (*mautrix.RespSendEvent, error) { - return intent.SendStateEvent(roomID, event.StateTopic, "", map[string]interface{}{ +func (intent *IntentAPI) SetRoomTopic(ctx context.Context, roomID id.RoomID, topic string) (*mautrix.RespSendEvent, error) { + return intent.SendStateEvent(ctx, roomID, event.StateTopic, "", map[string]interface{}{ "topic": topic, }) } -func (intent *IntentAPI) SetDisplayName(displayName string) error { - if err := intent.EnsureRegistered(); err != nil { +func (intent *IntentAPI) SetDisplayName(ctx context.Context, displayName string) error { + if err := intent.EnsureRegistered(ctx); err != nil { return err } - resp, err := intent.Client.GetOwnDisplayName() + resp, err := intent.Client.GetOwnDisplayName(ctx) if err != nil { return fmt.Errorf("failed to check current displayname: %w", err) } else if resp.DisplayName == displayName { // No need to update return nil } - return intent.Client.SetDisplayName(displayName) + return intent.Client.SetDisplayName(ctx, displayName) } -func (intent *IntentAPI) SetAvatarURL(avatarURL id.ContentURI) error { - if err := intent.EnsureRegistered(); err != nil { +func (intent *IntentAPI) SetAvatarURL(ctx context.Context, avatarURL id.ContentURI) error { + if err := intent.EnsureRegistered(ctx); err != nil { return err } - resp, err := intent.Client.GetOwnAvatarURL() + resp, err := intent.Client.GetOwnAvatarURL(ctx) if err != nil { return fmt.Errorf("failed to check current avatar URL: %w", err) } else if resp.FileID == avatarURL.FileID && resp.Homeserver == avatarURL.Homeserver { // No need to update return nil } - return intent.Client.SetAvatarURL(avatarURL) + return intent.Client.SetAvatarURL(ctx, avatarURL) } -func (intent *IntentAPI) Whoami() (*mautrix.RespWhoami, error) { - if err := intent.EnsureRegistered(); err != nil { +func (intent *IntentAPI) Whoami(ctx context.Context) (*mautrix.RespWhoami, error) { + if err := intent.EnsureRegistered(ctx); err != nil { return nil, err } - return intent.Client.Whoami() + return intent.Client.Whoami(ctx) } -func (intent *IntentAPI) EnsureInvited(roomID id.RoomID, userID id.UserID) error { +func (intent *IntentAPI) EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error { if !intent.as.StateStore.IsInvited(roomID, userID) { - _, err := intent.InviteUser(roomID, &mautrix.ReqInviteUser{ + _, err := intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ UserID: userID, }) if httpErr, ok := err.(mautrix.HTTPError); ok && diff --git a/bridge/bridge.go b/bridge/bridge.go index 291d6be9..763cb4e0 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -217,7 +217,7 @@ type Crypto interface { Decrypt(*event.Event) (*event.Event, error) Encrypt(id.RoomID, event.Type, *event.Content) error WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool - RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) + RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) ResetSession(id.RoomID) Init() error Start() @@ -287,9 +287,9 @@ func (br *Bridge) InitVersion(tag, commit, buildTime string) { var MinSpecVersion = mautrix.SpecV11 -func (br *Bridge) ensureConnection() { +func (br *Bridge) ensureConnection(ctx context.Context) { for { - versions, err := br.Bot.Versions() + versions, err := br.Bot.Versions(ctx) if err != nil { br.ZLog.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") time.Sleep(10 * time.Second) @@ -315,7 +315,7 @@ func (br *Bridge) ensureConnection() { } } - resp, err := br.Bot.Whoami() + resp, err := br.Bot.Whoami(ctx) if err != nil { if errors.Is(err, mautrix.MUnknownToken) { br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") @@ -346,7 +346,7 @@ func (br *Bridge) ensureConnection() { const maxRetries = 6 for { txnID = br.Bot.TxnID() - pingResp, err = br.Bot.AppservicePing(br.Config.AppService.ID, txnID) + pingResp, err = br.Bot.AppservicePing(ctx, br.Config.AppService.ID, txnID) if err == nil { break } @@ -385,8 +385,8 @@ func (br *Bridge) ensureConnection() { Msg("Homeserver -> bridge connection works") } -func (br *Bridge) fetchMediaConfig() { - cfg, err := br.Bot.GetMediaConfig() +func (br *Bridge) fetchMediaConfig(ctx context.Context) { + cfg, err := br.Bot.GetMediaConfig(ctx) if err != nil { br.ZLog.Warn().Err(err).Msg("Failed to fetch media config") } else { @@ -394,25 +394,25 @@ func (br *Bridge) fetchMediaConfig() { } } -func (br *Bridge) UpdateBotProfile() { +func (br *Bridge) UpdateBotProfile(ctx context.Context) { br.ZLog.Debug().Msg("Updating bot profile") botConfig := &br.Config.AppService.Bot var err error var mxc id.ContentURI if botConfig.Avatar == "remove" { - err = br.Bot.SetAvatarURL(mxc) + err = br.Bot.SetAvatarURL(ctx, mxc) } else if !botConfig.ParsedAvatar.IsEmpty() { - err = br.Bot.SetAvatarURL(botConfig.ParsedAvatar) + err = br.Bot.SetAvatarURL(ctx, botConfig.ParsedAvatar) } if err != nil { br.ZLog.Warn().Err(err).Msg("Failed to update bot avatar") } if botConfig.Displayname == "remove" { - err = br.Bot.SetDisplayName("") + err = br.Bot.SetDisplayName(ctx, "") } else if len(botConfig.Displayname) > 0 { - err = br.Bot.SetDisplayName(botConfig.Displayname) + err = br.Bot.SetDisplayName(ctx, botConfig.Displayname) } if err != nil { br.ZLog.Warn().Err(err).Msg("Failed to update bot displayname") @@ -420,7 +420,7 @@ func (br *Bridge) UpdateBotProfile() { if br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) && br.BeeperNetworkName != "" { br.ZLog.Debug().Msg("Setting contact info on the appservice bot") - br.Bot.BeeperUpdateProfile(map[string]any{ + br.Bot.BeeperUpdateProfile(ctx, map[string]any{ "com.beeper.bridge.service": br.BeeperServiceName, "com.beeper.bridge.network": br.BeeperNetworkName, "com.beeper.bridge.is_bridge_bot": true, @@ -633,8 +633,10 @@ func (br *Bridge) start() { os.Exit(23) } br.ZLog.Debug().Msg("Checking connection to homeserver") - br.ensureConnection() - go br.fetchMediaConfig() + + ctx := context.Background() + br.ensureConnection(ctx) + go br.fetchMediaConfig(ctx) if br.Crypto != nil { err = br.Crypto.Init() @@ -647,7 +649,7 @@ func (br *Bridge) start() { br.ZLog.Debug().Msg("Starting event processor") br.EventProcessor.Start() - go br.UpdateBotProfile() + go br.UpdateBotProfile(ctx) if br.Crypto != nil { go br.Crypto.Start() } diff --git a/bridge/commands/admin.go b/bridge/commands/admin.go index dde97de7..d07ada1a 100644 --- a/bridge/commands/admin.go +++ b/bridge/commands/admin.go @@ -7,6 +7,7 @@ package commands import ( + "context" "strconv" "maunium.net/go/mautrix/id" @@ -57,7 +58,7 @@ func fnSetPowerLevel(ce *Event) { ce.Reply("**Usage:** `set-pl [user] `") return } - _, err = ce.Portal.MainIntent().SetPowerLevel(ce.RoomID, userID, level) + _, err = ce.Portal.MainIntent().SetPowerLevel(context.Background(), ce.RoomID, userID, level) if err != nil { ce.Reply("Failed to set power levels: %v", err) } diff --git a/bridge/commands/doublepuppet.go b/bridge/commands/doublepuppet.go index 8c2e611e..9501d01f 100644 --- a/bridge/commands/doublepuppet.go +++ b/bridge/commands/doublepuppet.go @@ -6,6 +6,8 @@ package commands +import "context" + var CommandLoginMatrix = &FullHandler{ Func: fnLoginMatrix, Name: "login-matrix", @@ -54,7 +56,7 @@ func fnPingMatrix(ce *Event) { ce.Reply("You are not logged in with your Matrix account.") return } - resp, err := puppet.CustomIntent().Whoami() + resp, err := puppet.CustomIntent().Whoami(context.Background()) if err != nil { ce.Reply("Failed to validate Matrix login: %v", err) } else { diff --git a/bridge/commands/event.go b/bridge/commands/event.go index 0adc9237..24cf2eb9 100644 --- a/bridge/commands/event.go +++ b/bridge/commands/event.go @@ -67,7 +67,7 @@ func (ce *Event) Reply(msg string, args ...interface{}) { func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) content.MsgType = event.MsgNotice - _, err := ce.MainIntent().SendMessageEvent(ce.RoomID, event.EventMessage, content) + _, err := ce.MainIntent().SendMessageEvent(context.Background(), ce.RoomID, event.EventMessage, content) if err != nil { ce.ZLog.Error().Err(err).Msgf("Failed to reply to command") } @@ -75,7 +75,7 @@ func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { // React sends a reaction to the command. func (ce *Event) React(key string) { - _, err := ce.MainIntent().SendReaction(ce.RoomID, ce.EventID, key) + _, err := ce.MainIntent().SendReaction(context.Background(), ce.RoomID, ce.EventID, key) if err != nil { ce.ZLog.Error().Err(err).Msgf("Failed to react to command") } @@ -83,7 +83,7 @@ func (ce *Event) React(key string) { // Redact redacts the command. func (ce *Event) Redact(req ...mautrix.ReqRedact) { - _, err := ce.MainIntent().RedactEvent(ce.RoomID, ce.EventID, req...) + _, err := ce.MainIntent().RedactEvent(context.Background(), ce.RoomID, ce.EventID, req...) if err != nil { ce.ZLog.Error().Err(err).Msgf("Failed to redact command") } @@ -91,7 +91,7 @@ func (ce *Event) Redact(req ...mautrix.ReqRedact) { // MarkRead marks the command event as read. func (ce *Event) MarkRead() { - err := ce.MainIntent().SendReceipt(ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) + err := ce.MainIntent().SendReceipt(context.Background(), ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) if err != nil { ce.ZLog.Error().Err(err).Msgf("Failed to mark command as read") } diff --git a/bridge/commands/handler.go b/bridge/commands/handler.go index d158191a..cfed683b 100644 --- a/bridge/commands/handler.go +++ b/bridge/commands/handler.go @@ -7,6 +7,8 @@ package commands import ( + "context" + "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/bridge/bridgeconfig" "maunium.net/go/mautrix/event" @@ -76,7 +78,7 @@ func (fh *FullHandler) ShowInHelp(ce *Event) bool { } func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { - levels, err := ce.MainIntent().PowerLevels(ce.RoomID) + levels, err := ce.MainIntent().PowerLevels(context.Background(), ce.RoomID) if err != nil { ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels") ce.Reply("Failed to get room power levels to see if you're allowed to use that command") diff --git a/bridge/crypto.go b/bridge/crypto.go index 065bc017..73e5dbf8 100644 --- a/bridge/crypto.go +++ b/bridge/crypto.go @@ -137,8 +137,9 @@ func (helper *CryptoHelper) Init() error { } func (helper *CryptoHelper) resyncEncryptionInfo() { + ctx := context.Background() log := helper.log.With().Str("action", "resync encryption event").Logger() - rows, err := helper.bridge.DB.Query(`SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) + rows, err := helper.bridge.DB.QueryContext(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) if err != nil { log.Err(err).Msg("Failed to query rooms for resync") return @@ -158,10 +159,10 @@ func (helper *CryptoHelper) resyncEncryptionInfo() { log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") for _, roomID := range roomIDs { var evt event.EncryptionEventContent - err = helper.client.StateEvent(roomID, event.StateEncryption, "", &evt) + err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) if err != nil { log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") - _, err = helper.bridge.DB.Exec(` + _, err = helper.bridge.DB.ExecContext(ctx, ` UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' `, roomID) if err != nil { @@ -182,7 +183,7 @@ func (helper *CryptoHelper) resyncEncryptionInfo() { Int("max_messages", maxMessages). Interface("content", &evt). Msg("Resynced encryption event") - _, err = helper.bridge.DB.Exec(` + _, err = helper.bridge.DB.ExecContext(ctx, ` UPDATE crypto_megolm_inbound_session SET max_age=$1, max_messages=$2 WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL @@ -223,6 +224,7 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device } func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) { + ctx := context.Background() deviceID := helper.store.FindDeviceID() if len(deviceID) > 0 { helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") @@ -230,13 +232,13 @@ func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) { // Create a new client instance with the default AS settings (including as_token), // the Login call will then override the access token in the client. client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID()) - flows, err := client.GetLoginFlows() + flows, err := client.GetLoginFlows(ctx) if err != nil { return nil, deviceID != "", fmt.Errorf("failed to get supported login flows: %w", err) } else if !flows.HasFlow(mautrix.AuthTypeAppservice) { return nil, deviceID != "", fmt.Errorf("homeserver does not support appservice login") } - resp, err := client.Login(&mautrix.ReqLogin{ + resp, err := client.Login(ctx, &mautrix.ReqLogin{ Type: mautrix.AuthTypeAppservice, Identifier: mautrix.UserIdentifier{ Type: mautrix.IdentifierTypeUser, @@ -255,8 +257,9 @@ func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) { } func (helper *CryptoHelper) verifyKeysAreOnServer() { + ctx := context.Background() helper.log.Debug().Msg("Making sure keys are still on server") - resp, err := helper.client.QueryKeys(&mautrix.ReqQueryKeys{ + resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ helper.client.UserID: {helper.client.DeviceID}, }, @@ -333,7 +336,7 @@ func (helper *CryptoHelper) Reset(startAfterReset bool) { helper.log.Debug().Msg("Crypto syncer stopped, clearing database") helper.clearDatabase() helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions") - _, err := helper.client.LogoutAll() + _, err := helper.client.LogoutAll(context.Background()) if err != nil { helper.log.Warn().Err(err).Msg("Failed to log out all devices") } @@ -395,13 +398,13 @@ func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.Sender return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout) } -func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { +func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { helper.lock.RLock() defer helper.lock.RUnlock() if deviceID == "" { deviceID = "*" } - err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}}) + err := helper.mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}}) if err != nil { helper.log.Warn().Err(err). Str("user_id", userID.String()). diff --git a/bridge/doublepuppet.go b/bridge/doublepuppet.go index 7ddc1989..35903efd 100644 --- a/bridge/doublepuppet.go +++ b/bridge/doublepuppet.go @@ -7,6 +7,7 @@ package bridge import ( + "context" "crypto/hmac" "crypto/sha512" "encoding/hex" @@ -26,7 +27,7 @@ type doublePuppetUtil struct { log zerolog.Logger } -func (dp *doublePuppetUtil) newClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) { +func (dp *doublePuppetUtil) newClient(ctx context.Context, mxid id.UserID, accessToken string) (*mautrix.Client, error) { _, homeserver, err := mxid.Parse() if err != nil { return nil, err @@ -36,7 +37,7 @@ func (dp *doublePuppetUtil) newClient(mxid id.UserID, accessToken string) (*maut if homeserver == dp.br.AS.HomeserverDomain { homeserverURL = "" } else if dp.br.Config.Bridge.GetDoublePuppetConfig().AllowDiscovery { - resp, err := mautrix.DiscoverClientAPI(homeserver) + resp, err := mautrix.DiscoverClientAPI(ctx, homeserver) if err != nil { return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) } @@ -53,8 +54,8 @@ func (dp *doublePuppetUtil) newClient(mxid id.UserID, accessToken string) (*maut return dp.br.AS.NewExternalMautrixClient(mxid, accessToken, homeserverURL) } -func (dp *doublePuppetUtil) newIntent(mxid id.UserID, accessToken string) (*appservice.IntentAPI, error) { - client, err := dp.newClient(mxid, accessToken) +func (dp *doublePuppetUtil) newIntent(ctx context.Context, mxid id.UserID, accessToken string) (*appservice.IntentAPI, error) { + client, err := dp.newClient(ctx, mxid, accessToken) if err != nil { return nil, err } @@ -67,9 +68,9 @@ func (dp *doublePuppetUtil) newIntent(mxid id.UserID, accessToken string) (*apps return ia, nil } -func (dp *doublePuppetUtil) autoLogin(mxid id.UserID, loginSecret string) (string, error) { +func (dp *doublePuppetUtil) autoLogin(ctx context.Context, mxid id.UserID, loginSecret string) (string, error) { dp.log.Debug().Str("user_id", mxid.String()).Msg("Logging into user account with shared secret") - client, err := dp.newClient(mxid, "") + client, err := dp.newClient(ctx, mxid, "") if err != nil { return "", fmt.Errorf("failed to create mautrix client to log in: %v", err) } @@ -83,7 +84,7 @@ func (dp *doublePuppetUtil) autoLogin(mxid id.UserID, loginSecret string) (strin client.AccessToken = dp.br.AS.Registration.AppToken req.Type = mautrix.AuthTypeAppservice } else { - loginFlows, err := client.GetLoginFlows() + loginFlows, err := client.GetLoginFlows(ctx) if err != nil { return "", fmt.Errorf("failed to get supported login flows: %w", err) } @@ -101,7 +102,7 @@ func (dp *doublePuppetUtil) autoLogin(mxid id.UserID, loginSecret string) (strin return "", fmt.Errorf("no supported auth types for shared secret auth found") } } - resp, err := client.Login(&req) + resp, err := client.Login(ctx, &req) if err != nil { return "", err } @@ -122,18 +123,19 @@ func (dp *doublePuppetUtil) Setup(mxid id.UserID, savedAccessToken string, relog err = ErrNoMXID return } + ctx := context.Background() _, homeserver, _ := mxid.Parse() loginSecret, hasSecret := dp.br.Config.Bridge.GetDoublePuppetConfig().SharedSecretMap[homeserver] // Special case appservice: prefix to not login and use it as an as_token directly. if hasSecret && strings.HasPrefix(loginSecret, asTokenModePrefix) { - intent, err = dp.newIntent(mxid, strings.TrimPrefix(loginSecret, asTokenModePrefix)) + intent, err = dp.newIntent(ctx, mxid, strings.TrimPrefix(loginSecret, asTokenModePrefix)) if err != nil { return } intent.SetAppServiceUserID = true if savedAccessToken != useConfigASToken { var resp *mautrix.RespWhoami - resp, err = intent.Whoami() + resp, err = intent.Whoami(ctx) if err == nil && resp.UserID != mxid { err = ErrMismatchingMXID } @@ -142,7 +144,7 @@ func (dp *doublePuppetUtil) Setup(mxid id.UserID, savedAccessToken string, relog } if savedAccessToken == "" || savedAccessToken == useConfigASToken { if reloginOnFail && hasSecret { - savedAccessToken, err = dp.autoLogin(mxid, loginSecret) + savedAccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) } else { err = ErrNoAccessToken } @@ -150,15 +152,15 @@ func (dp *doublePuppetUtil) Setup(mxid id.UserID, savedAccessToken string, relog return } } - intent, err = dp.newIntent(mxid, savedAccessToken) + intent, err = dp.newIntent(ctx, mxid, savedAccessToken) if err != nil { return } var resp *mautrix.RespWhoami - resp, err = intent.Whoami() + resp, err = intent.Whoami(ctx) if err != nil { if reloginOnFail && hasSecret && errors.Is(err, mautrix.MUnknownToken) { - intent.AccessToken, err = dp.autoLogin(mxid, loginSecret) + intent.AccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) if err == nil { newAccessToken = intent.AccessToken } diff --git a/bridge/matrix.go b/bridge/matrix.go index 3196af60..f9a86d80 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -87,7 +87,7 @@ func (mx *MatrixHandler) HandleEncryption(evt *event.Event) { Msg("Encryption was enabled in room") portal.MarkEncrypted() if portal.IsPrivateChat() { - err := mx.as.BotIntent().EnsureJoined(evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) + err := mx.as.BotIntent().EnsureJoined(context.Background(), evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) if err != nil { mx.log.Err(err). Str("room_id", evt.RoomID.String()). @@ -99,32 +99,32 @@ func (mx *MatrixHandler) HandleEncryption(evt *event.Event) { func (mx *MatrixHandler) joinAndCheckMembers(ctx context.Context, evt *event.Event, intent *appservice.IntentAPI) *mautrix.RespJoinedMembers { log := zerolog.Ctx(ctx) - resp, err := intent.JoinRoomByID(evt.RoomID) + resp, err := intent.JoinRoomByID(ctx, evt.RoomID) if err != nil { log.Warn().Err(err).Msg("Failed to join room with invite") return nil } - members, err := intent.JoinedMembers(resp.RoomID) + members, err := intent.JoinedMembers(ctx, resp.RoomID) if err != nil { log.Warn().Err(err).Msg("Failed to get members in room after accepting invite, leaving room") - _, _ = intent.LeaveRoom(resp.RoomID) + _, _ = intent.LeaveRoom(ctx, resp.RoomID) return nil } if len(members.Joined) < 2 { log.Debug().Msg("Leaving empty room after accepting invite") - _, _ = intent.LeaveRoom(resp.RoomID) + _, _ = intent.LeaveRoom(ctx, resp.RoomID) return nil } return members } -func (mx *MatrixHandler) sendNoticeWithMarkdown(roomID id.RoomID, message string) (*mautrix.RespSendEvent, error) { +func (mx *MatrixHandler) sendNoticeWithMarkdown(ctx context.Context, roomID id.RoomID, message string) (*mautrix.RespSendEvent, error) { intent := mx.as.BotIntent() content := format.RenderMarkdown(message, true, false) content.MsgType = event.MsgNotice - return intent.SendMessageEvent(roomID, event.EventMessage, content) + return intent.SendMessageEvent(ctx, roomID, event.EventMessage, content) } func (mx *MatrixHandler) HandleBotInvite(ctx context.Context, evt *event.Event) { @@ -141,31 +141,31 @@ func (mx *MatrixHandler) HandleBotInvite(ctx context.Context, evt *event.Event) } if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser { - _, _ = intent.SendNotice(evt.RoomID, "You are not whitelisted to use this bridge.\n"+ + _, _ = intent.SendNotice(ctx, evt.RoomID, "You are not whitelisted to use this bridge.\n"+ "If you're the owner of this bridge, see the bridge.permissions section in your config file.") - _, _ = intent.LeaveRoom(evt.RoomID) + _, _ = intent.LeaveRoom(ctx, evt.RoomID) return } texts := mx.bridge.Config.Bridge.GetManagementRoomTexts() - _, _ = mx.sendNoticeWithMarkdown(evt.RoomID, texts.Welcome) + _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.Welcome) if len(members.Joined) == 2 && (len(user.GetManagementRoomID()) == 0 || evt.Content.AsMember().IsDirect) { user.SetManagementRoom(evt.RoomID) - _, _ = intent.SendNotice(user.GetManagementRoomID(), "This room has been registered as your bridge management/status room.") + _, _ = intent.SendNotice(ctx, user.GetManagementRoomID(), "This room has been registered as your bridge management/status room.") zerolog.Ctx(ctx).Debug().Msg("Registered room as management room with inviter") } if evt.RoomID == user.GetManagementRoomID() { if user.IsLoggedIn() { - _, _ = mx.sendNoticeWithMarkdown(evt.RoomID, texts.WelcomeConnected) + _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.WelcomeConnected) } else { - _, _ = mx.sendNoticeWithMarkdown(evt.RoomID, texts.WelcomeUnconnected) + _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.WelcomeUnconnected) } additionalHelp := texts.AdditionalHelp if len(additionalHelp) > 0 { - _, _ = mx.sendNoticeWithMarkdown(evt.RoomID, additionalHelp) + _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, additionalHelp) } } } @@ -176,7 +176,7 @@ func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event if inviter.GetPermissionLevel() < bridgeconfig.PermissionLevelUser { log.Debug().Msg("Rejecting invite: inviter is not whitelisted") - _, err := intent.LeaveRoom(evt.RoomID, &mautrix.ReqLeave{ + _, err := intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ Reason: "You're not whitelisted to use this bridge", }) if err != nil { @@ -185,7 +185,7 @@ func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event return } else if !inviter.IsLoggedIn() { log.Debug().Msg("Rejecting invite: inviter is not logged in") - _, err := intent.LeaveRoom(evt.RoomID, &mautrix.ReqLeave{ + _, err := intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ Reason: "You're not logged into this bridge", }) if err != nil { @@ -199,11 +199,11 @@ func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event return } var createEvent event.CreateEventContent - if err := intent.StateEvent(evt.RoomID, event.StateCreate, "", &createEvent); err != nil { + if err := intent.StateEvent(ctx, evt.RoomID, event.StateCreate, "", &createEvent); err != nil { log.Warn().Err(err).Msg("Failed to check m.room.create event in room") } else if createEvent.Type != "" { log.Warn().Str("room_type", string(createEvent.Type)).Msg("Non-standard room type, leaving room") - _, err = intent.LeaveRoom(evt.RoomID, &mautrix.ReqLeave{ + _, err = intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ Reason: "Unsupported room type", }) if err != nil { @@ -225,10 +225,10 @@ func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event mx.bridge.Child.CreatePrivatePortal(evt.RoomID, inviter, ghost) } else if !hasBridgeBot { log.Debug().Msg("Leaving multi-user room after accepting invite") - _, _ = intent.SendNotice(evt.RoomID, "Please invite the bridge bot first if you want to bridge to a remote chat.") - _, _ = intent.LeaveRoom(evt.RoomID) + _, _ = intent.SendNotice(ctx, evt.RoomID, "Please invite the bridge bot first if you want to bridge to a remote chat.") + _, _ = intent.LeaveRoom(ctx, evt.RoomID) } else { - _, _ = intent.SendNotice(evt.RoomID, "This puppet will remain inactive until this room is bridged to a remote chat.") + _, _ = intent.SendNotice(ctx, evt.RoomID, "This puppet will remain inactive until this room is bridged to a remote chat.") } } @@ -237,12 +237,12 @@ func (mx *MatrixHandler) HandleMembership(evt *event.Event) { return } defer mx.TrackEventDuration(evt.Type)() + ctx := context.Background() if mx.bridge.Crypto != nil { mx.bridge.Crypto.HandleMemberEvent(evt) } - ctx := context.Background() log := mx.log.With(). Str("sender", evt.Sender.String()). Str("target", evt.GetStateKey()). @@ -358,7 +358,7 @@ func (mx *MatrixHandler) sendCryptoStatusError(ctx context.Context, evt *event.E if !isFinal { statusEvent.Status = event.MessageStatusPending } - _, sendErr := mx.bridge.Bot.SendMessageEvent(evt.RoomID, event.BeeperMessageStatus, statusEvent) + _, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, statusEvent) if sendErr != nil { zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to send message status event") } @@ -377,7 +377,7 @@ func (mx *MatrixHandler) sendCryptoStatusError(ctx context.Context, evt *event.E } else if ok && relatable.OptionalGetRelatesTo().GetThreadParent() != "" { update.GetRelatesTo().SetThread(relatable.OptionalGetRelatesTo().GetThreadParent(), evt.ID) } - resp, sendErr := mx.bridge.Bot.SendMessageEvent(evt.RoomID, event.EventMessage, &update) + resp, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, &update) if sendErr != nil { zerolog.Ctx(ctx).Error().Err(sendErr).Msg("Failed to send decryption error notice") } else if resp != nil { @@ -471,7 +471,7 @@ func (mx *MatrixHandler) postDecrypt(ctx context.Context, original, decrypted *e decrypted.Mautrix.DecryptionDuration = duration mx.bridge.EventProcessor.Dispatch(decrypted) if errorEventID != "" { - _, _ = mx.bridge.Bot.RedactEvent(decrypted.RoomID, errorEventID) + _, _ = mx.bridge.Bot.RedactEvent(ctx, decrypted.RoomID, errorEventID) } } @@ -526,7 +526,7 @@ func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Ev Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). Msg("Couldn't find session, requesting keys and waiting longer...") - go mx.bridge.Crypto.RequestSession(evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) + go mx.bridge.Crypto.RequestSession(context.Background(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) errorEventID := mx.sendCryptoStatusError(ctx, evt, "", fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), 1, false) if !mx.bridge.Crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -587,7 +587,7 @@ func (mx *MatrixHandler) HandleMessage(evt *event.Event) { }, Status: event.MessageStatusSuccess, } - _, sendErr := mx.bridge.Bot.SendMessageEvent(evt.RoomID, event.BeeperMessageStatus, statusEvent) + _, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, statusEvent) if sendErr != nil { log.Warn().Err(sendErr).Msg("Failed to send message status event for command") } diff --git a/client.go b/client.go index 17720026..0aff8734 100644 --- a/client.go +++ b/client.go @@ -30,7 +30,7 @@ type CryptoHelper interface { Encrypt(id.RoomID, event.Type, any) (*event.EncryptedEventContent, error) Decrypt(*event.Event) (*event.Event, error) WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool - RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) + RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) Init() error } @@ -100,14 +100,14 @@ type IdentityServerInfo struct { // DiscoverClientAPI resolves the client API URL from a Matrix server name. // Use ParseUserID to extract the server name from a user ID. // https://spec.matrix.org/v1.2/client-server-api/#server-discovery -func DiscoverClientAPI(serverName string) (*ClientWellKnown, error) { +func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown, error) { wellKnownURL := url.URL{ Scheme: "https", Host: serverName, Path: "/.well-known/matrix/client", } - req, err := http.NewRequest("GET", wellKnownURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", wellKnownURL.String(), nil) if err != nil { return nil, err } @@ -174,16 +174,16 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { // We will keep syncing until the syncing state changes. Either because // Sync is called or StopSync is called. syncingID := cli.incrementSyncingID() - nextBatch := cli.Store.LoadNextBatch(cli.UserID) - filterID := cli.Store.LoadFilterID(cli.UserID) + nextBatch := cli.Store.LoadNextBatch(ctx, cli.UserID) + filterID := cli.Store.LoadFilterID(ctx, cli.UserID) if filterID == "" { filterJSON := cli.Syncer.GetFilterJSON(cli.UserID) - resFilter, err := cli.CreateFilter(filterJSON) + resFilter, err := cli.CreateFilter(ctx, filterJSON) if err != nil { return err } filterID = resFilter.FilterID - cli.Store.SaveFilterID(cli.UserID, filterID) + cli.Store.SaveFilterID(ctx, cli.UserID, filterID) } lastSuccessfulSync := time.Now().Add(-cli.StreamSyncMinAge - 1*time.Hour) for { @@ -192,13 +192,12 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { cli.Log.Debug().Msg("Last sync is old, will stream next response") streamResp = true } - resSync, err := cli.FullSyncRequest(ReqSync{ + resSync, err := cli.FullSyncRequest(ctx, ReqSync{ Timeout: 30000, Since: nextBatch, FilterID: filterID, FullState: false, SetPresence: cli.SyncPresence, - Context: ctx, StreamResponse: streamResp, }) if err != nil { @@ -228,7 +227,7 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { // Save the token now *before* processing it. This means it's possible // to not process some events, but it means that we won't get constantly stuck processing // a malformed/buggy event which keeps making us panic. - cli.Store.SaveNextBatch(cli.UserID, resSync.NextBatch) + cli.Store.SaveNextBatch(ctx, cli.UserID, resSync.NextBatch) if err = cli.Syncer.ProcessResponse(resSync, nextBatch); err != nil { return err } @@ -306,8 +305,8 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } } -func (cli *Client) MakeRequest(method string, httpURL string, reqBody interface{}, resBody interface{}) ([]byte, error) { - return cli.MakeFullRequest(FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) +func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL string, reqBody interface{}, resBody interface{}) ([]byte, error) { + return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) } type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) @@ -321,7 +320,6 @@ type FullRequest struct { RequestBody io.Reader RequestLength int64 ResponseJSON interface{} - Context context.Context MaxAttempts int SensitiveContent bool Handler ClientResponseHandler @@ -331,12 +329,9 @@ type FullRequest struct { var requestID int32 var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes" -func (params *FullRequest) compileRequest() (*http.Request, error) { +func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, error) { var logBody any reqBody := params.RequestBody - if params.Context == nil { - params.Context = context.Background() - } if params.RequestJSON != nil { jsonStr, err := json.Marshal(params.RequestJSON) if err != nil { @@ -363,7 +358,6 @@ func (params *FullRequest) compileRequest() (*http.Request, error) { reqBody = bytes.NewReader([]byte("{}")) } reqID := atomic.AddInt32(&requestID, 1) - ctx := params.Context logger := zerolog.Ctx(ctx) if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger { logger = params.Logger @@ -398,14 +392,14 @@ func (params *FullRequest) compileRequest() (*http.Request, error) { // Returns the HTTP body as bytes on 2xx with a nil error. Returns an error if the response is not 2xx along // with the HTTP body bytes if it got that far. This error is an HTTPError which includes the returned // HTTP status code and possibly a RespError as the WrappedError, if the HTTP body could be decoded as a RespError. -func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) { +func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]byte, error) { if params.MaxAttempts == 0 { params.MaxAttempts = 1 + cli.DefaultHTTPRetries } if params.Logger == nil { params.Logger = &cli.Log } - req, err := params.compileRequest() + req, err := params.compileRequest(ctx) if err != nil { return nil, err } @@ -567,39 +561,37 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof } // Whoami gets the user ID of the current user. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami -func (cli *Client) Whoami() (resp *RespWhoami, err error) { +func (cli *Client) Whoami(ctx context.Context) (resp *RespWhoami, err error) { + urlPath := cli.BuildClientURL("v3", "account", "whoami") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } // CreateFilter makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter -func (cli *Client) CreateFilter(filter *Filter) (resp *RespCreateFilter, err error) { +func (cli *Client) CreateFilter(ctx context.Context, filter *Filter) (resp *RespCreateFilter, err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "filter") - _, err = cli.MakeRequest("POST", urlPath, filter, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, filter, &resp) return } // SyncRequest makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3sync -func (cli *Client) SyncRequest(timeout int, since, filterID string, fullState bool, setPresence event.Presence, ctx context.Context) (resp *RespSync, err error) { - return cli.FullSyncRequest(ReqSync{ +func (cli *Client) SyncRequest(ctx context.Context, timeout int, since, filterID string, fullState bool, setPresence event.Presence) (resp *RespSync, err error) { + return cli.FullSyncRequest(ctx, ReqSync{ Timeout: timeout, Since: since, FilterID: filterID, FullState: fullState, SetPresence: setPresence, - Context: ctx, }) } type ReqSync struct { - Timeout int - Since string - FilterID string - FullState bool - SetPresence event.Presence - - Context context.Context + Timeout int + Since string + FilterID string + FullState bool + SetPresence event.Presence StreamResponse bool } @@ -623,13 +615,12 @@ func (req *ReqSync) BuildQuery() map[string]string { } // FullSyncRequest makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3sync -func (cli *Client) FullSyncRequest(req ReqSync) (resp *RespSync, err error) { +func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *RespSync, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "sync"}, req.BuildQuery()) fullReq := FullRequest{ Method: http.MethodGet, URL: urlPath, ResponseJSON: &resp, - Context: req.Context, // We don't want automatic retries for SyncRequest, the Sync() wrapper handles those. MaxAttempts: 1, } @@ -637,7 +628,7 @@ func (cli *Client) FullSyncRequest(req ReqSync) (resp *RespSync, err error) { fullReq.Handler = streamResponse } start := time.Now() - _, err = cli.MakeFullRequest(fullReq) + _, err = cli.MakeFullRequest(ctx, fullReq) duration := time.Now().Sub(start) timeout := time.Duration(req.Timeout) * time.Millisecond buffer := 10 * time.Second @@ -645,7 +636,7 @@ func (cli *Client) FullSyncRequest(req ReqSync) (resp *RespSync, err error) { buffer = 1 * time.Minute } if err == nil && duration > timeout+buffer { - cli.cliOrContextLog(fullReq.Context).Warn(). + cli.cliOrContextLog(ctx).Warn(). Str("since", req.Since). Dur("duration", duration). Dur("timeout", timeout). @@ -676,18 +667,18 @@ func (cli *Client) FullSyncRequest(req ReqSync) (resp *RespSync, err error) { // } else { // // Username is available // } -func (cli *Client) RegisterAvailable(username string) (resp *RespRegisterAvailable, err error) { +func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp *RespRegisterAvailable, err error) { u := cli.BuildURLWithQuery(ClientURLPath{"v3", "register", "available"}, map[string]string{"username": username}) - _, err = cli.MakeRequest(http.MethodGet, u, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && !resp.Available { err = fmt.Errorf(`request returned OK status without "available": true`) } return } -func (cli *Client) register(url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { +func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { var bodyBytes []byte - bodyBytes, err = cli.MakeFullRequest(FullRequest{ + bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: url, RequestJSON: req, @@ -709,21 +700,21 @@ func (cli *Client) register(url string, req *ReqRegister) (resp *RespRegister, u // Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register // // Registers with kind=user. For kind=guest, see RegisterGuest. -func (cli *Client) Register(req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { u := cli.BuildClientURL("v3", "register") - return cli.register(u, req) + return cli.register(ctx, u, req) } // RegisterGuest makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register // with kind=guest. // // For kind=user, see Register. -func (cli *Client) RegisterGuest(req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { query := map[string]string{ "kind": "guest", } u := cli.BuildURLWithQuery(ClientURLPath{"v3", "register"}, query) - return cli.register(u, req) + return cli.register(ctx, u, req) } // RegisterDummy performs m.login.dummy registration according to https://spec.matrix.org/v1.2/client-server-api/#dummy-auth @@ -741,8 +732,8 @@ func (cli *Client) RegisterGuest(req *ReqRegister) (*RespRegister, *RespUserInte // panic(err) // } // token := res.AccessToken -func (cli *Client) RegisterDummy(req *ReqRegister) (*RespRegister, error) { - res, uia, err := cli.Register(req) +func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) { + res, uia, err := cli.Register(ctx, req) if err != nil && uia == nil { return nil, err } else if uia == nil { @@ -751,7 +742,7 @@ func (cli *Client) RegisterDummy(req *ReqRegister) (*RespRegister, error) { return nil, errors.New("server does not support m.login.dummy") } req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session} - res, _, err = cli.Register(req) + res, _, err = cli.Register(ctx, req) if err != nil { return nil, err } @@ -759,15 +750,15 @@ func (cli *Client) RegisterDummy(req *ReqRegister) (*RespRegister, error) { } // GetLoginFlows fetches the login flows that the homeserver supports using https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3login -func (cli *Client) GetLoginFlows() (resp *RespLoginFlows, err error) { +func (cli *Client) GetLoginFlows(ctx context.Context) (resp *RespLoginFlows, err error) { urlPath := cli.BuildClientURL("v3", "login") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } // Login a user to the homeserver according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login -func (cli *Client) Login(req *ReqLogin) (resp *RespLogin, err error) { - _, err = cli.MakeFullRequest(FullRequest{ +func (cli *Client) Login(ctx context.Context, req *ReqLogin) (resp *RespLogin, err error) { + _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "login"), RequestJSON: req, @@ -803,31 +794,31 @@ func (cli *Client) Login(req *ReqLogin) (resp *RespLogin, err error) { // Logout the current user. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout // This does not clear the credentials from the client instance. See ClearCredentials() instead. -func (cli *Client) Logout() (resp *RespLogout, err error) { +func (cli *Client) Logout(ctx context.Context) (resp *RespLogout, err error) { urlPath := cli.BuildClientURL("v3", "logout") - _, err = cli.MakeRequest("POST", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, nil, &resp) return } // LogoutAll logs out all the devices of the current user. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logoutall // This does not clear the credentials from the client instance. See ClearCredentials() instead. -func (cli *Client) LogoutAll() (resp *RespLogout, err error) { +func (cli *Client) LogoutAll(ctx context.Context) (resp *RespLogout, err error) { urlPath := cli.BuildClientURL("v3", "logout", "all") - _, err = cli.MakeRequest("POST", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, nil, &resp) return } // Versions returns the list of supported Matrix versions on this homeserver. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientversions -func (cli *Client) Versions() (resp *RespVersions, err error) { +func (cli *Client) Versions(ctx context.Context) (resp *RespVersions, err error) { urlPath := cli.BuildClientURL("versions") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } // Capabilities returns capabilities on this homeserver. See https://spec.matrix.org/v1.3/client-server-api/#capabilities-negotiation -func (cli *Client) Capabilities() (resp *RespCapabilities, err error) { +func (cli *Client) Capabilities(ctx context.Context) (resp *RespCapabilities, err error) { urlPath := cli.BuildClientURL("v3", "capabilities") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } @@ -835,7 +826,7 @@ func (cli *Client) Capabilities() (resp *RespCapabilities, err error) { // // If serverName is specified, this will be added as a query param to instruct the homeserver to join via that server. If content is specified, it will // be JSON encoded and used as the request body. -func (cli *Client) JoinRoom(roomIDorAlias, serverName string, content interface{}) (resp *RespJoinRoom, err error) { +func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName string, content interface{}) (resp *RespJoinRoom, err error) { var urlPath string if serverName != "" { urlPath = cli.BuildURLWithQuery(ClientURLPath{"v3", "join", roomIDorAlias}, map[string]string{ @@ -844,7 +835,7 @@ func (cli *Client) JoinRoom(roomIDorAlias, serverName string, content interface{ } else { urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias) } - _, err = cli.MakeRequest("POST", urlPath, content, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, content, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin) } @@ -855,50 +846,50 @@ func (cli *Client) JoinRoom(roomIDorAlias, serverName string, content interface{ // // Unlike JoinRoom, this method can only be used to join rooms that the server already knows about. // It's mostly intended for bridges and other things where it's already certain that the server is in the room. -func (cli *Client) JoinRoomByID(roomID id.RoomID) (resp *RespJoinRoom, err error) { - _, err = cli.MakeRequest("POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp) +func (cli *Client) JoinRoomByID(ctx context.Context, roomID id.RoomID) (resp *RespJoinRoom, err error) { + _, err = cli.MakeRequest(ctx, "POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin) } return } -func (cli *Client) GetProfile(mxid id.UserID) (resp *RespUserProfile, err error) { +func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUserProfile, err error) { urlPath := cli.BuildClientURL("v3", "profile", mxid) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } // GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname -func (cli *Client) GetDisplayName(mxid id.UserID) (resp *RespUserDisplayName, err error) { +func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) { urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } // GetOwnDisplayName returns the user's display name. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname -func (cli *Client) GetOwnDisplayName() (resp *RespUserDisplayName, err error) { - return cli.GetDisplayName(cli.UserID) +func (cli *Client) GetOwnDisplayName(ctx context.Context) (resp *RespUserDisplayName, err error) { + return cli.GetDisplayName(ctx, cli.UserID) } // SetDisplayName sets the user's profile display name. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseriddisplayname -func (cli *Client) SetDisplayName(displayName string) (err error) { +func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "displayname") s := struct { DisplayName string `json:"displayname"` }{displayName} - _, err = cli.MakeRequest("PUT", urlPath, &s, nil) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, &s, nil) return } // GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url -func (cli *Client) GetAvatarURL(mxid id.UserID) (url id.ContentURI, err error) { +func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.ContentURI, err error) { urlPath := cli.BuildClientURL("v3", "profile", mxid, "avatar_url") s := struct { AvatarURL id.ContentURI `json:"avatar_url"` }{} - _, err = cli.MakeRequest("GET", urlPath, nil, &s) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &s) if err != nil { return } @@ -907,17 +898,17 @@ func (cli *Client) GetAvatarURL(mxid id.UserID) (url id.ContentURI, err error) { } // GetOwnAvatarURL gets the user's avatar URL. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url -func (cli *Client) GetOwnAvatarURL() (url id.ContentURI, err error) { - return cli.GetAvatarURL(cli.UserID) +func (cli *Client) GetOwnAvatarURL(ctx context.Context) (url id.ContentURI, err error) { + return cli.GetAvatarURL(ctx, cli.UserID) } // SetAvatarURL sets the user's avatar URL. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseridavatar_url -func (cli *Client) SetAvatarURL(url id.ContentURI) (err error) { +func (cli *Client) SetAvatarURL(ctx context.Context, url id.ContentURI) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "avatar_url") s := struct { AvatarURL string `json:"avatar_url"` }{url.String()} - _, err = cli.MakeRequest("PUT", urlPath, &s, nil) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, &s, nil) if err != nil { return err } @@ -926,23 +917,23 @@ func (cli *Client) SetAvatarURL(url id.ContentURI) (err error) { } // BeeperUpdateProfile sets custom fields in the user's profile. -func (cli *Client) BeeperUpdateProfile(data map[string]any) (err error) { +func (cli *Client) BeeperUpdateProfile(ctx context.Context, data map[string]any) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID) - _, err = cli.MakeRequest("PATCH", urlPath, &data, nil) + _, err = cli.MakeRequest(ctx, "PATCH", urlPath, &data, nil) return } // GetAccountData gets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3useruseridaccount_datatype -func (cli *Client) GetAccountData(name string, output interface{}) (err error) { +func (cli *Client) GetAccountData(ctx context.Context, name string, output interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name) - _, err = cli.MakeRequest("GET", urlPath, nil, output) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, output) return } // SetAccountData sets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridaccount_datatype -func (cli *Client) SetAccountData(name string, data interface{}) (err error) { +func (cli *Client) SetAccountData(ctx context.Context, name string, data interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name) - _, err = cli.MakeRequest("PUT", urlPath, data, nil) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, data, nil) if err != nil { return err } @@ -951,16 +942,16 @@ func (cli *Client) SetAccountData(name string, data interface{}) (err error) { } // GetRoomAccountData gets the user's account data of this type in a specific room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridaccount_datatype -func (cli *Client) GetRoomAccountData(roomID id.RoomID, name string, output interface{}) (err error) { +func (cli *Client) GetRoomAccountData(ctx context.Context, roomID id.RoomID, name string, output interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "account_data", name) - _, err = cli.MakeRequest("GET", urlPath, nil, output) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, output) return } // SetRoomAccountData sets the user's account data of this type in a specific room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridroomsroomidaccount_datatype -func (cli *Client) SetRoomAccountData(roomID id.RoomID, name string, data interface{}) (err error) { +func (cli *Client) SetRoomAccountData(ctx context.Context, roomID id.RoomID, name string, data interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "account_data", name) - _, err = cli.MakeRequest("PUT", urlPath, data, nil) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, data, nil) if err != nil { return err } @@ -979,7 +970,7 @@ type ReqSendEvent struct { // SendMessageEvent sends a message event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { +func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { var req ReqSendEvent if len(extra) > 0 { req = extra[0] @@ -1011,15 +1002,15 @@ func (cli *Client) SendMessageEvent(roomID id.RoomID, eventType event.Type, cont urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID} urlPath := cli.BuildURLWithQuery(urlData, queryParams) - _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) return } // SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) { +func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) - _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) if err == nil && cli.StateStore != nil { cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON) } @@ -1028,11 +1019,11 @@ func (cli *Client) SendStateEvent(roomID id.RoomID, eventType event.Type, stateK // SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { +func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ "ts": strconv.FormatInt(ts, 10), }) - _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) if err == nil && cli.StateStore != nil { cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON) } @@ -1041,8 +1032,8 @@ func (cli *Client) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type // SendText sends an m.room.message event into the given room with a msgtype of m.text // See https://spec.matrix.org/v1.2/client-server-api/#mtext -func (cli *Client) SendText(roomID id.RoomID, text string) (*RespSendEvent, error) { - return cli.SendMessageEvent(roomID, event.EventMessage, &event.MessageEventContent{ +func (cli *Client) SendText(ctx context.Context, roomID id.RoomID, text string) (*RespSendEvent, error) { + return cli.SendMessageEvent(ctx, roomID, event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgText, Body: text, }) @@ -1050,15 +1041,15 @@ func (cli *Client) SendText(roomID id.RoomID, text string) (*RespSendEvent, erro // SendNotice sends an m.room.message event into the given room with a msgtype of m.notice // See https://spec.matrix.org/v1.2/client-server-api/#mnotice -func (cli *Client) SendNotice(roomID id.RoomID, text string) (*RespSendEvent, error) { - return cli.SendMessageEvent(roomID, event.EventMessage, &event.MessageEventContent{ +func (cli *Client) SendNotice(ctx context.Context, roomID id.RoomID, text string) (*RespSendEvent, error) { + return cli.SendMessageEvent(ctx, roomID, event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgNotice, Body: text, }) } -func (cli *Client) SendReaction(roomID id.RoomID, eventID id.EventID, reaction string) (*RespSendEvent, error) { - return cli.SendMessageEvent(roomID, event.EventReaction, &event.ReactionEventContent{ +func (cli *Client) SendReaction(ctx context.Context, roomID id.RoomID, eventID id.EventID, reaction string) (*RespSendEvent, error) { + return cli.SendMessageEvent(ctx, roomID, event.EventReaction, &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ EventID: eventID, Type: event.RelAnnotation, @@ -1068,7 +1059,7 @@ func (cli *Client) SendReaction(roomID id.RoomID, eventID id.EventID, reaction s } // RedactEvent redacts the given event. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidredacteventidtxnid -func (cli *Client) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...ReqRedact) (resp *RespSendEvent, err error) { +func (cli *Client) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID, extra ...ReqRedact) (resp *RespSendEvent, err error) { req := ReqRedact{} if len(extra) > 0 { req = extra[0] @@ -1086,7 +1077,7 @@ func (cli *Client) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...Re txnID = cli.TxnID() } urlPath := cli.BuildClientURL("v3", "rooms", roomID, "redact", eventID, txnID) - _, err = cli.MakeRequest("PUT", urlPath, req.Extra, &resp) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, req.Extra, &resp) return } @@ -1096,9 +1087,9 @@ func (cli *Client) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...Re // Preset: "public_chat", // }) // fmt.Println("Room:", resp.RoomID) -func (cli *Client) CreateRoom(req *ReqCreateRoom) (resp *RespCreateRoom, err error) { +func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *RespCreateRoom, err error) { urlPath := cli.BuildClientURL("v3", "createRoom") - _, err = cli.MakeRequest("POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin) for _, evt := range req.InitialState { @@ -1119,7 +1110,7 @@ func (cli *Client) CreateRoom(req *ReqCreateRoom) (resp *RespCreateRoom, err err } // LeaveRoom leaves the given room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidleave -func (cli *Client) LeaveRoom(roomID id.RoomID, optionalReq ...*ReqLeave) (resp *RespLeaveRoom, err error) { +func (cli *Client) LeaveRoom(ctx context.Context, roomID id.RoomID, optionalReq ...*ReqLeave) (resp *RespLeaveRoom, err error) { req := &ReqLeave{} if len(optionalReq) == 1 { req = optionalReq[0] @@ -1127,7 +1118,7 @@ func (cli *Client) LeaveRoom(roomID id.RoomID, optionalReq ...*ReqLeave) (resp * panic("invalid number of arguments to LeaveRoom") } u := cli.BuildClientURL("v3", "rooms", roomID, "leave") - _, err = cli.MakeRequest("POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(roomID, cli.UserID, event.MembershipLeave) } @@ -1135,16 +1126,16 @@ func (cli *Client) LeaveRoom(roomID id.RoomID, optionalReq ...*ReqLeave) (resp * } // ForgetRoom forgets a room entirely. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidforget -func (cli *Client) ForgetRoom(roomID id.RoomID) (resp *RespForgetRoom, err error) { +func (cli *Client) ForgetRoom(ctx context.Context, roomID id.RoomID) (resp *RespForgetRoom, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "forget") - _, err = cli.MakeRequest("POST", u, struct{}{}, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, struct{}{}, &resp) return } // InviteUser invites a user to a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite -func (cli *Client) InviteUser(roomID id.RoomID, req *ReqInviteUser) (resp *RespInviteUser, err error) { +func (cli *Client) InviteUser(ctx context.Context, roomID id.RoomID, req *ReqInviteUser) (resp *RespInviteUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "invite") - _, err = cli.MakeRequest("POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipInvite) } @@ -1152,16 +1143,16 @@ func (cli *Client) InviteUser(roomID id.RoomID, req *ReqInviteUser) (resp *RespI } // InviteUserByThirdParty invites a third-party identifier to a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1 -func (cli *Client) InviteUserByThirdParty(roomID id.RoomID, req *ReqInvite3PID) (resp *RespInviteUser, err error) { +func (cli *Client) InviteUserByThirdParty(ctx context.Context, roomID id.RoomID, req *ReqInvite3PID) (resp *RespInviteUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "invite") - _, err = cli.MakeRequest("POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) return } // KickUser kicks a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidkick -func (cli *Client) KickUser(roomID id.RoomID, req *ReqKickUser) (resp *RespKickUser, err error) { +func (cli *Client) KickUser(ctx context.Context, roomID id.RoomID, req *ReqKickUser) (resp *RespKickUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "kick") - _, err = cli.MakeRequest("POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave) } @@ -1169,9 +1160,9 @@ func (cli *Client) KickUser(roomID id.RoomID, req *ReqKickUser) (resp *RespKickU } // BanUser bans a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidban -func (cli *Client) BanUser(roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser, err error) { +func (cli *Client) BanUser(ctx context.Context, roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "ban") - _, err = cli.MakeRequest("POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipBan) } @@ -1179,9 +1170,9 @@ func (cli *Client) BanUser(roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser } // UnbanUser unbans a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban -func (cli *Client) UnbanUser(roomID id.RoomID, req *ReqUnbanUser) (resp *RespUnbanUser, err error) { +func (cli *Client) UnbanUser(ctx context.Context, roomID id.RoomID, req *ReqUnbanUser) (resp *RespUnbanUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "unban") - _, err = cli.MakeRequest("POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave) } @@ -1189,30 +1180,30 @@ func (cli *Client) UnbanUser(roomID id.RoomID, req *ReqUnbanUser) (resp *RespUnb } // UserTyping sets the typing status of the user. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidtypinguserid -func (cli *Client) UserTyping(roomID id.RoomID, typing bool, timeout time.Duration) (resp *RespTyping, err error) { +func (cli *Client) UserTyping(ctx context.Context, roomID id.RoomID, typing bool, timeout time.Duration) (resp *RespTyping, err error) { req := ReqTyping{Typing: typing, Timeout: timeout.Milliseconds()} u := cli.BuildClientURL("v3", "rooms", roomID, "typing", cli.UserID) - _, err = cli.MakeRequest("PUT", u, req, &resp) + _, err = cli.MakeRequest(ctx, "PUT", u, req, &resp) return } // GetPresence gets the presence of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3presenceuseridstatus -func (cli *Client) GetPresence(userID id.UserID) (resp *RespPresence, err error) { +func (cli *Client) GetPresence(ctx context.Context, userID id.UserID) (resp *RespPresence, err error) { resp = new(RespPresence) u := cli.BuildClientURL("v3", "presence", userID, "status") - _, err = cli.MakeRequest("GET", u, nil, resp) + _, err = cli.MakeRequest(ctx, "GET", u, nil, resp) return } // GetOwnPresence gets the user's presence. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3presenceuseridstatus -func (cli *Client) GetOwnPresence() (resp *RespPresence, err error) { - return cli.GetPresence(cli.UserID) +func (cli *Client) GetOwnPresence(ctx context.Context) (resp *RespPresence, err error) { + return cli.GetPresence(ctx, cli.UserID) } -func (cli *Client) SetPresence(status event.Presence) (err error) { +func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err error) { req := ReqPresence{Presence: status} u := cli.BuildClientURL("v3", "presence", cli.UserID, "status") - _, err = cli.MakeRequest("PUT", u, req, nil) + _, err = cli.MakeRequest(ctx, "PUT", u, req, nil) return } @@ -1252,9 +1243,9 @@ func (cli *Client) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType even // StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with // the HTTP response body, or return an error. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstateeventtypestatekey -func (cli *Client) StateEvent(roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) { +func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) - _, err = cli.MakeRequest("GET", u, nil, outContent) + _, err = cli.MakeRequest(ctx, "GET", u, nil, outContent) if err == nil && cli.StateStore != nil { cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, outContent) } @@ -1302,8 +1293,8 @@ func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON inter // State gets all state in a room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate -func (cli *Client) State(roomID id.RoomID) (stateMap RoomStateMap, err error) { - _, err = cli.MakeFullRequest(FullRequest{ +func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) { + _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodGet, URL: cli.BuildClientURL("v3", "rooms", roomID, "state"), ResponseJSON: &stateMap, @@ -1321,34 +1312,35 @@ func (cli *Client) State(roomID id.RoomID) (stateMap RoomStateMap, err error) { } // GetMediaConfig fetches the configuration of the content repository, such as upload limitations. -func (cli *Client) GetMediaConfig() (resp *RespMediaConfig, err error) { +func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) { u := cli.BuildURL(MediaURLPath{"v3", "config"}) - _, err = cli.MakeRequest("GET", u, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) return } // UploadLink uploads an HTTP URL and then returns an MXC URI. -func (cli *Client) UploadLink(link string) (*RespMediaUpload, error) { - res, err := cli.Client.Get(link) +func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUpload, error) { + req, err := http.NewRequestWithContext(ctx, "GET", link, nil) + if err != nil { + return nil, err + } + + res, err := cli.Client.Do(req) if res != nil { defer res.Body.Close() } if err != nil { return nil, err } - return cli.Upload(res.Body, res.Header.Get("Content-Type"), res.ContentLength) + return cli.Upload(ctx, res.Body, res.Header.Get("Content-Type"), res.ContentLength) } func (cli *Client) GetDownloadURL(mxcURL id.ContentURI) string { return cli.BuildURLWithQuery(MediaURLPath{"v3", "download", mxcURL.Homeserver, mxcURL.FileID}, map[string]string{"allow_redirect": "true"}) } -func (cli *Client) Download(mxcURL id.ContentURI) (io.ReadCloser, error) { - return cli.DownloadContext(context.Background(), mxcURL) -} - -func (cli *Client) DownloadContext(ctx context.Context, mxcURL id.ContentURI) (io.ReadCloser, error) { - resp, err := cli.downloadContext(ctx, mxcURL) +func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (io.ReadCloser, error) { + resp, err := cli.download(ctx, mxcURL) if err != nil { return nil, err } @@ -1411,7 +1403,7 @@ func (cli *Client) doMediaRequest(req *http.Request, retries int, backoff time.D return res, err } -func (cli *Client) downloadContext(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { +func (cli *Client) download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { ctxLog := zerolog.Ctx(ctx) if ctxLog.GetLevel() == zerolog.Disabled || ctxLog == zerolog.DefaultContextLogger { ctx = cli.Log.WithContext(ctx) @@ -1424,12 +1416,8 @@ func (cli *Client) downloadContext(ctx context.Context, mxcURL id.ContentURI) (* return cli.doMediaRequest(req, cli.DefaultHTTPRetries, 4*time.Second) } -func (cli *Client) DownloadBytes(mxcURL id.ContentURI) ([]byte, error) { - return cli.DownloadBytesContext(context.Background(), mxcURL) -} - -func (cli *Client) DownloadBytesContext(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) { - resp, err := cli.downloadContext(ctx, mxcURL) +func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) { + resp, err := cli.download(ctx, mxcURL) if err != nil { return nil, err } @@ -1440,10 +1428,10 @@ func (cli *Client) DownloadBytesContext(ctx context.Context, mxcURL id.ContentUR // CreateMXC creates a blank Matrix content URI to allow uploading the content asynchronously later. // // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav1create -func (cli *Client) CreateMXC() (*RespCreateMXC, error) { +func (cli *Client) CreateMXC(ctx context.Context) (*RespCreateMXC, error) { u, _ := url.Parse(cli.BuildURL(MediaURLPath{"v1", "create"})) var m RespCreateMXC - _, err := cli.MakeFullRequest(FullRequest{ + _, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: u.String(), ResponseJSON: &m, @@ -1456,15 +1444,15 @@ func (cli *Client) CreateMXC() (*RespCreateMXC, error) { // // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav1create // and https://spec.matrix.org/v1.7/client-server-api/#put_matrixmediav3uploadservernamemediaid -func (cli *Client) UploadAsync(req ReqUploadMedia) (*RespCreateMXC, error) { - resp, err := cli.CreateMXC() +func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCreateMXC, error) { + resp, err := cli.CreateMXC(ctx) if err != nil { return nil, err } req.MXC = resp.ContentURI req.UnstableUploadURL = resp.UnstableUploadURL go func() { - _, err = cli.UploadMedia(req) + _, err = cli.UploadMedia(ctx, req) if err != nil { cli.Log.Error().Str("mxc", req.MXC.String()).Err(err).Msg("Async upload of media failed") } @@ -1472,12 +1460,12 @@ func (cli *Client) UploadAsync(req ReqUploadMedia) (*RespCreateMXC, error) { return resp, nil } -func (cli *Client) UploadBytes(data []byte, contentType string) (*RespMediaUpload, error) { - return cli.UploadBytesWithName(data, contentType, "") +func (cli *Client) UploadBytes(ctx context.Context, data []byte, contentType string) (*RespMediaUpload, error) { + return cli.UploadBytesWithName(ctx, data, contentType, "") } -func (cli *Client) UploadBytesWithName(data []byte, contentType, fileName string) (*RespMediaUpload, error) { - return cli.UploadMedia(ReqUploadMedia{ +func (cli *Client) UploadBytesWithName(ctx context.Context, data []byte, contentType, fileName string) (*RespMediaUpload, error) { + return cli.UploadMedia(ctx, ReqUploadMedia{ ContentBytes: data, ContentType: contentType, FileName: fileName, @@ -1487,8 +1475,8 @@ func (cli *Client) UploadBytesWithName(data []byte, contentType, fileName string // Upload uploads the given data to the content repository and returns an MXC URI. // // Deprecated: UploadMedia should be used instead. -func (cli *Client) Upload(content io.Reader, contentType string, contentLength int64) (*RespMediaUpload, error) { - return cli.UploadMedia(ReqUploadMedia{ +func (cli *Client) Upload(ctx context.Context, content io.Reader, contentType string, contentLength int64) (*RespMediaUpload, error) { + return cli.UploadMedia(ctx, ReqUploadMedia{ Content: content, ContentLength: contentLength, ContentType: contentType, @@ -1511,9 +1499,9 @@ type ReqUploadMedia struct { UnstableUploadURL string } -func (cli *Client) tryUploadMediaToURL(url, contentType string, content io.Reader) (*http.Response, error) { +func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader) (*http.Response, error) { cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL") - req, err := http.NewRequest(http.MethodPut, url, content) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content) if err != nil { return nil, err } @@ -1523,7 +1511,7 @@ func (cli *Client) tryUploadMediaToURL(url, contentType string, content io.Reade return http.DefaultClient.Do(req) } -func (cli *Client) uploadMediaToURL(data ReqUploadMedia) (*RespMediaUpload, error) { +func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { retries := cli.DefaultHTTPRetries if data.ContentBytes == nil { // Can't retry with a reader @@ -1536,7 +1524,7 @@ func (cli *Client) uploadMediaToURL(data ReqUploadMedia) (*RespMediaUpload, erro } else { data.Content = nil } - resp, err := cli.tryUploadMediaToURL(data.UnstableUploadURL, data.ContentType, reader) + resp, err := cli.tryUploadMediaToURL(ctx, data.UnstableUploadURL, data.ContentType, reader) if err == nil { if resp.StatusCode >= 200 && resp.StatusCode < 300 { // Everything is fine @@ -1562,7 +1550,7 @@ func (cli *Client) uploadMediaToURL(data ReqUploadMedia) (*RespMediaUpload, erro notifyURL := cli.BuildURLWithQuery(MediaURLPath{"unstable", "com.beeper.msc3870", "upload", data.MXC.Homeserver, data.MXC.FileID, "complete"}, query) var m *RespMediaUpload - _, err := cli.MakeFullRequest(FullRequest{ + _, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: notifyURL, ResponseJSON: m, @@ -1576,12 +1564,12 @@ func (cli *Client) uploadMediaToURL(data ReqUploadMedia) (*RespMediaUpload, erro // UploadMedia uploads the given data to the content repository and returns an MXC URI. // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav3upload -func (cli *Client) UploadMedia(data ReqUploadMedia) (*RespMediaUpload, error) { +func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { if data.UnstableUploadURL != "" { if data.MXC.IsEmpty() { return nil, errors.New("MXC must also be set when uploading to external URL") } - return cli.uploadMediaToURL(data) + return cli.uploadMediaToURL(ctx, data) } u, _ := url.Parse(cli.BuildURL(MediaURLPath{"v3", "upload"})) method := http.MethodPost @@ -1601,7 +1589,7 @@ func (cli *Client) UploadMedia(data ReqUploadMedia) (*RespMediaUpload, error) { } var m RespMediaUpload - _, err := cli.MakeFullRequest(FullRequest{ + _, err := cli.MakeFullRequest(ctx, FullRequest{ Method: method, URL: u.String(), Headers: headers, @@ -1616,12 +1604,12 @@ func (cli *Client) UploadMedia(data ReqUploadMedia) (*RespMediaUpload, error) { // GetURLPreview asks the homeserver to fetch a preview for a given URL. // // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url -func (cli *Client) GetURLPreview(url string) (*RespPreviewURL, error) { +func (cli *Client) GetURLPreview(ctx context.Context, url string) (*RespPreviewURL, error) { reqURL := cli.BuildURLWithQuery(MediaURLPath{"v3", "preview_url"}, map[string]string{ "url": url, }) var output RespPreviewURL - _, err := cli.MakeRequest(http.MethodGet, reqURL, nil, &output) + _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &output) return &output, err } @@ -1629,9 +1617,9 @@ func (cli *Client) GetURLPreview(url string) (*RespPreviewURL, error) { // // In general, usage of this API is discouraged in favour of /sync, as calling this API can race with incoming membership changes. // This API is primarily designed for application services which may want to efficiently look up joined members in a room. -func (cli *Client) JoinedMembers(roomID id.RoomID) (resp *RespJoinedMembers, err error) { +func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *RespJoinedMembers, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members") - _, err = cli.MakeRequest("GET", u, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.ClearCachedMembers(roomID, event.MembershipJoin) for userID, member := range resp.Joined { @@ -1645,7 +1633,7 @@ func (cli *Client) JoinedMembers(roomID id.RoomID) (resp *RespJoinedMembers, err return } -func (cli *Client) Members(roomID id.RoomID, req ...ReqMembers) (resp *RespMembers, err error) { +func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMembers) (resp *RespMembers, err error) { var extra ReqMembers if len(req) > 0 { extra = req[0] @@ -1661,7 +1649,7 @@ func (cli *Client) Members(roomID id.RoomID, req ...ReqMembers) (resp *RespMembe query["not_membership"] = string(extra.NotMembership) } u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "members"}, query) - _, err = cli.MakeRequest("GET", u, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) if err == nil && cli.StateStore != nil { var clearMemberships []event.Membership if extra.Membership != "" { @@ -1681,9 +1669,9 @@ func (cli *Client) Members(roomID id.RoomID, req ...ReqMembers) (resp *RespMembe // // In general, usage of this API is discouraged in favour of /sync, as calling this API can race with incoming membership changes. // This API is primarily designed for application services which may want to efficiently look up joined rooms. -func (cli *Client) JoinedRooms() (resp *RespJoinedRooms, err error) { +func (cli *Client) JoinedRooms(ctx context.Context) (resp *RespJoinedRooms, err error) { u := cli.BuildClientURL("v3", "joined_rooms") - _, err = cli.MakeRequest("GET", u, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) return } @@ -1693,16 +1681,16 @@ func (cli *Client) JoinedRooms() (resp *RespJoinedRooms, err error) { // when it encounters another space as a child it recurses into that space before returning non-space children. // // The second function parameter specifies query parameters to limit the response. No query parameters will be added if it's nil. -func (cli *Client) Hierarchy(roomID id.RoomID, req *ReqHierarchy) (resp *RespHierarchy, err error) { +func (cli *Client) Hierarchy(ctx context.Context, roomID id.RoomID, req *ReqHierarchy) (resp *RespHierarchy, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "rooms", roomID, "hierarchy"}, req.Query()) - _, err = cli.MakeRequest(http.MethodGet, urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } // Messages returns a list of message and state events for a room. It uses // pagination query parameters to paginate history in the room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidmessages -func (cli *Client) Messages(roomID id.RoomID, from, to string, dir Direction, filter *FilterPart, limit int) (resp *RespMessages, err error) { +func (cli *Client) Messages(ctx context.Context, roomID id.RoomID, from, to string, dir Direction, filter *FilterPart, limit int) (resp *RespMessages, err error) { query := map[string]string{ "from": from, "dir": string(dir), @@ -1722,20 +1710,20 @@ func (cli *Client) Messages(roomID id.RoomID, from, to string, dir Direction, fi } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "messages"}, query) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } // TimestampToEvent finds the ID of the event closest to the given timestamp. // // See https://spec.matrix.org/v1.6/client-server-api/#get_matrixclientv1roomsroomidtimestamp_to_event -func (cli *Client) TimestampToEvent(roomID id.RoomID, timestamp time.Time, dir Direction) (resp *RespTimestampToEvent, err error) { +func (cli *Client) TimestampToEvent(ctx context.Context, roomID id.RoomID, timestamp time.Time, dir Direction) (resp *RespTimestampToEvent, err error) { query := map[string]string{ "ts": strconv.FormatInt(timestamp.UnixMilli(), 10), "dir": string(dir), } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "rooms", roomID, "timestamp_to_event"}, query) - _, err = cli.MakeRequest(http.MethodGet, urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } @@ -1743,7 +1731,7 @@ func (cli *Client) TimestampToEvent(roomID id.RoomID, timestamp time.Time, dir D // specified event. It use pagination query parameters to paginate history in // the room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidcontexteventid -func (cli *Client) Context(roomID id.RoomID, eventID id.EventID, filter *FilterPart, limit int) (resp *RespContext, err error) { +func (cli *Client) Context(ctx context.Context, roomID id.RoomID, eventID id.EventID, filter *FilterPart, limit int) (resp *RespContext, err error) { query := map[string]string{} if filter != nil { filterJSON, err := json.Marshal(filter) @@ -1757,173 +1745,173 @@ func (cli *Client) Context(roomID id.RoomID, eventID id.EventID, filter *FilterP } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "context", eventID}, query) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) GetEvent(roomID id.RoomID, eventID id.EventID) (resp *event.Event, err error) { +func (cli *Client) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (resp *event.Event, err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "event", eventID) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) MarkRead(roomID id.RoomID, eventID id.EventID) (err error) { - return cli.SendReceipt(roomID, eventID, event.ReceiptTypeRead, nil) +func (cli *Client) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID) (err error) { + return cli.SendReceipt(ctx, roomID, eventID, event.ReceiptTypeRead, nil) } // MarkReadWithContent sends a read receipt including custom data. // // Deprecated: Use SendReceipt instead. -func (cli *Client) MarkReadWithContent(roomID id.RoomID, eventID id.EventID, content interface{}) (err error) { - return cli.SendReceipt(roomID, eventID, event.ReceiptTypeRead, content) +func (cli *Client) MarkReadWithContent(ctx context.Context, roomID id.RoomID, eventID id.EventID, content interface{}) (err error) { + return cli.SendReceipt(ctx, roomID, eventID, event.ReceiptTypeRead, content) } // SendReceipt sends a receipt, usually specifically a read receipt. // // Passing nil as the content is safe, the library will automatically replace it with an empty JSON object. // To mark a message in a specific thread as read, use pass a ReqSendReceipt as the content. -func (cli *Client) SendReceipt(roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType, content interface{}) (err error) { +func (cli *Client) SendReceipt(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType, content interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "receipt", receiptType, eventID) - _, err = cli.MakeRequest("POST", urlPath, content, nil) + _, err = cli.MakeRequest(ctx, "POST", urlPath, content, nil) return } -func (cli *Client) SetReadMarkers(roomID id.RoomID, content interface{}) (err error) { +func (cli *Client) SetReadMarkers(ctx context.Context, roomID id.RoomID, content interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "read_markers") - _, err = cli.MakeRequest("POST", urlPath, content, nil) + _, err = cli.MakeRequest(ctx, "POST", urlPath, content, nil) return } -func (cli *Client) AddTag(roomID id.RoomID, tag string, order float64) error { +func (cli *Client) AddTag(ctx context.Context, roomID id.RoomID, tag string, order float64) error { var tagData event.Tag if order == order { tagData.Order = json.Number(strconv.FormatFloat(order, 'e', -1, 64)) } - return cli.AddTagWithCustomData(roomID, tag, tagData) + return cli.AddTagWithCustomData(ctx, roomID, tag, tagData) } -func (cli *Client) AddTagWithCustomData(roomID id.RoomID, tag string, data interface{}) (err error) { +func (cli *Client) AddTagWithCustomData(ctx context.Context, roomID id.RoomID, tag string, data interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) - _, err = cli.MakeRequest("PUT", urlPath, data, nil) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, data, nil) return } -func (cli *Client) GetTags(roomID id.RoomID) (tags event.TagEventContent, err error) { - err = cli.GetTagsWithCustomData(roomID, &tags) +func (cli *Client) GetTags(ctx context.Context, roomID id.RoomID) (tags event.TagEventContent, err error) { + err = cli.GetTagsWithCustomData(ctx, roomID, &tags) return } -func (cli *Client) GetTagsWithCustomData(roomID id.RoomID, resp interface{}) (err error) { +func (cli *Client) GetTagsWithCustomData(ctx context.Context, roomID id.RoomID, resp interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) RemoveTag(roomID id.RoomID, tag string) (err error) { +func (cli *Client) RemoveTag(ctx context.Context, roomID id.RoomID, tag string) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) - _, err = cli.MakeRequest("DELETE", urlPath, nil, nil) + _, err = cli.MakeRequest(ctx, "DELETE", urlPath, nil, nil) return } // Deprecated: Synapse may not handle setting m.tag directly properly, so you should use the Add/RemoveTag methods instead. -func (cli *Client) SetTags(roomID id.RoomID, tags event.Tags) (err error) { - return cli.SetRoomAccountData(roomID, "m.tag", map[string]event.Tags{ +func (cli *Client) SetTags(ctx context.Context, roomID id.RoomID, tags event.Tags) (err error) { + return cli.SetRoomAccountData(ctx, roomID, "m.tag", map[string]event.Tags{ "tags": tags, }) } // TurnServer returns turn server details and credentials for the client to use when initiating calls. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3voipturnserver -func (cli *Client) TurnServer() (resp *RespTurnServer, err error) { +func (cli *Client) TurnServer(ctx context.Context) (resp *RespTurnServer, err error) { urlPath := cli.BuildClientURL("v3", "voip", "turnServer") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) CreateAlias(alias id.RoomAlias, roomID id.RoomID) (resp *RespAliasCreate, err error) { +func (cli *Client) CreateAlias(ctx context.Context, alias id.RoomAlias, roomID id.RoomID) (resp *RespAliasCreate, err error) { urlPath := cli.BuildClientURL("v3", "directory", "room", alias) - _, err = cli.MakeRequest("PUT", urlPath, &ReqAliasCreate{RoomID: roomID}, &resp) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, &ReqAliasCreate{RoomID: roomID}, &resp) return } -func (cli *Client) ResolveAlias(alias id.RoomAlias) (resp *RespAliasResolve, err error) { +func (cli *Client) ResolveAlias(ctx context.Context, alias id.RoomAlias) (resp *RespAliasResolve, err error) { urlPath := cli.BuildClientURL("v3", "directory", "room", alias) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) DeleteAlias(alias id.RoomAlias) (resp *RespAliasDelete, err error) { +func (cli *Client) DeleteAlias(ctx context.Context, alias id.RoomAlias) (resp *RespAliasDelete, err error) { urlPath := cli.BuildClientURL("v3", "directory", "room", alias) - _, err = cli.MakeRequest("DELETE", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "DELETE", urlPath, nil, &resp) return } -func (cli *Client) GetAliases(roomID id.RoomID) (resp *RespAliasList, err error) { +func (cli *Client) GetAliases(ctx context.Context, roomID id.RoomID) (resp *RespAliasList, err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "aliases") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) UploadKeys(req *ReqUploadKeys) (resp *RespUploadKeys, err error) { +func (cli *Client) UploadKeys(ctx context.Context, req *ReqUploadKeys) (resp *RespUploadKeys, err error) { urlPath := cli.BuildClientURL("v3", "keys", "upload") - _, err = cli.MakeRequest("POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) return } -func (cli *Client) QueryKeys(req *ReqQueryKeys) (resp *RespQueryKeys, err error) { +func (cli *Client) QueryKeys(ctx context.Context, req *ReqQueryKeys) (resp *RespQueryKeys, err error) { urlPath := cli.BuildClientURL("v3", "keys", "query") - _, err = cli.MakeRequest("POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) return } -func (cli *Client) ClaimKeys(req *ReqClaimKeys) (resp *RespClaimKeys, err error) { +func (cli *Client) ClaimKeys(ctx context.Context, req *ReqClaimKeys) (resp *RespClaimKeys, err error) { urlPath := cli.BuildClientURL("v3", "keys", "claim") - _, err = cli.MakeRequest("POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) return } -func (cli *Client) GetKeyChanges(from, to string) (resp *RespKeyChanges, err error) { +func (cli *Client) GetKeyChanges(ctx context.Context, from, to string) (resp *RespKeyChanges, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "keys", "changes"}, map[string]string{ "from": from, "to": to, }) - _, err = cli.MakeRequest("POST", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, nil, &resp) return } -func (cli *Client) SendToDevice(eventType event.Type, req *ReqSendToDevice) (resp *RespSendToDevice, err error) { +func (cli *Client) SendToDevice(ctx context.Context, eventType event.Type, req *ReqSendToDevice) (resp *RespSendToDevice, err error) { urlPath := cli.BuildClientURL("v3", "sendToDevice", eventType.String(), cli.TxnID()) - _, err = cli.MakeRequest("PUT", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, req, &resp) return } -func (cli *Client) GetDevicesInfo() (resp *RespDevicesInfo, err error) { +func (cli *Client) GetDevicesInfo(ctx context.Context) (resp *RespDevicesInfo, err error) { urlPath := cli.BuildClientURL("v3", "devices") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) GetDeviceInfo(deviceID id.DeviceID) (resp *RespDeviceInfo, err error) { +func (cli *Client) GetDeviceInfo(ctx context.Context, deviceID id.DeviceID) (resp *RespDeviceInfo, err error) { urlPath := cli.BuildClientURL("v3", "devices", deviceID) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) SetDeviceInfo(deviceID id.DeviceID, req *ReqDeviceInfo) error { +func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req *ReqDeviceInfo) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) - _, err := cli.MakeRequest("PUT", urlPath, req, nil) + _, err := cli.MakeRequest(ctx, "PUT", urlPath, req, nil) return err } -func (cli *Client) DeleteDevice(deviceID id.DeviceID, req *ReqDeleteDevice) error { +func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) - _, err := cli.MakeRequest("DELETE", urlPath, req, nil) + _, err := cli.MakeRequest(ctx, "DELETE", urlPath, req, nil) return err } -func (cli *Client) DeleteDevices(req *ReqDeleteDevices) error { +func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error { urlPath := cli.BuildClientURL("v3", "delete_devices") - _, err := cli.MakeRequest("DELETE", urlPath, req, nil) + _, err := cli.MakeRequest(ctx, "DELETE", urlPath, req, nil) return err } @@ -1932,8 +1920,8 @@ type UIACallback = func(*RespUserInteractive) interface{} // UploadCrossSigningKeys uploads the given cross-signing keys to the server. // Because the endpoint requires user-interactive authentication a callback must be provided that, // given the UI auth parameters, produces the required result (or nil to end the flow). -func (cli *Client) UploadCrossSigningKeys(keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { - content, err := cli.MakeFullRequest(FullRequest{ +func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { + content, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"), RequestJSON: keys, @@ -1948,48 +1936,48 @@ func (cli *Client) UploadCrossSigningKeys(keys *UploadCrossSigningKeysReq, uiaCa auth := uiaCallback(&uiAuthResp) if auth != nil { keys.Auth = auth - return cli.UploadCrossSigningKeys(keys, uiaCallback) + return cli.UploadCrossSigningKeys(ctx, keys, uiaCallback) } } return err } -func (cli *Client) UploadSignatures(req *ReqUploadSignatures) (resp *RespUploadSignatures, err error) { +func (cli *Client) UploadSignatures(ctx context.Context, req *ReqUploadSignatures) (resp *RespUploadSignatures, err error) { urlPath := cli.BuildClientURL("v3", "keys", "signatures", "upload") - _, err = cli.MakeRequest("POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) return } // GetPushRules returns the push notification rules for the global scope. -func (cli *Client) GetPushRules() (*pushrules.PushRuleset, error) { - return cli.GetScopedPushRules("global") +func (cli *Client) GetPushRules(ctx context.Context) (*pushrules.PushRuleset, error) { + return cli.GetScopedPushRules(ctx, "global") } // GetScopedPushRules returns the push notification rules for the given scope. -func (cli *Client) GetScopedPushRules(scope string) (resp *pushrules.PushRuleset, err error) { +func (cli *Client) GetScopedPushRules(ctx context.Context, scope string) (resp *pushrules.PushRuleset, err error) { u, _ := url.Parse(cli.BuildClientURL("v3", "pushrules", scope)) // client.BuildURL returns the URL without a trailing slash, but the pushrules endpoint requires the slash. u.Path += "/" - _, err = cli.MakeRequest("GET", u.String(), nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", u.String(), nil, &resp) return } -func (cli *Client) GetPushRule(scope string, kind pushrules.PushRuleType, ruleID string) (resp *pushrules.PushRule, err error) { +func (cli *Client) GetPushRule(ctx context.Context, scope string, kind pushrules.PushRuleType, ruleID string) (resp *pushrules.PushRule, err error) { urlPath := cli.BuildClientURL("v3", "pushrules", scope, kind, ruleID) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) if resp != nil { resp.Type = kind } return } -func (cli *Client) DeletePushRule(scope string, kind pushrules.PushRuleType, ruleID string) error { +func (cli *Client) DeletePushRule(ctx context.Context, scope string, kind pushrules.PushRuleType, ruleID string) error { urlPath := cli.BuildClientURL("v3", "pushrules", scope, kind, ruleID) - _, err := cli.MakeRequest("DELETE", urlPath, nil, nil) + _, err := cli.MakeRequest(ctx, "DELETE", urlPath, nil, nil) return err } -func (cli *Client) PutPushRule(scope string, kind pushrules.PushRuleType, ruleID string, req *ReqPutPushRule) error { +func (cli *Client) PutPushRule(ctx context.Context, scope string, kind pushrules.PushRuleType, ruleID string, req *ReqPutPushRule) error { query := make(map[string]string) if len(req.After) > 0 { query["after"] = req.After @@ -1998,14 +1986,14 @@ func (cli *Client) PutPushRule(scope string, kind pushrules.PushRuleType, ruleID query["before"] = req.Before } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "pushrules", scope, kind, ruleID}, query) - _, err := cli.MakeRequest("PUT", urlPath, req, nil) + _, err := cli.MakeRequest(ctx, "PUT", urlPath, req, nil) return err } // BatchSend sends a batch of historical events into a room. This is only available for appservices. // // Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead. -func (cli *Client) BatchSend(roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) { +func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) { path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"} query := map[string]string{ "prev_event_id": req.PrevEventID.String(), @@ -2019,12 +2007,12 @@ func (cli *Client) BatchSend(roomID id.RoomID, req *ReqBatchSend) (resp *RespBat if len(req.BatchID) > 0 { query["batch_id"] = req.BatchID.String() } - _, err = cli.MakeRequest("POST", cli.BuildURLWithQuery(path, query), req, &resp) + _, err = cli.MakeRequest(ctx, "POST", cli.BuildURLWithQuery(path, query), req, &resp) return } -func (cli *Client) AppservicePing(id, txnID string) (resp *RespAppservicePing, err error) { - _, err = cli.MakeFullRequest(FullRequest{ +func (cli *Client) AppservicePing(ctx context.Context, id, txnID string) (resp *RespAppservicePing, err error) { + _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v1", "appservice", id, "ping"), RequestJSON: &ReqAppservicePing{TxnID: txnID}, @@ -2035,27 +2023,26 @@ func (cli *Client) AppservicePing(id, txnID string) (resp *RespAppservicePing, e return } -func (cli *Client) BeeperBatchSend(roomID id.RoomID, req *ReqBeeperBatchSend) (resp *RespBeeperBatchSend, err error) { +func (cli *Client) BeeperBatchSend(ctx context.Context, roomID id.RoomID, req *ReqBeeperBatchSend) (resp *RespBeeperBatchSend, err error) { u := cli.BuildClientURL("unstable", "com.beeper.backfill", "rooms", roomID, "batch_send") - _, err = cli.MakeRequest(http.MethodPost, u, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp) return } -func (cli *Client) BeeperMergeRooms(req *ReqBeeperMergeRoom) (resp *RespBeeperMergeRoom, err error) { +func (cli *Client) BeeperMergeRooms(ctx context.Context, req *ReqBeeperMergeRoom) (resp *RespBeeperMergeRoom, err error) { urlPath := cli.BuildClientURL("unstable", "com.beeper.chatmerging", "merge") - _, err = cli.MakeRequest(http.MethodPost, urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } -func (cli *Client) BeeperSplitRoom(req *ReqBeeperSplitRoom) (resp *RespBeeperSplitRoom, err error) { +func (cli *Client) BeeperSplitRoom(ctx context.Context, req *ReqBeeperSplitRoom) (resp *RespBeeperSplitRoom, err error) { urlPath := cli.BuildClientURL("unstable", "com.beeper.chatmerging", "rooms", req.RoomID, "split") - _, err = cli.MakeRequest(http.MethodPost, urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } - -func (cli *Client) BeeperDeleteRoom(roomID id.RoomID) (err error) { +func (cli *Client) BeeperDeleteRoom(ctx context.Context, roomID id.RoomID) (err error) { urlPath := cli.BuildClientURL("unstable", "com.beeper.yeet", "rooms", roomID, "delete") - _, err = cli.MakeRequest(http.MethodPost, urlPath, nil, nil) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, nil, nil) return } diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index d38df8f3..4528ae02 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -8,6 +8,7 @@ package crypto import ( + "context" "fmt" "maunium.net/go/mautrix" @@ -89,7 +90,7 @@ func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, erro } // PublishCrossSigningKeys signs and uploads the public keys of the given cross-signing keys to the server. -func (mach *OlmMachine) PublishCrossSigningKeys(keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error { +func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error { userID := mach.Client.UserID masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String()) masterKey := mautrix.CrossSigningKeys{ @@ -134,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(keys *CrossSigningKeysCache, uia }, } - err = mach.Client.UploadCrossSigningKeys(&mautrix.UploadCrossSigningKeysReq{ + err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{ Master: masterKey, SelfSigning: selfKey, UserSigning: userKey, diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go index 3067753a..9f4f3583 100644 --- a/crypto/cross_sign_pubkey.go +++ b/crypto/cross_sign_pubkey.go @@ -7,6 +7,7 @@ package crypto import ( + "context" "fmt" "maunium.net/go/mautrix" @@ -19,7 +20,7 @@ type CrossSigningPublicKeysCache struct { UserSigningKey id.Ed25519 } -func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCache { +func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache { if mach.crossSigningPubkeys != nil { return mach.crossSigningPubkeys } @@ -30,7 +31,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCa if mach.crossSigningPubkeysFetched { return nil } - cspk, err := mach.GetCrossSigningPublicKeys(mach.Client.UserID) + cspk, err := mach.GetCrossSigningPublicKeys(ctx, mach.Client.UserID) if err != nil { mach.Log.Error().Err(err).Msg("Failed to get own cross-signing public keys") return nil @@ -40,7 +41,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCa return mach.crossSigningPubkeys } -func (mach *OlmMachine) GetCrossSigningPublicKeys(userID id.UserID) (*CrossSigningPublicKeysCache, error) { +func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id.UserID) (*CrossSigningPublicKeysCache, error) { dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) if err != nil { return nil, fmt.Errorf("failed to get keys from database: %w", err) @@ -58,7 +59,7 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(userID id.UserID) (*CrossSigni } } - keys, err := mach.Client.QueryKeys(&mautrix.ReqQueryKeys{ + keys, err := mach.Client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: mautrix.DeviceKeysRequest{ userID: mautrix.DeviceIDList{}, }, diff --git a/crypto/cross_sign_signing.go b/crypto/cross_sign_signing.go index 62c41b38..1a5a0233 100644 --- a/crypto/cross_sign_signing.go +++ b/crypto/cross_sign_signing.go @@ -8,6 +8,7 @@ package crypto import ( + "context" "errors" "fmt" @@ -59,7 +60,7 @@ func (mach *OlmMachine) fetchMasterKey(device *id.Device, content *event.Verific } // SignUser creates a cross-signing signature for a user, stores it and uploads it to the server. -func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error { +func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKey id.Ed25519) error { if userID == mach.Client.UserID { return ErrCantSignOwnMasterKey } else if mach.CrossSigningKeys == nil || mach.CrossSigningKeys.UserSigningKey == nil { @@ -74,7 +75,7 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error { }, } - signature, err := mach.signAndUpload(masterKeyObj, userID, masterKey.String(), mach.CrossSigningKeys.UserSigningKey) + signature, err := mach.signAndUpload(ctx, masterKeyObj, userID, masterKey.String(), mach.CrossSigningKeys.UserSigningKey) if err != nil { return err } @@ -92,7 +93,7 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error { } // SignOwnMasterKey uses the current account for signing the current user's master key and uploads the signature. -func (mach *OlmMachine) SignOwnMasterKey() error { +func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error { if mach.CrossSigningKeys == nil { return ErrCrossSigningKeysNotCached } else if mach.account == nil { @@ -124,7 +125,7 @@ func (mach *OlmMachine) SignOwnMasterKey() error { Str("signature", signature). Msg("Signed own master key with own device key") - resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{ + resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{ userID: map[string]mautrix.ReqKeysSignatures{ masterKey.String(): masterKeyObj, }, @@ -144,14 +145,14 @@ func (mach *OlmMachine) SignOwnMasterKey() error { } // SignOwnDevice creates a cross-signing signature for a device belonging to the current user and uploads it to the server. -func (mach *OlmMachine) SignOwnDevice(device *id.Device) error { +func (mach *OlmMachine) SignOwnDevice(ctx context.Context, device *id.Device) error { if device.UserID != mach.Client.UserID { return ErrCantSignOtherDevice } else if mach.CrossSigningKeys == nil || mach.CrossSigningKeys.SelfSigningKey == nil { return ErrSelfSigningKeyNotCached } - deviceKeys, err := mach.getFullDeviceKeys(device) + deviceKeys, err := mach.getFullDeviceKeys(ctx, device) if err != nil { return err } @@ -166,7 +167,7 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error { deviceKeyObj.Keys[id.KeyID(keyID)] = key } - signature, err := mach.signAndUpload(deviceKeyObj, device.UserID, device.DeviceID.String(), mach.CrossSigningKeys.SelfSigningKey) + signature, err := mach.signAndUpload(ctx, deviceKeyObj, device.UserID, device.DeviceID.String(), mach.CrossSigningKeys.SelfSigningKey) if err != nil { return err } @@ -186,8 +187,8 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error { // getFullDeviceKeys gets the full device keys object for the given device. // This is used because we don't cache some of the details like list of algorithms and unsupported key types. -func (mach *OlmMachine) getFullDeviceKeys(device *id.Device) (*mautrix.DeviceKeys, error) { - devicesKeys, err := mach.Client.QueryKeys(&mautrix.ReqQueryKeys{ +func (mach *OlmMachine) getFullDeviceKeys(ctx context.Context, device *id.Device) (*mautrix.DeviceKeys, error) { + devicesKeys, err := mach.Client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: mautrix.DeviceKeysRequest{ device.UserID: mautrix.DeviceIDList{device.DeviceID}, }, @@ -208,7 +209,7 @@ func (mach *OlmMachine) getFullDeviceKeys(device *id.Device) (*mautrix.DeviceKey } // signAndUpload signs the given key signatures object and uploads it to the server. -func (mach *OlmMachine) signAndUpload(req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) { +func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) { signature, err := key.SignJSON(req) if err != nil { return "", fmt.Errorf("failed to sign JSON: %w", err) @@ -219,7 +220,7 @@ func (mach *OlmMachine) signAndUpload(req mautrix.ReqKeysSignatures, userID id.U }, } - resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{ + resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{ userID: map[string]mautrix.ReqKeysSignatures{ signedThing: req, }, diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index b8ca71cb..ef8a0ad3 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -7,6 +7,7 @@ package crypto import ( + "context" "fmt" "maunium.net/go/mautrix" @@ -16,16 +17,16 @@ import ( ) // FetchCrossSigningKeysFromSSSS fetches all the cross-signing keys from SSSS, decrypts them using the given key and stores them in the olm machine. -func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(key *ssss.Key) error { - masterKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningMaster, key) +func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(ctx context.Context, key *ssss.Key) error { + masterKey, err := mach.retrieveDecryptXSigningKey(ctx, event.AccountDataCrossSigningMaster, key) if err != nil { return err } - selfSignKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningSelf, key) + selfSignKey, err := mach.retrieveDecryptXSigningKey(ctx, event.AccountDataCrossSigningSelf, key) if err != nil { return err } - userSignKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningUser, key) + userSignKey, err := mach.retrieveDecryptXSigningKey(ctx, event.AccountDataCrossSigningUser, key) if err != nil { return err } @@ -38,12 +39,12 @@ func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(key *ssss.Key) error { } // retrieveDecryptXSigningKey retrieves the requested cross-signing key from SSSS and decrypts it using the given SSSS key. -func (mach *OlmMachine) retrieveDecryptXSigningKey(keyName event.Type, key *ssss.Key) ([utils.AESCTRKeyLength]byte, error) { +func (mach *OlmMachine) retrieveDecryptXSigningKey(ctx context.Context, keyName event.Type, key *ssss.Key) ([utils.AESCTRKeyLength]byte, error) { var decryptedKey [utils.AESCTRKeyLength]byte var encData ssss.EncryptedAccountDataEventContent // retrieve and parse the account data for this key type from SSSS - err := mach.Client.GetAccountData(keyName.Type, &encData) + err := mach.Client.GetAccountData(ctx, keyName.Type, &encData) if err != nil { return decryptedKey, err } @@ -62,8 +63,8 @@ func (mach *OlmMachine) retrieveDecryptXSigningKey(keyName event.Type, key *ssss // is used. The base58-formatted recovery key is the first return parameter. // // The account password of the user is required for uploading keys to the server. -func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphrase string) (string, error) { - key, err := mach.SSSS.GenerateAndUploadKey(passphrase) +func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, userPassword, passphrase string) (string, error) { + key, err := mach.SSSS.GenerateAndUploadKey(ctx, passphrase) if err != nil { return "", fmt.Errorf("failed to generate and upload SSSS key: %w", err) } @@ -77,12 +78,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphra recoveryKey := key.RecoveryKey() // Store the private keys in SSSS - if err := mach.UploadCrossSigningKeysToSSSS(key, keysCache); err != nil { + if err := mach.UploadCrossSigningKeysToSSSS(ctx, key, keysCache); err != nil { return recoveryKey, fmt.Errorf("failed to upload cross-signing keys to SSSS: %w", err) } // Publish cross-signing keys - err = mach.PublishCrossSigningKeys(keysCache, func(uiResp *mautrix.RespUserInteractive) interface{} { + err = mach.PublishCrossSigningKeys(ctx, keysCache, func(uiResp *mautrix.RespUserInteractive) interface{} { return &mautrix.ReqUIAuthLogin{ BaseAuthData: mautrix.BaseAuthData{ Type: mautrix.AuthTypePassword, @@ -96,7 +97,7 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphra return recoveryKey, fmt.Errorf("failed to publish cross-signing keys: %w", err) } - err = mach.SSSS.SetDefaultKeyID(key.ID) + err = mach.SSSS.SetDefaultKeyID(ctx, key.ID) if err != nil { return recoveryKey, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) } @@ -105,14 +106,14 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphra } // UploadCrossSigningKeysToSSSS stores the given cross-signing keys on the server encrypted with the given key. -func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(key *ssss.Key, keys *CrossSigningKeysCache) error { - if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil { +func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(ctx context.Context, key *ssss.Key, keys *CrossSigningKeysCache) error { + if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil { return err } - if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil { + if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil { return err } - if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil { + if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil { return err } return nil diff --git a/crypto/cross_sign_validation.go b/crypto/cross_sign_validation.go index e8a7d79a..27afeb73 100644 --- a/crypto/cross_sign_validation.go +++ b/crypto/cross_sign_validation.go @@ -89,7 +89,7 @@ func (mach *OlmMachine) IsDeviceTrusted(device *id.Device) bool { // IsUserTrusted returns whether a user has been determined to be trusted by our user-signing key having signed their master key. // In the case the user ID is our own and we have successfully retrieved our cross-signing keys, we trust our own user. func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bool, error) { - csPubkeys := mach.GetOwnCrossSigningPublicKeys() + csPubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) if csPubkeys == nil { return false, nil } diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 35dadac8..9d071ba9 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -124,6 +124,7 @@ func (helper *CryptoHelper) Init() error { } else { stateStore = helper.client.StateStore.(crypto.StateStore) } + ctx := context.Background() var cryptoStore crypto.Store if helper.unmanagedCryptoStore == nil { managedCryptoStore := crypto.NewSQLCryptoStore(helper.dbForManagedStores, dbutil.ZeroLogger(helper.log.With().Str("db_section", "crypto").Logger()), helper.DBAccountID, helper.client.DeviceID, helper.pickleKey) @@ -146,7 +147,7 @@ func (helper *CryptoHelper) Init() error { Str("username", helper.LoginAs.Identifier.User). Str("device_id", helper.LoginAs.DeviceID.String()). Msg("Logging in") - _, err = helper.client.Login(helper.LoginAs) + _, err = helper.client.Login(ctx, helper.LoginAs) if err != nil { return err } @@ -170,7 +171,7 @@ func (helper *CryptoHelper) Init() error { err := helper.mach.Load() if err != nil { return fmt.Errorf("failed to load olm account: %w", err) - } else if err = helper.verifyDeviceKeysOnServer(); err != nil { + } else if err = helper.verifyDeviceKeysOnServer(ctx); err != nil { return err } @@ -204,9 +205,9 @@ func (helper *CryptoHelper) Machine() *crypto.OlmMachine { return helper.mach } -func (helper *CryptoHelper) verifyDeviceKeysOnServer() error { +func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error { helper.log.Debug().Msg("Making sure our device has the expected keys on the server") - resp, err := helper.client.QueryKeys(&mautrix.ReqQueryKeys{ + resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ helper.client.UserID: {helper.client.DeviceID}, }, @@ -278,7 +279,7 @@ func (helper *CryptoHelper) postDecrypt(src mautrix.EventSource, decrypted *even helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(src|mautrix.EventSourceDecrypted, decrypted) } -func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { +func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { if helper == nil { return } @@ -294,7 +295,7 @@ func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.Sender Str("device_id", deviceID.String()). Str("room_id", roomID.String()). Logger() - err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{ + err := helper.mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{ userID: {deviceID}, helper.client.UserID: {"*"}, }) @@ -309,7 +310,7 @@ func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") - go helper.RequestSession(evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) + go helper.RequestSession(context.Background(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) if !helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { log.Debug().Msg("Didn't get session, giving up") diff --git a/crypto/devicelist.go b/crypto/devicelist.go index 7f779259..8514275c 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -108,7 +108,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT req.DeviceKeys[userID] = mautrix.DeviceIDList{} } log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users") - resp, err := mach.Client.QueryKeys(req) + resp, err := mach.Client.QueryKeys(ctx, req) if err != nil { log.Error().Err(err).Msg("Failed to query keys") return diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 80aef710..078ef518 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -280,11 +280,11 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, Int("user_count", len(toDeviceWithheld.Messages)). Msg("Sending to-device messages to report withheld key") // TODO remove the next 4 lines once clients support m.room_key.withheld - _, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld) + _, err = mach.Client.SendToDevice(ctx, event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld) if err != nil { log.Warn().Err(err).Msg("Failed to report withheld keys (legacy event type)") } - _, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld) + _, err = mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyWithheld, toDeviceWithheld) if err != nil { log.Warn().Err(err).Msg("Failed to report withheld keys") } @@ -327,7 +327,7 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session Int("device_count", deviceCount). Int("user_count", len(toDevice.Messages)). Msg("Sending to-device messages to share group session") - _, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice) + _, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice) return err } diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 8ae3ba1a..f21ecd02 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -83,7 +83,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id if len(request) == 0 { return nil } - resp, err := mach.Client.ClaimKeys(&mautrix.ReqClaimKeys{ + resp, err := mach.Client.ClaimKeys(ctx, &mautrix.ReqClaimKeys{ OneTimeKeys: request, Timeout: 10 * 1000, }) diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 1cbc41bd..9b8eef7e 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -48,7 +48,7 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to keyResponseReceived := make(chan struct{}) mach.roomKeyRequestFilled.Store(sessionID, keyResponseReceived) - err := mach.SendRoomKeyRequest(roomID, senderKey, sessionID, requestID, map[id.UserID][]id.DeviceID{toUser: {toDevice}}) + err := mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, requestID, map[id.UserID][]id.DeviceID{toUser: {toDevice}}) if err != nil { return nil, err } @@ -85,7 +85,7 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to }, } - mach.Client.SendToDevice(event.ToDeviceRoomKeyRequest, toDeviceCancel) + mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyRequest, toDeviceCancel) }() return resChan, nil } @@ -99,7 +99,7 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to // to the specific key request, but currently it only supports a single target device and is therefore deprecated. // A future function may properly support multiple targets and automatically canceling the other requests when receiving // the first response. -func (mach *OlmMachine) SendRoomKeyRequest(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, requestID string, users map[id.UserID][]id.DeviceID) error { +func (mach *OlmMachine) SendRoomKeyRequest(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, requestID string, users map[id.UserID][]id.DeviceID) error { if len(requestID) == 0 { requestID = mach.Client.TxnID() } @@ -126,7 +126,7 @@ func (mach *OlmMachine) SendRoomKeyRequest(roomID id.RoomID, senderKey id.Sender toDeviceReq.Messages[user][device] = requestEvent } } - _, err := mach.Client.SendToDevice(event.ToDeviceRoomKeyRequest, toDeviceReq) + _, err := mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyRequest, toDeviceReq) return err } @@ -188,7 +188,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt return true } -func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id.Device, request event.RequestedKeyInfo) { +func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShareRejection, device *id.Device, request event.RequestedKeyInfo) { if rejection.Code == "" { // If the rejection code is empty, it means don't share keys, but also don't tell the requester. return @@ -201,7 +201,7 @@ func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id Code: rejection.Code, Reason: rejection.Reason, } - err := mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceRoomKeyWithheld, &content) + err := mach.sendToOneDevice(ctx, device.UserID, device.DeviceID, event.ToDeviceRoomKeyWithheld, &content) if err != nil { mach.Log.Warn().Err(err). Str("code", string(rejection.Code)). @@ -209,7 +209,7 @@ func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id Str("device_id", device.DeviceID.String()). Msg("Failed to send key share rejection") } - err = mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceOrgMatrixRoomKeyWithheld, &content) + err = mach.sendToOneDevice(ctx, device.UserID, device.DeviceID, event.ToDeviceOrgMatrixRoomKeyWithheld, &content) if err != nil { mach.Log.Warn().Err(err). Str("code", string(rejection.Code)). @@ -270,7 +270,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User rejection := mach.AllowKeyShare(ctx, device, content.Body) if rejection != nil { - mach.rejectKeyRequest(*rejection, device, content.Body) + mach.rejectKeyRequest(ctx, *rejection, device, content.Body) return } @@ -278,15 +278,15 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Requested group session not available") - mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body) + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) } else { log.Error().Err(err).Msg("Failed to get group session to forward") - mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body) + mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) } return } else if igs == nil { log.Error().Msg("Didn't find group session to forward") - mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body) + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) return } if internalID := igs.ID(); internalID != content.Body.SessionID { @@ -299,7 +299,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User exportedKey, err := igs.Internal.Export(firstKnownIndex) if err != nil { log.Error().Err(err).Msg("Failed to export group session to forward") - mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body) + mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) return } diff --git a/crypto/machine.go b/crypto/machine.go index 2c9b63c9..37a21da3 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -381,17 +381,17 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) { mach.handleBeeperRoomKeyAck(ctx, evt.Sender, content) // verification cases case *event.VerificationStartEventContent: - mach.handleVerificationStart(evt.Sender, content, content.TransactionID, 10*time.Minute, "") + mach.handleVerificationStart(ctx, evt.Sender, content, content.TransactionID, 10*time.Minute, "") case *event.VerificationAcceptEventContent: - mach.handleVerificationAccept(evt.Sender, content, content.TransactionID) + mach.handleVerificationAccept(ctx, evt.Sender, content, content.TransactionID) case *event.VerificationKeyEventContent: - mach.handleVerificationKey(evt.Sender, content, content.TransactionID) + mach.handleVerificationKey(ctx, evt.Sender, content, content.TransactionID) case *event.VerificationMacEventContent: - mach.handleVerificationMAC(evt.Sender, content, content.TransactionID) + mach.handleVerificationMAC(ctx, evt.Sender, content, content.TransactionID) case *event.VerificationCancelEventContent: mach.handleVerificationCancel(evt.Sender, content, content.TransactionID) case *event.VerificationRequestEventContent: - mach.handleVerificationRequest(evt.Sender, content, content.TransactionID, "") + mach.handleVerificationRequest(ctx, evt.Sender, content, content.TransactionID, "") case *event.RoomKeyWithheldEventContent: mach.handleRoomKeyWithheld(ctx, content) default: @@ -473,7 +473,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De Str("to_identity_key", device.IdentityKey.String()). Str("olm_session_id", olmSess.ID().String()). Msg("Sending encrypted to-device event") - _, err = mach.Client.SendToDevice(event.ToDeviceEncrypted, + _, err = mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, &mautrix.ReqSendToDevice{ Messages: map[id.UserID]map[id.DeviceID]*event.Content{ device.UserID: { @@ -624,7 +624,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro defer mach.otkUploadLock.Unlock() if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 { log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count") - resp, err := mach.Client.UploadKeys(&mautrix.ReqUploadKeys{}) + resp, err := mach.Client.UploadKeys(ctx, &mautrix.ReqUploadKeys{}) if err != nil { return fmt.Errorf("failed to check current OTK counts: %w", err) } @@ -649,7 +649,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro OneTimeKeys: oneTimeKeys, } log.Debug().Int("count", len(oneTimeKeys)).Msg("Uploading one-time keys") - _, err := mach.Client.UploadKeys(req) + _, err := mach.Client.UploadKeys(ctx, req) if err != nil { return err } diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 15709bdb..c73a859a 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -67,17 +67,17 @@ func (store *SQLCryptoStore) Flush() error { } // PutNextBatch stores the next sync batch token for the current account. -func (store *SQLCryptoStore) PutNextBatch(nextBatch string) error { +func (store *SQLCryptoStore) PutNextBatch(ctx context.Context, nextBatch string) error { store.SyncToken = nextBatch - _, err := store.DB.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID) + _, err := store.DB.ExecContext(ctx, `UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID) return err } // GetNextBatch retrieves the next sync batch token for the current account. -func (store *SQLCryptoStore) GetNextBatch() (string, error) { +func (store *SQLCryptoStore) GetNextBatch(ctx context.Context) (string, error) { if store.SyncToken == "" { err := store.DB. - QueryRow("SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID). + QueryRowContext(ctx, "SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID). Scan(&store.SyncToken) if !errors.Is(err, sql.ErrNoRows) { return "", err @@ -88,18 +88,18 @@ func (store *SQLCryptoStore) GetNextBatch() (string, error) { var _ mautrix.SyncStore = (*SQLCryptoStore)(nil) -func (store *SQLCryptoStore) SaveFilterID(_ id.UserID, _ string) {} -func (store *SQLCryptoStore) LoadFilterID(_ id.UserID) string { return "" } +func (store *SQLCryptoStore) SaveFilterID(ctx context.Context, _ id.UserID, _ string) {} +func (store *SQLCryptoStore) LoadFilterID(ctx context.Context, _ id.UserID) string { return "" } -func (store *SQLCryptoStore) SaveNextBatch(_ id.UserID, nextBatchToken string) { - err := store.PutNextBatch(nextBatchToken) +func (store *SQLCryptoStore) SaveNextBatch(ctx context.Context, _ id.UserID, nextBatchToken string) { + err := store.PutNextBatch(ctx, nextBatchToken) if err != nil { // TODO handle error } } -func (store *SQLCryptoStore) LoadNextBatch(_ id.UserID) string { - nb, err := store.GetNextBatch() +func (store *SQLCryptoStore) LoadNextBatch(ctx context.Context, _ id.UserID) string { + nb, err := store.GetNextBatch(ctx) if err != nil { // TODO handle error } diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go index b74deca1..2dac30e1 100644 --- a/crypto/ssss/client.go +++ b/crypto/ssss/client.go @@ -7,6 +7,7 @@ package ssss import ( + "context" "fmt" "maunium.net/go/mautrix" @@ -29,9 +30,9 @@ type DefaultSecretStorageKeyContent struct { } // GetDefaultKeyID retrieves the default key ID for this account from SSSS. -func (mach *Machine) GetDefaultKeyID() (string, error) { +func (mach *Machine) GetDefaultKeyID(ctx context.Context) (string, error) { var data DefaultSecretStorageKeyContent - err := mach.Client.GetAccountData(event.AccountDataSecretStorageDefaultKey.Type, &data) + err := mach.Client.GetAccountData(ctx, event.AccountDataSecretStorageDefaultKey.Type, &data) if err != nil { if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.RespError != nil && httpErr.RespError.ErrCode == "M_NOT_FOUND" { return "", ErrNoDefaultKeyAccountDataEvent @@ -45,36 +46,36 @@ func (mach *Machine) GetDefaultKeyID() (string, error) { } // SetDefaultKeyID sets the default key ID for this account on the server. -func (mach *Machine) SetDefaultKeyID(keyID string) error { - return mach.Client.SetAccountData(event.AccountDataSecretStorageDefaultKey.Type, &DefaultSecretStorageKeyContent{keyID}) +func (mach *Machine) SetDefaultKeyID(ctx context.Context, keyID string) error { + return mach.Client.SetAccountData(ctx, event.AccountDataSecretStorageDefaultKey.Type, &DefaultSecretStorageKeyContent{keyID}) } // GetKeyData gets the details about the given key ID. -func (mach *Machine) GetKeyData(keyID string) (keyData *KeyMetadata, err error) { +func (mach *Machine) GetKeyData(ctx context.Context, keyID string) (keyData *KeyMetadata, err error) { keyData = &KeyMetadata{id: keyID} - err = mach.Client.GetAccountData(fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData) + err = mach.Client.GetAccountData(ctx, fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData) return } // SetKeyData stores SSSS key metadata on the server. -func (mach *Machine) SetKeyData(keyID string, keyData *KeyMetadata) error { - return mach.Client.SetAccountData(fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData) +func (mach *Machine) SetKeyData(ctx context.Context, keyID string, keyData *KeyMetadata) error { + return mach.Client.SetAccountData(ctx, fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData) } // GetDefaultKeyData gets the details about the default key ID (see GetDefaultKeyID). -func (mach *Machine) GetDefaultKeyData() (keyID string, keyData *KeyMetadata, err error) { - keyID, err = mach.GetDefaultKeyID() +func (mach *Machine) GetDefaultKeyData(ctx context.Context) (keyID string, keyData *KeyMetadata, err error) { + keyID, err = mach.GetDefaultKeyID(ctx) if err != nil { return } - keyData, err = mach.GetKeyData(keyID) + keyData, err = mach.GetKeyData(ctx, keyID) return } // GetDecryptedAccountData gets the account data event with the given event type and decrypts it using the given key. -func (mach *Machine) GetDecryptedAccountData(eventType event.Type, key *Key) ([]byte, error) { +func (mach *Machine) GetDecryptedAccountData(ctx context.Context, eventType event.Type, key *Key) ([]byte, error) { var encData EncryptedAccountDataEventContent - err := mach.Client.GetAccountData(eventType.Type, &encData) + err := mach.Client.GetAccountData(ctx, eventType.Type, &encData) if err != nil { return nil, err } @@ -82,7 +83,7 @@ func (mach *Machine) GetDecryptedAccountData(eventType event.Type, key *Key) ([] } // SetEncryptedAccountData encrypts the given data with the given keys and stores it on the server. -func (mach *Machine) SetEncryptedAccountData(eventType event.Type, data []byte, keys ...*Key) error { +func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType event.Type, data []byte, keys ...*Key) error { if len(keys) == 0 { return ErrNoKeyGiven } @@ -90,17 +91,17 @@ func (mach *Machine) SetEncryptedAccountData(eventType event.Type, data []byte, for _, key := range keys { encrypted[key.ID] = key.Encrypt(eventType.Type, data) } - return mach.Client.SetAccountData(eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted}) + return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted}) } // GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server. -func (mach *Machine) GenerateAndUploadKey(passphrase string) (key *Key, err error) { +func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) { key, err = NewKey(passphrase) if err != nil { return nil, fmt.Errorf("failed to generate new key: %w", err) } - err = mach.SetKeyData(key.ID, key.Metadata) + err = mach.SetKeyData(ctx, key.ID, key.Metadata) if err != nil { err = fmt.Errorf("failed to upload key: %w", err) } diff --git a/crypto/store_test.go b/crypto/store_test.go index 9062d70d..ebeef393 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -54,8 +54,8 @@ func getCryptoStores(t *testing.T) map[string]Store { func TestPutNextBatch(t *testing.T) { stores := getCryptoStores(t) store := stores["sql"].(*SQLCryptoStore) - store.PutNextBatch("batch1") - if batch, _ := store.GetNextBatch(); batch != "batch1" { + store.PutNextBatch(context.Background(), "batch1") + if batch, _ := store.GetNextBatch(context.Background()); batch != "batch1" { t.Errorf("Expected batch1, got %v", batch) } } diff --git a/crypto/verification.go b/crypto/verification.go index 4925fed6..be246874 100644 --- a/crypto/verification.go +++ b/crypto/verification.go @@ -54,8 +54,8 @@ const ( ) // sendToOneDevice sends a to-device event to a single device. -func (mach *OlmMachine) sendToOneDevice(userID id.UserID, deviceID id.DeviceID, eventType event.Type, content interface{}) error { - _, err := mach.Client.SendToDevice(eventType, &mautrix.ReqSendToDevice{ +func (mach *OlmMachine) sendToOneDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID, eventType event.Type, content interface{}) error { + _, err := mach.Client.SendToDevice(ctx, eventType, &mautrix.ReqSendToDevice{ Messages: map[id.UserID]map[id.DeviceID]*event.Content{ userID: { deviceID: { @@ -118,19 +118,19 @@ type verificationState struct { } // getTransactionState retrieves the given transaction's state, or cancels the transaction if it cannot be found or there is a mismatch. -func (mach *OlmMachine) getTransactionState(transactionID string, userID id.UserID) (*verificationState, error) { +func (mach *OlmMachine) getTransactionState(ctx context.Context, transactionID string, userID id.UserID) (*verificationState, error) { verStateInterface, ok := mach.keyVerificationTransactionState.Load(userID.String() + ":" + transactionID) if !ok { - _ = mach.SendSASVerificationCancel(userID, id.DeviceID("*"), transactionID, "Unknown transaction: "+transactionID, event.VerificationCancelUnknownTransaction) + _ = mach.SendSASVerificationCancel(ctx, userID, id.DeviceID("*"), transactionID, "Unknown transaction: "+transactionID, event.VerificationCancelUnknownTransaction) return nil, ErrUnknownTransaction } verState := verStateInterface.(*verificationState) if verState.otherDevice.UserID != userID { reason := fmt.Sprintf("Unknown user for transaction %v: %v", transactionID, userID) if verState.inRoomID == "" { - _ = mach.SendSASVerificationCancel(userID, id.DeviceID("*"), transactionID, reason, event.VerificationCancelUserMismatch) + _ = mach.SendSASVerificationCancel(ctx, userID, id.DeviceID("*"), transactionID, reason, event.VerificationCancelUserMismatch) } else { - _ = mach.SendInRoomSASVerificationCancel(verState.inRoomID, userID, transactionID, reason, event.VerificationCancelUserMismatch) + _ = mach.SendInRoomSASVerificationCancel(ctx, verState.inRoomID, userID, transactionID, reason, event.VerificationCancelUserMismatch) } mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) return nil, fmt.Errorf("%w %s: %s", ErrUnknownUserForTransaction, transactionID, userID) @@ -140,9 +140,9 @@ func (mach *OlmMachine) getTransactionState(transactionID string, userID id.User // handleVerificationStart handles an incoming m.key.verification.start message. // It initializes the state for this SAS verification process and stores it. -func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) { +func (mach *OlmMachine) handleVerificationStart(ctx context.Context, userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) { mach.Log.Debug().Msgf("Received verification start from %v", content.FromDevice) - otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice) + otherDevice, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice) if err != nil { mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID) return @@ -150,9 +150,9 @@ func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event warnAndCancel := func(logReason, cancelReason string) { mach.Log.Warn().Msgf("Canceling verification transaction %v as it %s", transactionID, logReason) if inRoomID == "" { - _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod) + _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod) } else { - _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, cancelReason, event.VerificationCancelUnknownMethod) + _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, cancelReason, event.VerificationCancelUnknownMethod) } } switch { @@ -168,21 +168,21 @@ func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event case !content.SupportsSASMethod(event.SASDecimal): warnAndCancel("does not support decimal SAS", "Decimal SAS method must be supported") default: - mach.actuallyStartVerification(userID, content, otherDevice, transactionID, timeout, inRoomID) + mach.actuallyStartVerification(ctx, userID, content, otherDevice, transactionID, timeout, inRoomID) } } -func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *event.VerificationStartEventContent, otherDevice *id.Device, transactionID string, timeout time.Duration, inRoomID id.RoomID) { +func (mach *OlmMachine) actuallyStartVerification(ctx context.Context, userID id.UserID, content *event.VerificationStartEventContent, otherDevice *id.Device, transactionID string, timeout time.Duration, inRoomID id.RoomID) { if inRoomID != "" && transactionID != "" { - verState, err := mach.getTransactionState(transactionID, userID) + verState, err := mach.getTransactionState(ctx, transactionID, userID) if err != nil { mach.Log.Error().Msgf("Failed to get transaction state for in-room verification %s start: %v", transactionID, err) - _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error") + _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error") return } - mach.timeoutAfter(verState, transactionID, timeout) + mach.timeoutAfter(ctx, verState, transactionID, timeout) sasMethods := commonSASMethods(verState.hooks, content.ShortAuthenticationString) - err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods) + err = mach.SendInRoomSASVerificationAccept(ctx, inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods) if err != nil { mach.Log.Error().Msgf("Error accepting in-room SAS verification: %v", err) } @@ -196,9 +196,9 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve if len(sasMethods) == 0 { mach.Log.Error().Msgf("No common SAS methods: %v", content.ShortAuthenticationString) if inRoomID == "" { - _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod) + _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod) } else { - _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod) + _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod) } return } @@ -221,20 +221,20 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve // transaction already exists mach.Log.Error().Msgf("Transaction %v already exists, canceling", transactionID) if inRoomID == "" { - _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage) + _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage) } else { - _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage) + _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage) } return } - mach.timeoutAfter(verState, transactionID, timeout) + mach.timeoutAfter(ctx, verState, transactionID, timeout) var err error if inRoomID == "" { - err = mach.SendSASVerificationAccept(userID, content, verState.sas.GetPubkey(), sasMethods) + err = mach.SendSASVerificationAccept(ctx, userID, content, verState.sas.GetPubkey(), sasMethods) } else { - err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods) + err = mach.SendInRoomSASVerificationAccept(ctx, inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods) } if err != nil { mach.Log.Error().Msgf("Error accepting SAS verification: %v", err) @@ -243,9 +243,9 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve mach.Log.Debug().Msgf("Not accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) var err error if inRoomID == "" { - err = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser) + err = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser) } else { - err = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser) + err = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser) } if err != nil { mach.Log.Error().Msgf("Error canceling SAS verification: %v", err) @@ -255,8 +255,8 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve } } -func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID string, timeout time.Duration) { - timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), timeout) +func (mach *OlmMachine) timeoutAfter(ctx context.Context, verState *verificationState, transactionID string, timeout time.Duration) { + timeoutCtx, timeoutCancel := context.WithTimeout(ctx, timeout) verState.extendTimeout = timeoutCancel go func() { mapKey := verState.otherDevice.UserID.String() + ":" + transactionID @@ -272,7 +272,7 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID if timeoutCtx.Err() == context.DeadlineExceeded { // if deadline exceeded cancel due to timeout mach.keyVerificationTransactionState.Delete(mapKey) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Timed out", event.VerificationCancelByTimeout) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Timed out", event.VerificationCancelByTimeout) mach.Log.Warn().Msgf("Verification transaction %v is canceled due to timing out", transactionID) verState.lock.Unlock() return @@ -288,9 +288,9 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID // handleVerificationAccept handles an incoming m.key.verification.accept message. // It continues the SAS verification process by sending the SAS key message to the other device. -func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) { +func (mach *OlmMachine) handleVerificationAccept(ctx context.Context, userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) { mach.Log.Debug().Msgf("Received verification accept for transaction %v", transactionID) - verState, err := mach.getTransactionState(transactionID, userID) + verState, err := mach.getTransactionState(ctx, transactionID, userID) if err != nil { mach.Log.Error().Msgf("Error getting transaction state: %v", err) return @@ -303,7 +303,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even // unexpected accept at this point mach.Log.Warn().Msgf("Unexpected verification accept message for transaction %v", transactionID) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage) return } @@ -315,7 +315,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even mach.Log.Warn().Msgf("Canceling verification transaction %v due to unknown parameter", transactionID) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod) return } @@ -325,9 +325,9 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even verState.verificationStarted = true if verState.inRoomID == "" { - err = mach.SendSASVerificationKey(userID, verState.otherDevice.DeviceID, transactionID, string(key)) + err = mach.SendSASVerificationKey(ctx, userID, verState.otherDevice.DeviceID, transactionID, string(key)) } else { - err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key)) + err = mach.SendInRoomSASVerificationKey(ctx, verState.inRoomID, userID, transactionID, string(key)) } if err != nil { mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err) @@ -337,9 +337,9 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even // handleVerificationKey handles an incoming m.key.verification.key message. // It stores the other device's public key in order to acquire the SAS shared secret. -func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) { +func (mach *OlmMachine) handleVerificationKey(ctx context.Context, userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) { mach.Log.Debug().Msgf("Got verification key for transaction %v: %v", transactionID, content.Key) - verState, err := mach.getTransactionState(transactionID, userID) + verState, err := mach.getTransactionState(ctx, transactionID, userID) if err != nil { mach.Log.Error().Msgf("Error getting transaction state: %v", err) return @@ -354,7 +354,7 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V // unexpected key at this point mach.Log.Warn().Msgf("Unexpected verification key message for transaction %v", transactionID) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage) return } @@ -372,7 +372,7 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V if expectedCommitment != verState.commitment { mach.Log.Warn().Msgf("Canceling verification transaction %v due to commitment mismatch", transactionID) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch) return } } else { @@ -380,9 +380,9 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V key := verState.sas.GetPubkey() if verState.inRoomID == "" { - err = mach.SendSASVerificationKey(userID, device.DeviceID, transactionID, string(key)) + err = mach.SendSASVerificationKey(ctx, userID, device.DeviceID, transactionID, string(key)) } else { - err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key)) + err = mach.SendInRoomSASVerificationKey(ctx, verState.inRoomID, userID, transactionID, string(key)) } if err != nil { mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err) @@ -419,13 +419,13 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V mach.Log.Debug().Msgf("Generated SAS (%v): %v", sasMethod.Type(), sas) go func() { result := verState.hooks.VerifySASMatch(device, sas) - mach.sasCompared(result, transactionID, verState) + mach.sasCompared(ctx, result, transactionID, verState) }() } // sasCompared is called asynchronously. It waits for the SAS to be compared for the verification to proceed. // If the SAS match, then our MAC is sent out. Otherwise the transaction is canceled. -func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verState *verificationState) { +func (mach *OlmMachine) sasCompared(ctx context.Context, didMatch bool, transactionID string, verState *verificationState) { verState.lock.Lock() defer verState.lock.Unlock() verState.extendTimeout() @@ -433,9 +433,9 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat verState.sasMatched <- true var err error if verState.inRoomID == "" { - err = mach.SendSASVerificationMAC(verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas) + err = mach.SendSASVerificationMAC(ctx, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas) } else { - err = mach.SendInRoomSASVerificationMAC(verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas) + err = mach.SendInRoomSASVerificationMAC(ctx, verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas) } if err != nil { mach.Log.Error().Msgf("Error sending verification MAC to other device: %v", err) @@ -447,9 +447,9 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat // handleVerificationMAC handles an incoming m.key.verification.mac message. // It verifies the other device's MAC and if the MAC is valid it marks the device as trusted. -func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.VerificationMacEventContent, transactionID string) { +func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.UserID, content *event.VerificationMacEventContent, transactionID string) { mach.Log.Debug().Msgf("Got MAC for verification %v: %v, MAC for keys: %v", transactionID, content.Mac, content.Keys) - verState, err := mach.getTransactionState(transactionID, userID) + verState, err := mach.getTransactionState(ctx, transactionID, userID) if err != nil { mach.Log.Error().Msgf("Error getting transaction state: %v", err) return @@ -466,7 +466,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V if !verState.verificationStarted || !verState.keyReceived { // unexpected MAC at this point mach.Log.Warn().Msgf("Unexpected MAC message for transaction %v", transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage) return } @@ -478,7 +478,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V if !matched { mach.Log.Warn().Msgf("SAS do not match! Canceling transaction %v", transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch) return } @@ -494,14 +494,14 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V mach.Log.Debug().Msgf("Expected %s keys MAC, got %s", expectedKeysMAC, content.Keys) if content.Keys != expectedKeysMAC { mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched keys MAC", transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch) return } mach.Log.Debug().Msgf("Expected %s PK MAC, got %s", expectedPKMAC, content.Mac[keyID]) if content.Mac[keyID] != expectedPKMAC { mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched PK MAC", transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch) return } @@ -514,7 +514,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V if mach.CrossSigningKeys != nil { if device.UserID == mach.Client.UserID { - err := mach.SignOwnDevice(device) + err := mach.SignOwnDevice(ctx, device) if err != nil { mach.Log.Error().Msgf("Failed to cross-sign own device %s: %v", device.DeviceID, err) } else { @@ -525,7 +525,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V if err != nil { mach.Log.Warn().Msgf("Failed to fetch %s's master key: %v", device.UserID, err) } else { - if err := mach.SignUser(device.UserID, masterKey); err != nil { + if err := mach.SignUser(ctx, device.UserID, masterKey); err != nil { mach.Log.Error().Msgf("Failed to cross-sign master key of %s: %v", device.UserID, err) } else { mach.Log.Debug().Msgf("Cross-signed master key of %v after SAS verification", device.UserID) @@ -559,9 +559,9 @@ func (mach *OlmMachine) handleVerificationCancel(userID id.UserID, content *even } // handleVerificationRequest handles an incoming m.key.verification.request message. -func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) { +func (mach *OlmMachine) handleVerificationRequest(ctx context.Context, userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) { mach.Log.Debug().Msgf("Received verification request from %v", content.FromDevice) - otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice) + otherDevice, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice) if err != nil { mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID) return @@ -569,9 +569,9 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve if !content.SupportsVerificationMethod(event.VerificationMethodSAS) { mach.Log.Warn().Msgf("Canceling verification transaction %v as SAS is not supported", transactionID) if inRoomID == "" { - _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod) + _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod) } else { - _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod) + _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod) } return } @@ -579,14 +579,14 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve if resp == AcceptRequest { mach.Log.Debug().Msgf("Accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) if inRoomID == "" { - _, err = mach.NewSASVerificationWith(otherDevice, hooks, transactionID, mach.DefaultSASTimeout) + _, err = mach.NewSASVerificationWith(ctx, otherDevice, hooks, transactionID, mach.DefaultSASTimeout) } else { - if err := mach.SendInRoomSASVerificationReady(inRoomID, transactionID); err != nil { + if err := mach.SendInRoomSASVerificationReady(ctx, inRoomID, transactionID); err != nil { mach.Log.Error().Msgf("Error sending in-room SAS verification ready: %v", err) } if mach.Client.UserID < otherDevice.UserID { // up to us to send the start message - _, err = mach.newInRoomSASVerificationWithInner(inRoomID, otherDevice, hooks, transactionID, mach.DefaultSASTimeout) + _, err = mach.newInRoomSASVerificationWithInner(ctx, inRoomID, otherDevice, hooks, transactionID, mach.DefaultSASTimeout) } } if err != nil { @@ -595,9 +595,9 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve } else if resp == RejectRequest { mach.Log.Debug().Msgf("Rejecting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) if inRoomID == "" { - _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser) + _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser) } else { - _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser) + _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser) } } else { mach.Log.Debug().Msgf("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) @@ -606,14 +606,14 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve // NewSimpleSASVerificationWith starts the SAS verification process with another device with a default timeout, // a generated transaction ID and support for both emoji and decimal SAS methods. -func (mach *OlmMachine) NewSimpleSASVerificationWith(device *id.Device, hooks VerificationHooks) (string, error) { - return mach.NewSASVerificationWith(device, hooks, "", mach.DefaultSASTimeout) +func (mach *OlmMachine) NewSimpleSASVerificationWith(ctx context.Context, device *id.Device, hooks VerificationHooks) (string, error) { + return mach.NewSASVerificationWith(ctx, device, hooks, "", mach.DefaultSASTimeout) } // NewSASVerificationWith starts the SAS verification process with another device. // If the other device accepts the verification transaction, the methods in `hooks` will be used to verify the SAS match and to complete the transaction.. // If the transaction ID is empty, a new one is generated. -func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) { +func (mach *OlmMachine) NewSASVerificationWith(ctx context.Context, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) { if transactionID == "" { transactionID = strconv.Itoa(rand.Int()) } @@ -631,7 +631,7 @@ func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks Verifica verState.lock.Lock() defer verState.lock.Unlock() - startEvent, err := mach.SendSASVerificationStart(device.UserID, device.DeviceID, transactionID, hooks.VerificationMethods()) + startEvent, err := mach.SendSASVerificationStart(ctx, device.UserID, device.DeviceID, transactionID, hooks.VerificationMethods()) if err != nil { return "", err } @@ -651,13 +651,13 @@ func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks Verifica return "", ErrTransactionAlreadyExists } - mach.timeoutAfter(verState, transactionID, timeout) + mach.timeoutAfter(ctx, verState, transactionID, timeout) return transactionID, nil } // CancelSASVerification is used by the user to cancel a SAS verification process with the given reason. -func (mach *OlmMachine) CancelSASVerification(userID id.UserID, transactionID, reason string) error { +func (mach *OlmMachine) CancelSASVerification(ctx context.Context, userID id.UserID, transactionID, reason string) error { mapKey := userID.String() + ":" + transactionID verStateInterface, ok := mach.keyVerificationTransactionState.Load(mapKey) if !ok { @@ -668,21 +668,21 @@ func (mach *OlmMachine) CancelSASVerification(userID id.UserID, transactionID, r defer verState.lock.Unlock() mach.Log.Trace().Msgf("User canceled verification transaction %v with reason: %v", transactionID, reason) mach.keyVerificationTransactionState.Delete(mapKey) - return mach.callbackAndCancelSASVerification(verState, transactionID, reason, event.VerificationCancelByUser) + return mach.callbackAndCancelSASVerification(ctx, verState, transactionID, reason, event.VerificationCancelByUser) } // SendSASVerificationCancel is used to manually send a SAS cancel message process with the given reason and cancellation code. -func (mach *OlmMachine) SendSASVerificationCancel(userID id.UserID, deviceID id.DeviceID, transactionID string, reason string, code event.VerificationCancelCode) error { +func (mach *OlmMachine) SendSASVerificationCancel(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, reason string, code event.VerificationCancelCode) error { content := &event.VerificationCancelEventContent{ TransactionID: transactionID, Reason: reason, Code: code, } - return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationCancel, content) + return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationCancel, content) } // SendSASVerificationStart is used to manually send the SAS verification start message to another device. -func (mach *OlmMachine) SendSASVerificationStart(toUserID id.UserID, toDeviceID id.DeviceID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) { +func (mach *OlmMachine) SendSASVerificationStart(ctx context.Context, toUserID id.UserID, toDeviceID id.DeviceID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) { sasMethods := make([]event.SASMethod, len(methods)) for i, method := range methods { sasMethods[i] = method.Type() @@ -696,14 +696,14 @@ func (mach *OlmMachine) SendSASVerificationStart(toUserID id.UserID, toDeviceID MessageAuthenticationCodes: []event.MACMethod{event.HKDFHMACSHA256}, ShortAuthenticationString: sasMethods, } - return content, mach.sendToOneDevice(toUserID, toDeviceID, event.ToDeviceVerificationStart, content) + return content, mach.sendToOneDevice(ctx, toUserID, toDeviceID, event.ToDeviceVerificationStart, content) } // SendSASVerificationAccept is used to manually send an accept for a SAS verification process from a received m.key.verification.start event. -func (mach *OlmMachine) SendSASVerificationAccept(fromUser id.UserID, startEvent *event.VerificationStartEventContent, publicKey []byte, methods []VerificationMethod) error { +func (mach *OlmMachine) SendSASVerificationAccept(ctx context.Context, fromUser id.UserID, startEvent *event.VerificationStartEventContent, publicKey []byte, methods []VerificationMethod) error { if startEvent.Method != event.VerificationMethodSAS { reason := "Unknown verification method: " + string(startEvent.Method) - if err := mach.SendSASVerificationCancel(fromUser, startEvent.FromDevice, startEvent.TransactionID, reason, event.VerificationCancelUnknownMethod); err != nil { + if err := mach.SendSASVerificationCancel(ctx, fromUser, startEvent.FromDevice, startEvent.TransactionID, reason, event.VerificationCancelUnknownMethod); err != nil { return err } return ErrUnknownVerificationMethod @@ -730,25 +730,25 @@ func (mach *OlmMachine) SendSASVerificationAccept(fromUser id.UserID, startEvent ShortAuthenticationString: sasMethods, Commitment: hash, } - return mach.sendToOneDevice(fromUser, startEvent.FromDevice, event.ToDeviceVerificationAccept, content) + return mach.sendToOneDevice(ctx, fromUser, startEvent.FromDevice, event.ToDeviceVerificationAccept, content) } -func (mach *OlmMachine) callbackAndCancelSASVerification(verState *verificationState, transactionID, reason string, code event.VerificationCancelCode) error { +func (mach *OlmMachine) callbackAndCancelSASVerification(ctx context.Context, verState *verificationState, transactionID, reason string, code event.VerificationCancelCode) error { go verState.hooks.OnCancel(true, reason, code) - return mach.SendSASVerificationCancel(verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, reason, code) + return mach.SendSASVerificationCancel(ctx, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, reason, code) } // SendSASVerificationKey sends the ephemeral public key for a device to the partner device. -func (mach *OlmMachine) SendSASVerificationKey(userID id.UserID, deviceID id.DeviceID, transactionID string, key string) error { +func (mach *OlmMachine) SendSASVerificationKey(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, key string) error { content := &event.VerificationKeyEventContent{ TransactionID: transactionID, Key: key, } - return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationKey, content) + return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationKey, content) } // SendSASVerificationMAC is use the MAC of a device's key to the partner device. -func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error { +func (mach *OlmMachine) SendSASVerificationMAC(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error { keyID := id.NewKeyID(id.KeyAlgorithmEd25519, mach.Client.DeviceID.String()) signingKey := mach.account.SigningKey() @@ -784,7 +784,7 @@ func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.Dev Mac: macMap, } - return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationMAC, content) + return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationMAC, content) } func commonSASMethods(hooks VerificationHooks, otherDeviceMethods []event.SASMethod) []VerificationMethod { diff --git a/crypto/verification_in_room.go b/crypto/verification_in_room.go index cc9b9212..325b45ba 100644 --- a/crypto/verification_in_room.go +++ b/crypto/verification_in_room.go @@ -38,6 +38,7 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error { return ErrNoRelatesTo } + ctx := context.Background() switch content := evt.Content.Parsed.(type) { case *event.MessageEventContent: if content.MsgType == event.MsgVerificationRequest { @@ -54,18 +55,18 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error { Timestamp: evt.Timestamp, TransactionID: evt.ID.String(), } - mach.handleVerificationRequest(evt.Sender, newContent, evt.ID.String(), evt.RoomID) + mach.handleVerificationRequest(ctx, evt.Sender, newContent, evt.ID.String(), evt.RoomID) } case *event.VerificationStartEventContent: - mach.handleVerificationStart(evt.Sender, content, content.RelatesTo.EventID.String(), 10*time.Minute, evt.RoomID) + mach.handleVerificationStart(ctx, evt.Sender, content, content.RelatesTo.EventID.String(), 10*time.Minute, evt.RoomID) case *event.VerificationReadyEventContent: - mach.handleInRoomVerificationReady(evt.Sender, evt.RoomID, content, content.RelatesTo.EventID.String()) + mach.handleInRoomVerificationReady(ctx, evt.Sender, evt.RoomID, content, content.RelatesTo.EventID.String()) case *event.VerificationAcceptEventContent: - mach.handleVerificationAccept(evt.Sender, content, content.RelatesTo.EventID.String()) + mach.handleVerificationAccept(ctx, evt.Sender, content, content.RelatesTo.EventID.String()) case *event.VerificationKeyEventContent: - mach.handleVerificationKey(evt.Sender, content, content.RelatesTo.EventID.String()) + mach.handleVerificationKey(ctx, evt.Sender, content, content.RelatesTo.EventID.String()) case *event.VerificationMacEventContent: - mach.handleVerificationMAC(evt.Sender, content, content.RelatesTo.EventID.String()) + mach.handleVerificationMAC(ctx, evt.Sender, content, content.RelatesTo.EventID.String()) case *event.VerificationCancelEventContent: mach.handleVerificationCancel(evt.Sender, content, content.RelatesTo.EventID.String()) } @@ -73,7 +74,7 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error { } // SendInRoomSASVerificationCancel is used to manually send an in-room SAS cancel message process with the given reason and cancellation code. -func (mach *OlmMachine) SendInRoomSASVerificationCancel(roomID id.RoomID, userID id.UserID, transactionID string, reason string, code event.VerificationCancelCode) error { +func (mach *OlmMachine) SendInRoomSASVerificationCancel(ctx context.Context, roomID id.RoomID, userID id.UserID, transactionID string, reason string, code event.VerificationCancelCode) error { content := &event.VerificationCancelEventContent{ RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, Reason: reason, @@ -81,16 +82,16 @@ func (mach *OlmMachine) SendInRoomSASVerificationCancel(roomID id.RoomID, userID To: userID, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationCancel, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationCancel, content) if err != nil { return err } - _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) return err } // SendInRoomSASVerificationRequest is used to manually send an in-room SAS verification request message to another user. -func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUserID id.UserID, methods []VerificationMethod) (string, error) { +func (mach *OlmMachine) SendInRoomSASVerificationRequest(ctx context.Context, roomID id.RoomID, toUserID id.UserID, methods []VerificationMethod) (string, error) { content := &event.MessageEventContent{ MsgType: event.MsgVerificationRequest, FromDevice: mach.Client.DeviceID, @@ -98,11 +99,11 @@ func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUse To: toUserID, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.EventMessage, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.EventMessage, content) if err != nil { return "", err } - resp, err := mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + resp, err := mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) if err != nil { return "", err } @@ -110,23 +111,23 @@ func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUse } // SendInRoomSASVerificationReady is used to manually send an in-room SAS verification ready message to another user. -func (mach *OlmMachine) SendInRoomSASVerificationReady(roomID id.RoomID, transactionID string) error { +func (mach *OlmMachine) SendInRoomSASVerificationReady(ctx context.Context, roomID id.RoomID, transactionID string) error { content := &event.VerificationReadyEventContent{ FromDevice: mach.Client.DeviceID, Methods: []event.VerificationMethod{event.VerificationMethodSAS}, RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationReady, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationReady, content) if err != nil { return err } - _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) return err } // SendInRoomSASVerificationStart is used to manually send the in-room SAS verification start message to another user. -func (mach *OlmMachine) SendInRoomSASVerificationStart(roomID id.RoomID, toUserID id.UserID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) { +func (mach *OlmMachine) SendInRoomSASVerificationStart(ctx context.Context, roomID id.RoomID, toUserID id.UserID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) { sasMethods := make([]event.SASMethod, len(methods)) for i, method := range methods { sasMethods[i] = method.Type() @@ -142,19 +143,19 @@ func (mach *OlmMachine) SendInRoomSASVerificationStart(roomID id.RoomID, toUserI To: toUserID, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationStart, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationStart, content) if err != nil { return nil, err } - _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) return content, err } // SendInRoomSASVerificationAccept is used to manually send an accept for an in-room SAS verification process from a received m.key.verification.start event. -func (mach *OlmMachine) SendInRoomSASVerificationAccept(roomID id.RoomID, fromUser id.UserID, startEvent *event.VerificationStartEventContent, transactionID string, publicKey []byte, methods []VerificationMethod) error { +func (mach *OlmMachine) SendInRoomSASVerificationAccept(ctx context.Context, roomID id.RoomID, fromUser id.UserID, startEvent *event.VerificationStartEventContent, transactionID string, publicKey []byte, methods []VerificationMethod) error { if startEvent.Method != event.VerificationMethodSAS { reason := "Unknown verification method: " + string(startEvent.Method) - if err := mach.SendInRoomSASVerificationCancel(roomID, fromUser, transactionID, reason, event.VerificationCancelUnknownMethod); err != nil { + if err := mach.SendInRoomSASVerificationCancel(ctx, roomID, fromUser, transactionID, reason, event.VerificationCancelUnknownMethod); err != nil { return err } return ErrUnknownVerificationMethod @@ -183,32 +184,32 @@ func (mach *OlmMachine) SendInRoomSASVerificationAccept(roomID id.RoomID, fromUs To: fromUser, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationAccept, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationAccept, content) if err != nil { return err } - _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) return err } // SendInRoomSASVerificationKey sends the ephemeral public key for a device to the partner device for an in-room verification. -func (mach *OlmMachine) SendInRoomSASVerificationKey(roomID id.RoomID, userID id.UserID, transactionID string, key string) error { +func (mach *OlmMachine) SendInRoomSASVerificationKey(ctx context.Context, roomID id.RoomID, userID id.UserID, transactionID string, key string) error { content := &event.VerificationKeyEventContent{ RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, Key: key, To: userID, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationKey, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationKey, content) if err != nil { return err } - _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) return err } // SendInRoomSASVerificationMAC sends the MAC of a device's key to the partner device for an in-room verification. -func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error { +func (mach *OlmMachine) SendInRoomSASVerificationMAC(ctx context.Context, roomID id.RoomID, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error { keyID := id.NewKeyID(id.KeyAlgorithmEd25519, mach.Client.DeviceID.String()) signingKey := mach.account.SigningKey() @@ -245,28 +246,28 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id To: userID, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationMAC, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationMAC, content) if err != nil { return err } - _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) return err } // NewInRoomSASVerificationWith starts the in-room SAS verification process with another user in the given room. // It returns the generated transaction ID. -func (mach *OlmMachine) NewInRoomSASVerificationWith(inRoomID id.RoomID, userID id.UserID, hooks VerificationHooks, timeout time.Duration) (string, error) { - return mach.newInRoomSASVerificationWithInner(inRoomID, &id.Device{UserID: userID}, hooks, "", timeout) +func (mach *OlmMachine) NewInRoomSASVerificationWith(ctx context.Context, inRoomID id.RoomID, userID id.UserID, hooks VerificationHooks, timeout time.Duration) (string, error) { + return mach.newInRoomSASVerificationWithInner(ctx, inRoomID, &id.Device{UserID: userID}, hooks, "", timeout) } -func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) { +func (mach *OlmMachine) newInRoomSASVerificationWithInner(ctx context.Context, inRoomID id.RoomID, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) { mach.Log.Debug().Msgf("Starting new in-room verification transaction user %v", device.UserID) request := transactionID == "" if request { var err error // get new transaction ID from the request message event ID - transactionID, err = mach.SendInRoomSASVerificationRequest(inRoomID, device.UserID, hooks.VerificationMethods()) + transactionID, err = mach.SendInRoomSASVerificationRequest(ctx, inRoomID, device.UserID, hooks.VerificationMethods()) if err != nil { return "", err } @@ -286,7 +287,7 @@ func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, de if !request { // start in-room verification - startEvent, err := mach.SendInRoomSASVerificationStart(inRoomID, device.UserID, transactionID, hooks.VerificationMethods()) + startEvent, err := mach.SendInRoomSASVerificationStart(ctx, inRoomID, device.UserID, transactionID, hooks.VerificationMethods()) if err != nil { return "", err } @@ -305,19 +306,19 @@ func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, de mach.keyVerificationTransactionState.Store(device.UserID.String()+":"+transactionID, verState) - mach.timeoutAfter(verState, transactionID, timeout) + mach.timeoutAfter(ctx, verState, transactionID, timeout) return transactionID, nil } -func (mach *OlmMachine) handleInRoomVerificationReady(userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) { - device, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice) +func (mach *OlmMachine) handleInRoomVerificationReady(ctx context.Context, userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) { + device, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice) if err != nil { mach.Log.Error().Msgf("Error fetching device %v of user %v: %v", content.FromDevice, userID, err) return } - verState, err := mach.getTransactionState(transactionID, userID) + verState, err := mach.getTransactionState(ctx, transactionID, userID) if err != nil { mach.Log.Error().Msgf("Error getting transaction state: %v", err) return @@ -327,7 +328,7 @@ func (mach *OlmMachine) handleInRoomVerificationReady(userID id.UserID, roomID i if mach.Client.UserID < userID { // up to us to send the start message verState.lock.Lock() - mach.newInRoomSASVerificationWithInner(roomID, device, verState.hooks, transactionID, 10*time.Minute) + mach.newInRoomSASVerificationWithInner(ctx, roomID, device, verState.hooks, transactionID, 10*time.Minute) verState.lock.Unlock() } } diff --git a/synapseadmin/register.go b/synapseadmin/register.go index 36b310a9..d7a94f6f 100644 --- a/synapseadmin/register.go +++ b/synapseadmin/register.go @@ -73,11 +73,10 @@ func (req *ReqSharedSecretRegister) Sign(secret string) string { // This does not need to be called manually as SharedSecretRegister will automatically call this if no nonce is provided. func (cli *Client) GetRegisterNonce(ctx context.Context) (string, error) { var resp respGetRegisterNonce - _, err := cli.MakeFullRequest(mautrix.FullRequest{ + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodGet, URL: cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), ResponseJSON: &resp, - Context: ctx, }) if err != nil { return "", err @@ -98,12 +97,11 @@ func (cli *Client) SharedSecretRegister(ctx context.Context, sharedSecret string } req.SHA1Checksum = req.Sign(sharedSecret) var resp mautrix.RespRegister - _, err = cli.MakeFullRequest(mautrix.FullRequest{ + _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodPost, URL: cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), RequestJSON: req, ResponseJSON: &resp, - Context: ctx, }) if err != nil { return nil, err diff --git a/synapseadmin/userapi.go b/synapseadmin/userapi.go index ee457abc..aa1ce2a7 100644 --- a/synapseadmin/userapi.go +++ b/synapseadmin/userapi.go @@ -33,11 +33,10 @@ type ReqResetPassword struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#reset-password func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) error { reqURL := cli.BuildAdminURL("v1", "reset_password", req.UserID) - _, err := cli.MakeFullRequest(mautrix.FullRequest{ + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodPost, URL: reqURL, RequestJSON: &req, - Context: ctx, }) return err } @@ -50,11 +49,10 @@ func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) erro // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#check-username-availability func (cli *Client) UsernameAvailable(ctx context.Context, username string) (resp *mautrix.RespRegisterAvailable, err error) { u := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username}) - _, err = cli.MakeFullRequest(mautrix.FullRequest{ + _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodGet, URL: u, ResponseJSON: &resp, - Context: ctx, }) if err == nil && !resp.Available { err = fmt.Errorf(`request returned OK status without "available": true`) @@ -76,11 +74,10 @@ type RespListDevices struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#list-all-devices func (cli *Client) ListDevices(ctx context.Context, userID id.UserID) (resp *RespListDevices, err error) { - _, err = cli.MakeFullRequest(mautrix.FullRequest{ + _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodGet, URL: cli.BuildAdminURL("v2", "users", userID, "devices"), ResponseJSON: &resp, - Context: ctx, }) return } @@ -105,11 +102,10 @@ type RespUserInfo struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#query-user-account func (cli *Client) GetUserInfo(ctx context.Context, userID id.UserID) (resp *RespUserInfo, err error) { - _, err = cli.MakeFullRequest(mautrix.FullRequest{ + _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodGet, URL: cli.BuildAdminURL("v2", "users", userID), ResponseJSON: &resp, - Context: ctx, }) return } diff --git a/syncstore.go b/syncstore.go index d5fe2db4..8b5b3a55 100644 --- a/syncstore.go +++ b/syncstore.go @@ -1,21 +1,25 @@ package mautrix import ( + "context" "errors" "maunium.net/go/mautrix/id" ) +var _ SyncStore = (*MemorySyncStore)(nil) +var _ SyncStore = (*AccountDataStore)(nil) + // SyncStore is an interface which must be satisfied to store client data. // // You can either write a struct which persists this data to disk, or you can use the // provided "MemorySyncStore" which just keeps data around in-memory which is lost on // restarts. type SyncStore interface { - SaveFilterID(userID id.UserID, filterID string) - LoadFilterID(userID id.UserID) string - SaveNextBatch(userID id.UserID, nextBatchToken string) - LoadNextBatch(userID id.UserID) string + SaveFilterID(ctx context.Context, userID id.UserID, filterID string) + LoadFilterID(ctx context.Context, userID id.UserID) string + SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) + LoadNextBatch(ctx context.Context, userID id.UserID) string } // Deprecated: renamed to SyncStore @@ -32,22 +36,22 @@ type MemorySyncStore struct { } // SaveFilterID to memory. -func (s *MemorySyncStore) SaveFilterID(userID id.UserID, filterID string) { +func (s *MemorySyncStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) { s.Filters[userID] = filterID } // LoadFilterID from memory. -func (s *MemorySyncStore) LoadFilterID(userID id.UserID) string { +func (s *MemorySyncStore) LoadFilterID(ctx context.Context, userID id.UserID) string { return s.Filters[userID] } // SaveNextBatch to memory. -func (s *MemorySyncStore) SaveNextBatch(userID id.UserID, nextBatchToken string) { +func (s *MemorySyncStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) { s.NextBatch[userID] = nextBatchToken } // LoadNextBatch from memory. -func (s *MemorySyncStore) LoadNextBatch(userID id.UserID) string { +func (s *MemorySyncStore) LoadNextBatch(ctx context.Context, userID id.UserID) string { return s.NextBatch[userID] } @@ -72,21 +76,21 @@ type accountData struct { NextBatch string `json:"next_batch"` } -func (s *AccountDataStore) SaveFilterID(userID id.UserID, filterID string) { +func (s *AccountDataStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) { if userID.String() != s.client.UserID.String() { panic("AccountDataStore must only be used with a single account") } s.FilterID = filterID } -func (s *AccountDataStore) LoadFilterID(userID id.UserID) string { +func (s *AccountDataStore) LoadFilterID(ctx context.Context, userID id.UserID) string { if userID.String() != s.client.UserID.String() { panic("AccountDataStore must only be used with a single account") } return s.FilterID } -func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string) { +func (s *AccountDataStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) { if userID.String() != s.client.UserID.String() { panic("AccountDataStore must only be used with a single account") } else if nextBatchToken == s.nextBatch { @@ -97,7 +101,7 @@ func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string NextBatch: nextBatchToken, } - err := s.client.SetAccountData(s.EventType, data) + err := s.client.SetAccountData(ctx, s.EventType, data) if err != nil { s.client.Log.Warn().Err(err).Msg("Failed to save next batch token to account data") } else { @@ -109,14 +113,14 @@ func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string } } -func (s *AccountDataStore) LoadNextBatch(userID id.UserID) string { +func (s *AccountDataStore) LoadNextBatch(ctx context.Context, userID id.UserID) string { if userID.String() != s.client.UserID.String() { panic("AccountDataStore must only be used with a single account") } data := &accountData{} - err := s.client.GetAccountData(s.EventType, data) + err := s.client.GetAccountData(ctx, s.EventType, data) if err != nil { if errors.Is(err, MNotFound) { s.client.Log.Debug().Msg("No next batch token found in account data")