From dc0bac85d5bad933d32ee63f8bc1aef6348ca6e9 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 18 Jun 2020 18:36:03 +0100 Subject: [PATCH] Refactor account data (#1150) * Refactor account data * Tweak database fetching * Tweaks * Restore syncProducer notification * Various tweaks, update tag behaviour * Fix initial sync --- clientapi/routing/account_data.go | 55 ++++---- clientapi/routing/room_tagging.go | 127 +++++++----------- clientapi/routing/routing.go | 14 +- syncapi/sync/requestpool.go | 75 +++++++---- userapi/api/api.go | 27 ++-- userapi/internal/api.go | 31 ++++- userapi/inthttp/client.go | 10 ++ userapi/storage/accounts/interface.go | 7 +- .../accounts/postgres/account_data_table.go | 48 +++---- userapi/storage/accounts/postgres/storage.go | 13 +- .../accounts/sqlite3/account_data_table.go | 50 +++---- userapi/storage/accounts/sqlite3/storage.go | 13 +- 12 files changed, 248 insertions(+), 222 deletions(-) diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 68e0dc5d..d5fafedb 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -16,21 +16,20 @@ package routing import ( "encoding/json" + "fmt" "io/ioutil" "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) // GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type} func GetAccountData( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, dataType string, ) util.JSONResponse { if userID != device.UserID { @@ -40,18 +39,28 @@ func GetAccountData( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + dataReq := api.QueryAccountDataRequest{ + UserID: userID, + DataType: dataType, + RoomID: roomID, + } + dataRes := api.QueryAccountDataResponse{} + if err := userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") + return util.ErrorResponse(fmt.Errorf("userAPI.QueryAccountData: %w", err)) } - if data, err := accountDB.GetAccountDataByType( - req.Context(), localpart, roomID, dataType, - ); err == nil { + var data json.RawMessage + var ok bool + if roomID != "" { + data, ok = dataRes.RoomAccountData[roomID][dataType] + } else { + data, ok = dataRes.GlobalAccountData[dataType] + } + if ok { return util.JSONResponse{ Code: http.StatusOK, - JSON: data.Content, + JSON: data, } } @@ -63,7 +72,7 @@ func GetAccountData( // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} func SaveAccountData( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, ) util.JSONResponse { if userID != device.UserID { @@ -73,12 +82,6 @@ func SaveAccountData( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - defer req.Body.Close() // nolint: errcheck if req.Body == http.NoBody { @@ -101,13 +104,19 @@ func SaveAccountData( } } - if err := accountDB.SaveAccountData( - req.Context(), localpart, roomID, dataType, string(body), - ); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("accountDB.SaveAccountData failed") - return jsonerror.InternalServerError() + dataReq := api.InputAccountDataRequest{ + UserID: userID, + DataType: dataType, + RoomID: roomID, + AccountData: json.RawMessage(body), + } + dataRes := api.InputAccountDataResponse{} + if err := userAPI.InputAccountData(req.Context(), &dataReq, &dataRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") + return util.ErrorResponse(err) } + // TODO: user API should do this since it's account data if err := syncProducer.SendData(userID, roomID, dataType); err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") return jsonerror.InternalServerError() diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index b1cfcca8..c683cc94 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -24,23 +24,14 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -// newTag creates and returns a new gomatrix.TagContent -func newTag() gomatrix.TagContent { - return gomatrix.TagContent{ - Tags: make(map[string]gomatrix.TagProperties), - } -} - // GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags func GetTags( req *http.Request, - accountDB accounts.Database, + userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, @@ -54,22 +45,15 @@ func GetTags( } } - _, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - if data == nil { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, - } - } - return util.JSONResponse{ Code: http.StatusOK, - JSON: data.Content, + JSON: tagContent, } } @@ -78,7 +62,7 @@ func GetTags( // the tag to the "map" and saving the new "map" to the DB func PutTag( req *http.Request, - accountDB accounts.Database, + userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, @@ -98,34 +82,25 @@ func PutTag( return *reqErr } - localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - var tagContent gomatrix.TagContent - if data != nil { - if err = json.Unmarshal(data.Content, &tagContent); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") - return jsonerror.InternalServerError() - } - } else { - tagContent = newTag() + if tagContent.Tags == nil { + tagContent.Tags = make(map[string]gomatrix.TagProperties) } tagContent.Tags[tag] = properties - if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { + + if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") return jsonerror.InternalServerError() } - // Send data to syncProducer in order to inform clients of changes - // Run in a goroutine in order to prevent blocking the tag request response - go func() { - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { - logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") - } - }() + if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") + } return util.JSONResponse{ Code: http.StatusOK, @@ -138,7 +113,7 @@ func PutTag( // the "map" and then saving the new "map" in the DB func DeleteTag( req *http.Request, - accountDB accounts.Database, + userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, @@ -153,28 +128,12 @@ func DeleteTag( } } - localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - // If there are no tags in the database, exit - if data == nil { - // Spec only defines 200 responses for this endpoint so we don't return anything else. - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, - } - } - - var tagContent gomatrix.TagContent - err = json.Unmarshal(data.Content, &tagContent) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") - return jsonerror.InternalServerError() - } - // Check whether the tag to be deleted exists if _, ok := tagContent.Tags[tag]; ok { delete(tagContent.Tags, tag) @@ -185,18 +144,16 @@ func DeleteTag( JSON: struct{}{}, } } - if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { + + if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") return jsonerror.InternalServerError() } - // Send data to syncProducer in order to inform clients of changes - // Run in a goroutine in order to prevent blocking the tag request response - go func() { - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { - logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") - } - }() + // TODO: user API should do this since it's account data + if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") + } return util.JSONResponse{ Code: http.StatusOK, @@ -210,32 +167,46 @@ func obtainSavedTags( req *http.Request, userID string, roomID string, - accountDB accounts.Database, -) (string, *gomatrixserverlib.ClientEvent, error) { - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return "", nil, err + userAPI api.UserInternalAPI, +) (tags gomatrix.TagContent, err error) { + dataReq := api.QueryAccountDataRequest{ + UserID: userID, + RoomID: roomID, + DataType: "m.tag", } - - data, err := accountDB.GetAccountDataByType( - req.Context(), localpart, roomID, "m.tag", - ) - - return localpart, data, err + dataRes := api.QueryAccountDataResponse{} + err = userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes) + if err != nil { + return + } + data, ok := dataRes.RoomAccountData[roomID]["m.tag"] + if !ok { + return + } + if err = json.Unmarshal(data, &tags); err != nil { + return + } + return tags, nil } // saveTagData saves the provided tag data into the database func saveTagData( req *http.Request, - localpart string, + userID string, roomID string, - accountDB accounts.Database, + userAPI api.UserInternalAPI, Tag gomatrix.TagContent, ) error { newTagData, err := json.Marshal(Tag) if err != nil { return err } - - return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", string(newTagData)) + dataReq := api.InputAccountDataRequest{ + UserID: userID, + RoomID: roomID, + DataType: "m.tag", + AccountData: json.RawMessage(newTagData), + } + dataRes := api.InputAccountDataResponse{} + return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 41c7fb18..e91b07ac 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -476,7 +476,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"], syncProducer) + return SaveAccountData(req, userAPI, device, vars["userID"], "", vars["type"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -486,7 +486,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"], syncProducer) + return SaveAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -496,7 +496,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetAccountData(req, accountDB, device, vars["userID"], "", vars["type"]) + return GetAccountData(req, userAPI, device, vars["userID"], "", vars["type"]) }), ).Methods(http.MethodGet) @@ -506,7 +506,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"]) + return GetAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"]) }), ).Methods(http.MethodGet) @@ -604,7 +604,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetTags(req, accountDB, device, vars["userId"], vars["roomId"], syncProducer) + return GetTags(req, userAPI, device, vars["userId"], vars["roomId"], syncProducer) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -614,7 +614,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return PutTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) + return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -624,7 +624,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) + return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) }), ).Methods(http.MethodDelete, http.MethodOptions) diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 26b925ea..8d51689e 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -205,22 +205,34 @@ func (rp *RequestPool) appendAccountData( if req.since == nil { // If this is the initial sync, we don't need to check if a data has // already been sent. Instead, we send the whole batch. - var res userapi.QueryAccountDataResponse - err := rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{ + dataReq := &userapi.QueryAccountDataRequest{ UserID: userID, - }, &res) - if err != nil { + } + dataRes := &userapi.QueryAccountDataResponse{} + if err := rp.userAPI.QueryAccountData(req.ctx, dataReq, dataRes); err != nil { return nil, err } - data.AccountData.Events = res.GlobalAccountData - + for datatype, databody := range dataRes.GlobalAccountData { + data.AccountData.Events = append( + data.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: datatype, + Content: gomatrixserverlib.RawJSON(databody), + }, + ) + } for r, j := range data.Rooms.Join { - if len(res.RoomAccountData[r]) > 0 { - j.AccountData.Events = res.RoomAccountData[r] + for datatype, databody := range dataRes.RoomAccountData[r] { + j.AccountData.Events = append( + j.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: datatype, + Content: gomatrixserverlib.RawJSON(databody), + }, + ) data.Rooms.Join[r] = j } } - return data, nil } @@ -249,33 +261,42 @@ func (rp *RequestPool) appendAccountData( // Iterate over the rooms for roomID, dataTypes := range dataTypes { - events := []gomatrixserverlib.ClientEvent{} // Request the missing data from the database for _, dataType := range dataTypes { - var res userapi.QueryAccountDataResponse - err = rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{ + dataReq := userapi.QueryAccountDataRequest{ UserID: userID, RoomID: roomID, DataType: dataType, - }, &res) + } + dataRes := userapi.QueryAccountDataResponse{} + err = rp.userAPI.QueryAccountData(req.ctx, &dataReq, &dataRes) if err != nil { - return nil, err + continue } - if len(res.RoomAccountData[roomID]) > 0 { - events = append(events, res.RoomAccountData[roomID]...) - } else if len(res.GlobalAccountData) > 0 { - events = append(events, res.GlobalAccountData...) + if roomID == "" { + if globalData, ok := dataRes.GlobalAccountData[dataType]; ok { + data.AccountData.Events = append( + data.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: gomatrixserverlib.RawJSON(globalData), + }, + ) + } + } else { + if roomData, ok := dataRes.RoomAccountData[roomID][dataType]; ok { + joinData := data.Rooms.Join[roomID] + joinData.AccountData.Events = append( + joinData.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: gomatrixserverlib.RawJSON(roomData), + }, + ) + data.Rooms.Join[roomID] = joinData + } } } - - // Append the data to the response - if len(roomID) > 0 { - jr := data.Rooms.Join[roomID] - jr.AccountData.Events = events - data.Rooms.Join[roomID] = jr - } else { - data.AccountData.Events = events - } } return data, nil diff --git a/userapi/api/api.go b/userapi/api/api.go index c953a5ba..a80adf2d 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -16,12 +16,14 @@ package api import ( "context" + "encoding/json" "github.com/matrix-org/gomatrixserverlib" ) // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { + InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error @@ -30,6 +32,18 @@ type UserInternalAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error } +// InputAccountDataRequest is the request for InputAccountData +type InputAccountDataRequest struct { + UserID string // required: the user to set account data for + RoomID string // optional: the room to associate the account data with + DataType string // optional: the data type of the data + AccountData json.RawMessage // required: the message content +} + +// InputAccountDataResponse is the response for InputAccountData +type InputAccountDataResponse struct { +} + // QueryAccessTokenRequest is the request for QueryAccessToken type QueryAccessTokenRequest struct { AccessToken string @@ -46,18 +60,15 @@ type QueryAccessTokenResponse struct { // QueryAccountDataRequest is the request for QueryAccountData type QueryAccountDataRequest struct { - UserID string // required: the user to get account data for. - // TODO: This is a terribly confusing API shape :/ - DataType string // optional: if specified returns only a single event matching this data type. - // optional: Only used if DataType is set. If blank returns global account data matching the data type. - // If set, returns only room account data matching this data type. - RoomID string + UserID string // required: the user to get account data for. + RoomID string // optional: the room ID, or global account data if not specified. + DataType string // optional: the data type, or all types if not specified. } // QueryAccountDataResponse is the response for QueryAccountData type QueryAccountDataResponse struct { - GlobalAccountData []gomatrixserverlib.ClientEvent - RoomAccountData map[string][]gomatrixserverlib.ClientEvent + GlobalAccountData map[string]json.RawMessage // type -> data + RoomAccountData map[string]map[string]json.RawMessage // room -> type -> data } // QueryDevicesRequest is the request for QueryDevices diff --git a/userapi/internal/api.go b/userapi/internal/api.go index ae021f57..b081eca4 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -17,6 +17,7 @@ package internal import ( "context" "database/sql" + "encoding/json" "errors" "fmt" @@ -38,6 +39,20 @@ type UserInternalAPI struct { AppServices []config.ApplicationService } +func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + } + if req.DataType == "" { + return fmt.Errorf("data type must not be empty") + } + return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) +} + func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { if req.AccountType == api.AccountTypeGuest { acc, err := a.AccountDB.CreateGuestAccount(ctx) @@ -130,17 +145,21 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName) } if req.DataType != "" { - var event *gomatrixserverlib.ClientEvent - event, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) + var data json.RawMessage + data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) if err != nil { return err } - if event != nil { + res.RoomAccountData = make(map[string]map[string]json.RawMessage) + res.GlobalAccountData = make(map[string]json.RawMessage) + if data != nil { if req.RoomID != "" { - res.RoomAccountData = make(map[string][]gomatrixserverlib.ClientEvent) - res.RoomAccountData[req.RoomID] = []gomatrixserverlib.ClientEvent{*event} + if _, ok := res.RoomAccountData[req.RoomID]; !ok { + res.RoomAccountData[req.RoomID] = make(map[string]json.RawMessage) + } + res.RoomAccountData[req.RoomID][req.DataType] = data } else { - res.GlobalAccountData = append(res.GlobalAccountData, *event) + res.GlobalAccountData[req.DataType] = data } } return nil diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 0e9628c5..4ab0d690 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -26,6 +26,8 @@ import ( // HTTP paths for the internal HTTP APIs const ( + InputAccountDataPath = "/userapi/inputAccountData" + PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" @@ -55,6 +57,14 @@ type httpUserInternalAPI struct { httpClient *http.Client } +func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData") + defer span.Finish() + + apiURL := h.apiURL + InputAccountDataPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + func (h *httpUserInternalAPI) PerformAccountCreation( ctx context.Context, request *api.PerformAccountCreationRequest, diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 13e3e289..c6692879 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -16,6 +16,7 @@ package accounts import ( "context" + "encoding/json" "errors" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -39,13 +40,13 @@ type Database interface { GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) - SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error - GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error) + SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error + GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) // GetAccountDataByType returns account data matching a given // localpart, room ID and type. // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval - GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error) + GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) GetNewNumericLocalpart(ctx context.Context) (int64, error) SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go index 2f16c5c0..90c79e87 100644 --- a/userapi/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/accounts/postgres/account_data_table.go @@ -17,9 +17,9 @@ package postgres import ( "context" "database/sql" + "encoding/json" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -73,7 +73,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { stmt := txn.Stmt(s.insertAccountDataStmt) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) @@ -83,18 +83,18 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountData( ctx context.Context, localpart string, ) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, - err error, + /* global */ map[string]json.RawMessage, + /* rooms */ map[string]map[string]json.RawMessage, + error, ) { rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) if err != nil { - return + return nil, nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed") - global = []gomatrixserverlib.ClientEvent{} - rooms = make(map[string][]gomatrixserverlib.ClientEvent) + global := map[string]json.RawMessage{} + rooms := map[string]map[string]json.RawMessage{} for rows.Next() { var roomID string @@ -102,41 +102,33 @@ func (s *accountDataStatements) selectAccountData( var content []byte if err = rows.Scan(&roomID, &dataType, &content); err != nil { - return + return nil, nil, err } - ac := gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - - if len(roomID) > 0 { - rooms[roomID] = append(rooms[roomID], ac) + if roomID != "" { + if _, ok := rooms[roomID]; !ok { + rooms[roomID] = map[string]json.RawMessage{} + } + rooms[roomID][dataType] = content } else { - global = append(global, ac) + global[dataType] = content } } + return global, rooms, rows.Err() } func (s *accountDataStatements) selectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { + var bytes []byte stmt := s.selectAccountDataByTypeStmt - var content []byte - - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } - return } - - data = &gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - + data = json.RawMessage(bytes) return } diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 2b88cb70..e5509980 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "encoding/json" "errors" "strconv" @@ -169,7 +170,7 @@ func (d *Database) createAccount( return nil, err } - if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -177,7 +178,7 @@ func (d *Database) createAccount( "sender": [], "underride": [] } - }`); err != nil { + }`)); err != nil { return nil, err } return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) @@ -295,7 +296,7 @@ func (d *Database) newMembership( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType, content string, + ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) @@ -306,8 +307,8 @@ func (d *Database) SaveAccountData( // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, err error, ) { return d.accountDatas.selectAccountData(ctx, localpart) @@ -319,7 +320,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { return d.accountDatas.selectAccountDataByType( ctx, localpart, roomID, dataType, ) diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index b6bb6361..d048dbd1 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -17,8 +17,7 @@ package sqlite3 import ( "context" "database/sql" - - "github.com/matrix-org/gomatrixserverlib" + "encoding/json" ) const accountDataSchema = ` @@ -72,7 +71,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) return @@ -81,17 +80,17 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountData( ctx context.Context, localpart string, ) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, - err error, + /* global */ map[string]json.RawMessage, + /* rooms */ map[string]map[string]json.RawMessage, + error, ) { rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) if err != nil { - return + return nil, nil, err } - global = []gomatrixserverlib.ClientEvent{} - rooms = make(map[string][]gomatrixserverlib.ClientEvent) + global := map[string]json.RawMessage{} + rooms := map[string]map[string]json.RawMessage{} for rows.Next() { var roomID string @@ -99,42 +98,33 @@ func (s *accountDataStatements) selectAccountData( var content []byte if err = rows.Scan(&roomID, &dataType, &content); err != nil { - return + return nil, nil, err } - ac := gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - - if len(roomID) > 0 { - rooms[roomID] = append(rooms[roomID], ac) + if roomID != "" { + if _, ok := rooms[roomID]; !ok { + rooms[roomID] = map[string]json.RawMessage{} + } + rooms[roomID][dataType] = content } else { - global = append(global, ac) + global[dataType] = content } } - return + return global, rooms, nil } func (s *accountDataStatements) selectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { + var bytes []byte stmt := s.selectAccountDataByTypeStmt - var content []byte - - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } - return } - - data = &gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - + data = json.RawMessage(bytes) return } diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 4dd755a7..dbf6606c 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "encoding/json" "errors" "strconv" "sync" @@ -180,7 +181,7 @@ func (d *Database) createAccount( return nil, err } - if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -188,7 +189,7 @@ func (d *Database) createAccount( "sender": [], "underride": [] } - }`); err != nil { + }`)); err != nil { return nil, err } return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) @@ -306,7 +307,7 @@ func (d *Database) newMembership( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType, content string, + ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) @@ -317,8 +318,8 @@ func (d *Database) SaveAccountData( // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, err error, ) { return d.accountDatas.selectAccountData(ctx, localpart) @@ -330,7 +331,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { return d.accountDatas.selectAccountDataByType( ctx, localpart, roomID, dataType, )