mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-09 22:42:58 +00:00
Refactor account data (#1150)
* Refactor account data * Tweak database fetching * Tweaks * Restore syncProducer notification * Various tweaks, update tag behaviour * Fix initial sync
This commit is contained in:
parent
3547a1768c
commit
dc0bac85d5
@ -16,21 +16,20 @@ package routing
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
|
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type}
|
// GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type}
|
||||||
func GetAccountData(
|
func GetAccountData(
|
||||||
req *http.Request, accountDB accounts.Database, device *api.Device,
|
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
|
||||||
userID string, roomID string, dataType string,
|
userID string, roomID string, dataType string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if userID != device.UserID {
|
if userID != device.UserID {
|
||||||
@ -40,18 +39,28 @@ func GetAccountData(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
dataReq := api.QueryAccountDataRequest{
|
||||||
if err != nil {
|
UserID: userID,
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
DataType: dataType,
|
||||||
return jsonerror.InternalServerError()
|
RoomID: roomID,
|
||||||
|
}
|
||||||
|
dataRes := api.QueryAccountDataResponse{}
|
||||||
|
if err := userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes); err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed")
|
||||||
|
return util.ErrorResponse(fmt.Errorf("userAPI.QueryAccountData: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if data, err := accountDB.GetAccountDataByType(
|
var data json.RawMessage
|
||||||
req.Context(), localpart, roomID, dataType,
|
var ok bool
|
||||||
); err == nil {
|
if roomID != "" {
|
||||||
|
data, ok = dataRes.RoomAccountData[roomID][dataType]
|
||||||
|
} else {
|
||||||
|
data, ok = dataRes.GlobalAccountData[dataType]
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: data.Content,
|
JSON: data,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -63,7 +72,7 @@ func GetAccountData(
|
|||||||
|
|
||||||
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
|
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
|
||||||
func SaveAccountData(
|
func SaveAccountData(
|
||||||
req *http.Request, accountDB accounts.Database, device *api.Device,
|
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
|
||||||
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
|
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if userID != device.UserID {
|
if userID != device.UserID {
|
||||||
@ -73,12 +82,6 @@ func SaveAccountData(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
|
||||||
|
|
||||||
defer req.Body.Close() // nolint: errcheck
|
defer req.Body.Close() // nolint: errcheck
|
||||||
|
|
||||||
if req.Body == http.NoBody {
|
if req.Body == http.NoBody {
|
||||||
@ -101,13 +104,19 @@ func SaveAccountData(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := accountDB.SaveAccountData(
|
dataReq := api.InputAccountDataRequest{
|
||||||
req.Context(), localpart, roomID, dataType, string(body),
|
UserID: userID,
|
||||||
); err != nil {
|
DataType: dataType,
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("accountDB.SaveAccountData failed")
|
RoomID: roomID,
|
||||||
return jsonerror.InternalServerError()
|
AccountData: json.RawMessage(body),
|
||||||
|
}
|
||||||
|
dataRes := api.InputAccountDataResponse{}
|
||||||
|
if err := userAPI.InputAccountData(req.Context(), &dataReq, &dataRes); err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed")
|
||||||
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: user API should do this since it's account data
|
||||||
if err := syncProducer.SendData(userID, roomID, dataType); err != nil {
|
if err := syncProducer.SendData(userID, roomID, dataType); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
|
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -24,23 +24,14 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/clientapi/producers"
|
"github.com/matrix-org/dendrite/clientapi/producers"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
|
||||||
"github.com/matrix-org/gomatrix"
|
"github.com/matrix-org/gomatrix"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// newTag creates and returns a new gomatrix.TagContent
|
|
||||||
func newTag() gomatrix.TagContent {
|
|
||||||
return gomatrix.TagContent{
|
|
||||||
Tags: make(map[string]gomatrix.TagProperties),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
|
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
|
||||||
func GetTags(
|
func GetTags(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
accountDB accounts.Database,
|
userAPI api.UserInternalAPI,
|
||||||
device *api.Device,
|
device *api.Device,
|
||||||
userID string,
|
userID string,
|
||||||
roomID string,
|
roomID string,
|
||||||
@ -54,22 +45,15 @@ func GetTags(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, data, err := obtainSavedTags(req, userID, roomID, accountDB)
|
tagContent, err := obtainSavedTags(req, userID, roomID, userAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
|
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if data == nil {
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: struct{}{},
|
JSON: tagContent,
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusOK,
|
|
||||||
JSON: data.Content,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,7 +62,7 @@ func GetTags(
|
|||||||
// the tag to the "map" and saving the new "map" to the DB
|
// the tag to the "map" and saving the new "map" to the DB
|
||||||
func PutTag(
|
func PutTag(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
accountDB accounts.Database,
|
userAPI api.UserInternalAPI,
|
||||||
device *api.Device,
|
device *api.Device,
|
||||||
userID string,
|
userID string,
|
||||||
roomID string,
|
roomID string,
|
||||||
@ -98,34 +82,25 @@ func PutTag(
|
|||||||
return *reqErr
|
return *reqErr
|
||||||
}
|
}
|
||||||
|
|
||||||
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB)
|
tagContent, err := obtainSavedTags(req, userID, roomID, userAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
|
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
var tagContent gomatrix.TagContent
|
if tagContent.Tags == nil {
|
||||||
if data != nil {
|
tagContent.Tags = make(map[string]gomatrix.TagProperties)
|
||||||
if err = json.Unmarshal(data.Content, &tagContent); err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tagContent = newTag()
|
|
||||||
}
|
}
|
||||||
tagContent.Tags[tag] = properties
|
tagContent.Tags[tag] = properties
|
||||||
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
|
|
||||||
|
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
|
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send data to syncProducer in order to inform clients of changes
|
if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
|
||||||
// Run in a goroutine in order to prevent blocking the tag request response
|
|
||||||
go func() {
|
|
||||||
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
|
|
||||||
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
@ -138,7 +113,7 @@ func PutTag(
|
|||||||
// the "map" and then saving the new "map" in the DB
|
// the "map" and then saving the new "map" in the DB
|
||||||
func DeleteTag(
|
func DeleteTag(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
accountDB accounts.Database,
|
userAPI api.UserInternalAPI,
|
||||||
device *api.Device,
|
device *api.Device,
|
||||||
userID string,
|
userID string,
|
||||||
roomID string,
|
roomID string,
|
||||||
@ -153,28 +128,12 @@ func DeleteTag(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB)
|
tagContent, err := obtainSavedTags(req, userID, roomID, userAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
|
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are no tags in the database, exit
|
|
||||||
if data == nil {
|
|
||||||
// Spec only defines 200 responses for this endpoint so we don't return anything else.
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusOK,
|
|
||||||
JSON: struct{}{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var tagContent gomatrix.TagContent
|
|
||||||
err = json.Unmarshal(data.Content, &tagContent)
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check whether the tag to be deleted exists
|
// Check whether the tag to be deleted exists
|
||||||
if _, ok := tagContent.Tags[tag]; ok {
|
if _, ok := tagContent.Tags[tag]; ok {
|
||||||
delete(tagContent.Tags, tag)
|
delete(tagContent.Tags, tag)
|
||||||
@ -185,18 +144,16 @@ func DeleteTag(
|
|||||||
JSON: struct{}{},
|
JSON: struct{}{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
|
|
||||||
|
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
|
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send data to syncProducer in order to inform clients of changes
|
// TODO: user API should do this since it's account data
|
||||||
// Run in a goroutine in order to prevent blocking the tag request response
|
|
||||||
go func() {
|
|
||||||
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
|
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
|
||||||
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
@ -210,32 +167,46 @@ func obtainSavedTags(
|
|||||||
req *http.Request,
|
req *http.Request,
|
||||||
userID string,
|
userID string,
|
||||||
roomID string,
|
roomID string,
|
||||||
accountDB accounts.Database,
|
userAPI api.UserInternalAPI,
|
||||||
) (string, *gomatrixserverlib.ClientEvent, error) {
|
) (tags gomatrix.TagContent, err error) {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
dataReq := api.QueryAccountDataRequest{
|
||||||
if err != nil {
|
UserID: userID,
|
||||||
return "", nil, err
|
RoomID: roomID,
|
||||||
|
DataType: "m.tag",
|
||||||
}
|
}
|
||||||
|
dataRes := api.QueryAccountDataResponse{}
|
||||||
data, err := accountDB.GetAccountDataByType(
|
err = userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes)
|
||||||
req.Context(), localpart, roomID, "m.tag",
|
if err != nil {
|
||||||
)
|
return
|
||||||
|
}
|
||||||
return localpart, data, err
|
data, ok := dataRes.RoomAccountData[roomID]["m.tag"]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(data, &tags); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return tags, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// saveTagData saves the provided tag data into the database
|
// saveTagData saves the provided tag data into the database
|
||||||
func saveTagData(
|
func saveTagData(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
localpart string,
|
userID string,
|
||||||
roomID string,
|
roomID string,
|
||||||
accountDB accounts.Database,
|
userAPI api.UserInternalAPI,
|
||||||
Tag gomatrix.TagContent,
|
Tag gomatrix.TagContent,
|
||||||
) error {
|
) error {
|
||||||
newTagData, err := json.Marshal(Tag)
|
newTagData, err := json.Marshal(Tag)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
dataReq := api.InputAccountDataRequest{
|
||||||
return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", string(newTagData))
|
UserID: userID,
|
||||||
|
RoomID: roomID,
|
||||||
|
DataType: "m.tag",
|
||||||
|
AccountData: json.RawMessage(newTagData),
|
||||||
|
}
|
||||||
|
dataRes := api.InputAccountDataResponse{}
|
||||||
|
return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes)
|
||||||
}
|
}
|
||||||
|
@ -476,7 +476,7 @@ func Setup(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"], syncProducer)
|
return SaveAccountData(req, userAPI, device, vars["userID"], "", vars["type"], syncProducer)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
@ -486,7 +486,7 @@ func Setup(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"], syncProducer)
|
return SaveAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"], syncProducer)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
@ -496,7 +496,7 @@ func Setup(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return GetAccountData(req, accountDB, device, vars["userID"], "", vars["type"])
|
return GetAccountData(req, userAPI, device, vars["userID"], "", vars["type"])
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet)
|
).Methods(http.MethodGet)
|
||||||
|
|
||||||
@ -506,7 +506,7 @@ func Setup(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return GetAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"])
|
return GetAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"])
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet)
|
).Methods(http.MethodGet)
|
||||||
|
|
||||||
@ -604,7 +604,7 @@ func Setup(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return GetTags(req, accountDB, device, vars["userId"], vars["roomId"], syncProducer)
|
return GetTags(req, userAPI, device, vars["userId"], vars["roomId"], syncProducer)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
@ -614,7 +614,7 @@ func Setup(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return PutTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
|
return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
@ -624,7 +624,7 @@ func Setup(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
|
return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodDelete, http.MethodOptions)
|
).Methods(http.MethodDelete, http.MethodOptions)
|
||||||
|
|
||||||
|
@ -205,22 +205,34 @@ func (rp *RequestPool) appendAccountData(
|
|||||||
if req.since == nil {
|
if req.since == nil {
|
||||||
// If this is the initial sync, we don't need to check if a data has
|
// If this is the initial sync, we don't need to check if a data has
|
||||||
// already been sent. Instead, we send the whole batch.
|
// already been sent. Instead, we send the whole batch.
|
||||||
var res userapi.QueryAccountDataResponse
|
dataReq := &userapi.QueryAccountDataRequest{
|
||||||
err := rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{
|
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
}, &res)
|
}
|
||||||
if err != nil {
|
dataRes := &userapi.QueryAccountDataResponse{}
|
||||||
|
if err := rp.userAPI.QueryAccountData(req.ctx, dataReq, dataRes); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
data.AccountData.Events = res.GlobalAccountData
|
for datatype, databody := range dataRes.GlobalAccountData {
|
||||||
|
data.AccountData.Events = append(
|
||||||
|
data.AccountData.Events,
|
||||||
|
gomatrixserverlib.ClientEvent{
|
||||||
|
Type: datatype,
|
||||||
|
Content: gomatrixserverlib.RawJSON(databody),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
for r, j := range data.Rooms.Join {
|
for r, j := range data.Rooms.Join {
|
||||||
if len(res.RoomAccountData[r]) > 0 {
|
for datatype, databody := range dataRes.RoomAccountData[r] {
|
||||||
j.AccountData.Events = res.RoomAccountData[r]
|
j.AccountData.Events = append(
|
||||||
|
j.AccountData.Events,
|
||||||
|
gomatrixserverlib.ClientEvent{
|
||||||
|
Type: datatype,
|
||||||
|
Content: gomatrixserverlib.RawJSON(databody),
|
||||||
|
},
|
||||||
|
)
|
||||||
data.Rooms.Join[r] = j
|
data.Rooms.Join[r] = j
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -249,32 +261,41 @@ func (rp *RequestPool) appendAccountData(
|
|||||||
|
|
||||||
// Iterate over the rooms
|
// Iterate over the rooms
|
||||||
for roomID, dataTypes := range dataTypes {
|
for roomID, dataTypes := range dataTypes {
|
||||||
events := []gomatrixserverlib.ClientEvent{}
|
|
||||||
// Request the missing data from the database
|
// Request the missing data from the database
|
||||||
for _, dataType := range dataTypes {
|
for _, dataType := range dataTypes {
|
||||||
var res userapi.QueryAccountDataResponse
|
dataReq := userapi.QueryAccountDataRequest{
|
||||||
err = rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{
|
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
DataType: dataType,
|
DataType: dataType,
|
||||||
}, &res)
|
}
|
||||||
|
dataRes := userapi.QueryAccountDataResponse{}
|
||||||
|
err = rp.userAPI.QueryAccountData(req.ctx, &dataReq, &dataRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
continue
|
||||||
}
|
}
|
||||||
if len(res.RoomAccountData[roomID]) > 0 {
|
if roomID == "" {
|
||||||
events = append(events, res.RoomAccountData[roomID]...)
|
if globalData, ok := dataRes.GlobalAccountData[dataType]; ok {
|
||||||
} else if len(res.GlobalAccountData) > 0 {
|
data.AccountData.Events = append(
|
||||||
events = append(events, res.GlobalAccountData...)
|
data.AccountData.Events,
|
||||||
|
gomatrixserverlib.ClientEvent{
|
||||||
|
Type: dataType,
|
||||||
|
Content: gomatrixserverlib.RawJSON(globalData),
|
||||||
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Append the data to the response
|
|
||||||
if len(roomID) > 0 {
|
|
||||||
jr := data.Rooms.Join[roomID]
|
|
||||||
jr.AccountData.Events = events
|
|
||||||
data.Rooms.Join[roomID] = jr
|
|
||||||
} else {
|
} else {
|
||||||
data.AccountData.Events = events
|
if roomData, ok := dataRes.RoomAccountData[roomID][dataType]; ok {
|
||||||
|
joinData := data.Rooms.Join[roomID]
|
||||||
|
joinData.AccountData.Events = append(
|
||||||
|
joinData.AccountData.Events,
|
||||||
|
gomatrixserverlib.ClientEvent{
|
||||||
|
Type: dataType,
|
||||||
|
Content: gomatrixserverlib.RawJSON(roomData),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
data.Rooms.Join[roomID] = joinData
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,12 +16,14 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UserInternalAPI is the internal API for information about users and devices.
|
// UserInternalAPI is the internal API for information about users and devices.
|
||||||
type UserInternalAPI interface {
|
type UserInternalAPI interface {
|
||||||
|
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
||||||
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
||||||
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
||||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||||
@ -30,6 +32,18 @@ type UserInternalAPI interface {
|
|||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InputAccountDataRequest is the request for InputAccountData
|
||||||
|
type InputAccountDataRequest struct {
|
||||||
|
UserID string // required: the user to set account data for
|
||||||
|
RoomID string // optional: the room to associate the account data with
|
||||||
|
DataType string // optional: the data type of the data
|
||||||
|
AccountData json.RawMessage // required: the message content
|
||||||
|
}
|
||||||
|
|
||||||
|
// InputAccountDataResponse is the response for InputAccountData
|
||||||
|
type InputAccountDataResponse struct {
|
||||||
|
}
|
||||||
|
|
||||||
// QueryAccessTokenRequest is the request for QueryAccessToken
|
// QueryAccessTokenRequest is the request for QueryAccessToken
|
||||||
type QueryAccessTokenRequest struct {
|
type QueryAccessTokenRequest struct {
|
||||||
AccessToken string
|
AccessToken string
|
||||||
@ -47,17 +61,14 @@ type QueryAccessTokenResponse struct {
|
|||||||
// QueryAccountDataRequest is the request for QueryAccountData
|
// QueryAccountDataRequest is the request for QueryAccountData
|
||||||
type QueryAccountDataRequest struct {
|
type QueryAccountDataRequest struct {
|
||||||
UserID string // required: the user to get account data for.
|
UserID string // required: the user to get account data for.
|
||||||
// TODO: This is a terribly confusing API shape :/
|
RoomID string // optional: the room ID, or global account data if not specified.
|
||||||
DataType string // optional: if specified returns only a single event matching this data type.
|
DataType string // optional: the data type, or all types if not specified.
|
||||||
// optional: Only used if DataType is set. If blank returns global account data matching the data type.
|
|
||||||
// If set, returns only room account data matching this data type.
|
|
||||||
RoomID string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryAccountDataResponse is the response for QueryAccountData
|
// QueryAccountDataResponse is the response for QueryAccountData
|
||||||
type QueryAccountDataResponse struct {
|
type QueryAccountDataResponse struct {
|
||||||
GlobalAccountData []gomatrixserverlib.ClientEvent
|
GlobalAccountData map[string]json.RawMessage // type -> data
|
||||||
RoomAccountData map[string][]gomatrixserverlib.ClientEvent
|
RoomAccountData map[string]map[string]json.RawMessage // room -> type -> data
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryDevicesRequest is the request for QueryDevices
|
// QueryDevicesRequest is the request for QueryDevices
|
||||||
|
@ -17,6 +17,7 @@ package internal
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
@ -38,6 +39,20 @@ type UserInternalAPI struct {
|
|||||||
AppServices []config.ApplicationService
|
AppServices []config.ApplicationService
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
||||||
|
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if domain != a.ServerName {
|
||||||
|
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
|
||||||
|
}
|
||||||
|
if req.DataType == "" {
|
||||||
|
return fmt.Errorf("data type must not be empty")
|
||||||
|
}
|
||||||
|
return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
|
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
|
||||||
if req.AccountType == api.AccountTypeGuest {
|
if req.AccountType == api.AccountTypeGuest {
|
||||||
acc, err := a.AccountDB.CreateGuestAccount(ctx)
|
acc, err := a.AccountDB.CreateGuestAccount(ctx)
|
||||||
@ -130,17 +145,21 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
|||||||
return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName)
|
return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName)
|
||||||
}
|
}
|
||||||
if req.DataType != "" {
|
if req.DataType != "" {
|
||||||
var event *gomatrixserverlib.ClientEvent
|
var data json.RawMessage
|
||||||
event, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if event != nil {
|
res.RoomAccountData = make(map[string]map[string]json.RawMessage)
|
||||||
|
res.GlobalAccountData = make(map[string]json.RawMessage)
|
||||||
|
if data != nil {
|
||||||
if req.RoomID != "" {
|
if req.RoomID != "" {
|
||||||
res.RoomAccountData = make(map[string][]gomatrixserverlib.ClientEvent)
|
if _, ok := res.RoomAccountData[req.RoomID]; !ok {
|
||||||
res.RoomAccountData[req.RoomID] = []gomatrixserverlib.ClientEvent{*event}
|
res.RoomAccountData[req.RoomID] = make(map[string]json.RawMessage)
|
||||||
|
}
|
||||||
|
res.RoomAccountData[req.RoomID][req.DataType] = data
|
||||||
} else {
|
} else {
|
||||||
res.GlobalAccountData = append(res.GlobalAccountData, *event)
|
res.GlobalAccountData[req.DataType] = data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -26,6 +26,8 @@ import (
|
|||||||
|
|
||||||
// HTTP paths for the internal HTTP APIs
|
// HTTP paths for the internal HTTP APIs
|
||||||
const (
|
const (
|
||||||
|
InputAccountDataPath = "/userapi/inputAccountData"
|
||||||
|
|
||||||
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
|
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
|
||||||
PerformAccountCreationPath = "/userapi/performAccountCreation"
|
PerformAccountCreationPath = "/userapi/performAccountCreation"
|
||||||
|
|
||||||
@ -55,6 +57,14 @@ type httpUserInternalAPI struct {
|
|||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + InputAccountDataPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpUserInternalAPI) PerformAccountCreation(
|
func (h *httpUserInternalAPI) PerformAccountCreation(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.PerformAccountCreationRequest,
|
request *api.PerformAccountCreationRequest,
|
||||||
|
@ -16,6 +16,7 @@ package accounts
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
@ -39,13 +40,13 @@ type Database interface {
|
|||||||
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
|
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
|
||||||
GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error)
|
GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error)
|
||||||
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
|
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
|
||||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
|
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
||||||
GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)
|
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||||
// GetAccountDataByType returns account data matching a given
|
// GetAccountDataByType returns account data matching a given
|
||||||
// localpart, room ID and type.
|
// localpart, room ID and type.
|
||||||
// If no account data could be found, returns nil
|
// If no account data could be found, returns nil
|
||||||
// Returns an error if there was an issue with the retrieval
|
// Returns an error if there was an issue with the retrieval
|
||||||
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error)
|
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||||
|
@ -17,9 +17,9 @@ package postgres
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const accountDataSchema = `
|
const accountDataSchema = `
|
||||||
@ -73,7 +73,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) insertAccountData(
|
func (s *accountDataStatements) insertAccountData(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
|
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := txn.Stmt(s.insertAccountDataStmt)
|
stmt := txn.Stmt(s.insertAccountDataStmt)
|
||||||
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
|
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
|
||||||
@ -83,18 +83,18 @@ func (s *accountDataStatements) insertAccountData(
|
|||||||
func (s *accountDataStatements) selectAccountData(
|
func (s *accountDataStatements) selectAccountData(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (
|
) (
|
||||||
global []gomatrixserverlib.ClientEvent,
|
/* global */ map[string]json.RawMessage,
|
||||||
rooms map[string][]gomatrixserverlib.ClientEvent,
|
/* rooms */ map[string]map[string]json.RawMessage,
|
||||||
err error,
|
error,
|
||||||
) {
|
) {
|
||||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed")
|
||||||
|
|
||||||
global = []gomatrixserverlib.ClientEvent{}
|
global := map[string]json.RawMessage{}
|
||||||
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
|
rooms := map[string]map[string]json.RawMessage{}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var roomID string
|
var roomID string
|
||||||
@ -102,41 +102,33 @@ func (s *accountDataStatements) selectAccountData(
|
|||||||
var content []byte
|
var content []byte
|
||||||
|
|
||||||
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
|
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
|
||||||
return
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ac := gomatrixserverlib.ClientEvent{
|
if roomID != "" {
|
||||||
Type: dataType,
|
if _, ok := rooms[roomID]; !ok {
|
||||||
Content: content,
|
rooms[roomID] = map[string]json.RawMessage{}
|
||||||
}
|
}
|
||||||
|
rooms[roomID][dataType] = content
|
||||||
if len(roomID) > 0 {
|
|
||||||
rooms[roomID] = append(rooms[roomID], ac)
|
|
||||||
} else {
|
} else {
|
||||||
global = append(global, ac)
|
global[dataType] = content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return global, rooms, rows.Err()
|
return global, rooms, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) selectAccountDataByType(
|
func (s *accountDataStatements) selectAccountDataByType(
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
) (data *gomatrixserverlib.ClientEvent, err error) {
|
) (data json.RawMessage, err error) {
|
||||||
|
var bytes []byte
|
||||||
stmt := s.selectAccountDataByTypeStmt
|
stmt := s.selectAccountDataByTypeStmt
|
||||||
var content []byte
|
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||||
|
|
||||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
|
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
data = json.RawMessage(bytes)
|
||||||
data = &gomatrixserverlib.ClientEvent{
|
|
||||||
Type: dataType,
|
|
||||||
Content: content,
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,7 @@ package postgres
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@ -169,7 +170,7 @@ func (d *Database) createAccount(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
|
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
|
||||||
"global": {
|
"global": {
|
||||||
"content": [],
|
"content": [],
|
||||||
"override": [],
|
"override": [],
|
||||||
@ -177,7 +178,7 @@ func (d *Database) createAccount(
|
|||||||
"sender": [],
|
"sender": [],
|
||||||
"underride": []
|
"underride": []
|
||||||
}
|
}
|
||||||
}`); err != nil {
|
}`)); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
|
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
|
||||||
@ -295,7 +296,7 @@ func (d *Database) newMembership(
|
|||||||
// update the corresponding row with the new content
|
// update the corresponding row with the new content
|
||||||
// Returns a SQL error if there was an issue with the insertion/update
|
// Returns a SQL error if there was an issue with the insertion/update
|
||||||
func (d *Database) SaveAccountData(
|
func (d *Database) SaveAccountData(
|
||||||
ctx context.Context, localpart, roomID, dataType, content string,
|
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||||
) error {
|
) error {
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||||
@ -306,8 +307,8 @@ func (d *Database) SaveAccountData(
|
|||||||
// If no account data could be found, returns an empty arrays
|
// If no account data could be found, returns an empty arrays
|
||||||
// Returns an error if there was an issue with the retrieval
|
// Returns an error if there was an issue with the retrieval
|
||||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||||
global []gomatrixserverlib.ClientEvent,
|
global map[string]json.RawMessage,
|
||||||
rooms map[string][]gomatrixserverlib.ClientEvent,
|
rooms map[string]map[string]json.RawMessage,
|
||||||
err error,
|
err error,
|
||||||
) {
|
) {
|
||||||
return d.accountDatas.selectAccountData(ctx, localpart)
|
return d.accountDatas.selectAccountData(ctx, localpart)
|
||||||
@ -319,7 +320,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
|||||||
// Returns an error if there was an issue with the retrieval
|
// Returns an error if there was an issue with the retrieval
|
||||||
func (d *Database) GetAccountDataByType(
|
func (d *Database) GetAccountDataByType(
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
) (data *gomatrixserverlib.ClientEvent, err error) {
|
) (data json.RawMessage, err error) {
|
||||||
return d.accountDatas.selectAccountDataByType(
|
return d.accountDatas.selectAccountDataByType(
|
||||||
ctx, localpart, roomID, dataType,
|
ctx, localpart, roomID, dataType,
|
||||||
)
|
)
|
||||||
|
@ -17,8 +17,7 @@ package sqlite3
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const accountDataSchema = `
|
const accountDataSchema = `
|
||||||
@ -72,7 +71,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) insertAccountData(
|
func (s *accountDataStatements) insertAccountData(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
|
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||||
return
|
return
|
||||||
@ -81,17 +80,17 @@ func (s *accountDataStatements) insertAccountData(
|
|||||||
func (s *accountDataStatements) selectAccountData(
|
func (s *accountDataStatements) selectAccountData(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (
|
) (
|
||||||
global []gomatrixserverlib.ClientEvent,
|
/* global */ map[string]json.RawMessage,
|
||||||
rooms map[string][]gomatrixserverlib.ClientEvent,
|
/* rooms */ map[string]map[string]json.RawMessage,
|
||||||
err error,
|
error,
|
||||||
) {
|
) {
|
||||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
global = []gomatrixserverlib.ClientEvent{}
|
global := map[string]json.RawMessage{}
|
||||||
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
|
rooms := map[string]map[string]json.RawMessage{}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var roomID string
|
var roomID string
|
||||||
@ -99,42 +98,33 @@ func (s *accountDataStatements) selectAccountData(
|
|||||||
var content []byte
|
var content []byte
|
||||||
|
|
||||||
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
|
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
|
||||||
return
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ac := gomatrixserverlib.ClientEvent{
|
if roomID != "" {
|
||||||
Type: dataType,
|
if _, ok := rooms[roomID]; !ok {
|
||||||
Content: content,
|
rooms[roomID] = map[string]json.RawMessage{}
|
||||||
}
|
}
|
||||||
|
rooms[roomID][dataType] = content
|
||||||
if len(roomID) > 0 {
|
|
||||||
rooms[roomID] = append(rooms[roomID], ac)
|
|
||||||
} else {
|
} else {
|
||||||
global = append(global, ac)
|
global[dataType] = content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return global, rooms, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) selectAccountDataByType(
|
func (s *accountDataStatements) selectAccountDataByType(
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
) (data *gomatrixserverlib.ClientEvent, err error) {
|
) (data json.RawMessage, err error) {
|
||||||
|
var bytes []byte
|
||||||
stmt := s.selectAccountDataByTypeStmt
|
stmt := s.selectAccountDataByTypeStmt
|
||||||
var content []byte
|
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||||
|
|
||||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
|
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
data = json.RawMessage(bytes)
|
||||||
data = &gomatrixserverlib.ClientEvent{
|
|
||||||
Type: dataType,
|
|
||||||
Content: content,
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,7 @@ package sqlite3
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
@ -180,7 +181,7 @@ func (d *Database) createAccount(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
|
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
|
||||||
"global": {
|
"global": {
|
||||||
"content": [],
|
"content": [],
|
||||||
"override": [],
|
"override": [],
|
||||||
@ -188,7 +189,7 @@ func (d *Database) createAccount(
|
|||||||
"sender": [],
|
"sender": [],
|
||||||
"underride": []
|
"underride": []
|
||||||
}
|
}
|
||||||
}`); err != nil {
|
}`)); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
|
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
|
||||||
@ -306,7 +307,7 @@ func (d *Database) newMembership(
|
|||||||
// update the corresponding row with the new content
|
// update the corresponding row with the new content
|
||||||
// Returns a SQL error if there was an issue with the insertion/update
|
// Returns a SQL error if there was an issue with the insertion/update
|
||||||
func (d *Database) SaveAccountData(
|
func (d *Database) SaveAccountData(
|
||||||
ctx context.Context, localpart, roomID, dataType, content string,
|
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||||
) error {
|
) error {
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||||
@ -317,8 +318,8 @@ func (d *Database) SaveAccountData(
|
|||||||
// If no account data could be found, returns an empty arrays
|
// If no account data could be found, returns an empty arrays
|
||||||
// Returns an error if there was an issue with the retrieval
|
// Returns an error if there was an issue with the retrieval
|
||||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||||
global []gomatrixserverlib.ClientEvent,
|
global map[string]json.RawMessage,
|
||||||
rooms map[string][]gomatrixserverlib.ClientEvent,
|
rooms map[string]map[string]json.RawMessage,
|
||||||
err error,
|
err error,
|
||||||
) {
|
) {
|
||||||
return d.accountDatas.selectAccountData(ctx, localpart)
|
return d.accountDatas.selectAccountData(ctx, localpart)
|
||||||
@ -330,7 +331,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
|||||||
// Returns an error if there was an issue with the retrieval
|
// Returns an error if there was an issue with the retrieval
|
||||||
func (d *Database) GetAccountDataByType(
|
func (d *Database) GetAccountDataByType(
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
) (data *gomatrixserverlib.ClientEvent, err error) {
|
) (data json.RawMessage, err error) {
|
||||||
return d.accountDatas.selectAccountDataByType(
|
return d.accountDatas.selectAccountDataByType(
|
||||||
ctx, localpart, roomID, dataType,
|
ctx, localpart, roomID, dataType,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user