mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-25 05:01:41 +00:00
2136 lines
82 KiB
Go
2136 lines
82 KiB
Go
package clientapi
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/matrix-org/dendrite/appservice"
|
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
|
"github.com/matrix-org/dendrite/clientapi/routing"
|
|
"github.com/matrix-org/dendrite/clientapi/threepid"
|
|
"github.com/matrix-org/dendrite/internal/caching"
|
|
"github.com/matrix-org/dendrite/internal/httputil"
|
|
"github.com/matrix-org/dendrite/internal/pushrules"
|
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
"github.com/matrix-org/dendrite/roomserver"
|
|
"github.com/matrix-org/dendrite/roomserver/api"
|
|
"github.com/matrix-org/dendrite/roomserver/version"
|
|
"github.com/matrix-org/dendrite/setup/base"
|
|
"github.com/matrix-org/dendrite/setup/config"
|
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
|
"github.com/matrix-org/dendrite/test"
|
|
"github.com/matrix-org/dendrite/test/testrig"
|
|
"github.com/matrix-org/dendrite/userapi"
|
|
uapi "github.com/matrix-org/dendrite/userapi/api"
|
|
"github.com/matrix-org/gomatrix"
|
|
"github.com/matrix-org/gomatrixserverlib"
|
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
|
"github.com/matrix-org/util"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/tidwall/gjson"
|
|
"maunium.net/go/mautrix"
|
|
"maunium.net/go/mautrix/crypto"
|
|
"maunium.net/go/mautrix/event"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
type userDevice struct {
|
|
accessToken string
|
|
deviceID string
|
|
password string
|
|
}
|
|
|
|
func TestGetPutDevices(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
bob := test.NewUser(t)
|
|
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
testCases := []struct {
|
|
name string
|
|
requestUser *test.User
|
|
deviceUser *test.User
|
|
request *http.Request
|
|
wantStatusCode int
|
|
validateFunc func(t *testing.T, device userDevice, routers httputil.Routers)
|
|
}{
|
|
{
|
|
name: "can get all devices",
|
|
requestUser: alice,
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices", strings.NewReader("")),
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "can get specific own device",
|
|
requestUser: alice,
|
|
deviceUser: alice,
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices/", strings.NewReader("")),
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "can not get device for different user",
|
|
requestUser: alice,
|
|
deviceUser: bob,
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices/", strings.NewReader("")),
|
|
wantStatusCode: http.StatusNotFound,
|
|
},
|
|
{
|
|
name: "can update own device",
|
|
requestUser: alice,
|
|
deviceUser: alice,
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/devices/", strings.NewReader(`{"display_name":"my new displayname"}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, device userDevice, routers httputil.Routers) {
|
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices/"+device.deviceID, strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+device.accessToken)
|
|
rec := httptest.NewRecorder()
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "display_name").Str
|
|
if gotDisplayName != "my new displayname" {
|
|
t.Fatalf("expected displayname '%s', got '%s'", "my new displayname", gotDisplayName)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
// this should return "device does not exist"
|
|
name: "can not update device for different user",
|
|
requestUser: alice,
|
|
deviceUser: bob,
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/devices/", strings.NewReader(`{"display_name":"my new displayname"}`)),
|
|
wantStatusCode: http.StatusNotFound,
|
|
},
|
|
}
|
|
|
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
|
natsInstance := jetstream.NATSInstance{}
|
|
defer close()
|
|
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
|
|
|
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
bob: {},
|
|
}
|
|
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
dev := accessTokens[tc.requestUser]
|
|
if tc.deviceUser != nil {
|
|
tc.request = httptest.NewRequest(tc.request.Method, tc.request.RequestURI+accessTokens[tc.deviceUser].deviceID, tc.request.Body)
|
|
}
|
|
tc.request.Header.Set("Authorization", "Bearer "+dev.accessToken)
|
|
rec := httptest.NewRecorder()
|
|
routers.Client.ServeHTTP(rec, tc.request)
|
|
if rec.Code != tc.wantStatusCode {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
if tc.wantStatusCode != http.StatusOK && rec.Code != http.StatusOK {
|
|
return
|
|
}
|
|
if tc.validateFunc != nil {
|
|
tc.validateFunc(t, dev, routers)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
// Deleting devices requires the UIA dance, so do this in a different test
|
|
func TestDeleteDevice(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
|
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType)
|
|
defer closeDB()
|
|
|
|
natsInstance := jetstream.NATSInstance{}
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
|
|
|
// We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc.
|
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
}
|
|
|
|
// create the account and an initial device
|
|
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
|
|
|
// create some more devices
|
|
accessToken := util.RandomString(8)
|
|
devRes := &uapi.PerformDeviceCreationResponse{}
|
|
if err := userAPI.PerformDeviceCreation(processCtx.Context(), &uapi.PerformDeviceCreationRequest{
|
|
Localpart: localpart,
|
|
ServerName: serverName,
|
|
AccessToken: accessToken,
|
|
NoDeviceListUpdate: true,
|
|
}, devRes); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !devRes.DeviceCreated {
|
|
t.Fatalf("failed to create device")
|
|
}
|
|
secondDeviceID := devRes.Device.ID
|
|
|
|
// initiate UIA for the second device
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/devices/"+secondDeviceID, strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Fatalf("expected HTTP 401, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
// get the session ID
|
|
sessionID := gjson.GetBytes(rec.Body.Bytes(), "session").Str
|
|
|
|
// prepare UIA request body
|
|
reqBody := bytes.Buffer{}
|
|
if err := json.NewEncoder(&reqBody).Encode(map[string]interface{}{
|
|
"auth": map[string]string{
|
|
"session": sessionID,
|
|
"type": authtypes.LoginTypePassword,
|
|
"user": alice.ID,
|
|
"password": accessTokens[alice].password,
|
|
},
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// copy the request body, so we can use it again for the successful delete
|
|
reqBody2 := reqBody
|
|
|
|
// do the same request again, this time with our UIA, but for a different device ID, this should fail
|
|
rec = httptest.NewRecorder()
|
|
|
|
req = httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/devices/"+accessTokens[alice].deviceID, &reqBody)
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusForbidden {
|
|
t.Fatalf("expected HTTP 403, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
|
|
// do the same request again, this time with our UIA, but for the correct device ID, this should be fine
|
|
rec = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/devices/"+secondDeviceID, &reqBody2)
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
|
|
// verify devices are deleted
|
|
rec = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices", strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
for _, device := range gjson.GetBytes(rec.Body.Bytes(), "devices.#.device_id").Array() {
|
|
if device.Str == secondDeviceID {
|
|
t.Fatalf("expected device %s to be deleted, but wasn't", secondDeviceID)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// Deleting devices requires the UIA dance, so do this in a different test
|
|
func TestDeleteDevices(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
|
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType)
|
|
defer closeDB()
|
|
|
|
natsInstance := jetstream.NATSInstance{}
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
|
|
|
// We mostly need the rsAPI/ for this test, so nil for other APIs/caches etc.
|
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
}
|
|
|
|
// create the account and an initial device
|
|
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
|
|
|
// create some more devices
|
|
var devices []string
|
|
for i := 0; i < 10; i++ {
|
|
accessToken := util.RandomString(8)
|
|
devRes := &uapi.PerformDeviceCreationResponse{}
|
|
if err := userAPI.PerformDeviceCreation(processCtx.Context(), &uapi.PerformDeviceCreationRequest{
|
|
Localpart: localpart,
|
|
ServerName: serverName,
|
|
AccessToken: accessToken,
|
|
NoDeviceListUpdate: true,
|
|
}, devRes); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !devRes.DeviceCreated {
|
|
t.Fatalf("failed to create device")
|
|
}
|
|
devices = append(devices, devRes.Device.ID)
|
|
}
|
|
|
|
// initiate UIA
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/delete_devices", strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Fatalf("expected HTTP 401, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
// get the session ID
|
|
sessionID := gjson.GetBytes(rec.Body.Bytes(), "session").Str
|
|
|
|
// prepare UIA request body
|
|
reqBody := bytes.Buffer{}
|
|
if err := json.NewEncoder(&reqBody).Encode(map[string]interface{}{
|
|
"auth": map[string]string{
|
|
"session": sessionID,
|
|
"type": authtypes.LoginTypePassword,
|
|
"user": alice.ID,
|
|
"password": accessTokens[alice].password,
|
|
},
|
|
"devices": devices[5:],
|
|
}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// do the same request again, this time with our UIA,
|
|
rec = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/delete_devices", &reqBody)
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
|
|
// verify devices are deleted
|
|
rec = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/devices", strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
for _, device := range gjson.GetBytes(rec.Body.Bytes(), "devices.#.device_id").Array() {
|
|
for _, deletedDevice := range devices[5:] {
|
|
if device.Str == deletedDevice {
|
|
t.Fatalf("expected device %s to be deleted, but wasn't", deletedDevice)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func createAccessTokens(t *testing.T, accessTokens map[*test.User]userDevice, userAPI uapi.UserInternalAPI, ctx context.Context, routers httputil.Routers) {
|
|
t.Helper()
|
|
for u := range accessTokens {
|
|
localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID)
|
|
userRes := &uapi.PerformAccountCreationResponse{}
|
|
password := util.RandomString(8)
|
|
if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{
|
|
AccountType: u.AccountType,
|
|
Localpart: localpart,
|
|
ServerName: serverName,
|
|
Password: password,
|
|
}, userRes); err != nil {
|
|
t.Errorf("failed to create account: %s", err)
|
|
}
|
|
req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{
|
|
"type": authtypes.LoginTypePassword,
|
|
"identifier": map[string]interface{}{
|
|
"type": "m.id.user",
|
|
"user": u.ID,
|
|
},
|
|
"password": password,
|
|
}))
|
|
rec := httptest.NewRecorder()
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("failed to login: %s", rec.Body.String())
|
|
}
|
|
accessTokens[u] = userDevice{
|
|
accessToken: gjson.GetBytes(rec.Body.Bytes(), "access_token").String(),
|
|
deviceID: gjson.GetBytes(rec.Body.Bytes(), "device_id").String(),
|
|
password: password,
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSetDisplayname(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
bob := test.NewUser(t)
|
|
notLocalUser := &test.User{ID: "@charlie:localhost", Localpart: "charlie"}
|
|
changeDisplayName := "my new display name"
|
|
|
|
testCases := []struct {
|
|
name string
|
|
user *test.User
|
|
wantOK bool
|
|
changeReq io.Reader
|
|
wantDisplayName string
|
|
}{
|
|
{
|
|
name: "invalid user",
|
|
user: &test.User{ID: "!notauser"},
|
|
},
|
|
{
|
|
name: "non-existent user",
|
|
user: &test.User{ID: "@doesnotexist:test"},
|
|
},
|
|
{
|
|
name: "non-local user is not allowed",
|
|
user: notLocalUser,
|
|
},
|
|
{
|
|
name: "existing user is allowed to change own name",
|
|
user: alice,
|
|
wantOK: true,
|
|
wantDisplayName: changeDisplayName,
|
|
},
|
|
{
|
|
name: "existing user is not allowed to change own name if name is empty",
|
|
user: bob,
|
|
wantOK: false,
|
|
wantDisplayName: "",
|
|
},
|
|
}
|
|
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType)
|
|
defer closeDB()
|
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
natsInstance := &jetstream.NATSInstance{}
|
|
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
|
|
rsAPI.SetFederationAPI(nil, nil)
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil)
|
|
asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI)
|
|
|
|
AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
bob: {},
|
|
}
|
|
|
|
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
wantDisplayName := tc.user.Localpart
|
|
if tc.changeReq == nil {
|
|
tc.changeReq = strings.NewReader("")
|
|
}
|
|
|
|
// check profile after initial account creation
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/profile/"+tc.user.ID, strings.NewReader(""))
|
|
t.Logf("%s", req.URL.String())
|
|
routers.Client.ServeHTTP(rec, req)
|
|
|
|
if tc.wantOK && rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d", rec.Code)
|
|
}
|
|
|
|
if gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "displayname").Str; tc.wantOK && gotDisplayName != wantDisplayName {
|
|
t.Fatalf("expected displayname to be '%s', but got '%s'", wantDisplayName, gotDisplayName)
|
|
}
|
|
|
|
// now set the new display name
|
|
wantDisplayName = tc.wantDisplayName
|
|
tc.changeReq = strings.NewReader(fmt.Sprintf(`{"displayname":"%s"}`, tc.wantDisplayName))
|
|
|
|
rec = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/profile/"+tc.user.ID+"/displayname", tc.changeReq)
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.user].accessToken)
|
|
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if tc.wantOK && rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
|
|
// now only get the display name
|
|
rec = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/profile/"+tc.user.ID+"/displayname", strings.NewReader(""))
|
|
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if tc.wantOK && rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
|
|
if gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "displayname").Str; tc.wantOK && gotDisplayName != wantDisplayName {
|
|
t.Fatalf("expected displayname to be '%s', but got '%s'", wantDisplayName, gotDisplayName)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestSetAvatarURL(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
bob := test.NewUser(t)
|
|
notLocalUser := &test.User{ID: "@charlie:localhost", Localpart: "charlie"}
|
|
changeDisplayName := "mxc://newMXID"
|
|
|
|
testCases := []struct {
|
|
name string
|
|
user *test.User
|
|
wantOK bool
|
|
changeReq io.Reader
|
|
avatar_url string
|
|
}{
|
|
{
|
|
name: "invalid user",
|
|
user: &test.User{ID: "!notauser"},
|
|
},
|
|
{
|
|
name: "non-existent user",
|
|
user: &test.User{ID: "@doesnotexist:test"},
|
|
},
|
|
{
|
|
name: "non-local user is not allowed",
|
|
user: notLocalUser,
|
|
},
|
|
{
|
|
name: "existing user is allowed to change own avatar",
|
|
user: alice,
|
|
wantOK: true,
|
|
avatar_url: changeDisplayName,
|
|
},
|
|
{
|
|
name: "existing user is not allowed to change own avatar if avatar is empty",
|
|
user: bob,
|
|
wantOK: false,
|
|
avatar_url: "",
|
|
},
|
|
}
|
|
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType)
|
|
defer closeDB()
|
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
natsInstance := &jetstream.NATSInstance{}
|
|
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
|
|
rsAPI.SetFederationAPI(nil, nil)
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, natsInstance, rsAPI, nil)
|
|
asPI := appservice.NewInternalAPI(processCtx, cfg, natsInstance, userAPI, rsAPI)
|
|
|
|
AddPublicRoutes(processCtx, routers, cfg, natsInstance, base.CreateFederationClient(cfg, nil), rsAPI, asPI, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
bob: {},
|
|
}
|
|
|
|
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
wantAvatarURL := ""
|
|
if tc.changeReq == nil {
|
|
tc.changeReq = strings.NewReader("")
|
|
}
|
|
|
|
// check profile after initial account creation
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/profile/"+tc.user.ID, strings.NewReader(""))
|
|
t.Logf("%s", req.URL.String())
|
|
routers.Client.ServeHTTP(rec, req)
|
|
|
|
if tc.wantOK && rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d", rec.Code)
|
|
}
|
|
|
|
if gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "avatar_url").Str; tc.wantOK && gotDisplayName != wantAvatarURL {
|
|
t.Fatalf("expected displayname to be '%s', but got '%s'", wantAvatarURL, gotDisplayName)
|
|
}
|
|
|
|
// now set the new display name
|
|
wantAvatarURL = tc.avatar_url
|
|
tc.changeReq = strings.NewReader(fmt.Sprintf(`{"avatar_url":"%s"}`, tc.avatar_url))
|
|
|
|
rec = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/profile/"+tc.user.ID+"/avatar_url", tc.changeReq)
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[tc.user].accessToken)
|
|
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if tc.wantOK && rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
|
|
// now only get the display name
|
|
rec = httptest.NewRecorder()
|
|
req = httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/profile/"+tc.user.ID+"/avatar_url", strings.NewReader(""))
|
|
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if tc.wantOK && rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
|
|
if gotDisplayName := gjson.GetBytes(rec.Body.Bytes(), "avatar_url").Str; tc.wantOK && gotDisplayName != wantAvatarURL {
|
|
t.Fatalf("expected displayname to be '%s', but got '%s'", wantAvatarURL, gotDisplayName)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestTyping(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
room := test.NewRoom(t, alice)
|
|
ctx := context.Background()
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
|
defer close()
|
|
natsInstance := jetstream.NATSInstance{}
|
|
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
|
rsAPI.SetFederationAPI(nil, nil)
|
|
// Needed to create accounts
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
|
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
|
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
// Create the users in the userapi and login
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
}
|
|
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
|
|
|
|
// Create the room
|
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
typingForUser string
|
|
roomID string
|
|
requestBody io.Reader
|
|
wantOK bool
|
|
}{
|
|
{
|
|
name: "can not set typing for different user",
|
|
typingForUser: "@notourself:test",
|
|
roomID: room.ID,
|
|
requestBody: strings.NewReader(""),
|
|
},
|
|
{
|
|
name: "invalid request body",
|
|
typingForUser: alice.ID,
|
|
roomID: room.ID,
|
|
requestBody: strings.NewReader(""),
|
|
},
|
|
{
|
|
name: "non-existent room",
|
|
typingForUser: alice.ID,
|
|
roomID: "!doesnotexist:test",
|
|
},
|
|
{
|
|
name: "invalid room ID",
|
|
typingForUser: alice.ID,
|
|
roomID: "@notaroomid:test",
|
|
},
|
|
{
|
|
name: "allowed to set own typing status",
|
|
typingForUser: alice.ID,
|
|
roomID: room.ID,
|
|
requestBody: strings.NewReader(`{"typing":true}`),
|
|
wantOK: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/rooms/"+tc.roomID+"/typing/"+tc.typingForUser, tc.requestBody)
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
if tc.wantOK && rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestMembership(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
bob := test.NewUser(t)
|
|
room := test.NewRoom(t, alice)
|
|
ctx := context.Background()
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
|
cfg.ClientAPI.RateLimiting.Enabled = false
|
|
defer close()
|
|
natsInstance := jetstream.NATSInstance{}
|
|
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
|
rsAPI.SetFederationAPI(nil, nil)
|
|
// Needed to create accounts
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
|
rsAPI.SetUserAPI(userAPI)
|
|
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
|
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
// Create the users in the userapi and login
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
bob: {},
|
|
}
|
|
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
|
|
|
|
// Create the room
|
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
invalidBodyRequest := func(roomID, membershipType string) *http.Request {
|
|
return httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", roomID, membershipType), strings.NewReader(""))
|
|
}
|
|
|
|
missingUserIDRequest := func(roomID, membershipType string) *http.Request {
|
|
return httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", roomID, membershipType), strings.NewReader("{}"))
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
roomID string
|
|
request *http.Request
|
|
wantOK bool
|
|
asUser *test.User
|
|
}{
|
|
{
|
|
name: "ban - invalid request body",
|
|
request: invalidBodyRequest(room.ID, "ban"),
|
|
},
|
|
{
|
|
name: "kick - invalid request body",
|
|
request: invalidBodyRequest(room.ID, "kick"),
|
|
},
|
|
{
|
|
name: "unban - invalid request body",
|
|
request: invalidBodyRequest(room.ID, "unban"),
|
|
},
|
|
{
|
|
name: "invite - invalid request body",
|
|
request: invalidBodyRequest(room.ID, "invite"),
|
|
},
|
|
{
|
|
name: "ban - missing user_id body",
|
|
request: missingUserIDRequest(room.ID, "ban"),
|
|
},
|
|
{
|
|
name: "kick - missing user_id body",
|
|
request: missingUserIDRequest(room.ID, "kick"),
|
|
},
|
|
{
|
|
name: "unban - missing user_id body",
|
|
request: missingUserIDRequest(room.ID, "unban"),
|
|
},
|
|
{
|
|
name: "invite - missing user_id body",
|
|
request: missingUserIDRequest(room.ID, "invite"),
|
|
},
|
|
{
|
|
name: "Bob forgets invalid room",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", "!doesnotexist", "forget"), strings.NewReader("")),
|
|
asUser: bob,
|
|
},
|
|
{
|
|
name: "Alice can not ban Bob in non-existent room", // fails because "not joined"
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", "!doesnotexist:test", "ban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))),
|
|
},
|
|
{
|
|
name: "Alice can not kick Bob in non-existent room", // fails because "not joined"
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", "!doesnotexist:test", "kick"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))),
|
|
},
|
|
// the following must run in sequence, as they build up on each other
|
|
{
|
|
name: "Alice invites Bob",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "invite"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))),
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "Bob accepts invite",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "join"), strings.NewReader("")),
|
|
wantOK: true,
|
|
asUser: bob,
|
|
},
|
|
{
|
|
name: "Alice verifies that Bob is joined", // returns an error if no membership event can be found
|
|
request: httptest.NewRequest(http.MethodGet, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s/m.room.member/%s", room.ID, "state", bob.ID), strings.NewReader("")),
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "Bob forgets the room but is still a member",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "forget"), strings.NewReader("")),
|
|
wantOK: false, // user is still in the room
|
|
asUser: bob,
|
|
},
|
|
{
|
|
name: "Bob can not kick Alice",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "kick"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, alice.ID))),
|
|
wantOK: false, // powerlevel too low
|
|
asUser: bob,
|
|
},
|
|
{
|
|
name: "Bob can not ban Alice",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "ban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, alice.ID))),
|
|
wantOK: false, // powerlevel too low
|
|
asUser: bob,
|
|
},
|
|
{
|
|
name: "Alice can kick Bob",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "kick"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))),
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "Alice can ban Bob",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "ban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))),
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "Alice can not kick Bob again",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "kick"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))),
|
|
wantOK: false, // can not kick banned/left user
|
|
},
|
|
{
|
|
name: "Bob can not unban himself", // mostly because of not being a member of the room
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "unban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))),
|
|
asUser: bob,
|
|
},
|
|
{
|
|
name: "Alice can not invite Bob again",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "invite"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))),
|
|
wantOK: false, // user still banned
|
|
},
|
|
{
|
|
name: "Alice can unban Bob",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "unban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))),
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "Alice can not unban Bob again",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "unban"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))),
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "Alice can invite Bob again",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "invite"), strings.NewReader(fmt.Sprintf(`{"user_id":"%s"}`, bob.ID))),
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "Bob can reject the invite by leaving",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "leave"), strings.NewReader("")),
|
|
wantOK: true,
|
|
asUser: bob,
|
|
},
|
|
{
|
|
name: "Bob can forget the room",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "forget"), strings.NewReader("")),
|
|
wantOK: true,
|
|
asUser: bob,
|
|
},
|
|
{
|
|
name: "Bob can forget the room again",
|
|
request: httptest.NewRequest(http.MethodPost, fmt.Sprintf("/_matrix/client/v3/rooms/%s/%s", room.ID, "forget"), strings.NewReader("")),
|
|
wantOK: true,
|
|
asUser: bob,
|
|
},
|
|
// END must run in sequence
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if tc.asUser == nil {
|
|
tc.asUser = alice
|
|
}
|
|
rec := httptest.NewRecorder()
|
|
tc.request.Header.Set("Authorization", "Bearer "+accessTokens[tc.asUser].accessToken)
|
|
routers.Client.ServeHTTP(rec, tc.request)
|
|
if tc.wantOK && rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
if !tc.wantOK && rec.Code == http.StatusOK {
|
|
t.Fatalf("expected request to fail, but didn't: %s", rec.Body.String())
|
|
}
|
|
t.Logf("%s", rec.Body.String())
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestCapabilities(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
ctx := context.Background()
|
|
|
|
// construct the expected result
|
|
versionsMap := map[gomatrixserverlib.RoomVersion]string{}
|
|
for v, desc := range version.SupportedRoomVersions() {
|
|
if desc.Stable() {
|
|
versionsMap[v] = "stable"
|
|
} else {
|
|
versionsMap[v] = "unstable"
|
|
}
|
|
}
|
|
|
|
expectedMap := map[string]interface{}{
|
|
"capabilities": map[string]interface{}{
|
|
"m.change_password": map[string]bool{
|
|
"enabled": true,
|
|
},
|
|
"m.room_versions": map[string]interface{}{
|
|
"default": version.DefaultRoomVersion(),
|
|
"available": versionsMap,
|
|
},
|
|
},
|
|
}
|
|
|
|
expectedBuf := &bytes.Buffer{}
|
|
err := json.NewEncoder(expectedBuf).Encode(expectedMap)
|
|
assert.NoError(t, err)
|
|
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
|
cfg.ClientAPI.RateLimiting.Enabled = false
|
|
defer close()
|
|
natsInstance := jetstream.NATSInstance{}
|
|
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
|
|
// Needed to create accounts
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics)
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
|
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
|
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
// Create the users in the userapi and login
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
}
|
|
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
request *http.Request
|
|
}{
|
|
{
|
|
name: "can get capabilities",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/capabilities", strings.NewReader("")),
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, tc.request)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.ObjectsAreEqual(expectedBuf.Bytes(), rec.Body.Bytes())
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestTurnserver(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
ctx := context.Background()
|
|
|
|
cfg, processCtx, close := testrig.CreateConfig(t, test.DBTypeSQLite)
|
|
cfg.ClientAPI.RateLimiting.Enabled = false
|
|
defer close()
|
|
natsInstance := jetstream.NATSInstance{}
|
|
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
|
|
// Needed to create accounts
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics)
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
|
//rsAPI.SetUserAPI(userAPI)
|
|
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
|
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
// Create the users in the userapi and login
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
}
|
|
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
turnConfig config.TURN
|
|
wantEmptyResponse bool
|
|
}{
|
|
{
|
|
name: "no turn server configured",
|
|
wantEmptyResponse: true,
|
|
},
|
|
{
|
|
name: "servers configured but not userLifeTime",
|
|
wantEmptyResponse: true,
|
|
turnConfig: config.TURN{URIs: []string{""}},
|
|
},
|
|
{
|
|
name: "missing sharedSecret/username/password",
|
|
wantEmptyResponse: true,
|
|
turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m"},
|
|
},
|
|
{
|
|
name: "with shared secret",
|
|
turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m", SharedSecret: "iAmSecret"},
|
|
},
|
|
{
|
|
name: "with username/password secret",
|
|
turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m", Username: "username", Password: "iAmSecret"},
|
|
},
|
|
{
|
|
name: "only username set",
|
|
turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m", Username: "username"},
|
|
wantEmptyResponse: true,
|
|
},
|
|
{
|
|
name: "only password set",
|
|
turnConfig: config.TURN{URIs: []string{""}, UserLifetime: "1m", Username: "username"},
|
|
wantEmptyResponse: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/voip/turnServer", strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
cfg.ClientAPI.TURN = tc.turnConfig
|
|
routers.Client.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
|
|
if tc.wantEmptyResponse && rec.Body.String() != "{}" {
|
|
t.Fatalf("expected an empty response, but got %s", rec.Body.String())
|
|
}
|
|
if !tc.wantEmptyResponse {
|
|
assert.NotEqual(t, "{}", rec.Body.String())
|
|
|
|
resp := gomatrix.RespTurnServer{}
|
|
err := json.NewDecoder(rec.Body).Decode(&resp)
|
|
assert.NoError(t, err)
|
|
|
|
duration, _ := time.ParseDuration(tc.turnConfig.UserLifetime)
|
|
assert.Equal(t, tc.turnConfig.URIs, resp.URIs)
|
|
assert.Equal(t, int(duration.Seconds()), resp.TTL)
|
|
if tc.turnConfig.Username != "" && tc.turnConfig.Password != "" {
|
|
assert.Equal(t, tc.turnConfig.Username, resp.Username)
|
|
assert.Equal(t, tc.turnConfig.Password, resp.Password)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test3PID(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
ctx := context.Background()
|
|
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
|
cfg.ClientAPI.RateLimiting.Enabled = false
|
|
cfg.FederationAPI.DisableTLSValidation = true // needed to be able to connect to our identityServer below
|
|
defer close()
|
|
natsInstance := jetstream.NATSInstance{}
|
|
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
|
|
// Needed to create accounts
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics)
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
|
// We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc.
|
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
// Create the users in the userapi and login
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
}
|
|
createAccessTokens(t, accessTokens, userAPI, ctx, routers)
|
|
|
|
identityServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch {
|
|
case strings.Contains(r.URL.String(), "getValidated3pid"):
|
|
resp := threepid.GetValidatedResponse{}
|
|
switch r.URL.Query().Get("client_secret") {
|
|
case "fail":
|
|
resp.ErrCode = string(spec.ErrorSessionNotValidated)
|
|
case "fail2":
|
|
resp.ErrCode = "some other error"
|
|
case "fail3":
|
|
_, _ = w.Write([]byte("{invalidJson"))
|
|
return
|
|
case "success":
|
|
resp.Medium = "email"
|
|
case "success2":
|
|
resp.Medium = "email"
|
|
resp.Address = "somerandom@address.com"
|
|
}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
case strings.Contains(r.URL.String(), "requestToken"):
|
|
resp := threepid.SID{SID: "randomSID"}
|
|
_ = json.NewEncoder(w).Encode(resp)
|
|
}
|
|
}))
|
|
defer identityServer.Close()
|
|
|
|
identityServerBase := strings.TrimPrefix(identityServer.URL, "https://")
|
|
|
|
testCases := []struct {
|
|
name string
|
|
request *http.Request
|
|
wantOK bool
|
|
setTrustedServer bool
|
|
wantLen3PIDs int
|
|
}{
|
|
{
|
|
name: "can get associated threepid info",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/account/3pid", strings.NewReader("")),
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "can not set threepid info with invalid JSON",
|
|
request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader("")),
|
|
},
|
|
{
|
|
name: "can not set threepid info with untrusted server",
|
|
request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader("{}")),
|
|
},
|
|
{
|
|
name: "can check threepid info with trusted server, but unverified",
|
|
request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"fail"}}`, identityServerBase))),
|
|
setTrustedServer: true,
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "can check threepid info with trusted server, but fails for some other reason",
|
|
request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"fail2"}}`, identityServerBase))),
|
|
setTrustedServer: true,
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "can check threepid info with trusted server, but fails because of invalid json",
|
|
request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"fail3"}}`, identityServerBase))),
|
|
setTrustedServer: true,
|
|
wantOK: false,
|
|
},
|
|
{
|
|
name: "can save threepid info with trusted server",
|
|
request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"success"}}`, identityServerBase))),
|
|
setTrustedServer: true,
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "can save threepid info with trusted server using bind=true",
|
|
request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid", strings.NewReader(fmt.Sprintf(`{"three_pid_creds":{"id_server":"%s","client_secret":"success2"},"bind":true}`, identityServerBase))),
|
|
setTrustedServer: true,
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "can get associated threepid info again",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/account/3pid", strings.NewReader("")),
|
|
wantOK: true,
|
|
wantLen3PIDs: 2,
|
|
},
|
|
{
|
|
name: "can delete associated threepid info",
|
|
request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid/delete", strings.NewReader(`{"medium":"email","address":"somerandom@address.com"}`)),
|
|
wantOK: true,
|
|
},
|
|
{
|
|
name: "can get associated threepid after deleting association",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/account/3pid", strings.NewReader("")),
|
|
wantOK: true,
|
|
wantLen3PIDs: 1,
|
|
},
|
|
{
|
|
name: "can not request emailToken with invalid request body",
|
|
request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid/email/requestToken", strings.NewReader("")),
|
|
},
|
|
{
|
|
name: "can not request emailToken for in use address",
|
|
request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid/email/requestToken", strings.NewReader(fmt.Sprintf(`{"client_secret":"somesecret","email":"","send_attempt":1,"id_server":"%s"}`, identityServerBase))),
|
|
},
|
|
{
|
|
name: "can request emailToken",
|
|
request: httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/account/3pid/email/requestToken", strings.NewReader(fmt.Sprintf(`{"client_secret":"somesecret","email":"somerandom@address.com","send_attempt":1,"id_server":"%s"}`, identityServerBase))),
|
|
wantOK: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
|
|
if tc.setTrustedServer {
|
|
cfg.Global.TrustedIDServers = []string{identityServerBase}
|
|
}
|
|
|
|
rec := httptest.NewRecorder()
|
|
tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
|
|
routers.Client.ServeHTTP(rec, tc.request)
|
|
t.Logf("Response: %s", rec.Body.String())
|
|
if tc.wantOK && rec.Code != http.StatusOK {
|
|
t.Fatalf("expected HTTP 200, got %d: %s", rec.Code, rec.Body.String())
|
|
}
|
|
if !tc.wantOK && rec.Code == http.StatusOK {
|
|
t.Fatalf("expected request to fail, but didn't: %s", rec.Body.String())
|
|
}
|
|
if tc.wantLen3PIDs > 0 {
|
|
var resp routing.ThreePIDsResponse
|
|
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(resp.ThreePIDs) != tc.wantLen3PIDs {
|
|
t.Fatalf("expected %d threepids, got %d", tc.wantLen3PIDs, len(resp.ThreePIDs))
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestPushRules(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
|
|
// create the default push rules, used when validating responses
|
|
localpart, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID)
|
|
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
|
defaultRules, err := json.Marshal(pushRuleSets)
|
|
assert.NoError(t, err)
|
|
|
|
ruleID1 := "myrule"
|
|
ruleID2 := "myrule2"
|
|
ruleID3 := "myrule3"
|
|
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
|
cfg.ClientAPI.RateLimiting.Enabled = false
|
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
|
natsInstance := jetstream.NATSInstance{}
|
|
defer close()
|
|
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
|
|
|
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
}
|
|
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
request *http.Request
|
|
wantStatusCode int
|
|
validateFunc func(t *testing.T, respBody *bytes.Buffer) // used when updating rules, otherwise wantStatusCode should be enough
|
|
queryAttr map[string]string
|
|
}{
|
|
{
|
|
name: "can not get rules without trailing slash",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can get default rules",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/", strings.NewReader("")),
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
assert.Equal(t, defaultRules, respBody.Bytes())
|
|
},
|
|
},
|
|
{
|
|
name: "can get rules by scope",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/", strings.NewReader("")),
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
assert.Equal(t, gjson.GetBytes(defaultRules, "global").Raw, respBody.String())
|
|
},
|
|
},
|
|
{
|
|
name: "can not get invalid rules by scope",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not get rules for invalid scope and kind",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/invalid/", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not get rules for invalid kind",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can get rules by scope and kind",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/", strings.NewReader("")),
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
assert.Equal(t, gjson.GetBytes(defaultRules, "global.override").Raw, respBody.String())
|
|
},
|
|
},
|
|
{
|
|
name: "can get rules by scope and content kind",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader("")),
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
assert.Equal(t, gjson.GetBytes(defaultRules, "global.content").Raw, respBody.String())
|
|
},
|
|
},
|
|
{
|
|
name: "can not get rules by scope and room kind",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/room/", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not get rules by scope and sender kind",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/sender/", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can get rules by scope and underride kind",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/underride/", strings.NewReader("")),
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
assert.Equal(t, gjson.GetBytes(defaultRules, "global.underride").Raw, respBody.String())
|
|
},
|
|
},
|
|
{
|
|
name: "can not get rules by scope, kind and ID for invalid scope",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/doesnotexist/doesnotexist/.m.rule.master", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not get rules by scope, kind and ID for invalid kind",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/doesnotexist/.m.rule.master", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can get rules by scope, kind and ID",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master", strings.NewReader("")),
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
{
|
|
name: "can not get rules by scope, kind and ID for invalid ID",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.doesnotexist", strings.NewReader("")),
|
|
wantStatusCode: http.StatusNotFound,
|
|
},
|
|
{
|
|
name: "can not get status for invalid attribute",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/invalid", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not get status for invalid kind",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not get enabled status for invalid scope",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not get enabled status for invalid rule",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/doesnotexist/enabled", strings.NewReader("")),
|
|
wantStatusCode: http.StatusNotFound,
|
|
},
|
|
{
|
|
name: "can get enabled rules by scope, kind and ID",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")),
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
assert.False(t, gjson.GetBytes(respBody.Bytes(), "enabled").Bool(), "expected master rule to be disabled")
|
|
},
|
|
},
|
|
{
|
|
name: "can get actions scope, kind and ID",
|
|
request: httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader("")),
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
actions := gjson.GetBytes(respBody.Bytes(), "actions").Array()
|
|
// only a basic check
|
|
assert.Equal(t, 1, len(actions))
|
|
},
|
|
},
|
|
{
|
|
name: "can not set enabled status with invalid JSON",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not set attribute for invalid attribute",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/doesnotexist", strings.NewReader("{}")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not set attribute for invalid scope",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/override/.m.rule.master/enabled", strings.NewReader("{}")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not set attribute for invalid kind",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/invalid/.m.rule.master/enabled", strings.NewReader("{}")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not set attribute for invalid rule",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/invalid/enabled", strings.NewReader("{}")),
|
|
wantStatusCode: http.StatusNotFound,
|
|
},
|
|
{
|
|
name: "can set enabled status with valid JSON",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader(`{"enabled":true}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/enabled", strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
|
assert.True(t, gjson.GetBytes(rec.Body.Bytes(), "enabled").Bool(), "expected master rule to be enabled: %s", rec.Body.String())
|
|
},
|
|
},
|
|
{
|
|
name: "can set actions with valid JSON",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader(`{"actions":["dont_notify","notify"]}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/override/.m.rule.master/actions", strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
|
assert.Equal(t, 2, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 2 actions %s", rec.Body.String())
|
|
},
|
|
},
|
|
{
|
|
name: "can not create new push rule with invalid JSON",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not create new push rule with invalid rule content",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader("{}")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not create new push rule with invalid scope",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/invalid/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can create new push rule with valid rule content",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule/actions", strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
|
assert.Equal(t, 1, len(gjson.GetBytes(rec.Body.Bytes(), "actions").Array()), "expected 1 action %s", rec.Body.String())
|
|
},
|
|
},
|
|
{
|
|
name: "can not create new push starting with a dot",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/.myrule", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can create new push rule after existing",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
|
|
queryAttr: map[string]string{
|
|
"after": ruleID1,
|
|
},
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
|
rules := gjson.ParseBytes(rec.Body.Bytes())
|
|
for i, rule := range rules.Array() {
|
|
if rule.Get("rule_id").Str == ruleID1 && i != 0 {
|
|
t.Fatalf("expected '%s' to be the first, but wasn't", ruleID1)
|
|
}
|
|
if rule.Get("rule_id").Str == ruleID2 && i != 1 {
|
|
t.Fatalf("expected '%s' to be the second, but wasn't", ruleID2)
|
|
}
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "can create new push rule before existing",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule3", strings.NewReader(`{"actions":["notify"],"pattern":"world"}`)),
|
|
queryAttr: map[string]string{
|
|
"before": ruleID1,
|
|
},
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
|
rules := gjson.ParseBytes(rec.Body.Bytes())
|
|
for i, rule := range rules.Array() {
|
|
if rule.Get("rule_id").Str == ruleID3 && i != 0 {
|
|
t.Fatalf("expected '%s' to be the first, but wasn't", ruleID3)
|
|
}
|
|
if rule.Get("rule_id").Str == ruleID1 && i != 1 {
|
|
t.Fatalf("expected '%s' to be the second, but wasn't", ruleID1)
|
|
}
|
|
if rule.Get("rule_id").Str == ruleID2 && i != 2 {
|
|
t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1)
|
|
}
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "can modify existing push rule",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)),
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/myrule2/actions", strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
|
actions := gjson.GetBytes(rec.Body.Bytes(), "actions").Array()
|
|
// there should only be one action
|
|
assert.Equal(t, "dont_notify", actions[0].Str)
|
|
},
|
|
},
|
|
{
|
|
name: "can move existing push rule to the front",
|
|
request: httptest.NewRequest(http.MethodPut, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader(`{"actions":["dont_notify"],"pattern":"world"}`)),
|
|
queryAttr: map[string]string{
|
|
"before": ruleID3,
|
|
},
|
|
wantStatusCode: http.StatusOK,
|
|
validateFunc: func(t *testing.T, respBody *bytes.Buffer) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/pushrules/global/content/", strings.NewReader(""))
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code, rec.Body.String())
|
|
rules := gjson.ParseBytes(rec.Body.Bytes())
|
|
for i, rule := range rules.Array() {
|
|
if rule.Get("rule_id").Str == ruleID2 && i != 0 {
|
|
t.Fatalf("expected '%s' to be the first, but wasn't", ruleID2)
|
|
}
|
|
if rule.Get("rule_id").Str == ruleID3 && i != 1 {
|
|
t.Fatalf("expected '%s' to be the second, but wasn't", ruleID3)
|
|
}
|
|
if rule.Get("rule_id").Str == ruleID1 && i != 2 {
|
|
t.Fatalf("expected '%s' to be the third, but wasn't", ruleID1)
|
|
}
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "can not delete push rule with invalid scope",
|
|
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/invalid/content/myrule2", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not delete push rule with invalid kind",
|
|
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/invalid/myrule2", strings.NewReader("")),
|
|
wantStatusCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "can not delete push rule with non-existent rule",
|
|
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/doesnotexist", strings.NewReader("")),
|
|
wantStatusCode: http.StatusNotFound,
|
|
},
|
|
{
|
|
name: "can delete existing push rule",
|
|
request: httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/pushrules/global/content/myrule2", strings.NewReader("")),
|
|
wantStatusCode: http.StatusOK,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
|
|
if tc.queryAttr != nil {
|
|
params := url.Values{}
|
|
for k, v := range tc.queryAttr {
|
|
params.Set(k, v)
|
|
}
|
|
|
|
tc.request = httptest.NewRequest(tc.request.Method, tc.request.URL.String()+"?"+params.Encode(), tc.request.Body)
|
|
}
|
|
|
|
tc.request.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
|
|
routers.Client.ServeHTTP(rec, tc.request)
|
|
assert.Equal(t, tc.wantStatusCode, rec.Code, rec.Body.String())
|
|
if tc.validateFunc != nil {
|
|
tc.validateFunc(t, rec.Body)
|
|
}
|
|
t.Logf("%s", rec.Body.String())
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
// Tests the `/keys` endpoints.
|
|
// Note that this only tests the happy path.
|
|
func TestKeys(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
|
|
ctx := context.Background()
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
|
cfg.ClientAPI.RateLimiting.Enabled = false
|
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
|
natsInstance := jetstream.NATSInstance{}
|
|
defer close()
|
|
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
|
|
|
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
}
|
|
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
|
|
|
// Start a TLSServer with our client mux
|
|
srv := httptest.NewTLSServer(routers.Client)
|
|
defer srv.Close()
|
|
|
|
cl, err := mautrix.NewClient(srv.URL, id.UserID(alice.ID), accessTokens[alice].accessToken)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
// Set the client so the self-signed certificate is trusted
|
|
cl.Client = srv.Client()
|
|
cl.DeviceID = id.DeviceID(accessTokens[alice].deviceID)
|
|
|
|
cs := crypto.NewMemoryStore(nil)
|
|
oc := crypto.NewOlmMachine(cl, nil, cs, dummyStore{})
|
|
if err = oc.Load(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// tests `/keys/upload`
|
|
if err = oc.ShareKeys(ctx, 0); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// tests `/keys/device_signing/upload`
|
|
_, err = oc.GenerateAndUploadCrossSigningKeys(accessTokens[alice].password, "passphrase")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// tests `/keys/query`
|
|
dev, err := oc.GetOrFetchDevice(ctx, id.UserID(alice.ID), id.DeviceID(accessTokens[alice].deviceID))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Validate that the keys returned from the server are what the client has stored
|
|
oi := oc.OwnIdentity()
|
|
if oi.SigningKey != dev.SigningKey {
|
|
t.Fatalf("expected signing key '%s', got '%s'", oi.SigningKey, dev.SigningKey)
|
|
}
|
|
if oi.IdentityKey != dev.IdentityKey {
|
|
t.Fatalf("expected identity '%s', got '%s'", oi.IdentityKey, dev.IdentityKey)
|
|
}
|
|
|
|
// tests `/keys/signatures/upload`
|
|
if err = oc.SignOwnMasterKey(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// tests `/keys/claim`
|
|
otks := make(map[string]map[string]string)
|
|
otks[alice.ID] = map[string]string{
|
|
accessTokens[alice].deviceID: string(id.KeyAlgorithmSignedCurve25519),
|
|
}
|
|
|
|
data, err := json.Marshal(claimKeysRequest{OneTimeKeys: otks})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req, err := http.NewRequest(http.MethodPost, srv.URL+"/_matrix/client/v3/keys/claim", bytes.NewBuffer(data))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
resp, err := srv.Client().Do(req)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !gjson.GetBytes(respBody, "one_time_keys."+alice.ID+"."+string(dev.DeviceID)).Exists() {
|
|
t.Fatalf("expected one time keys for alice, but didn't find any: %s", string(respBody))
|
|
}
|
|
})
|
|
}
|
|
|
|
type claimKeysRequest struct {
|
|
// The keys to be claimed. A map from user ID, to a map from device ID to algorithm name.
|
|
OneTimeKeys map[string]map[string]string `json:"one_time_keys"`
|
|
}
|
|
|
|
type dummyStore struct{}
|
|
|
|
func (d dummyStore) IsEncrypted(roomID id.RoomID) bool {
|
|
return true
|
|
}
|
|
|
|
func (d dummyStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
|
|
return &event.EncryptionEventContent{}
|
|
}
|
|
|
|
func (d dummyStore) FindSharedRooms(userID id.UserID) []id.RoomID {
|
|
return []id.RoomID{}
|
|
}
|
|
|
|
func TestKeyBackup(t *testing.T) {
|
|
alice := test.NewUser(t)
|
|
|
|
handleResponseCode := func(t *testing.T, rec *httptest.ResponseRecorder, expectedCode int) {
|
|
t.Helper()
|
|
if rec.Code != expectedCode {
|
|
t.Fatalf("expected HTTP %d, but got %d: %s", expectedCode, rec.Code, rec.Body.String())
|
|
}
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
request func(t *testing.T) *http.Request
|
|
validate func(t *testing.T, rec *httptest.ResponseRecorder)
|
|
}{
|
|
{
|
|
name: "can not create backup with invalid JSON",
|
|
request: func(t *testing.T) *http.Request {
|
|
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"`) // missing closing braces
|
|
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusBadRequest)
|
|
},
|
|
},
|
|
{
|
|
name: "can not create backup with missing auth_data", // as this would result in MarshalJSON errors when querying again
|
|
request: func(t *testing.T) *http.Request {
|
|
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1"}`)
|
|
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusBadRequest)
|
|
},
|
|
},
|
|
{
|
|
name: "can create backup",
|
|
request: func(t *testing.T) *http.Request {
|
|
reqBody := strings.NewReader(`{"algorithm":"m.megolm_backup.v1","auth_data":{"data":"random"}}`)
|
|
return httptest.NewRequest(http.MethodPost, "/_matrix/client/v3/room_keys/version", reqBody)
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
wantVersion := "1"
|
|
if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion {
|
|
t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "can not query backup for invalid version",
|
|
request: func(t *testing.T) *http.Request {
|
|
return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version/1337", nil)
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusNotFound)
|
|
},
|
|
},
|
|
{
|
|
name: "can not query backup for invalid version string",
|
|
request: func(t *testing.T) *http.Request {
|
|
return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version/notanumber", nil)
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusNotFound)
|
|
},
|
|
},
|
|
{
|
|
name: "can query backup",
|
|
request: func(t *testing.T) *http.Request {
|
|
return httptest.NewRequest(http.MethodGet, "/_matrix/client/v3/room_keys/version", nil)
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
wantVersion := "1"
|
|
if gotVersion := gjson.GetBytes(rec.Body.Bytes(), "version").Str; gotVersion != wantVersion {
|
|
t.Fatalf("expected version '%s', got '%s'", wantVersion, gotVersion)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "can query backup without returning rooms",
|
|
request: func(t *testing.T) *http.Request {
|
|
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys", test.WithQueryParams(map[string]string{
|
|
"version": "1",
|
|
}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "rooms").Map(); len(gotRooms) > 0 {
|
|
t.Fatalf("expected no rooms in version, but got %#v", gotRooms)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "can query backup for invalid room",
|
|
request: func(t *testing.T) *http.Request {
|
|
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!abc:test", test.WithQueryParams(map[string]string{
|
|
"version": "1",
|
|
}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
if gotSessions := gjson.GetBytes(rec.Body.Bytes(), "sessions").Map(); len(gotSessions) > 0 {
|
|
t.Fatalf("expected no sessions in version, but got %#v", gotSessions)
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "can not query backup for invalid session",
|
|
request: func(t *testing.T) *http.Request {
|
|
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!abc:test/doesnotexist", test.WithQueryParams(map[string]string{
|
|
"version": "1",
|
|
}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusNotFound)
|
|
},
|
|
},
|
|
{
|
|
name: "can not update backup with missing version",
|
|
request: func(t *testing.T) *http.Request {
|
|
return test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys")
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusBadRequest)
|
|
},
|
|
},
|
|
{
|
|
name: "can not update backup with invalid data",
|
|
request: func(t *testing.T) *http.Request {
|
|
reqBody := test.WithJSONBody(t, "")
|
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{
|
|
"version": "0",
|
|
}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusBadRequest)
|
|
},
|
|
},
|
|
{
|
|
name: "can not update backup with wrong version",
|
|
request: func(t *testing.T) *http.Request {
|
|
reqBody := test.WithJSONBody(t, map[string]interface{}{
|
|
"rooms": map[string]interface{}{
|
|
"!testroom:test": map[string]interface{}{
|
|
"sessions": map[string]uapi.KeyBackupSession{},
|
|
},
|
|
},
|
|
})
|
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{
|
|
"version": "5",
|
|
}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusForbidden)
|
|
},
|
|
},
|
|
{
|
|
name: "can update backup with correct version",
|
|
request: func(t *testing.T) *http.Request {
|
|
reqBody := test.WithJSONBody(t, map[string]interface{}{
|
|
"rooms": map[string]interface{}{
|
|
"!testroom:test": map[string]interface{}{
|
|
"sessions": map[string]uapi.KeyBackupSession{
|
|
"dummySession": {
|
|
FirstMessageIndex: 1,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
})
|
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys", reqBody, test.WithQueryParams(map[string]string{
|
|
"version": "1",
|
|
}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
},
|
|
},
|
|
{
|
|
name: "can update backup with correct version for specific room",
|
|
request: func(t *testing.T) *http.Request {
|
|
reqBody := test.WithJSONBody(t, map[string]interface{}{
|
|
"sessions": map[string]uapi.KeyBackupSession{
|
|
"dummySession": {
|
|
FirstMessageIndex: 1,
|
|
IsVerified: true,
|
|
SessionData: json.RawMessage("{}"),
|
|
},
|
|
},
|
|
})
|
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys/!testroom:test", reqBody, test.WithQueryParams(map[string]string{
|
|
"version": "1",
|
|
}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
t.Logf("%#v", rec.Body.String())
|
|
},
|
|
},
|
|
{
|
|
name: "can update backup with correct version for specific room and session",
|
|
request: func(t *testing.T) *http.Request {
|
|
reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{
|
|
FirstMessageIndex: 1,
|
|
SessionData: json.RawMessage("{}"),
|
|
IsVerified: true,
|
|
ForwardedCount: 0,
|
|
})
|
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/keys/!testroom:test/dummySession", reqBody, test.WithQueryParams(map[string]string{
|
|
"version": "1",
|
|
}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
},
|
|
},
|
|
{
|
|
name: "can update backup by version",
|
|
request: func(t *testing.T) *http.Request {
|
|
reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{
|
|
FirstMessageIndex: 1,
|
|
SessionData: json.RawMessage("{}"),
|
|
IsVerified: true,
|
|
ForwardedCount: 0,
|
|
})
|
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/version/1", reqBody, test.WithQueryParams(map[string]string{"version": "1"}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
t.Logf("%#v", rec.Body.String())
|
|
},
|
|
},
|
|
{
|
|
name: "can not update backup by version for invalid version",
|
|
request: func(t *testing.T) *http.Request {
|
|
reqBody := test.WithJSONBody(t, uapi.KeyBackupSession{
|
|
FirstMessageIndex: 1,
|
|
SessionData: json.RawMessage("{}"),
|
|
IsVerified: true,
|
|
ForwardedCount: 0,
|
|
})
|
|
req := test.NewRequest(t, http.MethodPut, "/_matrix/client/v3/room_keys/version/2", reqBody, test.WithQueryParams(map[string]string{"version": "1"}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
},
|
|
},
|
|
{
|
|
name: "can query backup sessions",
|
|
request: func(t *testing.T) *http.Request {
|
|
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys", test.WithQueryParams(map[string]string{
|
|
"version": "1",
|
|
}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "rooms").Map(); len(gotRooms) != 1 {
|
|
t.Fatalf("expected one room in response, but got %#v", rec.Body.String())
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "can query backup sessions by room",
|
|
request: func(t *testing.T) *http.Request {
|
|
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!testroom:test", test.WithQueryParams(map[string]string{
|
|
"version": "1",
|
|
}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
if gotRooms := gjson.GetBytes(rec.Body.Bytes(), "sessions").Map(); len(gotRooms) != 1 {
|
|
t.Fatalf("expected one session in response, but got %#v", rec.Body.String())
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "can query backup sessions by room and sessionID",
|
|
request: func(t *testing.T) *http.Request {
|
|
req := test.NewRequest(t, http.MethodGet, "/_matrix/client/v3/room_keys/keys/!testroom:test/dummySession", test.WithQueryParams(map[string]string{
|
|
"version": "1",
|
|
}))
|
|
return req
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
if !gjson.GetBytes(rec.Body.Bytes(), "is_verified").Bool() {
|
|
t.Fatalf("expected session to be verified, but wasn't: %#v", rec.Body.String())
|
|
}
|
|
},
|
|
},
|
|
{
|
|
name: "can not delete invalid version backup",
|
|
request: func(t *testing.T) *http.Request {
|
|
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/2", nil)
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusNotFound)
|
|
},
|
|
},
|
|
{
|
|
name: "can delete version backup",
|
|
request: func(t *testing.T) *http.Request {
|
|
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil)
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
},
|
|
},
|
|
{
|
|
name: "deleting the same backup version twice doesn't error",
|
|
request: func(t *testing.T) *http.Request {
|
|
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/1", nil)
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusOK)
|
|
},
|
|
},
|
|
{
|
|
name: "deleting an empty version doesn't work", // make sure we can't delete an empty backup version. Handled at the router level
|
|
request: func(t *testing.T) *http.Request {
|
|
return httptest.NewRequest(http.MethodDelete, "/_matrix/client/v3/room_keys/version/", nil)
|
|
},
|
|
validate: func(t *testing.T, rec *httptest.ResponseRecorder) {
|
|
handleResponseCode(t, rec, http.StatusNotFound)
|
|
},
|
|
},
|
|
}
|
|
|
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
|
cfg, processCtx, close := testrig.CreateConfig(t, dbType)
|
|
cfg.ClientAPI.RateLimiting.Enabled = false
|
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
|
natsInstance := jetstream.NATSInstance{}
|
|
defer close()
|
|
|
|
routers := httputil.NewRouters()
|
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics)
|
|
userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil)
|
|
|
|
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
|
|
AddPublicRoutes(processCtx, routers, cfg, &natsInstance, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, caching.DisableMetrics)
|
|
|
|
accessTokens := map[*test.User]userDevice{
|
|
alice: {},
|
|
}
|
|
createAccessTokens(t, accessTokens, userAPI, processCtx.Context(), routers)
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
req := tc.request(t)
|
|
req.Header.Set("Authorization", "Bearer "+accessTokens[alice].accessToken)
|
|
routers.Client.ServeHTTP(rec, req)
|
|
tc.validate(t, rec)
|
|
})
|
|
}
|
|
})
|
|
}
|