Skip to content

Commit

Permalink
Merge pull request #1 from bubunyo/feature/improve-api
Browse files Browse the repository at this point in the history
add improvements to api
  • Loading branch information
Bubunyo Nyavor authored Dec 3, 2023
2 parents f4b89d6 + e575141 commit 8e73e71
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 71 deletions.
20 changes: 9 additions & 11 deletions README.MD
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,23 @@ package main

type PingService struct{}

func (e PingService) Echo(_ context.Context, req *rpc.RequestParams) (any, error) {
func (s PingService) Echo(_ context.Context, req *rpc.RequestParams) (any, error) {
return "ok", nil
}

func main() {
server := rpc.NewServer()
server.ExecutionTimeout = 15 * time.Second // max time a function should execute for.
server.MaxBytesRead = 1 << 20 // (1mb) - the maximum size of the total request payload

ping := PingService{}
pingService := rpc.NewService("PingService")
pingService.RegisterMethod("Ping", ping.Echo)
func (s PingService) Register() (string, rpc.RequestMap) {
return "PingService", map[string]rpc.RequestFunc{
"Ping": s.Echo,
}
}

server.AddService(pingService)
func main() {
server := rpc.NewDefaultServer()
server.AddService(PingService{})

mux := http.NewServeMux()
mux.Handle("/rpc", server)
log.Fatalln(http.ListenAndServe(":8080", mux))
}
```
You can consume a single ping service resource with this curl request
Expand Down
26 changes: 16 additions & 10 deletions example/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,26 @@ import (

type PingService struct{}

func (e PingService) Echo(_ context.Context, req *rpc.RequestParams) (any, error) {
func (s PingService) Echo(_ context.Context, req *rpc.RequestParams) (any, error) {
return "ok", nil
}

func main() {
server := rpc.NewServer()
server.ExecutionTimeout = 15 * time.Second // max time a function should execute for.
server.MaxBytesRead = 1 << 20 // (1mb) - the maximum size of the total request payload

ping := PingService{}
pingService := rpc.NewService("PingService")
pingService.RegisterMethod("Ping", ping.Echo)
func (s PingService) Register() (string, rpc.RequestMap) {
return "PingService", map[string]rpc.RequestFunc{
"Ping": s.Echo,
}
}

server.AddService(pingService)
func main() {
// Create an rpc server
server := rpc.NewServer(rpc.Opts{
ExecutionTimeout: 15 * time.Second, // max time a function should execute for.
MaxBytesRead: 1 << 20, // (1mb) - the maximum size of the total request payload
})
// or use the default servver with
// server := rpc.NewDefaultServer()

server.AddService(PingService{})

mux := http.NewServeMux()
mux.Handle("/rpc", server)
Expand Down
68 changes: 41 additions & 27 deletions rcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,47 @@ import (
)

const (
Version = "2.0" // JSON RPC Version
MaxBytesRead = 1 << 20 // 1mb
Version = "2.0" // JSON RPC Version
MaxBytesRead = 1 << 20 // 1mb
ExecutionTimeout = 15 * time.Second // execution timout
)

var (
defaultReq = Request{JsonRpc: Version}
defaultReq = Request{JsonRpc: Version}
DefaultOpts = Opts{
MaxBytesRead: MaxBytesRead,
ExecutionTimeout: ExecutionTimeout,
}
)

// NewServer creates a new JSON RPC Server that can handle requests.
func NewServer() *Service {
return NewService("")
func NewServer(opts Opts) *Service {
return NewService(opts)
}

// NewServer creates a new JSON RPC Server that can handle requests.
func NewDefaultServer() *Service {
return NewService(DefaultOpts)
}

type (
Service struct {
name string
methodMap map[string]func(context.Context, *RequestParams) (any, error)
Opts struct {
// MaxBytesRead is the maximum bytes a request object can contain
MaxBytesRead int64
// ExecutionTimeout is the maximum time a method should execute for. If the
// execution exceeds the timeout, and ExectutionTimeout Error is returned for
// that request
ExecutionTimeout time.Duration
// MaxBytesRead is the maximum bytes a request object can contain
MaxBytesRead int64
}
Service struct {
methodMap map[string]func(context.Context, *RequestParams) (any, error)
executionTimeout time.Duration
maxBytesRead int64
}
RequestFunc = func(context.Context, *RequestParams) (any, error)
RequestMap = map[string]RequestFunc
ServiceRegistrar interface {
Register() (string, RequestMap)
}
Request struct {
JsonRpc string `json:"jsonrpc"` // must always be 2.0
Expand Down Expand Up @@ -103,7 +121,7 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
writeResponse(w, errorResponse(&defaultReq, RequestBodyIsEmpty))
return
}
r.Body = http.MaxBytesReader(w, r.Body, s.MaxBytesRead)
r.Body = http.MaxBytesReader(w, r.Body, s.maxBytesRead)
buf := bytes.NewBuffer([]byte{})
n, err := io.Copy(buf, r.Body)
if err != nil {
Expand Down Expand Up @@ -191,7 +209,7 @@ func (s *Service) handleMethod(req Request) (any, error) {
res.resp, res.err = fn(ctx, params)
result <- res
}()
delay := time.NewTimer(s.ExecutionTimeout)
delay := time.NewTimer(s.executionTimeout)
select {
case <-delay.C:
return nil, ExecutionTimeoutError
Expand All @@ -203,27 +221,23 @@ func (s *Service) handleMethod(req Request) (any, error) {
}
}

func (s Service) AddService(services ...*Service) {
prefix := ""
if s.name != "" {
prefix = s.name + "."
}
func (s Service) AddService(services ...ServiceRegistrar) {
for _, srv := range services {
for methodName, fn := range srv.methodMap {
s.methodMap[prefix+srv.name+"."+methodName] = fn
name, requestMap := srv.Register()
nameFmt := "%s"
if name != "" {
nameFmt = "%s.%s"
}
for methodName, fn := range requestMap {
s.methodMap[fmt.Sprintf(nameFmt, name, methodName)] = fn
}
}
}

func NewService(name string) *Service {
func NewService(opts Opts) *Service {
return &Service{
name: name,
methodMap: map[string]func(context.Context, *RequestParams) (any, error){},
ExecutionTimeout: 15 * time.Second,
MaxBytesRead: MaxBytesRead,
executionTimeout: opts.ExecutionTimeout,
maxBytesRead: opts.MaxBytesRead,
}
}

func (s *Service) RegisterMethod(methodName string, fn func(context.Context, *RequestParams) (any, error)) {
s.methodMap[methodName] = fn
}
55 changes: 32 additions & 23 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,16 @@ type (
}
)

func NewTestService(ts *TestService) *rpc.Service {
s := rpc.NewService("TestService")
s.RegisterMethod("Exec", ts.Exec)
return s
func (s EchoService) Register() (string, rpc.RequestMap) {
return "EchoService", map[string]rpc.RequestFunc{
"Ping": s.Ping,
}
}

func (ts TestService) Register() (string, rpc.RequestMap) {
return "TestService", map[string]rpc.RequestFunc{
"Exec": ts.Exec,
}
}

func (s TestService) MethodName() string {
Expand All @@ -42,11 +48,8 @@ func (s TestService) Exec(ctx context.Context, req *rpc.RequestParams) (any, err
return s.ProcessFn(ctx, req)
}

func NewEchoService() *rpc.Service {
echo := EchoService{}
echoService := rpc.NewService("EchoService")
echoService.RegisterMethod("Ping", echo.Ping)
return echoService
func NewEchoService() rpc.ServiceRegistrar {
return EchoService{}
}

func (s EchoService) Ping(_ context.Context, req *rpc.RequestParams) (any, error) {
Expand Down Expand Up @@ -99,7 +102,7 @@ func errorResponse(t *testing.T, resp *http.Response) (int, string) {
}

func TestRpcServerResponses(t *testing.T) {
server := rpc.NewServer()
server := rpc.NewDefaultServer()
server.AddService(NewEchoService())
req := requestObj(t, "EchoService.Ping", map[string]any{
"echo": "ping",
Expand All @@ -111,7 +114,7 @@ func TestRpcServerResponses(t *testing.T) {
}

func TestRpcServer_ErrorResponses(t *testing.T) {
server := rpc.NewServer()
server := rpc.NewDefaultServer()
server.AddService(NewEchoService())
req := requestObj(t, "EchoService.NonMethod", map[string]any{
"echo": "ping",
Expand All @@ -124,7 +127,7 @@ func TestRpcServer_ErrorResponses(t *testing.T) {
}

func TestRpcServer_InvalidJsonRpcVersion(t *testing.T) {
server := rpc.NewServer()
server := rpc.NewDefaultServer()
server.AddService(NewEchoService())
reqObj := map[string]any{
"jsonrpc": "1.0",
Expand All @@ -146,7 +149,7 @@ func TestRpcServer_InvalidJsonRpcVersion(t *testing.T) {
}

func TestRpcServer_EmptyMethodName(t *testing.T) {
server := rpc.NewServer()
server := rpc.NewDefaultServer()
server.AddService(NewEchoService())
cases := []string{" ", "", "\n\n", "\t\n"}
for _, m := range cases {
Expand All @@ -163,12 +166,12 @@ func TestRpcServer_EmptyMethodName(t *testing.T) {
}

func TestRpcServer_ValidRequestParams(t *testing.T) {
server := rpc.NewServer()
server := rpc.NewDefaultServer()
ts := &TestService{}
ts.ProcessFn = func(_ context.Context, req *rpc.RequestParams) (any, error) {
return "ok", nil
}
server.AddService(NewTestService(ts))
server.AddService(ts)
cases := []struct {
name string
param any
Expand All @@ -190,14 +193,17 @@ func TestRpcServer_ValidRequestParams(t *testing.T) {
}

func TestRpcServer_ExecutionTimeout(t *testing.T) {
server := rpc.NewServer()
server.ExecutionTimeout = time.Second
opts := rpc.Opts{
ExecutionTimeout: time.Second,
MaxBytesRead: rpc.MaxBytesRead,
}
server := rpc.NewServer(opts)
ts := &TestService{}
ts.ProcessFn = func(_ context.Context, req *rpc.RequestParams) (any, error) {
time.Sleep(server.ExecutionTimeout + (2 * time.Second))
time.Sleep(opts.ExecutionTimeout + (2 * time.Second))
return "ok", nil
}
server.AddService(NewTestService(ts))
server.AddService(ts)
rec := httptest.NewRecorder()
req := requestObj(t, ts.MethodName(), nil)
server.ServeHTTP(rec, req)
Expand All @@ -207,23 +213,26 @@ func TestRpcServer_ExecutionTimeout(t *testing.T) {
}

func TestRpcServer_ExecuteMultipleRequests(t *testing.T) {
server := rpc.NewServer()
server.ExecutionTimeout = time.Second
opts := rpc.Opts{
ExecutionTimeout: time.Second,
MaxBytesRead: rpc.MaxBytesRead,
}
server := rpc.NewServer(opts)
ts := &TestService{}
ts.ProcessFn = func(_ context.Context, req *rpc.RequestParams) (any, error) {
var s string
_ = json.Unmarshal(req.Payload, &s)
switch s {
case "wait":
time.Sleep(server.ExecutionTimeout + (2 * time.Second))
time.Sleep(opts.ExecutionTimeout + (2 * time.Second))
return "ok - " + s, nil
case "error":
return nil, errors.New("static error")
default:
return "ok - " + s, nil
}
}
server.AddService(NewTestService(ts))
server.AddService(ts)
rec := httptest.NewRecorder()

reqObj := []map[string]any{
Expand Down

0 comments on commit 8e73e71

Please sign in to comment.