mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-24 04:31:39 +00:00
254 lines
6.8 KiB
Go
254 lines
6.8 KiB
Go
|
package routing
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"testing"
|
||
|
|
||
|
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||
|
"github.com/matrix-org/dendrite/setup/config"
|
||
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
||
|
"github.com/matrix-org/gomatrixserverlib"
|
||
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||
|
"github.com/matrix-org/util"
|
||
|
"gotest.tools/v3/assert"
|
||
|
)
|
||
|
|
||
|
var ()
|
||
|
|
||
|
type stateTestRoomserverAPI struct {
|
||
|
rsapi.RoomserverInternalAPI
|
||
|
t *testing.T
|
||
|
roomState map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent
|
||
|
roomIDStr string
|
||
|
roomVersion gomatrixserverlib.RoomVersion
|
||
|
userIDStr string
|
||
|
// userID -> senderID
|
||
|
senderMapping map[string]string
|
||
|
}
|
||
|
|
||
|
func (s stateTestRoomserverAPI) QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) {
|
||
|
if roomID == s.roomIDStr {
|
||
|
return s.roomVersion, nil
|
||
|
} else {
|
||
|
s.t.Logf("room version queried for %s", roomID)
|
||
|
return "", fmt.Errorf("unknown room")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s stateTestRoomserverAPI) QueryLatestEventsAndState(
|
||
|
ctx context.Context,
|
||
|
req *rsapi.QueryLatestEventsAndStateRequest,
|
||
|
res *rsapi.QueryLatestEventsAndStateResponse,
|
||
|
) error {
|
||
|
res.RoomExists = req.RoomID == s.roomIDStr
|
||
|
if !res.RoomExists {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
res.StateEvents = []*types.HeaderedEvent{}
|
||
|
for _, stateKeyTuple := range req.StateToFetch {
|
||
|
val, ok := s.roomState[stateKeyTuple]
|
||
|
if ok && val != nil {
|
||
|
res.StateEvents = append(res.StateEvents, val)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s stateTestRoomserverAPI) QueryMembershipForUser(
|
||
|
ctx context.Context,
|
||
|
req *rsapi.QueryMembershipForUserRequest,
|
||
|
res *rsapi.QueryMembershipForUserResponse,
|
||
|
) error {
|
||
|
if req.UserID.String() == s.userIDStr {
|
||
|
res.HasBeenInRoom = true
|
||
|
res.IsInRoom = true
|
||
|
res.RoomExists = true
|
||
|
res.Membership = spec.Join
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s stateTestRoomserverAPI) QuerySenderIDForUser(
|
||
|
ctx context.Context,
|
||
|
roomID spec.RoomID,
|
||
|
userID spec.UserID,
|
||
|
) (*spec.SenderID, error) {
|
||
|
sID, ok := s.senderMapping[userID.String()]
|
||
|
if ok {
|
||
|
sender := spec.SenderID(sID)
|
||
|
return &sender, nil
|
||
|
} else {
|
||
|
return nil, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s stateTestRoomserverAPI) QueryUserIDForSender(
|
||
|
ctx context.Context,
|
||
|
roomID spec.RoomID,
|
||
|
senderID spec.SenderID,
|
||
|
) (*spec.UserID, error) {
|
||
|
for uID, sID := range s.senderMapping {
|
||
|
if sID == string(senderID) {
|
||
|
parsedUserID, err := spec.NewUserID(uID, true)
|
||
|
if err != nil {
|
||
|
s.t.Fatalf("Mock QueryUserIDForSender failed: %s", err)
|
||
|
}
|
||
|
return parsedUserID, nil
|
||
|
}
|
||
|
}
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
func (s stateTestRoomserverAPI) QueryStateAfterEvents(
|
||
|
ctx context.Context,
|
||
|
req *rsapi.QueryStateAfterEventsRequest,
|
||
|
res *rsapi.QueryStateAfterEventsResponse,
|
||
|
) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func Test_OnIncomingStateTypeRequest(t *testing.T) {
|
||
|
var tempRoomServerCfg config.RoomServer
|
||
|
tempRoomServerCfg.Defaults(config.DefaultOpts{})
|
||
|
defaultRoomVersion := tempRoomServerCfg.DefaultRoomVersion
|
||
|
pseudoIDRoomVersion := gomatrixserverlib.RoomVersionPseudoIDs
|
||
|
nonPseudoIDRoomVersion := gomatrixserverlib.RoomVersionV10
|
||
|
|
||
|
userIDStr := "@testuser:domain"
|
||
|
eventType := "com.example.test"
|
||
|
stateKey := "testStateKey"
|
||
|
roomIDStr := "!id:domain"
|
||
|
|
||
|
device := &uapi.Device{
|
||
|
UserID: userIDStr,
|
||
|
}
|
||
|
|
||
|
t.Run("request simple state key", func(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
|
||
|
rsAPI := stateTestRoomserverAPI{
|
||
|
roomVersion: defaultRoomVersion,
|
||
|
roomIDStr: roomIDStr,
|
||
|
roomState: map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent{
|
||
|
{
|
||
|
EventType: eventType,
|
||
|
StateKey: stateKey,
|
||
|
}: mustCreateStatePDU(t, defaultRoomVersion, roomIDStr, eventType, stateKey, map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
}),
|
||
|
},
|
||
|
userIDStr: userIDStr,
|
||
|
}
|
||
|
|
||
|
jsonResp := OnIncomingStateTypeRequest(ctx, device, rsAPI, roomIDStr, eventType, stateKey, false)
|
||
|
|
||
|
assert.DeepEqual(t, jsonResp, util.JSONResponse{
|
||
|
Code: http.StatusOK,
|
||
|
JSON: spec.RawJSON(`{"foo":"bar"}`),
|
||
|
})
|
||
|
})
|
||
|
|
||
|
t.Run("user ID key translated to room key in pseudo ID rooms", func(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
|
||
|
stateSenderUserID := "@sender:domain"
|
||
|
stateSenderRoomKey := "testsenderkey"
|
||
|
|
||
|
rsAPI := stateTestRoomserverAPI{
|
||
|
roomVersion: pseudoIDRoomVersion,
|
||
|
roomIDStr: roomIDStr,
|
||
|
roomState: map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent{
|
||
|
{
|
||
|
EventType: eventType,
|
||
|
StateKey: stateSenderRoomKey,
|
||
|
}: mustCreateStatePDU(t, pseudoIDRoomVersion, roomIDStr, eventType, stateSenderRoomKey, map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
}),
|
||
|
{
|
||
|
EventType: eventType,
|
||
|
StateKey: stateSenderUserID,
|
||
|
}: mustCreateStatePDU(t, pseudoIDRoomVersion, roomIDStr, eventType, stateSenderUserID, map[string]interface{}{
|
||
|
"not": "thisone",
|
||
|
}),
|
||
|
},
|
||
|
userIDStr: userIDStr,
|
||
|
senderMapping: map[string]string{
|
||
|
stateSenderUserID: stateSenderRoomKey,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
jsonResp := OnIncomingStateTypeRequest(ctx, device, rsAPI, roomIDStr, eventType, stateSenderUserID, false)
|
||
|
|
||
|
assert.DeepEqual(t, jsonResp, util.JSONResponse{
|
||
|
Code: http.StatusOK,
|
||
|
JSON: spec.RawJSON(`{"foo":"bar"}`),
|
||
|
})
|
||
|
})
|
||
|
|
||
|
t.Run("user ID key not translated to room key in non-pseudo ID rooms", func(t *testing.T) {
|
||
|
ctx := context.Background()
|
||
|
|
||
|
stateSenderUserID := "@sender:domain"
|
||
|
stateSenderRoomKey := "testsenderkey"
|
||
|
|
||
|
rsAPI := stateTestRoomserverAPI{
|
||
|
roomVersion: nonPseudoIDRoomVersion,
|
||
|
roomIDStr: roomIDStr,
|
||
|
roomState: map[gomatrixserverlib.StateKeyTuple]*types.HeaderedEvent{
|
||
|
{
|
||
|
EventType: eventType,
|
||
|
StateKey: stateSenderRoomKey,
|
||
|
}: mustCreateStatePDU(t, nonPseudoIDRoomVersion, roomIDStr, eventType, stateSenderRoomKey, map[string]interface{}{
|
||
|
"not": "thisone",
|
||
|
}),
|
||
|
{
|
||
|
EventType: eventType,
|
||
|
StateKey: stateSenderUserID,
|
||
|
}: mustCreateStatePDU(t, nonPseudoIDRoomVersion, roomIDStr, eventType, stateSenderUserID, map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
}),
|
||
|
},
|
||
|
userIDStr: userIDStr,
|
||
|
senderMapping: map[string]string{
|
||
|
stateSenderUserID: stateSenderUserID,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
jsonResp := OnIncomingStateTypeRequest(ctx, device, rsAPI, roomIDStr, eventType, stateSenderUserID, false)
|
||
|
|
||
|
assert.DeepEqual(t, jsonResp, util.JSONResponse{
|
||
|
Code: http.StatusOK,
|
||
|
JSON: spec.RawJSON(`{"foo":"bar"}`),
|
||
|
})
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func mustCreateStatePDU(t *testing.T, roomVer gomatrixserverlib.RoomVersion, roomID string, stateType string, stateKey string, stateContent map[string]interface{}) *types.HeaderedEvent {
|
||
|
t.Helper()
|
||
|
roomVerImpl := gomatrixserverlib.MustGetRoomVersion(roomVer)
|
||
|
|
||
|
evBytes, err := json.Marshal(map[string]interface{}{
|
||
|
"room_id": roomID,
|
||
|
"type": stateType,
|
||
|
"state_key": stateKey,
|
||
|
"content": stateContent,
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatalf("failed to create event: %v", err)
|
||
|
}
|
||
|
|
||
|
ev, err := roomVerImpl.NewEventFromTrustedJSON(evBytes, false)
|
||
|
if err != nil {
|
||
|
t.Fatalf("failed to create event: %v", err)
|
||
|
}
|
||
|
|
||
|
return &types.HeaderedEvent{PDU: ev}
|
||
|
}
|