Skip to content

Commit

Permalink
fix: cancel request context when timeout exceeded (#244)
Browse files Browse the repository at this point in the history
* feat: requests timeout respecting CLOUD_RUN_TIMEOUT_SECONDS

* add test coverage

* fix windows test
  • Loading branch information
garethgeorge authored Jun 13, 2024
1 parent ac7db72 commit 298bc02
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 1 deletion.
4 changes: 4 additions & 0 deletions funcframework/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ func convertBackgroundToCloudEvent(ceHandler http.Handler) http.Handler {
return
}
}
r, cancel := setContextTimeoutIfRequested(r)
if cancel != nil {
defer cancel()
}
ceHandler.ServeHTTP(w, r)
})
}
Expand Down
26 changes: 25 additions & 1 deletion funcframework/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ import (
"os"
"reflect"
"runtime/debug"
"strconv"
"strings"
"time"

"github.com/GoogleCloudPlatform/functions-framework-go/internal/registry"
cloudevents "github.com/cloudevents/sdk-go/v2"
Expand Down Expand Up @@ -196,6 +198,10 @@ func wrapHTTPFunction(fn func(http.ResponseWriter, *http.Request)) (http.Handler
defer fmt.Println()
defer fmt.Fprintln(os.Stderr)
}
r, cancel := setContextTimeoutIfRequested(r)
if cancel != nil {
defer cancel()
}
defer recoverPanic(w, "user function execution", false)
fn(w, r)
}), nil
Expand All @@ -212,7 +218,10 @@ func wrapEventFunction(fn interface{}) (http.Handler, error) {
defer fmt.Println()
defer fmt.Fprintln(os.Stderr)
}

r, cancel := setContextTimeoutIfRequested(r)
if cancel != nil {
defer cancel()
}
if shouldConvertCloudEventToBackgroundRequest(r) {
if err := convertCloudEventToBackgroundRequest(r); err != nil {
writeHTTPErrorResponse(w, http.StatusBadRequest, crashStatus, fmt.Sprintf("error converting CloudEvent to Background Event: %v", err))
Expand Down Expand Up @@ -388,3 +397,18 @@ func writeHTTPErrorResponse(w http.ResponseWriter, statusCode int, status, msg s
w.WriteHeader(statusCode)
fmt.Fprint(w, msg)
}

// setContextTimeoutIfRequested replaces the request's context with a cancellation if requested
func setContextTimeoutIfRequested(r *http.Request) (*http.Request, func()) {
timeoutStr := os.Getenv("CLOUD_RUN_TIMEOUT_SECONDS")
if timeoutStr == "" {
return r, nil
}
timeoutSecs, err := strconv.Atoi(timeoutStr)
if err != nil {
fmt.Fprintf(os.Stderr, "Could not parse CLOUD_RUN_TIMEOUT_SECONDS as an integer value in seconds: %v\n", err)
return r, nil
}
ctx, cancel := context.WithTimeout(r.Context(), time.Duration(timeoutSecs)*time.Second)
return r.WithContext(ctx), cancel
}
121 changes: 121 additions & 0 deletions funcframework/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ import (
"os"
"strings"
"testing"
"time"

"github.com/GoogleCloudPlatform/functions-framework-go/functions"
"github.com/GoogleCloudPlatform/functions-framework-go/internal/registry"
cloudevents "github.com/cloudevents/sdk-go/v2"
"github.com/cloudevents/sdk-go/v2/event"
"github.com/google/go-cmp/cmp"
)

Expand Down Expand Up @@ -995,6 +997,125 @@ func TestServeMultipleFunctions(t *testing.T) {
}
}

func TestHTTPRequestTimeout(t *testing.T) {
timeoutEnvVar := "CLOUD_RUN_TIMEOUT_SECONDS"
prev := os.Getenv(timeoutEnvVar)
defer os.Setenv(timeoutEnvVar, prev)

cloudeventsJSON := []byte(`{
"specversion" : "1.0",
"type" : "com.github.pull.create",
"source" : "https://github.com/cloudevents/spec/pull",
"subject" : "123",
"id" : "A234-1234-1234",
"time" : "2018-04-05T17:31:00Z",
"comexampleextension1" : "value",
"datacontenttype" : "application/xml",
"data" : "<much wow=\"xml\"/>"
}`)

tcs := []struct {
name string
wantDeadline bool
waitForExpiration bool
timeout string
}{
{
name: "deadline not requested",
wantDeadline: false,
waitForExpiration: false,
timeout: "",
},
{
name: "NaN deadline",
wantDeadline: false,
waitForExpiration: false,
timeout: "aaa",
},
{
name: "very long deadline",
wantDeadline: true,
waitForExpiration: false,
timeout: "3600",
},
{
name: "short deadline should terminate",
wantDeadline: true,
waitForExpiration: true,
timeout: "1",
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
defer cleanup()
os.Setenv(timeoutEnvVar, tc.timeout)

var httpReqCtx context.Context
functions.HTTP("http", func(w http.ResponseWriter, r *http.Request) {
if tc.waitForExpiration {
<-r.Context().Done()
}
httpReqCtx = r.Context()
})
var ceReqCtx context.Context
functions.CloudEvent("cloudevent", func(ctx context.Context, event event.Event) error {
if tc.waitForExpiration {
<-ctx.Done()
}
ceReqCtx = ctx
return nil
})
server, err := initServer()
if err != nil {
t.Fatalf("initServer(): %v", err)
}
srv := httptest.NewServer(server)
defer srv.Close()

t.Run("http", func(t *testing.T) {
_, err = http.Get(srv.URL + "/http")
if err != nil {
t.Fatalf("expected success")
}
if httpReqCtx == nil {
t.Fatalf("expected non-nil request context")
}
deadline, ok := httpReqCtx.Deadline()
if ok != tc.wantDeadline {
t.Errorf("expected deadline %v but got %v", tc.wantDeadline, ok)
}
if expired := deadline.Before(time.Now()); ok && expired != tc.waitForExpiration {
t.Errorf("expected expired %v but got %v", tc.waitForExpiration, expired)
}
})

t.Run("cloudevent", func(t *testing.T) {
req, err := http.NewRequest("POST", srv.URL+"/cloudevent", bytes.NewBuffer(cloudeventsJSON))
if err != nil {
t.Fatalf("failed to create request")
}
req.Header.Add("Content-Type", "application/cloudevents+json")
client := &http.Client{}
_, err = client.Do(req)
if err != nil {
t.Fatalf("request failed")
}
if ceReqCtx == nil {
t.Fatalf("expected non-nil request context")
}
deadline, ok := ceReqCtx.Deadline()
if ok != tc.wantDeadline {
t.Errorf("expected deadline %v but got %v", tc.wantDeadline, ok)
}
if expired := deadline.Before(time.Now()); ok && expired != tc.waitForExpiration {
t.Errorf("expected expired %v but got %v", tc.waitForExpiration, expired)
}
})
})
}
}

func cleanup() {
os.Unsetenv("FUNCTION_TARGET")
registry.Default().Reset()
Expand Down

0 comments on commit 298bc02

Please sign in to comment.