Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move /send_leave to GMSL #387

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
124 changes: 124 additions & 0 deletions handleleave.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
package gomatrixserverlib

import (
"context"
"fmt"

"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
)

type HandleMakeLeaveResponse struct {
Expand All @@ -37,6 +39,7 @@ type HandleMakeLeaveInput struct {
BuildEventTemplate func(*ProtoEvent) (PDU, []PDU, error)
}

// HandleMakeLeave handles requests to `/make_leave`
func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, error) {

if input.UserID.Domain() != input.RequestOrigin {
Expand Down Expand Up @@ -98,3 +101,124 @@ func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, erro
}
return &makeLeaveResponse, nil
}

type CurrentStateQuerier interface {
CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (PDU, error)
}

// HandleSendLeave handles requests to `/send_leave
// Returns the parsed event and an error.
S7evinK marked this conversation as resolved.
Show resolved Hide resolved
func HandleSendLeave(ctx context.Context,
requestContent []byte,
origin spec.ServerName,
roomVersion RoomVersion,
eventID, roomID string,
querier CurrentStateQuerier,
verifier JSONVerifier,
) (PDU, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sidenote: this function and others like it will likely need to be modified to support pseudo IDs, as in many cases we pull out the state key and check that it is a user ID, and then check that the domain is correct on it. I would probably structure this as:

  • static function HandleSendLeave as it is in this PR
  • Pull out the room version immediately.
  • Call roomVer.HandleSendLeave(args...) to let room versions decide how to implement this.
  • Most room versions will call a private function which does the below code, checking user IDs etc.
  • But pseudo IDs will call a custom function which omits these checks.

Whilst the caller could call roomVer.HandleSendLeave this then creates bad symmetry, as there are cases where you don't know the room version at this point (e.g invites).


rID, err := spec.NewRoomID(roomID)
if err != nil {
return nil, err
}

verImpl, err := GetRoomVersion(roomVersion)
if err != nil {
return nil, spec.UnsupportedRoomVersion(fmt.Sprintf("QueryRoomVersionForRoom returned unknown version: %s", roomVersion))
}

// Decode the event JSON from the request.
event, err := verImpl.NewEventFromUntrustedJSON(requestContent)
switch err.(type) {
case BadJSONError:
return nil, spec.BadJSON(err.Error())
case nil:
default:
return nil, spec.NotJSON("The request body could not be decoded into valid JSON. " + err.Error())
}

// Check that the room ID is correct.
if (event.RoomID()) != roomID {
return nil, spec.BadJSON("The room ID in the request path must match the room ID in the leave event JSON")
}

// Check that the event ID is correct.
if event.EventID() != eventID {
return nil, spec.BadJSON("The event ID in the request path must match the event ID in the leave event JSON")

}

// Sanity check that we really received a state event
if event.StateKey() == nil || event.StateKeyEquals("") {
return nil, spec.BadJSON("No state key was provided in the leave event.")
}
if !event.StateKeyEquals(event.Sender()) {
return nil, spec.BadJSON("Event state key must match the event sender.")
}

leavingUser, err := spec.NewUserID(*event.StateKey(), true)
if err != nil {
return nil, spec.Forbidden("The leaving user ID is invalid")
}

// Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both.
sender, err := spec.NewUserID(event.Sender(), true)
if err != nil {
return nil, spec.Forbidden("The sender of the join is invalid")
}
if sender.Domain() != origin {
return nil, spec.Forbidden("The sender does not match the server that originated the request")
}

stateEvent, err := querier.CurrentStateEvent(ctx, *rID, spec.MRoomMember, leavingUser.String())
if err != nil {
return nil, err
}
// we weren't joined at all
if stateEvent == nil {
return nil, nil
}
// We are/were joined/invited/banned or something
if mem, merr := stateEvent.Membership(); merr == nil && mem == spec.Leave {
return nil, nil
}
// we already processed this event
if event.EventID() == stateEvent.EventID() {
return nil, nil
}

// Check that the event is signed by the server sending the request.
redacted, err := verImpl.RedactEventJSON(event.JSON())
if err != nil {
util.GetLogger(ctx).WithError(err).Errorf("unable to redact event")
return nil, spec.BadJSON("The event JSON could not be redacted")
}
verifyRequests := []VerifyJSONRequest{{
ServerName: sender.Domain(),
Message: redacted,
AtTS: event.OriginServerTS(),
StrictValidityChecking: true,
}}
verifyResults, err := verifier.VerifyJSONs(ctx, verifyRequests)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("keys.VerifyJSONs failed")
return nil, spec.InternalServerError{}
}
if verifyResults[0].Error != nil {
return nil, spec.Forbidden("The leave must be signed by the server it originated on")
}

// check membership is set to leave
mem, err := event.Membership()
if err != nil {
util.GetLogger(ctx).WithError(err).Error("event.Membership failed")
return nil, spec.BadJSON("missing content.membership key")
}
if mem != spec.Leave {
return nil, spec.BadJSON("The membership in the event content must be set to leave")
}

return event, nil
}
175 changes: 175 additions & 0 deletions handleleave_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gomatrixserverlib

import (
"context"
"crypto/rand"
"fmt"
"testing"
Expand Down Expand Up @@ -221,3 +222,177 @@ func TestHandleMakeLeave(t *testing.T) {
})
}
}

type dummyQuerier struct {
pdu PDU
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to assert the type/state key are correct when queried via CurrentStateEvent.

}

func (d dummyQuerier) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (PDU, error) {
return d.pdu, nil
}

type noopJSONVerifier struct {
err error
results []VerifyJSONResult
}

func (v *noopJSONVerifier) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest) ([]VerifyJSONResult, error) {
return v.results, v.err
}

func TestHandleSendLeave(t *testing.T) {
type args struct {
ctx context.Context
requestContent []byte
origin spec.ServerName
roomVersion RoomVersion
eventID string
roomID string
querier CurrentStateQuerier
verifier JSONVerifier
}

_, sk, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("Failed generating key: %v", err)
}
keyID := KeyID("ed25519:1234")

validUser, _ := spec.NewUserID("@valid:localhost", true)

stateKey := ""
eb := MustGetRoomVersion(RoomVersionV10).NewEventBuilderFromProtoEvent(&ProtoEvent{
Sender: validUser.String(),
RoomID: "!valid:localhost",
Type: spec.MRoomCreate,
StateKey: &stateKey,
PrevEvents: []interface{}{},
AuthEvents: []interface{}{},
Depth: 0,
Content: spec.RawJSON(`{"creator":"@user:local","m.federate":true,"room_version":"10"}`),
Unsigned: spec.RawJSON(""),
})
createEvent, err := eb.Build(time.Now(), "localhost", keyID, sk)
if err != nil {
t.Fatalf("Failed building create event: %v", err)
}

stateKey = validUser.String()
eb = MustGetRoomVersion(RoomVersionV10).NewEventBuilderFromProtoEvent(&ProtoEvent{
Sender: validUser.String(),
RoomID: "!valid:localhost",
Type: spec.MRoomMember,
StateKey: &stateKey,
PrevEvents: []interface{}{},
AuthEvents: []interface{}{},
Depth: 0,
Content: spec.RawJSON(`{"membership":"leave"}`),
Unsigned: spec.RawJSON(""),
})
leaveEvent, err := eb.Build(time.Now(), "localhost", keyID, sk)
if err != nil {
t.Fatalf("Failed building create event: %v", err)
}

eb = MustGetRoomVersion(RoomVersionV10).NewEventBuilderFromProtoEvent(&ProtoEvent{
Sender: validUser.String(),
RoomID: "!valid:localhost",
Type: spec.MRoomMember,
StateKey: &stateKey,
PrevEvents: []interface{}{},
AuthEvents: []interface{}{},
Depth: 0,
Content: spec.RawJSON(`{"membership":"join"}`),
Unsigned: spec.RawJSON(""),
})
joinEvent, err := eb.Build(time.Now(), "localhost", keyID, sk)
if err != nil {
t.Fatalf("Failed building create event: %v", err)
}

tests := []struct {
name string
args args
want PDU
wantErr assert.ErrorAssertionFunc
}{
{
name: "invalid roomID",
args: args{roomID: "@notvalid:localhost"},
wantErr: assert.Error,
},
{
name: "invalid room version",
args: args{roomID: "!notvalid:localhost", roomVersion: "-1"},
wantErr: assert.Error,
},
{
name: "invalid content body",
args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV1, requestContent: []byte("{")},
wantErr: assert.Error,
},
{
name: "not canonical JSON",
args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV10, requestContent: []byte(`{"int":9007199254740992}`)}, // number to large, not canonical json
wantErr: assert.Error,
},
{
name: "wrong roomID in request",
args: args{roomID: "!notvalid:localhost", roomVersion: RoomVersionV10, requestContent: createEvent.JSON()},
wantErr: assert.Error,
},
{
name: "wrong eventID in request",
args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, requestContent: createEvent.JSON()},
wantErr: assert.Error,
},
{
name: "empty statekey",
args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: createEvent.EventID(), requestContent: createEvent.JSON()},
wantErr: assert.Error,
},
{
name: "wrong request origin",
args: args{roomID: "!valid:localhost", roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()},
wantErr: assert.Error,
},
{
name: "never joined the room no-ops",
args: args{roomID: "!valid:localhost", querier: dummyQuerier{}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()},
wantErr: assert.NoError,
},
{
name: "already left the room no-ops",
args: args{roomID: "!valid:localhost", querier: dummyQuerier{pdu: leaveEvent}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()},
wantErr: assert.NoError,
},
{
name: "JSON validation fails",
args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{err: fmt.Errorf("err")}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()},
wantErr: assert.Error,
},
{
name: "JSON validation fails 2",
args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{Error: fmt.Errorf("err")}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()},
wantErr: assert.Error,
},
{
name: "membership not set to leave",
args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: joinEvent.EventID(), requestContent: joinEvent.JSON()},
wantErr: assert.Error,
},
{
name: "membership set to leave",
args: args{ctx: context.Background(), roomID: "!valid:localhost", querier: dummyQuerier{pdu: createEvent}, verifier: &noopJSONVerifier{results: []VerifyJSONResult{{}}}, origin: validUser.Domain(), roomVersion: RoomVersionV10, eventID: leaveEvent.EventID(), requestContent: leaveEvent.JSON()},
wantErr: assert.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := HandleSendLeave(tt.args.ctx, tt.args.requestContent, tt.args.origin, tt.args.roomVersion, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier)
if !tt.wantErr(t, err, fmt.Sprintf("HandleSendLeave(%v, %v, %v, %v, %v, %v, %v, %v)", tt.args.ctx, tt.args.requestContent, tt.args.origin, tt.args.roomVersion, tt.args.eventID, tt.args.roomID, tt.args.querier, tt.args.verifier)) {
return
}
})
}
}