Merge keyserver & userapi (#2972)

As discussed yesterday, a first draft of merging the keyserver and the
userapi.
This commit is contained in:
Till 2023-02-20 14:58:03 +01:00 committed by GitHub
parent bd6f0c14e5
commit 4594233f89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
107 changed files with 1730 additions and 1863 deletions

View File

@ -38,7 +38,7 @@ import (
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
base *base.BaseDendrite,
userAPI userapi.UserInternalAPI,
userAPI userapi.AppserviceUserAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
) appserviceAPI.AppServiceInternalAPI {
client := &http.Client{

View File

@ -125,7 +125,7 @@ func TestAppserviceInternalAPI(t *testing.T) {
// Create required internal APIs
rsAPI := roomserver.NewInternalAPI(base)
usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
usrAPI := userapi.NewInternalAPI(base, rsAPI, nil)
asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI)
runCases(t, asAPI)

View File

@ -30,7 +30,6 @@ import (
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
@ -183,13 +182,11 @@ func startup() {
rsAPI := roomserver.NewInternalAPI(base)
federation := conn.CreateFederationClient(base, pSessions)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
userAPI := userapi.NewInternalAPI(base, rsAPI, federation)
asQuery := appservice.NewInternalAPI(
base, userAPI, rsAPI,
@ -208,7 +205,6 @@ func startup() {
FederationAPI: fedSenderAPI,
RoomserverAPI: rsAPI,
UserAPI: userAPI,
KeyAPI: keyAPI,
//ServerKeyAPI: serverKeyAPI,
ExtPublicRoomsProvider: rooms.NewPineconeRoomProvider(pRouter, pSessions, fedSenderAPI, federation),
}

View File

@ -20,7 +20,6 @@ import (
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
@ -165,9 +164,7 @@ func (m *DendriteMonolith) Start() {
base, federation, rsAPI, base.Caches, keyRing, true,
)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
userAPI := userapi.NewInternalAPI(base, rsAPI, federation)
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
rsAPI.SetAppserviceAPI(asAPI)
@ -186,7 +183,6 @@ func (m *DendriteMonolith) Start() {
FederationAPI: fsAPI,
RoomserverAPI: rsAPI,
UserAPI: userAPI,
KeyAPI: keyAPI,
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
ygg, fsAPI, federation,
),

View File

@ -8,7 +8,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
@ -40,11 +39,9 @@ func TestAdminResetPassword(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(base)
// Needed for changing the password/login
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
keyAPI.SetUserAPI(userAPI)
userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil)
AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil)
// Create the users in the userapi and login
accessTokens := map[*test.User]string{
@ -155,14 +152,12 @@ func TestPurgeRoom(t *testing.T) {
fedClient := base.CreateFederationClient()
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
// this starts the JetStream consumers
syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI)
syncapi.AddPublicRoutes(base, userAPI, rsAPI)
federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true)
rsAPI.SetFederationAPI(nil, nil)
keyAPI.SetUserAPI(userAPI)
// Create the room
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
@ -170,7 +165,7 @@ func TestPurgeRoom(t *testing.T) {
}
// We mostly need the rsAPI for this test, so nil for other APIs/caches etc.
AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil)
AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil)
// Create the users in the userapi and login
accessTokens := map[*test.User]string{

View File

@ -15,6 +15,7 @@
package clientapi
import (
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
@ -23,11 +24,9 @@ import (
"github.com/matrix-org/dendrite/clientapi/routing"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/transactions"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
// AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component.
@ -40,7 +39,6 @@ func AddPublicRoutes(
fsAPI federationAPI.ClientFederationAPI,
userAPI userapi.ClientUserAPI,
userDirectoryProvider userapi.QuerySearchProfilesAPI,
keyAPI keyserverAPI.ClientKeyAPI,
extRoomsProvider api.ExtraPublicRoomsProvider,
) {
cfg := &base.Cfg.ClientAPI
@ -61,7 +59,7 @@ func AddPublicRoutes(
base,
cfg, rsAPI, asAPI,
userAPI, userDirectoryProvider, federation,
syncProducer, transactionsCache, fsAPI, keyAPI,
syncProducer, transactionsCache, fsAPI,
extRoomsProvider, mscCfg, natsClient,
)
}

View File

@ -16,14 +16,13 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/api"
)
func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -56,7 +55,7 @@ func AdminEvacuateRoom(req *http.Request, cfg *config.ClientAPI, device *userapi
}
}
func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -99,7 +98,7 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi
}
}
func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
@ -130,7 +129,7 @@ func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.De
}
}
func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse {
func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *api.Device, userAPI api.ClientUserAPI) util.JSONResponse {
if req.Body == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
@ -150,8 +149,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
JSON: jsonerror.InvalidArgumentValue(err.Error()),
}
}
accAvailableResp := &userapi.QueryAccountAvailabilityResponse{}
if err = userAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{
accAvailableResp := &api.QueryAccountAvailabilityResponse{}
if err = userAPI.QueryAccountAvailability(req.Context(), &api.QueryAccountAvailabilityRequest{
Localpart: localpart,
ServerName: serverName,
}, accAvailableResp); err != nil {
@ -186,13 +185,13 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
return *internal.PasswordResponse(err)
}
updateReq := &userapi.PerformPasswordUpdateRequest{
updateReq := &api.PerformPasswordUpdateRequest{
Localpart: localpart,
ServerName: serverName,
Password: request.Password,
LogoutDevices: true,
}
updateRes := &userapi.PerformPasswordUpdateResponse{}
updateRes := &api.PerformPasswordUpdateResponse{}
if err := userAPI.PerformPasswordUpdate(req.Context(), updateReq, updateRes); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
@ -209,7 +208,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
}
}
func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, natsClient *nats.Conn) util.JSONResponse {
func AdminReindex(req *http.Request, cfg *config.ClientAPI, device *api.Device, natsClient *nats.Conn) util.JSONResponse {
_, err := natsClient.RequestMsg(nats.NewMsg(cfg.Matrix.JetStream.Prefixed(jetstream.InputFulltextReindex)), time.Second*10)
if err != nil {
logrus.WithError(err).Error("failed to publish nats message")
@ -255,7 +254,7 @@ func AdminMarkAsStale(req *http.Request, cfg *config.ClientAPI, keyAPI api.Clien
}
}
func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
func AdminDownloadState(req *http.Request, cfg *config.ClientAPI, device *api.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)

View File

@ -10,7 +10,6 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
@ -29,8 +28,7 @@ func TestJoinRoomByIDOrAlias(t *testing.T) {
defer baseClose()
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc

View File

@ -21,9 +21,8 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
@ -34,8 +33,8 @@ type crossSigningRequest struct {
func UploadCrossSigningDeviceKeys(
req *http.Request, userInteractiveAuth *auth.UserInteractive,
keyserverAPI api.ClientKeyAPI, device *userapi.Device,
accountAPI userapi.ClientUserAPI, cfg *config.ClientAPI,
keyserverAPI api.ClientKeyAPI, device *api.Device,
accountAPI api.ClientUserAPI, cfg *config.ClientAPI,
) util.JSONResponse {
uploadReq := &crossSigningRequest{}
uploadRes := &api.PerformUploadDeviceKeysResponse{}
@ -107,7 +106,7 @@ func UploadCrossSigningDeviceKeys(
}
}
func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse {
func UploadCrossSigningDeviceSignatures(req *http.Request, keyserverAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
uploadReq := &api.PerformUploadDeviceSignaturesRequest{}
uploadRes := &api.PerformUploadDeviceSignaturesResponse{}

View File

@ -23,8 +23,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/keyserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/api"
)
type uploadKeysRequest struct {
@ -32,7 +31,7 @@ type uploadKeysRequest struct {
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
}
func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse {
func UploadKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
var r uploadKeysRequest
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
@ -106,7 +105,7 @@ func (r *queryKeysRequest) GetTimeout() time.Duration {
return timeout
}
func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *userapi.Device) util.JSONResponse {
func QueryKeys(req *http.Request, keyAPI api.ClientKeyAPI, device *api.Device) util.JSONResponse {
var r queryKeysRequest
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {

View File

@ -9,7 +9,6 @@ import (
"testing"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
@ -39,12 +38,10 @@ func TestLogin(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(base)
// Needed for /login
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
keyAPI.SetUserAPI(userAPI)
userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
// We mostly need the userAPI for this test, so nil for other APIs/caches etc.
Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, keyAPI, nil, &base.Cfg.MSCs, nil)
Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, nil, &base.Cfg.MSCs, nil)
// Create password
password := util.RandomString(8)

View File

@ -30,7 +30,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
@ -409,9 +408,7 @@ func Test_register(t *testing.T) {
defer baseClose()
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
keyAPI.SetUserAPI(userAPI)
userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
@ -582,9 +579,7 @@ func TestRegisterUserWithDisplayName(t *testing.T) {
base.Cfg.Global.ServerName = "server"
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
keyAPI.SetUserAPI(userAPI)
userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
deviceName, deviceID := "deviceName", "deviceID"
expectedDisplayName := "DisplayName"
response := completeRegistration(
@ -623,9 +618,7 @@ func TestRegisterAdminUsingSharedSecret(t *testing.T) {
base.Cfg.ClientAPI.RegistrationSharedSecret = sharedSecret
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
keyAPI.SetUserAPI(userAPI)
userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
expectedDisplayName := "rabbit"
jsonStr := []byte(`{"admin":true,"mac":"24dca3bba410e43fe64b9b5c28306693bf3baa9f","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`)

View File

@ -22,6 +22,7 @@ import (
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/setup/base"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go"
@ -37,11 +38,9 @@ import (
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/transactions"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
// Setup registers HTTP handlers with the given ServeMux. It also supplies the given http.Client
@ -61,7 +60,6 @@ func Setup(
syncProducer *producers.SyncAPIProducer,
transactionsCache *transactions.Cache,
federationSender federationAPI.ClientFederationAPI,
keyAPI keyserverAPI.ClientKeyAPI,
extRoomsProvider api.ExtraPublicRoomsProvider,
mscCfg *config.MSCs, natsClient *nats.Conn,
) {
@ -192,7 +190,7 @@ func Setup(
dendriteAdminRouter.Handle("/admin/refreshDevices/{userID}",
httputil.MakeAdminAPI("admin_refresh_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminMarkAsStale(req, cfg, keyAPI)
return AdminMarkAsStale(req, cfg, userAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
@ -1372,11 +1370,11 @@ func Setup(
// Cross-signing device keys
postDeviceSigningKeys := httputil.MakeAuthAPI("post_device_signing_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, keyAPI, device, userAPI, cfg)
return UploadCrossSigningDeviceKeys(req, userInteractiveAuth, userAPI, device, userAPI, cfg)
})
postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadCrossSigningDeviceSignatures(req, keyAPI, device)
return UploadCrossSigningDeviceSignatures(req, userAPI, device)
}, httputil.WithAllowGuests())
v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions)
@ -1388,22 +1386,22 @@ func Setup(
// Supplying a device ID is deprecated.
v3mux.Handle("/keys/upload/{deviceID}",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device)
return UploadKeys(req, userAPI, device)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/upload",
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return UploadKeys(req, keyAPI, device)
return UploadKeys(req, userAPI, device)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/query",
httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return QueryKeys(req, keyAPI, device)
return QueryKeys(req, userAPI, device)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/keys/claim",
httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return ClaimKeys(req, keyAPI)
return ClaimKeys(req, userAPI)
}, httputil.WithAllowGuests()),
).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}",

View File

@ -38,7 +38,6 @@ import (
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/relayapi"
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
"github.com/matrix-org/dendrite/roomserver"
@ -139,9 +138,7 @@ func (p *P2PMonolith) SetupDendrite(cfg *config.Dendrite, port int, enableRelayi
p.BaseDendrite, federation, rsAPI, p.BaseDendrite.Caches, keyRing, true,
)
keyAPI := keyserver.NewInternalAPI(p.BaseDendrite, &p.BaseDendrite.Cfg.KeyServer, fsAPI, rsComponent)
userAPI := userapi.NewInternalAPI(p.BaseDendrite, &cfg.UserAPI, nil, keyAPI, rsAPI, p.BaseDendrite.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
userAPI := userapi.NewInternalAPI(p.BaseDendrite, rsAPI, federation)
asAPI := appservice.NewInternalAPI(p.BaseDendrite, userAPI, rsAPI)
@ -175,7 +172,6 @@ func (p *P2PMonolith) SetupDendrite(cfg *config.Dendrite, port int, enableRelayi
FederationAPI: fsAPI,
RoomserverAPI: rsAPI,
UserAPI: userAPI,
KeyAPI: keyAPI,
RelayAPI: relayAPI,
ExtPublicRoomsProvider: roomProvider,
ExtUserDirectoryProvider: userProvider,

View File

@ -39,7 +39,6 @@ import (
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
@ -161,10 +160,7 @@ func main() {
base,
)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
userAPI := userapi.NewInternalAPI(base, rsAPI, federation)
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
rsAPI.SetAppserviceAPI(asAPI)
@ -184,7 +180,6 @@ func main() {
FederationAPI: fsAPI,
RoomserverAPI: rsAPI,
UserAPI: userAPI,
KeyAPI: keyAPI,
ExtPublicRoomsProvider: yggrooms.NewYggdrasilRoomProvider(
ygg, fsAPI, federation,
),

View File

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup"
basepkg "github.com/matrix-org/dendrite/setup/base"
@ -56,10 +55,7 @@ func main() {
keyRing := fsAPI.KeyRing()
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
pgClient := base.PushGatewayHTTPClient()
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient)
userAPI := userapi.NewInternalAPI(base, rsAPI, federation)
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
@ -69,7 +65,6 @@ func main() {
rsAPI.SetFederationAPI(fsAPI, keyRing)
rsAPI.SetAppserviceAPI(asAPI)
rsAPI.SetUserAPI(userAPI)
keyAPI.SetUserAPI(userAPI)
monolith := setup.Monolith{
Config: base.Cfg,
@ -83,7 +78,6 @@ func main() {
FederationAPI: fsAPI,
RoomserverAPI: rsAPI,
UserAPI: userAPI,
KeyAPI: keyAPI,
}
monolith.AddAllPublicRoutes(base)

View File

@ -26,11 +26,11 @@ import (
"github.com/matrix-org/dendrite/federationapi/queue"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/userapi/api"
)
// KeyChangeConsumer consumes events that originate in key server.

View File

@ -28,7 +28,6 @@ import (
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/internal/caching"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream"
@ -42,12 +41,11 @@ import (
// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component.
func AddPublicRoutes(
base *base.BaseDendrite,
userAPI userapi.UserInternalAPI,
userAPI userapi.FederationUserAPI,
federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.JSONVerifier,
rsAPI roomserverAPI.FederationRoomserverAPI,
fedAPI federationAPI.FederationInternalAPI,
keyAPI keyserverAPI.FederationKeyAPI,
servers federationAPI.ServersInRoomProvider,
) {
cfg := &base.Cfg.FederationAPI
@ -79,7 +77,7 @@ func AddPublicRoutes(
routing.Setup(
base,
rsAPI, f, keyRing,
federation, userAPI, keyAPI, mscCfg,
federation, userAPI, mscCfg,
servers, producer,
)
}

View File

@ -17,13 +17,13 @@ import (
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/internal"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
type fedRoomserverAPI struct {
@ -230,9 +230,9 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
// Inject a keyserver key change event and ensure we try to send it out. If we don't, then the
// federationapi is incorrectly waiting for an output room event to arrive to update the joined
// hosts table.
key := keyapi.DeviceMessage{
Type: keyapi.TypeDeviceKeyUpdate,
DeviceKeys: &keyapi.DeviceKeys{
key := userapi.DeviceMessage{
Type: userapi.TypeDeviceKeyUpdate,
DeviceKeys: &userapi.DeviceKeys{
UserID: joiningUser.ID,
DeviceID: "MY_DEVICE",
DisplayName: "BLARGLE",
@ -277,7 +277,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
keyRing := &test.NopJSONVerifier{}
// TODO: This is pretty fragile, as if anything calls anything on these nils this test will break.
// Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing.
federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil)
federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil)
baseURL, cancel := test.ListenAndServe(t, b.PublicFederationAPIMux, true)
defer cancel()
serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://"))

View File

@ -41,7 +41,7 @@ type SyncAPIProducer struct {
TopicSigningKeyUpdate string
JetStream nats.JetStreamContext
Config *config.FederationAPI
UserAPI userapi.UserInternalAPI
UserAPI userapi.FederationUserAPI
}
func (p *SyncAPIProducer) SendReceipt(

View File

@ -17,7 +17,7 @@ import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson"
@ -26,11 +26,11 @@ import (
// GetUserDevices for the given user id
func GetUserDevices(
req *http.Request,
keyAPI keyapi.FederationKeyAPI,
keyAPI api.FederationKeyAPI,
userID string,
) util.JSONResponse {
var res keyapi.QueryDeviceMessagesResponse
if err := keyAPI.QueryDeviceMessages(req.Context(), &keyapi.QueryDeviceMessagesRequest{
var res api.QueryDeviceMessagesResponse
if err := keyAPI.QueryDeviceMessages(req.Context(), &api.QueryDeviceMessagesRequest{
UserID: userID,
}, &res); err != nil {
return util.ErrorResponse(err)
@ -40,12 +40,12 @@ func GetUserDevices(
return jsonerror.InternalServerError()
}
sigReq := &keyapi.QuerySignaturesRequest{
sigReq := &api.QuerySignaturesRequest{
TargetIDs: map[string][]gomatrixserverlib.KeyID{
userID: {},
},
}
sigRes := &keyapi.QuerySignaturesResponse{}
sigRes := &api.QuerySignaturesResponse{}
for _, dev := range res.Devices {
sigReq.TargetIDs[userID] = append(sigReq.TargetIDs[userID], gomatrixserverlib.KeyID(dev.DeviceID))
}

View File

@ -22,8 +22,8 @@ import (
clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"

View File

@ -62,7 +62,7 @@ func TestHandleQueryProfile(t *testing.T) {
if !ok {
panic("This is a programming error.")
}
routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil)
routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil)
handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP
_, sk, _ := ed25519.GenerateKey(nil)

View File

@ -62,7 +62,7 @@ func TestHandleQueryDirectory(t *testing.T) {
if !ok {
panic("This is a programming error.")
}
routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil)
routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, &base.Cfg.MSCs, nil, nil)
handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP
_, sk, _ := ed25519.GenerateKey(nil)

View File

@ -29,7 +29,6 @@ import (
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/httputil"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
@ -62,7 +61,6 @@ func Setup(
keys gomatrixserverlib.JSONVerifier,
federation federationAPI.FederationClient,
userAPI userapi.FederationUserAPI,
keyAPI keyserverAPI.FederationKeyAPI,
mscCfg *config.MSCs,
servers federationAPI.ServersInRoomProvider,
producer *producers.SyncAPIProducer,
@ -141,7 +139,7 @@ func Setup(
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return Send(
httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]),
cfg, rsAPI, keyAPI, keys, federation, mu, servers, producer,
cfg, rsAPI, userAPI, keys, federation, mu, servers, producer,
)
},
)).Methods(http.MethodPut, http.MethodOptions).Name(SendRouteName)
@ -269,7 +267,7 @@ func Setup(
"federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return GetUserDevices(
httpReq, keyAPI, vars["userID"],
httpReq, userAPI, vars["userID"],
)
},
)).Methods(http.MethodGet)
@ -494,14 +492,14 @@ func Setup(
v1fedmux.Handle("/user/keys/claim", MakeFedAPI(
"federation_keys_claim", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return ClaimOneTimeKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
return ClaimOneTimeKeys(httpReq, request, userAPI, cfg.Matrix.ServerName)
},
)).Methods(http.MethodPost)
v1fedmux.Handle("/user/keys/query", MakeFedAPI(
"federation_keys_query", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
return QueryDeviceKeys(httpReq, request, userAPI, cfg.Matrix.ServerName)
},
)).Methods(http.MethodPost)

View File

@ -28,9 +28,9 @@ import (
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/internal"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
userAPI "github.com/matrix-org/dendrite/userapi/api"
)
const (
@ -59,7 +59,7 @@ func Send(
txnID gomatrixserverlib.TransactionID,
cfg *config.FederationAPI,
rsAPI api.FederationRoomserverAPI,
keyAPI keyapi.FederationKeyAPI,
keyAPI userAPI.FederationUserAPI,
keys gomatrixserverlib.JSONVerifier,
federation federationAPI.FederationClient,
mu *internal.MutexByRoom,

View File

@ -58,7 +58,7 @@ func TestHandleSend(t *testing.T) {
if !ok {
panic("This is a programming error.")
}
routing.Setup(base, nil, r, keyRing, nil, nil, nil, &base.Cfg.MSCs, nil, nil)
routing.Setup(base, nil, r, keyRing, nil, nil, &base.Cfg.MSCs, nil, nil)
handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP
_, sk, _ := ed25519.GenerateKey(nil)

View File

@ -24,9 +24,9 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/federationapi/types"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
syncTypes "github.com/matrix-org/dendrite/syncapi/types"
userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
@ -56,7 +56,7 @@ var (
type TxnReq struct {
gomatrixserverlib.Transaction
rsAPI api.FederationRoomserverAPI
keyAPI keyapi.FederationKeyAPI
userAPI userAPI.FederationUserAPI
ourServerName gomatrixserverlib.ServerName
keys gomatrixserverlib.JSONVerifier
roomsMu *MutexByRoom
@ -66,7 +66,7 @@ type TxnReq struct {
func NewTxnReq(
rsAPI api.FederationRoomserverAPI,
keyAPI keyapi.FederationKeyAPI,
userAPI userAPI.FederationUserAPI,
ourServerName gomatrixserverlib.ServerName,
keys gomatrixserverlib.JSONVerifier,
roomsMu *MutexByRoom,
@ -80,7 +80,7 @@ func NewTxnReq(
) TxnReq {
t := TxnReq{
rsAPI: rsAPI,
keyAPI: keyAPI,
userAPI: userAPI,
ourServerName: ourServerName,
keys: keys,
roomsMu: roomsMu,

View File

@ -23,13 +23,13 @@ import (
"time"
"github.com/matrix-org/dendrite/federationapi/producers"
keyAPI "github.com/matrix-org/dendrite/keyserver/api"
rsAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
keyAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/stretchr/testify/assert"

View File

@ -1,19 +0,0 @@
## Key Server
This is an internal component which manages E2E keys from clients. It handles all the [Key Management APIs](https://matrix.org/docs/spec/client_server/r0.6.1#key-management-api) with the exception of `/keys/changes` which is handled by Sync API. This component is designed to shard by user ID.
Keys are uploaded and stored in this component, and key changes are emitted to a Kafka topic for downstream components such as Sync API.
### Internal APIs
- `PerformUploadKeys` stores identity keys and one-time public keys for given user(s).
- `PerformClaimKeys` acquires one-time public keys for given user(s). This may involve outbound federation calls.
- `QueryKeys` returns identity keys for given user(s). This may involve outbound federation calls. This component may then cache federated identity keys to avoid repeatedly hitting remote servers.
- A topic which emits identity keys every time there is a change (addition or deletion).
### Endpoint mappings
- Client API maps `/keys/upload` to `PerformUploadKeys`.
- Client API maps `/keys/query` to `QueryKeys`.
- Client API maps `/keys/claim` to `PerformClaimKeys`.
- Federation API maps `/user/keys/query` to `QueryKeys`.
- Federation API maps `/user/keys/claim` to `PerformClaimKeys`.
- Sync API maps `/keys/changes` to consuming from the Kafka topic.

View File

@ -1,346 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api
import (
"bytes"
"context"
"encoding/json"
"strings"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/keyserver/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
type KeyInternalAPI interface {
SyncKeyAPI
ClientKeyAPI
FederationKeyAPI
UserKeyAPI
// SetUserAPI assigns a user API to query when extracting device names.
SetUserAPI(i userapi.KeyserverUserAPI)
}
// API functions required by the clientapi
type ClientKeyAPI interface {
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error
// PerformClaimKeys claims one-time keys for use in pre-key messages
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
}
// API functions required by the userapi
type UserKeyAPI interface {
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse) error
}
// API functions required by the syncapi
type SyncKeyAPI interface {
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
}
type FederationKeyAPI interface {
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
}
// KeyError is returned if there was a problem performing/querying the server
type KeyError struct {
Err string `json:"error"`
IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE
IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM
IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM
}
func (k *KeyError) Error() string {
return k.Err
}
type DeviceMessageType int
const (
TypeDeviceKeyUpdate DeviceMessageType = iota
TypeCrossSigningUpdate
)
// DeviceMessage represents the message produced into Kafka by the key server.
type DeviceMessage struct {
Type DeviceMessageType `json:"Type,omitempty"`
*DeviceKeys `json:"DeviceKeys,omitempty"`
*OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
// A monotonically increasing number which represents device changes for this user.
StreamID int64
DeviceChangeID int64
}
// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log
type OutputCrossSigningKeyUpdate struct {
CrossSigningKeyUpdate `json:"signing_keys"`
}
type CrossSigningKeyUpdate struct {
MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"`
SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"`
UserID string `json:"user_id"`
}
// DeviceKeysEqual returns true if the device keys updates contain the
// same display name and key JSON. This will return false if either of
// the updates is not a device keys update, or if the user ID/device ID
// differ between the two.
func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool {
if m1.DeviceKeys == nil || m2.DeviceKeys == nil {
return false
}
if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID {
return false
}
if m1.DisplayName != m2.DisplayName {
return false // different display names
}
if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 {
return false // either is empty
}
return bytes.Equal(m1.KeyJSON, m2.KeyJSON)
}
// DeviceKeys represents a set of device keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type DeviceKeys struct {
// The user who owns this device
UserID string
// The device ID of this device
DeviceID string
// The device display name
DisplayName string
// The raw device key JSON
KeyJSON []byte
}
// WithStreamID returns a copy of this device message with the given stream ID
func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage {
return DeviceMessage{
DeviceKeys: k,
StreamID: streamID,
}
}
// OneTimeKeys represents a set of one-time keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type OneTimeKeys struct {
// The user who owns this device
UserID string
// The device ID of this device
DeviceID string
// A map of algorithm:key_id => key JSON
KeyJSON map[string]json.RawMessage
}
// Split a key in KeyJSON into algorithm and key ID
func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
segments := strings.Split(keyIDWithAlgo, ":")
return segments[0], segments[1]
}
// OneTimeKeysCount represents the counts of one-time keys for a single device
type OneTimeKeysCount struct {
// The user who owns this device
UserID string
// The device ID of this device
DeviceID string
// algorithm to count e.g:
// {
// "curve25519": 10,
// "signed_curve25519": 20
// }
KeyCount map[string]int
}
// PerformUploadKeysRequest is the request to PerformUploadKeys
type PerformUploadKeysRequest struct {
UserID string // Required - User performing the request
DeviceID string // Optional - Device performing the request, for fetching OTK count
DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
// the display name for their respective device, and NOT to modify the keys. The key
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
// Without this flag, requests to modify device display names would delete device keys.
OnlyDisplayNameUpdates bool
}
// PerformUploadKeysResponse is the response to PerformUploadKeys
type PerformUploadKeysResponse struct {
// A fatal error when processing e.g database failures
Error *KeyError
// A map of user_id -> device_id -> Error for tracking failures.
KeyErrors map[string]map[string]*KeyError
OneTimeKeyCounts []OneTimeKeysCount
}
// PerformDeleteKeysRequest asks the keyserver to forget about certain
// keys, and signatures related to those keys.
type PerformDeleteKeysRequest struct {
UserID string
KeyIDs []gomatrixserverlib.KeyID
}
// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest.
type PerformDeleteKeysResponse struct {
Error *KeyError
}
// KeyError sets a key error field on KeyErrors
func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) {
if r.KeyErrors[userID] == nil {
r.KeyErrors[userID] = make(map[string]*KeyError)
}
r.KeyErrors[userID][deviceID] = err
}
type PerformClaimKeysRequest struct {
// Map of user_id to device_id to algorithm name
OneTimeKeys map[string]map[string]string
Timeout time.Duration
}
type PerformClaimKeysResponse struct {
// Map of user_id to device_id to algorithm:key_id to key JSON
OneTimeKeys map[string]map[string]map[string]json.RawMessage
// Map of remote server domain to error JSON
Failures map[string]interface{}
// Set if there was a fatal error processing this action
Error *KeyError
}
type PerformUploadDeviceKeysRequest struct {
gomatrixserverlib.CrossSigningKeys
// The user that uploaded the key, should be populated by the clientapi.
UserID string
}
type PerformUploadDeviceKeysResponse struct {
Error *KeyError
}
type PerformUploadDeviceSignaturesRequest struct {
Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice
// The user that uploaded the sig, should be populated by the clientapi.
UserID string
}
type PerformUploadDeviceSignaturesResponse struct {
Error *KeyError
}
type QueryKeysRequest struct {
// The user ID asking for the keys, e.g. if from a client API request.
// Will not be populated if the key request came from federation.
UserID string
// Maps user IDs to a list of devices
UserToDevices map[string][]string
Timeout time.Duration
}
type QueryKeysResponse struct {
// Map of remote server domain to error JSON
Failures map[string]interface{}
// Map of user_id to device_id to device_key
DeviceKeys map[string]map[string]json.RawMessage
// Maps of user_id to cross signing key
MasterKeys map[string]gomatrixserverlib.CrossSigningKey
SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
// Set if there was a fatal error processing this query
Error *KeyError
}
type QueryKeyChangesRequest struct {
// The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
Offset int64
// The inclusive offset where to track key changes up to. Messages with this offset are included in the response.
// Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing).
ToOffset int64
}
type QueryKeyChangesResponse struct {
// The set of users who have had their keys change.
UserIDs []string
// The latest offset represented in this response.
Offset int64
// Set if there was a problem handling the request.
Error *KeyError
}
type QueryOneTimeKeysRequest struct {
// The local user to query OTK counts for
UserID string
// The device to query OTK counts for
DeviceID string
}
type QueryOneTimeKeysResponse struct {
// OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
Count OneTimeKeysCount
Error *KeyError
}
type QueryDeviceMessagesRequest struct {
UserID string
}
type QueryDeviceMessagesResponse struct {
// The latest stream ID
StreamID int64
Devices []DeviceMessage
Error *KeyError
}
type QuerySignaturesRequest struct {
// A map of target user ID -> target key/device IDs to retrieve signatures for
TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"`
}
type QuerySignaturesResponse struct {
// A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures
Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap
// A map of target user ID -> cross-signing master key
MasterKeys map[string]gomatrixserverlib.CrossSigningKey
// A map of target user ID -> cross-signing self-signing key
SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
// A map of target user ID -> cross-signing user-signing key
UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
// The request error, if any
Error *KeyError
}
type PerformMarkAsStaleRequest struct {
UserID string
Domain gomatrixserverlib.ServerName
DeviceID string
}

View File

@ -1,86 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package keyserver
import (
"github.com/sirupsen/logrus"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/consumers"
"github.com/matrix-org/dendrite/keyserver/internal"
"github.com/matrix-org/dendrite/keyserver/producers"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
)
// NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
rsAPI rsapi.KeyserverRoomserverAPI,
) api.KeyInternalAPI {
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
db, err := storage.NewDatabase(base, &cfg.Database)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to key server database")
}
keyChangeProducer := &producers.KeyChange{
Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)),
JetStream: js,
DB: db,
}
ap := &internal.KeyInternalAPI{
DB: db,
Cfg: cfg,
FedClient: fedClient,
Producer: keyChangeProducer,
}
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable
ap.Updater = updater
// Remove users which we don't share a room with anymore
if err := updater.CleanUp(); err != nil {
logrus.WithError(err).Error("failed to cleanup stale device lists")
}
go func() {
if err := updater.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start device list updater")
}
}()
dlConsumer := consumers.NewDeviceListUpdateConsumer(
base.ProcessContext, cfg, js, updater,
)
if err := dlConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start device list consumer")
}
sigConsumer := consumers.NewSigningKeyUpdateConsumer(
base.ProcessContext, cfg, js, ap,
)
if err := sigConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start signing key consumer")
}
return ap
}

View File

@ -1,29 +0,0 @@
package keyserver
import (
"context"
"testing"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
)
type mockKeyserverRoomserverAPI struct {
leftUsers []string
}
func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
res.LeftUsers = m.leftUsers
return nil
}
// Merely tests that we can create an internal keyserver API
func Test_NewInternalAPI(t *testing.T) {
rsAPI := &mockKeyserverRoomserverAPI{}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, closeBase := testrig.CreateBaseDendrite(t, dbType)
defer closeBase()
_ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
})
}

View File

@ -1,93 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
// ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
// of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
// StoreOneTimeKeys persists the given one-time keys.
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
// StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device).
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
// Returns an error if there was a problem storing the keys.
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.
DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
// StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
// `userID` is the the user who has changed their keys in some way.
StoreKeyChange(ctx context.Context, userID string) (int64, error)
// KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
// A to offset of types.OffsetNewest means no upper limit.
// Returns the offset of the latest key change.
KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
// MarkDeviceListStale sets the stale bit for this user to isStale.
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
DeleteStaleDeviceLists(
ctx context.Context,
userIDs []string,
) error
}

View File

@ -1,69 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/shared"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
)
// NewDatabase creates a new sync server database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) {
var err error
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter())
if err != nil {
return nil, err
}
otk, err := NewPostgresOneTimeKeysTable(db)
if err != nil {
return nil, err
}
dk, err := NewPostgresDeviceKeysTable(db)
if err != nil {
return nil, err
}
kc, err := NewPostgresKeyChangesTable(db)
if err != nil {
return nil, err
}
sdl, err := NewPostgresStaleDeviceListsTable(db)
if err != nil {
return nil, err
}
csk, err := NewPostgresCrossSigningKeysTable(db)
if err != nil {
return nil, err
}
css, err := NewPostgresCrossSigningSigsTable(db)
if err != nil {
return nil, err
}
if err = kc.Prepare(); err != nil {
return nil, err
}
d := &shared.Database{
DB: db,
Writer: writer,
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
CrossSigningKeysTable: csk,
CrossSigningSigsTable: css,
}
return d, nil
}

View File

@ -1,261 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type Database struct {
DB *sql.DB
Writer sqlutil.Writer
OneTimeKeysTable tables.OneTimeKeys
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists
CrossSigningKeysTable tables.CrossSigningKeys
CrossSigningSigsTable tables.CrossSigningSigs
}
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
}
func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
return err
})
return
}
func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
}
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}
func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
if err != nil {
return false, err
}
return count == len(prevIDs), nil
}
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for _, userID := range clearUserIDs {
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
if err != nil {
return err
}
}
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
})
}
func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
// work out the latest stream IDs for each user
userIDToStreamID := make(map[string]int64)
for _, k := range keys {
userIDToStreamID[k.UserID] = 0
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for userID := range userIDToStreamID {
streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
if err != nil {
return err
}
userIDToStreamID[userID] = streamID
}
// set the stream IDs for each key
for i := range keys {
k := keys[i]
userIDToStreamID[k.UserID]++ // start stream from 1
k.StreamID = userIDToStreamID[k.UserID]
keys[i] = k
}
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
})
}
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
}
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
var result []api.OneTimeKeys
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for userID, deviceToAlgo := range userToDeviceToAlgorithm {
for deviceID, algo := range deviceToAlgo {
keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
if err != nil {
return err
}
if keyJSON != nil {
result = append(result, api.OneTimeKeys{
UserID: userID,
DeviceID: deviceID,
KeyJSON: keyJSON,
})
}
}
}
return nil
})
return result, err
}
func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
return err
})
return
}
func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
}
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
}
// MarkDeviceListStale sets the stale bit for this user to isStale.
func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
})
}
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.
func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for _, deviceID := range deviceIDs {
if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
}
if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
}
if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err)
}
}
return nil
})
}
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
if err != nil {
return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err)
}
results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
for purpose, key := range keyMap {
keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode())
result := gomatrixserverlib.CrossSigningKey{
UserID: userID,
Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose},
Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
keyID: key,
},
}
sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID)
if err != nil {
continue
}
for sigUserID, forSigUserID := range sigMap {
if userID != sigUserID {
continue
}
if result.Signatures == nil {
result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
if _, ok := result.Signatures[sigUserID]; !ok {
result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
for sigKeyID, sigBytes := range forSigUserID {
result.Signatures[sigUserID][sigKeyID] = sigBytes
}
}
results[purpose] = result
}
return results, nil
}
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
}
// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
func (d *Database) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
}
// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for keyType, keyData := range keyMap {
if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
}
}
return nil
})
}
// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
func (d *Database) StoreCrossSigningSigsForTarget(
ctx context.Context,
originUserID string, originKeyID gomatrixserverlib.KeyID,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
signature gomatrixserverlib.Base64Bytes,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err)
}
return nil
})
}
// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
func (d *Database) DeleteStaleDeviceLists(
ctx context.Context,
userIDs []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
})
}

View File

@ -1,68 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/shared"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
)
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
if err != nil {
return nil, err
}
otk, err := NewSqliteOneTimeKeysTable(db)
if err != nil {
return nil, err
}
dk, err := NewSqliteDeviceKeysTable(db)
if err != nil {
return nil, err
}
kc, err := NewSqliteKeyChangesTable(db)
if err != nil {
return nil, err
}
sdl, err := NewSqliteStaleDeviceListsTable(db)
if err != nil {
return nil, err
}
csk, err := NewSqliteCrossSigningKeysTable(db)
if err != nil {
return nil, err
}
css, err := NewSqliteCrossSigningSigsTable(db)
if err != nil {
return nil, err
}
if err = kc.Prepare(); err != nil {
return nil, err
}
d := &shared.Database{
DB: db,
Writer: writer,
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
CrossSigningKeysTable: csk,
CrossSigningSigsTable: css,
}
return d, nil
}

View File

@ -1,40 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package storage
import (
"fmt"
"github.com/matrix-org/dendrite/keyserver/storage/postgres"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
)
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View File

@ -1,197 +0,0 @@
package storage_test
import (
"context"
"reflect"
"sync"
"testing"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
)
var ctx = context.Background()
func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
base, close := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database)
if err != nil {
t.Fatalf("failed to create new database: %v", err)
}
return db, close
}
func MustNotError(t *testing.T, err error) {
t.Helper()
if err == nil {
return
}
t.Fatalf("operation failed: %s", err)
}
func TestKeyChanges(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := MustCreateDatabase(t, dbType)
defer clean()
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
MustNotError(t, err)
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeIDC {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
}
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
}
func TestKeyChangesNoDupes(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := MustCreateDatabase(t, dbType)
defer clean()
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
if deviceChangeIDA == deviceChangeIDB {
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
}
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeID {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
}
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
}
func TestKeyChangesUpperLimit(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := MustCreateDatabase(t, dbType)
defer clean()
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
MustNotError(t, err)
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeIDB {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
}
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
}
var dbLock sync.Mutex
var deviceArray = []string{"AAA", "another_device"}
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
// and that they are returned correctly when querying for device keys.
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
var err error
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := MustCreateDatabase(t, dbType)
defer clean()
alice := "@alice:TestDeviceKeysStreamIDGeneration"
bob := "@bob:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1
},
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: bob,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1 as this is a different user
},
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "another_device",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 2 as this is a 2nd device key
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
}
if msgs[1].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
}
if msgs[2].StreamID != 2 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
}
// updating a device sets the next stream ID for that user
msgs = []api.DeviceMessage{
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v2"}`),
},
// StreamID: 3
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 3 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
}
dbLock.Lock()
defer dbLock.Unlock()
// Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false)
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
wantStreamIDs := map[string]int64{
"AAA": 3,
"another_device": 2,
}
if len(msgs) != len(wantStreamIDs) {
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
}
for _, m := range msgs {
if m.StreamID != wantStreamIDs[m.DeviceID] {
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
}
}
})
}

View File

@ -1,34 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"fmt"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
)
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View File

@ -1,71 +0,0 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type OneTimeKeys interface {
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
// Returns an empty map if the key does not exist.
SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
}
type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}
type KeyChanges interface {
InsertKeyChange(ctx context.Context, userID string) (int64, error)
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset.
SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
Prepare() error
}
type StaleDeviceLists interface {
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
}
type CrossSigningKeys interface {
SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error)
UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error
}
type CrossSigningSigs interface {
SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
}

View File

@ -17,7 +17,6 @@ type RoomserverInternalAPI interface {
ClientRoomserverAPI
UserRoomserverAPI
FederationRoomserverAPI
KeyserverRoomserverAPI
// needed to avoid chicken and egg scenario when setting up the
// interdependencies between the roomserver and other input APIs
@ -167,6 +166,7 @@ type ClientRoomserverAPI interface {
type UserRoomserverAPI interface {
QueryLatestEventsAndStateAPI
KeyserverRoomserverAPI
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error

View File

@ -12,7 +12,6 @@ import (
userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi"
"github.com/matrix-org/gomatrixserverlib"
@ -47,7 +46,7 @@ func TestUsers(t *testing.T) {
})
t.Run("kick users", func(t *testing.T) {
usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil)
usrAPI := userapi.NewInternalAPI(base, rsAPI, nil)
rsAPI.SetUserAPI(usrAPI)
testKickUsers(t, rsAPI, usrAPI)
})
@ -228,11 +227,10 @@ func TestPurgeRoom(t *testing.T) {
fedClient := base.CreateFederationClient()
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
userAPI := userapi.NewInternalAPI(base, rsAPI, nil)
// this starts the JetStream consumers
syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI)
syncapi.AddPublicRoutes(base, userAPI, rsAPI)
federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true)
rsAPI.SetFederationAPI(nil, nil)

View File

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/federationapi"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/transactions"
keyAPI "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/mediaapi"
"github.com/matrix-org/dendrite/relayapi"
relayAPI "github.com/matrix-org/dendrite/relayapi/api"
@ -45,7 +44,6 @@ type Monolith struct {
FederationAPI federationAPI.FederationInternalAPI
RoomserverAPI roomserverAPI.RoomserverInternalAPI
UserAPI userapi.UserInternalAPI
KeyAPI keyAPI.KeyInternalAPI
RelayAPI relayAPI.RelayInternalAPI
// Optional
@ -61,19 +59,14 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) {
}
clientapi.AddPublicRoutes(
base, m.FedClient, m.RoomserverAPI, m.AppserviceAPI, transactions.New(),
m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI,
m.FederationAPI, m.UserAPI, userDirectoryProvider,
m.ExtPublicRoomsProvider,
)
federationapi.AddPublicRoutes(
base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI,
m.KeyAPI, nil,
)
mediaapi.AddPublicRoutes(
base, m.UserAPI, m.Client,
)
syncapi.AddPublicRoutes(
base, m.UserAPI, m.RoomserverAPI, m.KeyAPI,
base, m.UserAPI, m.FedClient, m.KeyRing, m.RoomserverAPI, m.FederationAPI, nil,
)
mediaapi.AddPublicRoutes(base, m.UserAPI, m.Client)
syncapi.AddPublicRoutes(base, m.UserAPI, m.RoomserverAPI)
if m.RelayAPI != nil {
relayapi.AddPublicRoutes(base, m.KeyRing, m.RelayAPI)

View File

@ -19,7 +19,6 @@ import (
"encoding/json"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
@ -28,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
)

View File

@ -25,7 +25,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
@ -33,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/userapi/api"
)
// OutputSendToDeviceEventConsumer consumes events that originated in the EDU server.
@ -42,7 +42,7 @@ type OutputSendToDeviceEventConsumer struct {
durable string
topic string
db storage.Database
keyAPI keyapi.SyncKeyAPI
userAPI api.SyncKeyAPI
isLocalServerName func(gomatrixserverlib.ServerName) bool
stream streams.StreamProvider
notifier *notifier.Notifier
@ -55,7 +55,7 @@ func NewOutputSendToDeviceEventConsumer(
cfg *config.SyncAPI,
js nats.JetStreamContext,
store storage.Database,
keyAPI keyapi.SyncKeyAPI,
userAPI api.SyncKeyAPI,
notifier *notifier.Notifier,
stream streams.StreamProvider,
) *OutputSendToDeviceEventConsumer {
@ -65,7 +65,7 @@ func NewOutputSendToDeviceEventConsumer(
topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"),
db: store,
keyAPI: keyAPI,
userAPI: userAPI,
isLocalServerName: cfg.Matrix.IsLocalServerName,
notifier: notifier,
stream: stream,
@ -116,7 +116,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs []
_, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender)
if requestingDeviceID != "" && !s.isLocalServerName(senderDomain) {
// Mark the requesting device as stale, if we don't know about it.
if err = s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{
if err = s.userAPI.PerformMarkAsStaleIfNeeded(ctx, &api.PerformMarkAsStaleRequest{
UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID,
}, &struct{}{}); err != nil {
logger.WithError(err).Errorf("failed to mark as stale if needed")

View File

@ -18,22 +18,22 @@ import (
"context"
"strings"
keytypes "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
keytypes "github.com/matrix-org/dendrite/keyserver/types"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/userapi/api"
)
// DeviceOTKCounts adds one-time key counts to the /sync response
func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error {
var queryRes keyapi.QueryOneTimeKeysResponse
_ = keyAPI.QueryOneTimeKeys(ctx, &keyapi.QueryOneTimeKeysRequest{
func DeviceOTKCounts(ctx context.Context, keyAPI api.SyncKeyAPI, userID, deviceID string, res *types.Response) error {
var queryRes api.QueryOneTimeKeysResponse
_ = keyAPI.QueryOneTimeKeys(ctx, &api.QueryOneTimeKeysRequest{
UserID: userID,
DeviceID: deviceID,
}, &queryRes)
@ -48,7 +48,7 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, devi
// was filled in, else false if there are no new device list changes because there is nothing to catch up on. The response MUST
// be already filled in with join/leave information.
func DeviceListCatchup(
ctx context.Context, db storage.SharedUsers, keyAPI keyapi.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
ctx context.Context, db storage.SharedUsers, userAPI api.SyncKeyAPI, rsAPI roomserverAPI.SyncRoomserverAPI,
userID string, res *types.Response, from, to types.StreamPosition,
) (newPos types.StreamPosition, hasNew bool, err error) {
@ -74,8 +74,8 @@ func DeviceListCatchup(
if from > 0 {
offset = int64(from)
}
var queryRes keyapi.QueryKeyChangesResponse
_ = keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{
var queryRes api.QueryKeyChangesResponse
_ = userAPI.QueryKeyChanges(ctx, &api.QueryKeyChangesRequest{
Offset: offset,
ToOffset: toOffset,
}, &queryRes)

View File

@ -9,7 +9,6 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
@ -22,44 +21,44 @@ var (
type mockKeyAPI struct{}
func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *keyapi.PerformMarkAsStaleRequest, res *struct{}) error {
func (k *mockKeyAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *userapi.PerformMarkAsStaleRequest, res *struct{}) error {
return nil
}
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *keyapi.PerformUploadKeysRequest, res *keyapi.PerformUploadKeysResponse) error {
func (k *mockKeyAPI) PerformUploadKeys(ctx context.Context, req *userapi.PerformUploadKeysRequest, res *userapi.PerformUploadKeysResponse) error {
return nil
}
func (k *mockKeyAPI) SetUserAPI(i userapi.UserInternalAPI) {}
// PerformClaimKeys claims one-time keys for use in pre-key messages
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *keyapi.PerformClaimKeysRequest, res *keyapi.PerformClaimKeysResponse) error {
func (k *mockKeyAPI) PerformClaimKeys(ctx context.Context, req *userapi.PerformClaimKeysRequest, res *userapi.PerformClaimKeysResponse) error {
return nil
}
func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *keyapi.PerformDeleteKeysRequest, res *keyapi.PerformDeleteKeysResponse) error {
func (k *mockKeyAPI) PerformDeleteKeys(ctx context.Context, req *userapi.PerformDeleteKeysRequest, res *userapi.PerformDeleteKeysResponse) error {
return nil
}
func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *keyapi.PerformUploadDeviceKeysRequest, res *keyapi.PerformUploadDeviceKeysResponse) error {
func (k *mockKeyAPI) PerformUploadDeviceKeys(ctx context.Context, req *userapi.PerformUploadDeviceKeysRequest, res *userapi.PerformUploadDeviceKeysResponse) error {
return nil
}
func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *keyapi.PerformUploadDeviceSignaturesRequest, res *keyapi.PerformUploadDeviceSignaturesResponse) error {
func (k *mockKeyAPI) PerformUploadDeviceSignatures(ctx context.Context, req *userapi.PerformUploadDeviceSignaturesRequest, res *userapi.PerformUploadDeviceSignaturesResponse) error {
return nil
}
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *keyapi.QueryKeysRequest, res *keyapi.QueryKeysResponse) error {
func (k *mockKeyAPI) QueryKeys(ctx context.Context, req *userapi.QueryKeysRequest, res *userapi.QueryKeysResponse) error {
return nil
}
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error {
func (k *mockKeyAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error {
return nil
}
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error {
func (k *mockKeyAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
return nil
}
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *keyapi.QueryDeviceMessagesRequest, res *keyapi.QueryDeviceMessagesResponse) error {
func (k *mockKeyAPI) QueryDeviceMessages(ctx context.Context, req *userapi.QueryDeviceMessagesRequest, res *userapi.QueryDeviceMessagesResponse) error {
return nil
}
func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *keyapi.QuerySignaturesRequest, res *keyapi.QuerySignaturesResponse) error {
func (k *mockKeyAPI) QuerySignatures(ctx context.Context, req *userapi.QuerySignaturesRequest, res *userapi.QuerySignaturesResponse) error {
return nil
}

View File

@ -3,17 +3,17 @@ package streams
import (
"context"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
type DeviceListStreamProvider struct {
DefaultStreamProvider
rsAPI api.SyncRoomserverAPI
keyAPI keyapi.SyncKeyAPI
userAPI userapi.SyncKeyAPI
}
func (p *DeviceListStreamProvider) CompleteSync(
@ -31,12 +31,12 @@ func (p *DeviceListStreamProvider) IncrementalSync(
from, to types.StreamPosition,
) types.StreamPosition {
var err error
to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
to, _, err = internal.DeviceListCatchup(context.Background(), snapshot, p.userAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
return from
}
err = internal.DeviceOTKCounts(req.Context, p.keyAPI, req.Device.UserID, req.Device.ID, req.Response)
err = internal.DeviceOTKCounts(req.Context, p.userAPI, req.Device.UserID, req.Device.ID, req.Response)
if err != nil {
req.Log.WithError(err).Error("internal.DeviceListCatchup failed")
return from

View File

@ -5,7 +5,6 @@ import (
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
@ -27,7 +26,7 @@ type Streams struct {
func NewSyncStreamProviders(
d storage.Database, userAPI userapi.SyncUserAPI,
rsAPI rsapi.SyncRoomserverAPI, keyAPI keyapi.SyncKeyAPI,
rsAPI rsapi.SyncRoomserverAPI,
eduCache *caching.EDUCache, lazyLoadCache caching.LazyLoadCache, notifier *notifier.Notifier,
) *Streams {
streams := &Streams{
@ -60,7 +59,7 @@ func NewSyncStreamProviders(
DeviceListStreamProvider: &DeviceListStreamProvider{
DefaultStreamProvider: DefaultStreamProvider{DB: d},
rsAPI: rsAPI,
keyAPI: keyAPI,
userAPI: userAPI,
},
PresenceStreamProvider: &PresenceStreamProvider{
DefaultStreamProvider: DefaultStreamProvider{DB: d},

View File

@ -32,7 +32,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/internal"
@ -48,7 +47,6 @@ type RequestPool struct {
db storage.Database
cfg *config.SyncAPI
userAPI userapi.SyncUserAPI
keyAPI keyapi.SyncKeyAPI
rsAPI roomserverAPI.SyncRoomserverAPI
lastseen *sync.Map
presence *sync.Map
@ -69,7 +67,7 @@ type PresenceConsumer interface {
// NewRequestPool makes a new RequestPool
func NewRequestPool(
db storage.Database, cfg *config.SyncAPI,
userAPI userapi.SyncUserAPI, keyAPI keyapi.SyncKeyAPI,
userAPI userapi.SyncUserAPI,
rsAPI roomserverAPI.SyncRoomserverAPI,
streams *streams.Streams, notifier *notifier.Notifier,
producer PresencePublisher, consumer PresenceConsumer, enableMetrics bool,
@ -83,7 +81,6 @@ func NewRequestPool(
db: db,
cfg: cfg,
userAPI: userAPI,
keyAPI: keyAPI,
rsAPI: rsAPI,
lastseen: &sync.Map{},
presence: &sync.Map{},
@ -280,7 +277,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
// https://github.com/matrix-org/synapse/blob/29f06704b8871a44926f7c99e73cf4a978fb8e81/synapse/rest/client/sync.py#L276-L281
// Only try to get OTKs if the context isn't already done.
if syncReq.Context.Err() == nil {
err = internal.DeviceOTKCounts(syncReq.Context, rp.keyAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response)
err = internal.DeviceOTKCounts(syncReq.Context, rp.userAPI, syncReq.Device.UserID, syncReq.Device.ID, syncReq.Response)
if err != nil && err != context.Canceled {
syncReq.Log.WithError(err).Warn("failed to get OTK counts")
}
@ -551,7 +548,7 @@ func (rp *RequestPool) OnIncomingKeyChangeRequest(req *http.Request, device *use
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
rp.streams.PDUStreamProvider.IncrementalSync(req.Context(), snapshot, syncReq, fromToken.PDUPosition, toToken.PDUPosition)
_, _, err = internal.DeviceListCatchup(
req.Context(), snapshot, rp.keyAPI, rp.rsAPI, syncReq.Device.UserID,
req.Context(), snapshot, rp.userAPI, rp.rsAPI, syncReq.Device.UserID,
syncReq.Response, fromToken.DeviceListPosition, toToken.DeviceListPosition,
)
if err != nil {

View File

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/internal/caching"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/jetstream"
@ -42,7 +41,6 @@ func AddPublicRoutes(
base *base.BaseDendrite,
userAPI userapi.SyncUserAPI,
rsAPI api.SyncRoomserverAPI,
keyAPI keyapi.SyncKeyAPI,
) {
cfg := &base.Cfg.SyncAPI
@ -55,7 +53,7 @@ func AddPublicRoutes(
eduCache := caching.NewTypingCache()
notifier := notifier.NewNotifier()
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, keyAPI, eduCache, base.Caches, notifier)
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, base.Caches, notifier)
notifier.SetCurrentPosition(streams.Latest(context.Background()))
if err = notifier.Load(context.Background(), syncDB); err != nil {
logrus.WithError(err).Panicf("failed to load notifier ")
@ -71,7 +69,7 @@ func AddPublicRoutes(
userAPI,
)
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics)
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, rsAPI, streams, notifier, federationPresenceProducer, presenceConsumer, base.EnableMetrics)
if err = presenceConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start presence consumer")
@ -117,7 +115,7 @@ func AddPublicRoutes(
}
sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer(
base.ProcessContext, cfg, js, syncDB, keyAPI, notifier, streams.SendToDeviceStreamProvider,
base.ProcessContext, cfg, js, syncDB, userAPI, notifier, streams.SendToDeviceStreamProvider,
)
if err = sendToDeviceConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start send-to-device consumer")

View File

@ -16,7 +16,6 @@ import (
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/clientapi/producers"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/roomserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
@ -83,21 +82,18 @@ func (s *syncUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAc
return nil
}
func (s *syncUserAPI) QueryKeyChanges(ctx context.Context, req *userapi.QueryKeyChangesRequest, res *userapi.QueryKeyChangesResponse) error {
return nil
}
func (s *syncUserAPI) QueryOneTimeKeys(ctx context.Context, req *userapi.QueryOneTimeKeysRequest, res *userapi.QueryOneTimeKeysResponse) error {
return nil
}
func (s *syncUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.PerformLastSeenUpdateRequest, res *userapi.PerformLastSeenUpdateResponse) error {
return nil
}
type syncKeyAPI struct {
keyapi.SyncKeyAPI
}
func (s *syncKeyAPI) QueryKeyChanges(ctx context.Context, req *keyapi.QueryKeyChangesRequest, res *keyapi.QueryKeyChangesResponse) error {
return nil
}
func (s *syncKeyAPI) QueryOneTimeKeys(ctx context.Context, req *keyapi.QueryOneTimeKeysRequest, res *keyapi.QueryOneTimeKeysResponse) error {
return nil
}
func TestSyncAPIAccessTokens(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
testSyncAccessTokens(t, dbType)
@ -121,7 +117,7 @@ func testSyncAccessTokens(t *testing.T, dbType test.DBType) {
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
msgs := toNATSMsgs(t, base, room.Events()...)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}})
testrig.MustPublishMsgs(t, jsctx, msgs...)
testCases := []struct {
@ -220,7 +216,7 @@ func testSyncAPICreateRoomSyncEarly(t *testing.T, dbType test.DBType) {
// m.room.history_visibility
msgs := toNATSMsgs(t, base, room.Events()...)
sinceTokens := make([]string, len(msgs))
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}}, &syncKeyAPI{})
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{rooms: []*test.Room{room}})
for i, msg := range msgs {
testrig.MustPublishMsgs(t, jsctx, msg)
time.Sleep(100 * time.Millisecond)
@ -304,7 +300,7 @@ func testSyncAPIUpdatePresenceImmediately(t *testing.T, dbType test.DBType) {
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{})
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{})
w := httptest.NewRecorder()
base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", "/_matrix/client/v3/sync", test.WithQueryParams(map[string]string{
"access_token": alice.AccessToken,
@ -422,7 +418,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) {
rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{})
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI)
for _, tc := range testCases {
testname := fmt.Sprintf("%s - %s", tc.historyVisibility, userType)
@ -722,7 +718,7 @@ func TestGetMembership(t *testing.T) {
rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{})
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
@ -789,7 +785,7 @@ func testSendToDevice(t *testing.T, dbType test.DBType) {
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{}, &syncKeyAPI{})
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, &syncRoomserverAPI{})
producer := producers.SyncAPIProducer{
TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent),
@ -1008,7 +1004,7 @@ func testContext(t *testing.T, dbType test.DBType) {
rsAPI := roomserver.NewInternalAPI(base)
rsAPI.SetFederationAPI(nil, nil)
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, &syncKeyAPI{})
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI)
room := test.NewRoom(t, user)

View File

@ -15,9 +15,13 @@
package api
import (
"bytes"
"context"
"encoding/json"
"strings"
"time"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -26,15 +30,12 @@ import (
// UserInternalAPI is the internal API for information about users and devices.
type UserInternalAPI interface {
AppserviceUserAPI
SyncUserAPI
ClientUserAPI
MediaUserAPI
FederationUserAPI
RoomserverUserAPI
KeyserverUserAPI
QuerySearchProfilesAPI // used by p2p demos
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
}
// api functions required by the appservice api
@ -43,11 +44,6 @@ type AppserviceUserAPI interface {
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
}
type KeyserverUserAPI interface {
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
}
type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
@ -60,13 +56,20 @@ type MediaUserAPI interface {
// api functions required by the federation api
type FederationUserAPI interface {
UploadDeviceKeysAPI
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
}
// api functions required by the sync api
type SyncUserAPI interface {
QueryAcccessTokenAPI
SyncKeyAPI
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
@ -79,6 +82,7 @@ type ClientUserAPI interface {
QueryAcccessTokenAPI
LoginTokenInternalAPI
UserLoginAPI
ClientKeyAPI
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
@ -681,3 +685,310 @@ type QueryAccountByLocalpartRequest struct {
type QueryAccountByLocalpartResponse struct {
Account *Account
}
// API functions required by the clientapi
type ClientKeyAPI interface {
UploadDeviceKeysAPI
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error
// PerformClaimKeys claims one-time keys for use in pre-key messages
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
}
type UploadDeviceKeysAPI interface {
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
}
// API functions required by the syncapi
type SyncKeyAPI interface {
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
}
type FederationKeyAPI interface {
UploadDeviceKeysAPI
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
}
// KeyError is returned if there was a problem performing/querying the server
type KeyError struct {
Err string `json:"error"`
IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE
IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM
IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM
}
func (k *KeyError) Error() string {
return k.Err
}
type DeviceMessageType int
const (
TypeDeviceKeyUpdate DeviceMessageType = iota
TypeCrossSigningUpdate
)
// DeviceMessage represents the message produced into Kafka by the key server.
type DeviceMessage struct {
Type DeviceMessageType `json:"Type,omitempty"`
*DeviceKeys `json:"DeviceKeys,omitempty"`
*OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
// A monotonically increasing number which represents device changes for this user.
StreamID int64
DeviceChangeID int64
}
// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log
type OutputCrossSigningKeyUpdate struct {
CrossSigningKeyUpdate `json:"signing_keys"`
}
type CrossSigningKeyUpdate struct {
MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"`
SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"`
UserID string `json:"user_id"`
}
// DeviceKeysEqual returns true if the device keys updates contain the
// same display name and key JSON. This will return false if either of
// the updates is not a device keys update, or if the user ID/device ID
// differ between the two.
func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool {
if m1.DeviceKeys == nil || m2.DeviceKeys == nil {
return false
}
if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID {
return false
}
if m1.DisplayName != m2.DisplayName {
return false // different display names
}
if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 {
return false // either is empty
}
return bytes.Equal(m1.KeyJSON, m2.KeyJSON)
}
// DeviceKeys represents a set of device keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type DeviceKeys struct {
// The user who owns this device
UserID string
// The device ID of this device
DeviceID string
// The device display name
DisplayName string
// The raw device key JSON
KeyJSON []byte
}
// WithStreamID returns a copy of this device message with the given stream ID
func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage {
return DeviceMessage{
DeviceKeys: k,
StreamID: streamID,
}
}
// OneTimeKeys represents a set of one-time keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type OneTimeKeys struct {
// The user who owns this device
UserID string
// The device ID of this device
DeviceID string
// A map of algorithm:key_id => key JSON
KeyJSON map[string]json.RawMessage
}
// Split a key in KeyJSON into algorithm and key ID
func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
segments := strings.Split(keyIDWithAlgo, ":")
return segments[0], segments[1]
}
// OneTimeKeysCount represents the counts of one-time keys for a single device
type OneTimeKeysCount struct {
// The user who owns this device
UserID string
// The device ID of this device
DeviceID string
// algorithm to count e.g:
// {
// "curve25519": 10,
// "signed_curve25519": 20
// }
KeyCount map[string]int
}
// PerformUploadKeysRequest is the request to PerformUploadKeys
type PerformUploadKeysRequest struct {
UserID string // Required - User performing the request
DeviceID string // Optional - Device performing the request, for fetching OTK count
DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
// the display name for their respective device, and NOT to modify the keys. The key
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
// Without this flag, requests to modify device display names would delete device keys.
OnlyDisplayNameUpdates bool
}
// PerformUploadKeysResponse is the response to PerformUploadKeys
type PerformUploadKeysResponse struct {
// A fatal error when processing e.g database failures
Error *KeyError
// A map of user_id -> device_id -> Error for tracking failures.
KeyErrors map[string]map[string]*KeyError
OneTimeKeyCounts []OneTimeKeysCount
}
// PerformDeleteKeysRequest asks the keyserver to forget about certain
// keys, and signatures related to those keys.
type PerformDeleteKeysRequest struct {
UserID string
KeyIDs []gomatrixserverlib.KeyID
}
// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest.
type PerformDeleteKeysResponse struct {
Error *KeyError
}
// KeyError sets a key error field on KeyErrors
func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) {
if r.KeyErrors[userID] == nil {
r.KeyErrors[userID] = make(map[string]*KeyError)
}
r.KeyErrors[userID][deviceID] = err
}
type PerformClaimKeysRequest struct {
// Map of user_id to device_id to algorithm name
OneTimeKeys map[string]map[string]string
Timeout time.Duration
}
type PerformClaimKeysResponse struct {
// Map of user_id to device_id to algorithm:key_id to key JSON
OneTimeKeys map[string]map[string]map[string]json.RawMessage
// Map of remote server domain to error JSON
Failures map[string]interface{}
// Set if there was a fatal error processing this action
Error *KeyError
}
type PerformUploadDeviceKeysRequest struct {
gomatrixserverlib.CrossSigningKeys
// The user that uploaded the key, should be populated by the clientapi.
UserID string
}
type PerformUploadDeviceKeysResponse struct {
Error *KeyError
}
type PerformUploadDeviceSignaturesRequest struct {
Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice
// The user that uploaded the sig, should be populated by the clientapi.
UserID string
}
type PerformUploadDeviceSignaturesResponse struct {
Error *KeyError
}
type QueryKeysRequest struct {
// The user ID asking for the keys, e.g. if from a client API request.
// Will not be populated if the key request came from federation.
UserID string
// Maps user IDs to a list of devices
UserToDevices map[string][]string
Timeout time.Duration
}
type QueryKeysResponse struct {
// Map of remote server domain to error JSON
Failures map[string]interface{}
// Map of user_id to device_id to device_key
DeviceKeys map[string]map[string]json.RawMessage
// Maps of user_id to cross signing key
MasterKeys map[string]gomatrixserverlib.CrossSigningKey
SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
// Set if there was a fatal error processing this query
Error *KeyError
}
type QueryKeyChangesRequest struct {
// The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
Offset int64
// The inclusive offset where to track key changes up to. Messages with this offset are included in the response.
// Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing).
ToOffset int64
}
type QueryKeyChangesResponse struct {
// The set of users who have had their keys change.
UserIDs []string
// The latest offset represented in this response.
Offset int64
// Set if there was a problem handling the request.
Error *KeyError
}
type QueryOneTimeKeysRequest struct {
// The local user to query OTK counts for
UserID string
// The device to query OTK counts for
DeviceID string
}
type QueryOneTimeKeysResponse struct {
// OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
Count OneTimeKeysCount
Error *KeyError
}
type QueryDeviceMessagesRequest struct {
UserID string
}
type QueryDeviceMessagesResponse struct {
// The latest stream ID
StreamID int64
Devices []DeviceMessage
Error *KeyError
}
type QuerySignaturesRequest struct {
// A map of target user ID -> target key/device IDs to retrieve signatures for
TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"`
}
type QuerySignaturesResponse struct {
// A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures
Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap
// A map of target user ID -> cross-signing master key
MasterKeys map[string]gomatrixserverlib.CrossSigningKey
// A map of target user ID -> cross-signing self-signing key
SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
// A map of target user ID -> cross-signing user-signing key
UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
// The request error, if any
Error *KeyError
}
type PerformMarkAsStaleRequest struct {
UserID string
Domain gomatrixserverlib.ServerName
DeviceID string
}

View File

@ -37,7 +37,7 @@ type OutputReceiptEventConsumer struct {
jetstream nats.JetStreamContext
durable string
topic string
db storage.Database
db storage.UserDatabase
serverName gomatrixserverlib.ServerName
syncProducer *producers.SyncAPI
pgClient pushgateway.Client
@ -49,7 +49,7 @@ func NewOutputReceiptEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
store storage.UserDatabase,
syncProducer *producers.SyncAPI,
pgClient pushgateway.Client,
) *OutputReceiptEventConsumer {

View File

@ -18,11 +18,11 @@ import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/keyserver/internal"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
@ -41,7 +41,7 @@ type DeviceListUpdateConsumer struct {
// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers.
func NewDeviceListUpdateConsumer(
process *process.ProcessContext,
cfg *config.KeyServer,
cfg *config.UserAPI,
js nats.JetStreamContext,
updater *internal.DeviceListUpdater,
) *DeviceListUpdateConsumer {

View File

@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct {
rsAPI rsapi.UserRoomserverAPI
jetstream nats.JetStreamContext
durable string
db storage.Database
db storage.UserDatabase
topic string
pgClient pushgateway.Client
syncProducer *producers.SyncAPI
@ -53,7 +53,7 @@ func NewOutputRoomEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
store storage.UserDatabase,
pgClient pushgateway.Client,
rsAPI rsapi.UserRoomserverAPI,
syncProducer *producers.SyncAPI,

View File

@ -18,11 +18,11 @@ import (
userAPITypes "github.com/matrix-org/dendrite/userapi/types"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "", 4, 0, 0, "")
if err != nil {

View File

@ -22,11 +22,10 @@ import (
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/internal"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/userapi/api"
)
// SigningKeyUpdateConsumer consumes signing key updates that came in over federation.
@ -35,24 +34,24 @@ type SigningKeyUpdateConsumer struct {
jetstream nats.JetStreamContext
durable string
topic string
keyAPI *internal.KeyInternalAPI
cfg *config.KeyServer
userAPI api.UploadDeviceKeysAPI
cfg *config.UserAPI
isLocalServerName func(gomatrixserverlib.ServerName) bool
}
// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers.
func NewSigningKeyUpdateConsumer(
process *process.ProcessContext,
cfg *config.KeyServer,
cfg *config.UserAPI,
js nats.JetStreamContext,
keyAPI *internal.KeyInternalAPI,
userAPI api.UploadDeviceKeysAPI,
) *SigningKeyUpdateConsumer {
return &SigningKeyUpdateConsumer{
ctx: process.Context(),
jetstream: js,
durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"),
topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
keyAPI: keyAPI,
userAPI: userAPI,
cfg: cfg,
isLocalServerName: cfg.Matrix.IsLocalServerName,
}
@ -70,7 +69,7 @@ func (t *SigningKeyUpdateConsumer) Start() error {
// signing key update events topic from the key server.
func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
msg := msgs[0] // Guaranteed to exist if onMessage is called
var updatePayload keyapi.CrossSigningKeyUpdate
var updatePayload api.CrossSigningKeyUpdate
if err := json.Unmarshal(msg.Data, &updatePayload); err != nil {
logrus.WithError(err).Errorf("Failed to read from signing key update input topic")
return true
@ -94,12 +93,12 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
if updatePayload.SelfSigningKey != nil {
keys.SelfSigningKey = *updatePayload.SelfSigningKey
}
uploadReq := &keyapi.PerformUploadDeviceKeysRequest{
uploadReq := &api.PerformUploadDeviceKeysRequest{
CrossSigningKeys: keys,
UserID: updatePayload.UserID,
}
uploadRes := &keyapi.PerformUploadDeviceKeysResponse{}
if err := t.keyAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil {
uploadRes := &api.PerformUploadDeviceKeysResponse{}
if err := t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil {
logrus.WithError(err).Error("failed to upload device keys")
return false
}

View File

@ -22,8 +22,8 @@ import (
"fmt"
"strings"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/curve25519"
@ -103,7 +103,7 @@ func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpos
}
// nolint:gocyclo
func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
// Find the keys to store.
byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
toStore := types.CrossSigningKeyMap{}
@ -169,7 +169,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
// something if any of the specified keys in the request are different
// to what we've got in the database, to avoid generating key change
// notifications unnecessarily.
existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
existingKeys, err := a.KeyDatabase.CrossSigningKeysDataForUser(ctx, req.UserID)
if err != nil {
res.Error = &api.KeyError{
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
@ -216,7 +216,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
}
// Store the keys.
if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {
if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
}
@ -234,7 +234,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
continue
}
for sigKeyID, sigBytes := range forSigUserID {
if err := a.DB.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil {
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err),
}
@ -257,7 +257,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
if update.MasterKey == nil && update.SelfSigningKey == nil {
return nil
}
if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil {
if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
}
@ -266,7 +266,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
return nil
}
func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error {
func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error {
// Before we do anything, we need the master and self-signing keys for this user.
// Then we can verify the signatures make sense.
queryReq := &api.QueryKeysRequest{
@ -342,7 +342,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
MasterKey: &masterKey,
SelfSigningKey: &selfSigningKey,
}
if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil {
if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
}
@ -352,7 +352,7 @@ func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req
return nil
}
func (a *KeyInternalAPI) processSelfSignatures(
func (a *UserInternalAPI) processSelfSignatures(
ctx context.Context,
signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
) error {
@ -373,7 +373,7 @@ func (a *KeyInternalAPI) processSelfSignatures(
}
for originUserID, forOriginUserID := range sig.Signatures {
for originKeyID, originSig := range forOriginUserID {
if err := a.DB.StoreCrossSigningSigsForTarget(
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
); err != nil {
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
@ -384,7 +384,7 @@ func (a *KeyInternalAPI) processSelfSignatures(
case *gomatrixserverlib.DeviceKeys:
for originUserID, forOriginUserID := range sig.Signatures {
for originKeyID, originSig := range forOriginUserID {
if err := a.DB.StoreCrossSigningSigsForTarget(
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
); err != nil {
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
@ -401,7 +401,7 @@ func (a *KeyInternalAPI) processSelfSignatures(
return nil
}
func (a *KeyInternalAPI) processOtherSignatures(
func (a *UserInternalAPI) processOtherSignatures(
ctx context.Context, userID string, queryRes *api.QueryKeysResponse,
signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
) error {
@ -442,7 +442,7 @@ func (a *KeyInternalAPI) processOtherSignatures(
}
for originKeyID, originSig := range userSigs {
if err := a.DB.StoreCrossSigningSigsForTarget(
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
ctx, userID, originKeyID, targetUserID, targetKeyID, originSig,
); err != nil {
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
@ -461,11 +461,11 @@ func (a *KeyInternalAPI) processOtherSignatures(
return nil
}
func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
func (a *UserInternalAPI) crossSigningKeysFromDatabase(
ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse,
) {
for targetUserID := range req.UserToDevices {
keys, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
if err != nil {
logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
continue
@ -478,7 +478,7 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
break
}
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
continue
@ -522,9 +522,9 @@ func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
}
}
func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
for targetUserID, forTargetUser := range req.TargetIDs {
keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
if err != nil && err != sql.ErrNoRows {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err),
@ -556,7 +556,7 @@ func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySign
for _, targetKeyID := range forTargetUser {
// Get own signatures only.
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID)
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID)
if err != nil && err != sql.ErrNoRows {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),

View File

@ -33,8 +33,8 @@ import (
"github.com/sirupsen/logrus"
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/userapi/api"
)
var (

View File

@ -29,13 +29,13 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
roomserver "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
)
var (
@ -360,12 +360,12 @@ func TestDebounce(t *testing.T) {
}
}
func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) {
func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
t.Helper()
base, _, _ := testrig.Base(nil)
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
if err != nil {
t.Fatal(err)
}

View File

@ -29,29 +29,11 @@ import (
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/producers"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/api"
)
type KeyInternalAPI struct {
DB storage.Database
Cfg *config.KeyServer
FedClient fedsenderapi.KeyserverFederationAPI
UserAPI userapi.KeyserverUserAPI
Producer *producers.KeyChange
Updater *DeviceListUpdater
}
func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) {
a.UserAPI = i
}
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset)
func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset)
if err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
@ -63,7 +45,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC
return nil
}
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
res.KeyErrors = make(map[string]map[string]*api.KeyError)
if len(req.DeviceKeys) > 0 {
a.uploadLocalDeviceKeys(ctx, req, res)
@ -71,7 +53,7 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform
if len(req.OneTimeKeys) > 0 {
a.uploadOneTimeKeys(ctx, req, res)
}
otks, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
return err
}
@ -79,7 +61,7 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform
return nil
}
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{})
// wrap request map in a top-level by-domain map
@ -97,11 +79,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
domainToDeviceKeys[string(serverName)] = nested
}
for domain, local := range domainToDeviceKeys {
if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
// claim local keys
keys, err := a.DB.ClaimKeys(ctx, local)
keys, err := a.KeyDatabase.ClaimKeys(ctx, local)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
@ -129,7 +111,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
return nil
}
func (a *KeyInternalAPI) claimRemoteKeys(
func (a *UserInternalAPI) claimRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
) {
var wg sync.WaitGroup // Wait for fan-out goroutines to finish
@ -146,7 +128,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
defer cancel()
defer wg.Done()
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
mu.Lock()
defer mu.Unlock()
@ -177,8 +159,8 @@ func (a *KeyInternalAPI) claimRemoteKeys(
}).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys))
}
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to delete device keys: %s", err),
}
@ -186,8 +168,8 @@ func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.Perform
return nil
}
func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
@ -198,8 +180,8 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne
return nil
}
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
@ -225,8 +207,8 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present
// in our database.
func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error {
knownDevices, err := a.DB.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error {
knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
if err != nil {
return err
}
@ -244,7 +226,7 @@ func (a *KeyInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *ap
}
// nolint:gocyclo
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
var respMu sync.Mutex
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
@ -262,8 +244,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
domain := string(serverName)
// query local devices
if a.Cfg.Matrix.IsLocalServerName(serverName) {
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if a.Config.Matrix.IsLocalServerName(serverName) {
deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query local device keys: %s", err),
@ -276,8 +258,8 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
for _, dk := range deviceKeys {
dids = append(dids, dk.DeviceID)
}
var queryRes userapi.QueryDeviceInfosResponse
err = a.UserAPI.QueryDeviceInfos(ctx, &userapi.QueryDeviceInfosRequest{
var queryRes api.QueryDeviceInfosResponse
err = a.QueryDeviceInfos(ctx, &api.QueryDeviceInfosRequest{
DeviceIDs: dids,
}, &queryRes)
if err != nil {
@ -341,14 +323,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
for targetKeyID := range masterKey.Keys {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
if err != nil {
// Stop executing the function if the context was canceled/the deadline was exceeded,
// as we can't continue without a valid context.
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
continue
}
if len(sigMap) == 0 {
@ -367,14 +349,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
for targetUserID, forUserID := range res.DeviceKeys {
for targetKeyID, key := range forUserID {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
if err != nil {
// Stop executing the function if the context was canceled/the deadline was exceeded,
// as we can't continue without a valid context.
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
continue
}
if len(sigMap) == 0 {
@ -403,7 +385,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
return nil
}
func (a *KeyInternalAPI) remoteKeysFromDatabase(
func (a *UserInternalAPI) remoteKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
) map[string]map[string][]string {
fetchRemote := make(map[string]map[string][]string)
@ -429,7 +411,7 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
return fetchRemote
}
func (a *KeyInternalAPI) queryRemoteKeys(
func (a *UserInternalAPI) queryRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse,
domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{},
) {
@ -441,13 +423,13 @@ func (a *KeyInternalAPI) queryRemoteKeys(
domains := map[string]struct{}{}
for domain := range domainToDeviceKeys {
if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
domains[domain] = struct{}{}
}
for domain := range domainToCrossSigningKeys {
if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
domains[domain] = struct{}{}
@ -499,7 +481,7 @@ func (a *KeyInternalAPI) queryRemoteKeys(
}
}
func (a *KeyInternalAPI) queryRemoteKeysOnServer(
func (a *UserInternalAPI) queryRemoteKeysOnServer(
ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{},
wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
res *api.QueryKeysResponse,
@ -559,7 +541,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
if len(devKeys) == 0 {
return
}
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
if err == nil {
resultCh <- &queryKeysResp
return
@ -586,10 +568,10 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
respMu.Unlock()
}
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase(
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
) error {
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
// if we can't query the db or there are fewer keys than requested, fetch from remote.
if err != nil {
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
@ -621,11 +603,11 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
return nil
}
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
// get a list of devices from the user API that actually exist, as
// we won't store keys for devices that don't exist
uapidevices := &userapi.QueryDevicesResponse{}
if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
uapidevices := &api.QueryDevicesResponse{}
if err := a.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
res.Error = &api.KeyError{
Err: err.Error(),
}
@ -643,7 +625,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
}
// Get all of the user existing device keys so we can check for changes.
existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
@ -662,7 +644,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
}
if len(toClean) > 0 {
if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
} else {
logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
@ -693,7 +675,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
if err != nil {
continue // ignore invalid users
}
if !a.Cfg.Matrix.IsLocalServerName(serverName) {
if !a.Config.Matrix.IsLocalServerName(serverName) {
continue // ignore remote users
}
if len(key.KeyJSON) == 0 {
@ -722,30 +704,30 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
}
// store the device keys and emit changes
err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
}
return
}
err = emitDeviceKeyChanges(a.Producer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates)
err = emitDeviceKeyChanges(a.KeyChangeProducer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates)
if err != nil {
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
}
}
func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if req.UserID == "" {
res.Error = &api.KeyError{
Err: "user ID missing",
}
}
if req.DeviceID != "" && len(req.OneTimeKeys) == 0 {
counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err),
Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err),
}
}
if counts != nil {
@ -761,7 +743,7 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
keyIDsWithAlgorithms[i] = keyIDWithAlgo
i++
}
existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: "failed to query existing one-time keys: " + err.Error(),
@ -778,7 +760,7 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
}
}
// store one-time keys
counts, err := a.DB.StoreOneTimeKeys(ctx, key)
counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key)
if err != nil {
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),

View File

@ -5,23 +5,28 @@ import (
"reflect"
"testing"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/internal"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/storage"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewDatabase(nil, &config.DatabaseOptions{
base, _, _ := testrig.Base(nil)
db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("failed to create new user db: %v", err)
}
return db, close
return db, func() {
base.Close()
close()
}
}
func Test_QueryDeviceMessages(t *testing.T) {
@ -140,8 +145,8 @@ func Test_QueryDeviceMessages(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &internal.KeyInternalAPI{
DB: db,
a := &internal.UserInternalAPI{
KeyDatabase: db,
}
if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr {
t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr)

View File

@ -23,6 +23,7 @@ import (
"strconv"
"time"
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
@ -32,7 +33,6 @@ import (
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
synctypes "github.com/matrix-org/dendrite/syncapi/types"
@ -44,17 +44,19 @@ import (
)
type UserInternalAPI struct {
DB storage.Database
DB storage.UserDatabase
KeyDatabase storage.KeyDatabase
SyncProducer *producers.SyncAPI
KeyChangeProducer *producers.KeyChange
Config *config.UserAPI
DisableTLSValidation bool
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
KeyAPI keyapi.UserKeyAPI
RSAPI rsapi.UserRoomserverAPI
PgClient pushgateway.Client
Cfg *config.UserAPI
FedClient fedsenderapi.KeyserverFederationAPI
Updater *DeviceListUpdater
}
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
@ -221,7 +223,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return fmt.Errorf("a.DB.SetDisplayName: %w", err)
}
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
postRegisterJoinRooms(a.Config, acc, a.RSAPI)
res.AccountCreated = true
res.Account = acc
@ -293,14 +295,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
return err
}
// Ask the keyserver to delete device keys and signatures for those devices
deleteReq := &keyapi.PerformDeleteKeysRequest{
deleteReq := &api.PerformDeleteKeysRequest{
UserID: req.UserID,
}
for _, keyID := range req.DeviceIDs {
deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID))
}
deleteRes := &keyapi.PerformDeleteKeysResponse{}
if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
deleteRes := &api.PerformDeleteKeysResponse{}
if err := a.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
return err
}
if err := deleteRes.Error; err != nil {
@ -311,17 +313,17 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
}
func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error {
deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs))
deviceKeys := make([]api.DeviceKeys, len(deviceIDs))
for i, did := range deviceIDs {
deviceKeys[i] = keyapi.DeviceKeys{
deviceKeys[i] = api.DeviceKeys{
UserID: userID,
DeviceID: did,
KeyJSON: nil,
}
}
var uploadRes keyapi.PerformUploadKeysResponse
if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
var uploadRes api.PerformUploadKeysResponse
if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{
UserID: userID,
DeviceKeys: deviceKeys,
}, &uploadRes); err != nil {
@ -385,10 +387,10 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
}
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
// display name has changed: update the device key
var uploadRes keyapi.PerformUploadKeysResponse
if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
var uploadRes api.PerformUploadKeysResponse
if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{
UserID: req.RequestingUserID,
DeviceKeys: []keyapi.DeviceKeys{
DeviceKeys: []api.DeviceKeys{
{
DeviceID: dev.ID,
DisplayName: *req.DisplayName,

View File

@ -18,9 +18,9 @@ import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
)
@ -28,8 +28,8 @@ import (
// KeyChange produces key change events for the sync API and federation sender to consume
type KeyChange struct {
Topic string
JetStream nats.JetStreamContext
DB storage.Database
JetStream JetStreamPublisher
DB storage.KeyChangeDatabase
}
// ProduceKeyChanges creates new change events for each key

View File

@ -19,13 +19,13 @@ type JetStreamPublisher interface {
// SyncAPI produces messages for the Sync API server to consume.
type SyncAPI struct {
db storage.Database
db storage.Notification
producer JetStreamPublisher
clientDataTopic string
notificationDataTopic string
}
func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
func NewSyncAPI(db storage.UserDatabase, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
return &SyncAPI{
db: db,
producer: js,

View File

@ -90,7 +90,7 @@ type KeyBackup interface {
type LoginToken interface {
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
// determined by the loginTokenLifetime given to the UserDatabase constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
@ -130,7 +130,7 @@ type Notification interface {
DeleteOldNotifications(ctx context.Context) error
}
type Database interface {
type UserDatabase interface {
Account
AccountData
Device
@ -144,6 +144,78 @@ type Database interface {
ThreePID
}
type KeyChangeDatabase interface {
// StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
// `userID` is the the user who has changed their keys in some way.
StoreKeyChange(ctx context.Context, userID string) (int64, error)
}
type KeyDatabase interface {
KeyChangeDatabase
// ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
// of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
// StoreOneTimeKeys persists the given one-time keys.
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
// StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device).
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
// Returns an error if there was a problem storing the keys.
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.
DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
// KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
// A to offset of types.OffsetNewest means no upper limit.
// Returns the offset of the latest key change.
KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
// MarkDeviceListStale sets the stale bit for this user to isStale.
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
DeleteStaleDeviceLists(
ctx context.Context,
userIDs []string,
) error
}
type Statistics interface {
UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error)
DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error)

View File

@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData(
roomID, dataType string, content json.RawMessage,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
// Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage
// when passing the data to trigger "NOT NULL" constraint
var data *json.RawMessage
if len(content) > 0 {
data = &content
}
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data)
return
}

View File

@ -21,8 +21,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
)

View File

@ -21,9 +21,9 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
)

View File

@ -23,8 +23,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var deviceKeysSchema = `
@ -92,31 +92,16 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if err != nil {
return nil, err
}
if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
return nil, err
}
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
return nil, err
}
if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil {
return nil, err
}
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
{&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
{&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
{&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
{&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
{&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL},
{&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
{&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
}.Prepare(db)
}
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {

View File

@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice(
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
}
return &api.Device{
dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@ -168,7 +168,11 @@ func (s *devicesStatements) InsertDevice(
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
}, nil
}
if displayName != nil {
dev.DisplayName = *displayName
}
return dev, nil
}
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,

View File

@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
const countKeysSQL = "" +
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" +
const selectBackupKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
{&s.selectKeysStmt, selectKeysSQL},
{&s.selectKeysStmt, selectBackupKeysSQL},
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
}.Prepare(db)

View File

@ -22,8 +22,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var keyChangesSchema = `
@ -66,7 +66,10 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
if err = executeMigration(context.Background(), db); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
{&s.selectKeyChangesStmt, selectKeyChangesSQL},
}.Prepare(db)
}
func executeMigration(ctx context.Context, db *sql.DB) error {
@ -95,16 +98,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error {
return m.Up(ctx)
}
func (s *keyChangesStatements) Prepare() (err error) {
if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil {
return err
}
if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil {
return err
}
return nil
}
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
return

View File

@ -23,8 +23,8 @@ import (
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var oneTimeKeysSchema = `
@ -49,7 +49,7 @@ const upsertKeysSQL = "" +
" ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" +
" DO UPDATE SET key_json = $6"
const selectKeysSQL = "" +
const selectOneTimeKeysSQL = "" +
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
const selectKeysCountSQL = "" +
@ -84,25 +84,14 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
if err != nil {
return nil, err
}
if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil {
return nil, err
}
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return nil, err
}
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
return nil, err
}
if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
return nil, err
}
if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
return nil, err
}
if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.upsertKeysStmt, upsertKeysSQL},
{&s.selectKeysStmt, selectOneTimeKeysSQL},
{&s.selectKeysCountStmt, selectKeysCountSQL},
{&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
{&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
{&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
}.Prepare(db)
}
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {

View File

@ -24,7 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)

View File

@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter())
if err != nil {
return nil, err
}
otk, err := NewPostgresOneTimeKeysTable(db)
if err != nil {
return nil, err
}
dk, err := NewPostgresDeviceKeysTable(db)
if err != nil {
return nil, err
}
kc, err := NewPostgresKeyChangesTable(db)
if err != nil {
return nil, err
}
sdl, err := NewPostgresStaleDeviceListsTable(db)
if err != nil {
return nil, err
}
csk, err := NewPostgresCrossSigningKeysTable(db)
if err != nil {
return nil, err
}
css, err := NewPostgresCrossSigningSigsTable(db)
if err != nil {
return nil, err
}
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
CrossSigningKeysTable: csk,
CrossSigningSigsTable: css,
Writer: writer,
}, nil
}

View File

@ -59,6 +59,17 @@ type Database struct {
OpenIDTokenLifetimeMS int64
}
type KeyDatabase struct {
OneTimeKeysTable tables.OneTimeKeys
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists
CrossSigningKeysTable tables.CrossSigningKeys
CrossSigningSigsTable tables.CrossSigningSigs
DB *sql.DB
Writer sqlutil.Writer
}
const (
// The length of generated device IDs
deviceIDByteLength = 6
@ -875,3 +886,227 @@ func (d *Database) DailyRoomsMessages(
) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) {
return d.Stats.DailyRoomsMessages(ctx, nil, serverName)
}
//
func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
}
func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
return err
})
return
}
func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
}
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}
func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
if err != nil {
return false, err
}
return count == len(prevIDs), nil
}
func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for _, userID := range clearUserIDs {
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
if err != nil {
return err
}
}
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
})
}
func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
// work out the latest stream IDs for each user
userIDToStreamID := make(map[string]int64)
for _, k := range keys {
userIDToStreamID[k.UserID] = 0
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for userID := range userIDToStreamID {
streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
if err != nil {
return err
}
userIDToStreamID[userID] = streamID
}
// set the stream IDs for each key
for i := range keys {
k := keys[i]
userIDToStreamID[k.UserID]++ // start stream from 1
k.StreamID = userIDToStreamID[k.UserID]
keys[i] = k
}
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
})
}
func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
}
func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
var result []api.OneTimeKeys
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for userID, deviceToAlgo := range userToDeviceToAlgorithm {
for deviceID, algo := range deviceToAlgo {
keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
if err != nil {
return err
}
if keyJSON != nil {
result = append(result, api.OneTimeKeys{
UserID: userID,
DeviceID: deviceID,
KeyJSON: keyJSON,
})
}
}
}
return nil
})
return result, err
}
func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
return err
})
return
}
func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
}
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
}
// MarkDeviceListStale sets the stale bit for this user to isStale.
func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
})
}
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.
func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for _, deviceID := range deviceIDs {
if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
}
if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
}
if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err)
}
}
return nil
})
}
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
if err != nil {
return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err)
}
results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
for purpose, key := range keyMap {
keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode())
result := gomatrixserverlib.CrossSigningKey{
UserID: userID,
Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose},
Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
keyID: key,
},
}
sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID)
if err != nil {
continue
}
for sigUserID, forSigUserID := range sigMap {
if userID != sigUserID {
continue
}
if result.Signatures == nil {
result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
if _, ok := result.Signatures[sigUserID]; !ok {
result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
for sigKeyID, sigBytes := range forSigUserID {
result.Signatures[sigUserID][sigKeyID] = sigBytes
}
}
results[purpose] = result
}
return results, nil
}
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
}
// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
}
// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for keyType, keyData := range keyMap {
if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
}
}
return nil
})
}
// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
func (d *KeyDatabase) StoreCrossSigningSigsForTarget(
ctx context.Context,
originUserID string, originKeyID gomatrixserverlib.KeyID,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
signature gomatrixserverlib.Base64Bytes,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err)
}
return nil
})
}
// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
func (d *KeyDatabase) DeleteStaleDeviceLists(
ctx context.Context,
userIDs []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
})
}

View File

@ -21,8 +21,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
)

View File

@ -21,9 +21,9 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
)

View File

@ -22,8 +22,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var deviceKeysSchema = `
@ -86,28 +86,16 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if err != nil {
return nil, err
}
if s.upsertDeviceKeysStmt, err = db.Prepare(upsertDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectDeviceKeysStmt, err = db.Prepare(selectDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err
}
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
return nil, err
}
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil {
return nil, err
}
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
{&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
{&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
{&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
{&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
// {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime
{&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
{&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
}.Prepare(db)
}
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {

View File

@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice(
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
return &api.Device{
dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@ -159,7 +159,11 @@ func (s *devicesStatements) InsertDevice(
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
}, nil
}
if displayName != nil {
dev.DisplayName = *displayName
}
return dev, nil
}
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
@ -172,7 +176,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
return &api.Device{
dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@ -180,7 +184,11 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
}, nil
}
if displayName != nil {
dev.DisplayName = *displayName
}
return dev, nil
}
func (s *devicesStatements) DeleteDevice(
@ -202,6 +210,7 @@ func (s *devicesStatements) DeleteDevices(
if err != nil {
return err
}
defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+2)
params[0] = localpart

View File

@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
const countKeysSQL = "" +
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" +
const selectBackupKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
{&s.selectKeysStmt, selectKeysSQL},
{&s.selectKeysStmt, selectBackupKeysSQL},
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
}.Prepare(db)

View File

@ -22,8 +22,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var keyChangesSchema = `
@ -65,7 +65,10 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
{&s.selectKeyChangesStmt, selectKeyChangesSQL},
}.Prepare(db)
}
func executeMigration(ctx context.Context, db *sql.DB) error {
@ -93,16 +96,6 @@ func executeMigration(ctx context.Context, db *sql.DB) error {
return m.Up(ctx)
}
func (s *keyChangesStatements) Prepare() (err error) {
if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil {
return err
}
if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil {
return err
}
return nil
}
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
return

View File

@ -22,8 +22,8 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var oneTimeKeysSchema = `
@ -48,7 +48,7 @@ const upsertKeysSQL = "" +
" ON CONFLICT (user_id, device_id, key_id, algorithm)" +
" DO UPDATE SET key_json = $6"
const selectKeysSQL = "" +
const selectOneTimeKeysSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
const selectKeysCountSQL = "" +
@ -83,25 +83,14 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
if err != nil {
return nil, err
}
if s.upsertKeysStmt, err = db.Prepare(upsertKeysSQL); err != nil {
return nil, err
}
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return nil, err
}
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
return nil, err
}
if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
return nil, err
}
if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
return nil, err
}
if s.deleteOneTimeKeysStmt, err = db.Prepare(deleteOneTimeKeysSQL); err != nil {
return nil, err
}
return s, nil
return s, sqlutil.StatementList{
{&s.upsertKeysStmt, upsertKeysSQL},
{&s.selectKeysStmt, selectOneTimeKeysSQL},
{&s.selectKeysCountStmt, selectKeysCountSQL},
{&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
{&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
{&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
}.Prepare(db)
}
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {

View File

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)

View File

@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
err = stmt.QueryRowContext(ctx,
1, 2, 3, 4,
@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
err = stmt.QueryRowContext(ctx,
1, 2, 3,
@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
registeredAfter := time.Now().AddDate(0, 0, -30)

View File

@ -30,8 +30,8 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
)
// NewDatabase creates a new accounts and profiles database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
// NewUserDatabase creates a new accounts and profiles database
func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
if err != nil {
return nil, err
@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
if err != nil {
return nil, err
}
otk, err := NewSqliteOneTimeKeysTable(db)
if err != nil {
return nil, err
}
dk, err := NewSqliteDeviceKeysTable(db)
if err != nil {
return nil, err
}
kc, err := NewSqliteKeyChangesTable(db)
if err != nil {
return nil, err
}
sdl, err := NewSqliteStaleDeviceListsTable(db)
if err != nil {
return nil, err
}
csk, err := NewSqliteCrossSigningKeysTable(db)
if err != nil {
return nil, err
}
css, err := NewSqliteCrossSigningSigsTable(db)
if err != nil {
return nil, err
}
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
CrossSigningKeysTable: csk,
CrossSigningSigsTable: css,
Writer: writer,
}, nil
}

View File

@ -29,15 +29,36 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
)
// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
func NewUserDatabase(
base *base.BaseDendrite,
dbProperties *config.DatabaseOptions,
serverName gomatrixserverlib.ServerName,
bcryptCost int,
openIDTokenLifetimeMS int64,
loginTokenLifetime time.Duration,
serverNoticesLocalpart string,
) (UserDatabase, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
default:
return nil, fmt.Errorf("unexpected database type")
}
}
// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme)
// and sets postgres connection parameters.
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewKeyDatabase(base, dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewKeyDatabase(base, dbProperties)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View File

@ -4,9 +4,12 @@ import (
"context"
"encoding/json"
"fmt"
"reflect"
"sync"
"testing"
"time"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
@ -29,14 +32,14 @@ var (
ctx = context.Background()
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
if err != nil {
t.Fatalf("NewUserAPIDatabase returned %s", err)
t.Fatalf("NewUserDatabase returned %s", err)
}
return db, func() {
close()
@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun
// Tests storing and getting account data
func Test_AccountData(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) {
// Tests the creation of accounts
func Test_Accounts(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
@ -158,7 +161,7 @@ func Test_Devices(t *testing.T) {
accessToken := util.RandomString(16)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
@ -238,7 +241,7 @@ func Test_KeyBackup(t *testing.T) {
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
wantAuthData := json.RawMessage("my auth data")
@ -315,7 +318,7 @@ func Test_KeyBackup(t *testing.T) {
func Test_LoginToken(t *testing.T) {
alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
// create a new token
@ -347,7 +350,7 @@ func Test_OpenID(t *testing.T) {
token := util.RandomString(24)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS
@ -368,7 +371,7 @@ func Test_Profile(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
// create account, which also creates a profile
@ -417,7 +420,7 @@ func Test_Pusher(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
appID := util.RandomString(8)
@ -468,7 +471,7 @@ func Test_ThreePID(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
threePID := util.RandomString(8)
medium := util.RandomString(8)
@ -507,7 +510,7 @@ func Test_Notification(t *testing.T) {
room := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
// generate some dummy notifications
for i := 0; i < 10; i++ {
@ -571,3 +574,184 @@ func Test_Notification(t *testing.T) {
assert.Equal(t, int64(0), total)
})
}
func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
base, close := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database)
if err != nil {
t.Fatalf("failed to create new database: %v", err)
}
return db, close
}
func MustNotError(t *testing.T, err error) {
t.Helper()
if err == nil {
return
}
t.Fatalf("operation failed: %s", err)
}
func TestKeyChanges(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
MustNotError(t, err)
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeIDC {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
}
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
}
func TestKeyChangesNoDupes(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
if deviceChangeIDA == deviceChangeIDB {
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
}
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeID {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
}
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
}
func TestKeyChangesUpperLimit(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
MustNotError(t, err)
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeIDB {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
}
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
}
var dbLock sync.Mutex
var deviceArray = []string{"AAA", "another_device"}
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
// and that they are returned correctly when querying for device keys.
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
var err error
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
alice := "@alice:TestDeviceKeysStreamIDGeneration"
bob := "@bob:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1
},
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: bob,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1 as this is a different user
},
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "another_device",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 2 as this is a 2nd device key
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
}
if msgs[1].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
}
if msgs[2].StreamID != 2 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
}
// updating a device sets the next stream ID for that user
msgs = []api.DeviceMessage{
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v2"}`),
},
// StreamID: 3
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 3 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
}
dbLock.Lock()
defer dbLock.Unlock()
// Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false)
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
wantStreamIDs := map[string]int64{
"AAA": 3,
"another_device": 2,
}
if len(msgs) != len(wantStreamIDs) {
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
}
for _, m := range msgs {
if m.StreamID != wantStreamIDs[m.DeviceID] {
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
}
}
})
}

View File

@ -32,10 +32,10 @@ func NewUserAPIDatabase(
openIDTokenLifetimeMS int64,
loginTokenLifetime time.Duration,
serverNoticesLocalpart string,
) (Database, error) {
) (UserDatabase, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:

View File

@ -20,10 +20,10 @@ import (
"encoding/json"
"time"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/types"
)
@ -145,3 +145,47 @@ const (
// uint32.
AllNotifications NotificationFilter = (1 << 31) - 1
)
type OneTimeKeys interface {
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
// Returns an empty map if the key does not exist.
SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
}
type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}
type KeyChanges interface {
InsertKeyChange(ctx context.Context, userID string) (int64, error)
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset.
SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
}
type StaleDeviceLists interface {
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
}
type CrossSigningKeys interface {
SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error)
UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error
}
type CrossSigningSigs interface {
SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
}

View File

@ -4,15 +4,15 @@ import (
"context"
"testing"
"github.com/matrix-org/dendrite/userapi/storage/postgres"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/keyserver/storage/postgres"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {

Some files were not shown because too many files have changed in this diff Show More