Skip to content

Commit

Permalink
Add OnAfterUpdate hook, pass stop processing flag via context
Browse files Browse the repository at this point in the history
  • Loading branch information
alufers committed Oct 17, 2023
1 parent ae2b17b commit da6fc44
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 39 deletions.
12 changes: 8 additions & 4 deletions paczkobot/image_scanning_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func NewImageScanningService(app *BotApp) *ImageScanningService {
}
}

func (i *ImageScanningService) OnUpdate(ctx context.Context) bool {
func (i *ImageScanningService) OnUpdate(ctx context.Context) context.Context {
update := tghelpers.UpdateFromCtx(ctx)
if update.Message != nil && update.Message.Photo != nil && len(update.Message.Photo) > 0 {

Expand All @@ -46,16 +46,16 @@ func (i *ImageScanningService) OnUpdate(ctx context.Context) bool {
})
if err != nil {
log.Printf("Failed to get file: %v", err)
return false
return ctx
}
url := fmt.Sprintf("https://api.telegram.org/file/bot%s/%s", i.App.Bot.Token, file.FilePath)
err = i.ScanIncomingImage(ctx, tghelpers.ArgsFromCtx(ctx), url)
if err != nil {
log.Printf("Failed to ScanIncomingImage: %v", err)
}
return true
return tghelpers.WithStopProcessingCommands(ctx)
}
return false
return ctx
}

func (i *ImageScanningService) ScanIncomingImage(ctx context.Context, args *tghelpers.CommandArguments, url string) error {
Expand Down Expand Up @@ -178,6 +178,10 @@ func (i *ImageScanningService) ScanIncomingImage(ctx context.Context, args *tghe
return nil
}

func (i *ImageScanningService) OnAfterUpdate(ctx context.Context) context.Context {
return ctx
}

func (*ImageScanningService) DrawResultPoints(img image.Image, points []gozxing.ResultPoint) image.Image {
if len(points) <= 1 {
return img
Expand Down
11 changes: 6 additions & 5 deletions providers/cainiao/cainiao_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ func (pp *CainiaoProvider) MatchesNumber(trackingNumber string) bool {
}

func (pp *CainiaoProvider) Track(ctx context.Context, trackingNumber string) (*commondata.TrackingData, error) {

req, err := http.NewRequest(
"GET",
"https://global.cainiao.com/global/detail.json?mailNos="+url.QueryEscape(trackingNumber)+"&lang=en-US&language=en-US",
Expand All @@ -37,6 +36,7 @@ func (pp *CainiaoProvider) Track(ctx context.Context, trackingNumber string) (*c
if err != nil {
return nil, commonerrors.NewNetworkError(pp.GetName(), req)
}
defer httpResponse.Body.Close()

if httpResponse.StatusCode != 200 {
return nil, commonerrors.NotFoundError
Expand Down Expand Up @@ -83,20 +83,21 @@ func (pp *CainiaoProvider) Track(ctx context.Context, trackingNumber string) (*c
nil,
)
if err != nil {
return td, nil
return td, nil //nolint:nilerr
}
commondata.SetCommonHTTPHeaders(&cityReq.Header)
cityResp, err := http.DefaultClient.Do(cityReq)
if err != nil {
return td, nil
return td, nil //nolint:nilerr
}
defer cityResp.Body.Close()
if cityResp.StatusCode != 200 {
return td, nil
return td, nil //nolint:nilerr
}
var cityResponse GetCityResponse
err = json.NewDecoder(cityResp.Body).Decode(&cityResponse)
if err != nil || !cityResponse.Success {
return td, nil
return td, nil //nolint:nilerr
}

td.Destination = cityResponse.Module + ", " + td.Destination
Expand Down
18 changes: 11 additions & 7 deletions tghelpers/ask_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ func NewAskService(bot BotAPI) *AskService {
}

// Implements UpdateHook
func (a *AskService) OnUpdate(ctx context.Context) bool {
func (a *AskService) OnUpdate(ctx context.Context) context.Context {
update := UpdateFromCtx(ctx)
a.AskCallbacksMutex.Lock()
defer a.AskCallbacksMutex.Unlock()
if update.CallbackQuery != nil {
if update.CallbackQuery.Message == nil || update.CallbackQuery.Message.Chat == nil {
return false
return ctx
}
chatID := update.CallbackQuery.Message.Chat.ID
if update.CallbackQuery.Data == "/cancel" {
Expand All @@ -45,7 +45,7 @@ func (a *AskService) OnUpdate(ctx context.Context) bool {
callback("", errors.New("canceled"))
delete(a.AskCallbacks, chatID)
}
return true
return WithStopProcessingCommands(ctx)
}
if update.CallbackQuery.Data == "/yes" {
if callback, ok := a.AskCallbacks[chatID]; ok {
Expand All @@ -56,7 +56,7 @@ func (a *AskService) OnUpdate(ctx context.Context) bool {
callback("", nil)
delete(a.AskCallbacks, chatID)
}
return true
return WithStopProcessingCommands(ctx)
}
if strings.HasPrefix(update.CallbackQuery.Data, "/sugg ") {
val := strings.TrimPrefix(update.CallbackQuery.Data, "/sugg ")
Expand All @@ -77,7 +77,7 @@ func (a *AskService) OnUpdate(ctx context.Context) bool {
if callback, ok := a.AskCallbacks[update.Message.Chat.ID]; ok {
callback("", errors.New("canceled"))
delete(a.AskCallbacks, update.Message.Chat.ID)
return false
return ctx
}
}

Expand All @@ -88,11 +88,15 @@ func (a *AskService) OnUpdate(ctx context.Context) bool {
}
callback(update.Message.Text, nil)
delete(a.AskCallbacks, update.Message.Chat.ID)
return true
return WithStopProcessingCommands(ctx)
}
}

return false
return ctx
}

func (a *AskService) OnAfterUpdate(ctx context.Context) context.Context {
return ctx
}

// AskForArgument asks the user at the specified chatID for a text value.
Expand Down
4 changes: 2 additions & 2 deletions tghelpers/ask_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestAskServiceReturnsFalseForUnrelatedUpdates(t *testing.T) {
},
)
res := askService.OnUpdate(ctx)
assert.False(t, res)
assert.True(t, res.Value(tghelpers.StopProcessingCommandsCtxKey) == nil) // should return false because it's not a related update
}

func TestAskServiceConfirmWorks(t *testing.T) {
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestAskServiceConfirmWorks(t *testing.T) {
},
)
res := askService.OnUpdate(ctx)
assert.True(t, res) // should return true because it's a related update
assert.True(t, res.Value(tghelpers.StopProcessingCommandsCtxKey) != nil) // should return true because it's a related update
}()

return msg, nil
Expand Down
36 changes: 18 additions & 18 deletions tghelpers/command_dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,30 +79,30 @@ func (d *CommandDispatcher) processIncomingUpdate(ctx context.Context, u tgbotap
}

for _, hook := range d.UpdateHooks {
if hook.OnUpdate(ctx) {
return // hook has handled the message stop processing
}
ctx = hook.OnUpdate(ctx)
}

shouldProcessCommands := ctx.Value(StopProcessingCommandsCtxKey) == nil
var err error

for _, cmd := range d.Commands {
if CommandMatches(cmd, cmdText) {
args.Command = cmd
for i, argTpl := range cmd.Arguments() {
if argTpl.Variadic {
args.NamedArguments[argTpl.Name] = strings.Join(args.Arguments[i:], " ")
break
if shouldProcessCommands {
for _, cmd := range d.Commands {
if CommandMatches(cmd, cmdText) {
args.Command = cmd
for i, argTpl := range cmd.Arguments() {
if argTpl.Variadic {
args.NamedArguments[argTpl.Name] = strings.Join(args.Arguments[i:], " ")
break
}
if i >= len(args.Arguments) {
break
}
args.NamedArguments[argTpl.Name] = args.Arguments[i]
}
if i >= len(args.Arguments) {
break
}
args.NamedArguments[argTpl.Name] = args.Arguments[i]
}

err = cmd.Execute(ctx)
err = cmd.Execute(ctx)

break
break
}
}
}

Expand Down
8 changes: 6 additions & 2 deletions tghelpers/command_dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ type FakeUpdateHook struct {
didRun bool
}

func (h *FakeUpdateHook) OnUpdate(context.Context) bool {
func (h *FakeUpdateHook) OnUpdate(ctx context.Context) context.Context {
h.didRun = true
return true
return ctx
}

func (h *FakeUpdateHook) OnAfterUpdate(ctx context.Context) context.Context {
return ctx
}

// a test that chcks if command dispatcher executes update hooks
Expand Down
17 changes: 16 additions & 1 deletion tghelpers/update_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@ import (
"context"
)

type stopProcessingCommandsCtxKeyType struct{}

// StopProcessingCommandsCtxKey is a context key that can be used to stop
// processing commands for an update.

// It should be added to the context by an UpdateHook, which
// wishes to stop processing commands.
var StopProcessingCommandsCtxKey = stopProcessingCommandsCtxKeyType{}

func WithStopProcessingCommands(ctx context.Context) context.Context {
return context.WithValue(ctx, StopProcessingCommandsCtxKey, true)
}

// UpdateHook allows a service to listen for all telegram updates
// before they are processed for commands
type UpdateHook interface {
Expand All @@ -12,5 +25,7 @@ type UpdateHook interface {
// as handled by the hook. Further processing is stopped.
// The update shall be extracted from the context using
// tghelpers.UpdateFromCtx(ctx)
OnUpdate(context.Context) bool
OnUpdate(context.Context) context.Context

OnAfterUpdate(context.Context) context.Context
}

0 comments on commit da6fc44

Please sign in to comment.