mirror of
https://github.com/1f349/dendrite.git
synced 2025-01-21 23:06:32 +00:00
Virtual hosting schema and logic changes (#2876)
Note that virtual users cannot federate correctly yet.
This commit is contained in:
parent
e177e0ae73
commit
529df30b56
@ -32,6 +32,7 @@ import (
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// AddInternalRoutes registers HTTP handlers for internal API calls
|
||||
@ -74,7 +75,7 @@ func NewInternalAPI(
|
||||
// events to be sent out.
|
||||
for _, appservice := range base.Cfg.Derived.ApplicationServices {
|
||||
// Create bot account for this AS if it doesn't already exist
|
||||
if err := generateAppServiceAccount(userAPI, appservice); err != nil {
|
||||
if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"appservice": appservice.ID,
|
||||
}).WithError(err).Panicf("failed to generate bot account for appservice")
|
||||
@ -101,11 +102,13 @@ func NewInternalAPI(
|
||||
func generateAppServiceAccount(
|
||||
userAPI userapi.AppserviceUserAPI,
|
||||
as config.ApplicationService,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
var accRes userapi.PerformAccountCreationResponse
|
||||
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
|
||||
AccountType: userapi.AccountTypeAppService,
|
||||
Localpart: as.SenderLocalpart,
|
||||
ServerName: serverName,
|
||||
AppServiceID: as.ID,
|
||||
OnConflict: userapi.ConflictUpdate,
|
||||
}, &accRes)
|
||||
@ -115,6 +118,7 @@ func generateAppServiceAccount(
|
||||
var devRes userapi.PerformDeviceCreationResponse
|
||||
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
|
||||
Localpart: as.SenderLocalpart,
|
||||
ServerName: serverName,
|
||||
AccessToken: as.ASToken,
|
||||
DeviceID: &as.SenderLocalpart,
|
||||
DeviceDisplayName: &as.SenderLocalpart,
|
||||
|
@ -61,7 +61,7 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte)
|
||||
|
||||
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
|
||||
r := req.(*PasswordRequest)
|
||||
username := strings.ToLower(r.Username())
|
||||
username := r.Username()
|
||||
if username == "" {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
@ -74,32 +74,43 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
|
||||
JSON: jsonerror.BadJSON("A password must be supplied."),
|
||||
}
|
||||
}
|
||||
localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
|
||||
localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
|
||||
if err != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: jsonerror.InvalidUsername(err.Error()),
|
||||
}
|
||||
}
|
||||
if !t.Config.Matrix.IsLocalServerName(domain) {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusUnauthorized,
|
||||
JSON: jsonerror.InvalidUsername("The server name is not known."),
|
||||
}
|
||||
}
|
||||
// Squash username to all lowercase letters
|
||||
res := &api.QueryAccountByPasswordResponse{}
|
||||
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res)
|
||||
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
|
||||
Localpart: strings.ToLower(localpart),
|
||||
ServerName: domain,
|
||||
PlaintextPassword: r.Password,
|
||||
}, res)
|
||||
if err != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: jsonerror.Unknown("unable to fetch account by password"),
|
||||
JSON: jsonerror.Unknown("Unable to fetch account by password."),
|
||||
}
|
||||
}
|
||||
|
||||
if !res.Exists {
|
||||
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
PlaintextPassword: r.Password,
|
||||
}, res)
|
||||
if err != nil {
|
||||
return nil, &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: jsonerror.Unknown("unable to fetch account by password"),
|
||||
JSON: jsonerror.Unknown("Unable to fetch account by password."),
|
||||
}
|
||||
}
|
||||
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
|
||||
|
@ -102,6 +102,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
serverName := cfg.Matrix.ServerName
|
||||
localpart, ok := vars["localpart"]
|
||||
if !ok {
|
||||
return util.JSONResponse{
|
||||
@ -109,6 +110,9 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
||||
JSON: jsonerror.MissingArgument("Expecting user localpart."),
|
||||
}
|
||||
}
|
||||
if l, s, err := gomatrixserverlib.SplitID('@', localpart); err == nil {
|
||||
localpart, serverName = l, s
|
||||
}
|
||||
request := struct {
|
||||
Password string `json:"password"`
|
||||
}{}
|
||||
@ -126,6 +130,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
||||
}
|
||||
updateReq := &userapi.PerformPasswordUpdateRequest{
|
||||
Localpart: localpart,
|
||||
ServerName: serverName,
|
||||
Password: request.Password,
|
||||
LogoutDevices: true,
|
||||
}
|
||||
|
@ -100,6 +100,7 @@ func completeAuth(
|
||||
DeviceID: login.DeviceID,
|
||||
AccessToken: token,
|
||||
Localpart: localpart,
|
||||
ServerName: serverName,
|
||||
IPAddr: ipAddr,
|
||||
UserAgent: userAgent,
|
||||
}, &performRes)
|
||||
|
@ -40,16 +40,17 @@ func GetNotifications(
|
||||
}
|
||||
|
||||
var queryRes userapi.QueryNotificationsResponse
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
|
||||
Localpart: localpart,
|
||||
From: req.URL.Query().Get("from"),
|
||||
Limit: int(limit),
|
||||
Only: req.URL.Query().Get("only"),
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
From: req.URL.Query().Get("from"),
|
||||
Limit: int(limit),
|
||||
Only: req.URL.Query().Get("only"),
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed")
|
||||
|
@ -86,7 +86,7 @@ func Password(
|
||||
}
|
||||
|
||||
// Get the local part.
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
@ -94,8 +94,9 @@ func Password(
|
||||
|
||||
// Ask the user API to perform the password change.
|
||||
passwordReq := &api.PerformPasswordUpdateRequest{
|
||||
Localpart: localpart,
|
||||
Password: r.NewPassword,
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
Password: r.NewPassword,
|
||||
}
|
||||
passwordRes := &api.PerformPasswordUpdateResponse{}
|
||||
if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil {
|
||||
@ -122,8 +123,9 @@ func Password(
|
||||
}
|
||||
|
||||
pushersReq := &api.PerformPusherDeletionRequest{
|
||||
Localpart: localpart,
|
||||
SessionID: device.SessionID,
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
SessionID: device.SessionID,
|
||||
}
|
||||
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
|
||||
|
@ -31,13 +31,14 @@ func GetPushers(
|
||||
userAPI userapi.ClientUserAPI,
|
||||
) util.JSONResponse {
|
||||
var queryRes userapi.QueryPushersResponse
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
|
||||
Localpart: localpart,
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
|
||||
@ -59,7 +60,7 @@ func SetPusher(
|
||||
req *http.Request, device *userapi.Device,
|
||||
userAPI userapi.ClientUserAPI,
|
||||
) util.JSONResponse {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
@ -93,6 +94,7 @@ func SetPusher(
|
||||
|
||||
}
|
||||
body.Localpart = localpart
|
||||
body.ServerName = domain
|
||||
body.SessionID = device.SessionID
|
||||
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
|
||||
if err != nil {
|
||||
|
@ -588,12 +588,15 @@ func Register(
|
||||
}
|
||||
// Auto generate a numeric username if r.Username is empty
|
||||
if r.Username == "" {
|
||||
res := &userapi.QueryNumericLocalpartResponse{}
|
||||
if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil {
|
||||
nreq := &userapi.QueryNumericLocalpartRequest{
|
||||
ServerName: cfg.Matrix.ServerName, // TODO: might not be right
|
||||
}
|
||||
nres := &userapi.QueryNumericLocalpartResponse{}
|
||||
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
r.Username = strconv.FormatInt(res.ID, 10)
|
||||
r.Username = strconv.FormatInt(nres.ID, 10)
|
||||
}
|
||||
|
||||
// Is this an appservice registration? It will be if the access
|
||||
@ -676,6 +679,7 @@ func handleGuestRegistration(
|
||||
var devRes userapi.PerformDeviceCreationResponse
|
||||
err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{
|
||||
Localpart: res.Account.Localpart,
|
||||
ServerName: res.Account.ServerName,
|
||||
DeviceDisplayName: r.InitialDisplayName,
|
||||
AccessToken: token,
|
||||
IPAddr: req.RemoteAddr,
|
||||
|
@ -157,7 +157,7 @@ func Setup(
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}",
|
||||
dendriteAdminRouter.Handle("/admin/resetPassword/{userID}",
|
||||
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
return AdminResetPassword(req, cfg, device, userAPI)
|
||||
}),
|
||||
|
@ -286,6 +286,7 @@ func getSenderDevice(
|
||||
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
|
||||
AccountType: userapi.AccountTypeUser,
|
||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
OnConflict: userapi.ConflictUpdate,
|
||||
}, &accRes)
|
||||
if err != nil {
|
||||
@ -295,8 +296,9 @@ func getSenderDevice(
|
||||
// Set the avatarurl for the user
|
||||
avatarRes := &userapi.PerformSetAvatarURLResponse{}
|
||||
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
|
||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
|
||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
|
||||
}, avatarRes); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
|
||||
return nil, err
|
||||
@ -308,6 +310,7 @@ func getSenderDevice(
|
||||
displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
|
||||
if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
|
||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
DisplayName: cfg.Matrix.ServerNotices.DisplayName,
|
||||
}, displayNameRes); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
|
||||
@ -353,6 +356,7 @@ func getSenderDevice(
|
||||
var devRes userapi.PerformDeviceCreationResponse
|
||||
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
|
||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart,
|
||||
AccessToken: token,
|
||||
NoDeviceListUpdate: true,
|
||||
|
@ -136,16 +136,17 @@ func CheckAndSave3PIDAssociation(
|
||||
}
|
||||
|
||||
// Save the association in the database
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{
|
||||
ThreePID: address,
|
||||
Localpart: localpart,
|
||||
Medium: medium,
|
||||
ThreePID: address,
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
Medium: medium,
|
||||
}, &struct{}{}); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed")
|
||||
return jsonerror.InternalServerError()
|
||||
@ -161,7 +162,7 @@ func CheckAndSave3PIDAssociation(
|
||||
func GetAssociated3PIDs(
|
||||
req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device,
|
||||
) util.JSONResponse {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
@ -169,7 +170,8 @@ func GetAssociated3PIDs(
|
||||
|
||||
res := &api.QueryThreePIDsForLocalpartResponse{}
|
||||
err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{
|
||||
Localpart: localpart,
|
||||
Localpart: localpart,
|
||||
ServerName: domain,
|
||||
}, res)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed")
|
||||
|
@ -120,15 +120,23 @@ func NewInternalAPI(
|
||||
|
||||
js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
||||
|
||||
signingInfo := map[gomatrixserverlib.ServerName]*queue.SigningInfo{}
|
||||
for _, serverName := range append(
|
||||
[]gomatrixserverlib.ServerName{base.Cfg.Global.ServerName},
|
||||
base.Cfg.Global.SecondaryServerNames...,
|
||||
) {
|
||||
signingInfo[serverName] = &queue.SigningInfo{
|
||||
KeyID: cfg.Matrix.KeyID,
|
||||
PrivateKey: cfg.Matrix.PrivateKey,
|
||||
ServerName: serverName,
|
||||
}
|
||||
}
|
||||
|
||||
queues := queue.NewOutgoingQueues(
|
||||
federationDB, base.ProcessContext,
|
||||
cfg.Matrix.DisableFederation,
|
||||
cfg.Matrix.ServerName, federation, rsAPI, &stats,
|
||||
&queue.SigningInfo{
|
||||
KeyID: cfg.Matrix.KeyID,
|
||||
PrivateKey: cfg.Matrix.PrivateKey,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
},
|
||||
signingInfo,
|
||||
)
|
||||
|
||||
rsConsumer := consumers.NewOutputRoomEventConsumer(
|
||||
|
@ -137,7 +137,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err
|
||||
}
|
||||
|
||||
// Get the keys and JSON-ify them.
|
||||
keys := routing.LocalKeys(s.config)
|
||||
keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host))
|
||||
body, err := json.MarshalIndent(keys.JSON, "", " ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -50,7 +50,7 @@ type destinationQueue struct {
|
||||
queues *OutgoingQueues
|
||||
db storage.Database
|
||||
process *process.ProcessContext
|
||||
signing *SigningInfo
|
||||
signing map[gomatrixserverlib.ServerName]*SigningInfo
|
||||
rsAPI api.FederationRoomserverAPI
|
||||
client fedapi.FederationClient // federation client
|
||||
origin gomatrixserverlib.ServerName // origin of requests
|
||||
|
@ -46,7 +46,7 @@ type OutgoingQueues struct {
|
||||
origin gomatrixserverlib.ServerName
|
||||
client fedapi.FederationClient
|
||||
statistics *statistics.Statistics
|
||||
signing *SigningInfo
|
||||
signing map[gomatrixserverlib.ServerName]*SigningInfo
|
||||
queuesMutex sync.Mutex // protects the below
|
||||
queues map[gomatrixserverlib.ServerName]*destinationQueue
|
||||
}
|
||||
@ -91,7 +91,7 @@ func NewOutgoingQueues(
|
||||
client fedapi.FederationClient,
|
||||
rsAPI api.FederationRoomserverAPI,
|
||||
statistics *statistics.Statistics,
|
||||
signing *SigningInfo,
|
||||
signing map[gomatrixserverlib.ServerName]*SigningInfo,
|
||||
) *OutgoingQueues {
|
||||
queues := &OutgoingQueues{
|
||||
disabled: disabled,
|
||||
@ -199,11 +199,10 @@ func (oqs *OutgoingQueues) SendEvent(
|
||||
log.Trace("Federation is disabled, not sending event")
|
||||
return nil
|
||||
}
|
||||
if origin != oqs.origin {
|
||||
// TODO: Support virtual hosting; gh issue #577.
|
||||
if _, ok := oqs.signing[origin]; !ok {
|
||||
return fmt.Errorf(
|
||||
"sendevent: unexpected server to send as: got %q expected %q",
|
||||
origin, oqs.origin,
|
||||
"sendevent: unexpected server to send as %q",
|
||||
origin,
|
||||
)
|
||||
}
|
||||
|
||||
@ -214,7 +213,9 @@ func (oqs *OutgoingQueues) SendEvent(
|
||||
destmap[d] = struct{}{}
|
||||
}
|
||||
delete(destmap, oqs.origin)
|
||||
delete(destmap, oqs.signing.ServerName)
|
||||
for local := range oqs.signing {
|
||||
delete(destmap, local)
|
||||
}
|
||||
|
||||
// Check if any of the destinations are prohibited by server ACLs.
|
||||
for destination := range destmap {
|
||||
@ -288,11 +289,10 @@ func (oqs *OutgoingQueues) SendEDU(
|
||||
log.Trace("Federation is disabled, not sending EDU")
|
||||
return nil
|
||||
}
|
||||
if origin != oqs.origin {
|
||||
// TODO: Support virtual hosting; gh issue #577.
|
||||
if _, ok := oqs.signing[origin]; !ok {
|
||||
return fmt.Errorf(
|
||||
"sendevent: unexpected server to send as: got %q expected %q",
|
||||
origin, oqs.origin,
|
||||
"sendevent: unexpected server to send as %q",
|
||||
origin,
|
||||
)
|
||||
}
|
||||
|
||||
@ -303,7 +303,9 @@ func (oqs *OutgoingQueues) SendEDU(
|
||||
destmap[d] = struct{}{}
|
||||
}
|
||||
delete(destmap, oqs.origin)
|
||||
delete(destmap, oqs.signing.ServerName)
|
||||
for local := range oqs.signing {
|
||||
delete(destmap, local)
|
||||
}
|
||||
|
||||
// There is absolutely no guarantee that the EDU will have a room_id
|
||||
// field, as it is not required by the spec. However, if it *does*
|
||||
|
@ -350,10 +350,12 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T
|
||||
}
|
||||
rs := &stubFederationRoomServerAPI{}
|
||||
stats := statistics.NewStatistics(db, failuresUntilBlacklist)
|
||||
signingInfo := &SigningInfo{
|
||||
KeyID: "ed21019:auto",
|
||||
PrivateKey: test.PrivateKeyA,
|
||||
ServerName: "localhost",
|
||||
signingInfo := map[gomatrixserverlib.ServerName]*SigningInfo{
|
||||
"localhost": {
|
||||
KeyID: "ed21019:auto",
|
||||
PrivateKey: test.PrivateKeyA,
|
||||
ServerName: "localhost",
|
||||
},
|
||||
}
|
||||
queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo)
|
||||
|
||||
|
@ -16,6 +16,7 @@ package routing
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@ -134,18 +135,21 @@ func ClaimOneTimeKeys(
|
||||
|
||||
// LocalKeys returns the local keys for the server.
|
||||
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
|
||||
func LocalKeys(cfg *config.FederationAPI) util.JSONResponse {
|
||||
keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
|
||||
func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse {
|
||||
keys, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: keys}
|
||||
}
|
||||
|
||||
func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
|
||||
func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
|
||||
var keys gomatrixserverlib.ServerKeys
|
||||
if !cfg.Matrix.IsLocalServerName(serverName) {
|
||||
return nil, fmt.Errorf("server name not known")
|
||||
}
|
||||
|
||||
keys.ServerName = cfg.Matrix.ServerName
|
||||
keys.ServerName = serverName
|
||||
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil)
|
||||
|
||||
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
|
||||
@ -172,7 +176,7 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
|
||||
}
|
||||
|
||||
keys.Raw, err = gomatrixserverlib.SignJSON(
|
||||
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign,
|
||||
string(serverName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -186,6 +190,14 @@ func NotaryKeys(
|
||||
fsAPI federationAPI.FederationInternalAPI,
|
||||
req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
|
||||
) util.JSONResponse {
|
||||
serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal
|
||||
if !cfg.Matrix.IsLocalServerName(serverName) {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: jsonerror.NotFound("Server name not known"),
|
||||
}
|
||||
}
|
||||
|
||||
if req == nil {
|
||||
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
|
||||
if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
|
||||
@ -201,7 +213,7 @@ func NotaryKeys(
|
||||
for serverName, kidToCriteria := range req.ServerKeys {
|
||||
var keyList []gomatrixserverlib.ServerKeys
|
||||
if serverName == cfg.Matrix.ServerName {
|
||||
if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil {
|
||||
if k, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil {
|
||||
keyList = append(keyList, *k)
|
||||
} else {
|
||||
return util.ErrorResponse(err)
|
||||
|
@ -74,7 +74,7 @@ func Setup(
|
||||
}
|
||||
|
||||
localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse {
|
||||
return LocalKeys(cfg)
|
||||
return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host))
|
||||
})
|
||||
|
||||
notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse {
|
||||
|
@ -33,16 +33,17 @@ import (
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/producers"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
type KeyInternalAPI struct {
|
||||
DB storage.Database
|
||||
ThisServer gomatrixserverlib.ServerName
|
||||
FedClient fedsenderapi.KeyserverFederationAPI
|
||||
UserAPI userapi.KeyserverUserAPI
|
||||
Producer *producers.KeyChange
|
||||
Updater *DeviceListUpdater
|
||||
DB storage.Database
|
||||
Cfg *config.KeyServer
|
||||
FedClient fedsenderapi.KeyserverFederationAPI
|
||||
UserAPI userapi.KeyserverUserAPI
|
||||
Producer *producers.KeyChange
|
||||
Updater *DeviceListUpdater
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) {
|
||||
@ -95,8 +96,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
|
||||
nested[userID] = val
|
||||
domainToDeviceKeys[string(serverName)] = nested
|
||||
}
|
||||
// claim local keys
|
||||
if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
|
||||
for domain, local := range domainToDeviceKeys {
|
||||
if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||
continue
|
||||
}
|
||||
// claim local keys
|
||||
keys, err := a.DB.ClaimKeys(ctx, local)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
@ -117,7 +121,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
|
||||
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
|
||||
}
|
||||
}
|
||||
delete(domainToDeviceKeys, string(a.ThisServer))
|
||||
delete(domainToDeviceKeys, domain)
|
||||
}
|
||||
if len(domainToDeviceKeys) > 0 {
|
||||
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
|
||||
@ -258,7 +262,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||
}
|
||||
domain := string(serverName)
|
||||
// query local devices
|
||||
if serverName == a.ThisServer {
|
||||
if a.Cfg.Matrix.IsLocalServerName(serverName) {
|
||||
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
@ -437,13 +441,13 @@ func (a *KeyInternalAPI) queryRemoteKeys(
|
||||
|
||||
domains := map[string]struct{}{}
|
||||
for domain := range domainToDeviceKeys {
|
||||
if domain == string(a.ThisServer) {
|
||||
if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||
continue
|
||||
}
|
||||
domains[domain] = struct{}{}
|
||||
}
|
||||
for domain := range domainToCrossSigningKeys {
|
||||
if domain == string(a.ThisServer) {
|
||||
if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||
continue
|
||||
}
|
||||
domains[domain] = struct{}{}
|
||||
@ -689,7 +693,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
||||
if err != nil {
|
||||
continue // ignore invalid users
|
||||
}
|
||||
if serverName != a.ThisServer {
|
||||
if !a.Cfg.Matrix.IsLocalServerName(serverName) {
|
||||
continue // ignore remote users
|
||||
}
|
||||
if len(key.KeyJSON) == 0 {
|
||||
|
@ -53,10 +53,10 @@ func NewInternalAPI(
|
||||
DB: db,
|
||||
}
|
||||
ap := &internal.KeyInternalAPI{
|
||||
DB: db,
|
||||
ThisServer: cfg.Matrix.ServerName,
|
||||
FedClient: fedClient,
|
||||
Producer: keyChangeProducer,
|
||||
DB: db,
|
||||
Cfg: cfg,
|
||||
FedClient: fedClient,
|
||||
Producer: keyChangeProducer,
|
||||
}
|
||||
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
|
||||
ap.Updater = updater
|
||||
|
@ -78,7 +78,7 @@ type ClientUserAPI interface {
|
||||
QueryAcccessTokenAPI
|
||||
LoginTokenInternalAPI
|
||||
UserLoginAPI
|
||||
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
|
||||
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
|
||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
@ -335,9 +335,10 @@ type PerformAccountCreationResponse struct {
|
||||
|
||||
// PerformAccountCreationRequest is the request for PerformAccountCreation
|
||||
type PerformPasswordUpdateRequest struct {
|
||||
Localpart string // Required: The localpart for this account.
|
||||
Password string // Required: The new password to set.
|
||||
LogoutDevices bool // Optional: Whether to log out all user devices.
|
||||
Localpart string // Required: The localpart for this account.
|
||||
ServerName gomatrixserverlib.ServerName // Required: The domain for this account.
|
||||
Password string // Required: The new password to set.
|
||||
LogoutDevices bool // Optional: Whether to log out all user devices.
|
||||
}
|
||||
|
||||
// PerformAccountCreationResponse is the response for PerformAccountCreation
|
||||
@ -518,7 +519,8 @@ const (
|
||||
)
|
||||
|
||||
type QueryPushersRequest struct {
|
||||
Localpart string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryPushersResponse struct {
|
||||
@ -526,14 +528,16 @@ type QueryPushersResponse struct {
|
||||
}
|
||||
|
||||
type PerformPusherSetRequest struct {
|
||||
Pusher // Anonymous field because that's how clientapi unmarshals it.
|
||||
Localpart string
|
||||
Append bool `json:"append"`
|
||||
Pusher // Anonymous field because that's how clientapi unmarshals it.
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
Append bool `json:"append"`
|
||||
}
|
||||
|
||||
type PerformPusherDeletionRequest struct {
|
||||
Localpart string
|
||||
SessionID int64
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
SessionID int64
|
||||
}
|
||||
|
||||
// Pusher represents a push notification subscriber
|
||||
@ -571,10 +575,11 @@ type QueryPushRulesResponse struct {
|
||||
}
|
||||
|
||||
type QueryNotificationsRequest struct {
|
||||
Localpart string `json:"localpart"` // Required.
|
||||
From string `json:"from,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Only string `json:"only,omitempty"`
|
||||
Localpart string `json:"localpart"` // Required.
|
||||
ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required.
|
||||
From string `json:"from,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
Only string `json:"only,omitempty"`
|
||||
}
|
||||
|
||||
type QueryNotificationsResponse struct {
|
||||
@ -601,12 +606,17 @@ type PerformSetAvatarURLResponse struct {
|
||||
Changed bool `json:"changed"`
|
||||
}
|
||||
|
||||
type QueryNumericLocalpartRequest struct {
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryNumericLocalpartResponse struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
type QueryAccountAvailabilityRequest struct {
|
||||
Localpart string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryAccountAvailabilityResponse struct {
|
||||
@ -614,7 +624,9 @@ type QueryAccountAvailabilityResponse struct {
|
||||
}
|
||||
|
||||
type QueryAccountByPasswordRequest struct {
|
||||
Localpart, PlaintextPassword string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
PlaintextPassword string
|
||||
}
|
||||
|
||||
type QueryAccountByPasswordResponse struct {
|
||||
@ -638,11 +650,13 @@ type QueryLocalpartForThreePIDRequest struct {
|
||||
}
|
||||
|
||||
type QueryLocalpartForThreePIDResponse struct {
|
||||
Localpart string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryThreePIDsForLocalpartRequest struct {
|
||||
Localpart string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type QueryThreePIDsForLocalpartResponse struct {
|
||||
@ -652,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct {
|
||||
type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest
|
||||
|
||||
type PerformSaveThreePIDAssociationRequest struct {
|
||||
ThreePID, Localpart, Medium string
|
||||
ThreePID string
|
||||
Localpart string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
Medium string
|
||||
}
|
||||
|
@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error {
|
||||
err := t.Impl.QueryNumericLocalpart(ctx, res)
|
||||
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error {
|
||||
err := t.Impl.QueryNumericLocalpart(ctx, req, res)
|
||||
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
|
||||
return err
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
|
||||
return false
|
||||
}
|
||||
|
||||
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
|
||||
updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("userapi EDU consumer")
|
||||
return false
|
||||
@ -118,7 +118,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
|
||||
if !updated {
|
||||
return true
|
||||
}
|
||||
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
|
||||
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, domain, s.db); err != nil {
|
||||
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
|
||||
return false
|
||||
}
|
||||
|
@ -192,25 +192,25 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy
|
||||
func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error {
|
||||
for _, membership := range localMembers {
|
||||
// Copy any existing push rules from old -> new room
|
||||
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
|
||||
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// preserve m.direct room state
|
||||
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil {
|
||||
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// copy existing m.tag entries, if any
|
||||
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
|
||||
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error {
|
||||
pushRules, err := s.db.QueryPushRules(ctx, localpart)
|
||||
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error {
|
||||
pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query pushrules for user: %w", err)
|
||||
}
|
||||
@ -229,7 +229,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil {
|
||||
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil {
|
||||
return fmt.Errorf("failed to update pushrules: %w", err)
|
||||
}
|
||||
}
|
||||
@ -237,13 +237,13 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
|
||||
}
|
||||
|
||||
// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID
|
||||
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error {
|
||||
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error {
|
||||
// this is most likely not a DM, so skip updating m.direct state
|
||||
if roomSize > 2 {
|
||||
return nil
|
||||
}
|
||||
// Get direct message state
|
||||
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct")
|
||||
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get m.direct from database: %w", err)
|
||||
}
|
||||
@ -267,7 +267,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil {
|
||||
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -279,15 +279,15 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error {
|
||||
tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag")
|
||||
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error {
|
||||
tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag")
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
if tag == nil {
|
||||
return nil
|
||||
}
|
||||
return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag)
|
||||
return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag)
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
|
||||
@ -492,11 +492,11 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
|
||||
func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
|
||||
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("s.evaluatePushRules: %w", err)
|
||||
}
|
||||
a, tweaks, err := pushrules.ActionsToTweaks(actions)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("pushrules.ActionsToTweaks: %w", err)
|
||||
}
|
||||
// TODO: support coalescing.
|
||||
if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
|
||||
@ -508,9 +508,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
||||
return nil
|
||||
}
|
||||
|
||||
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks)
|
||||
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("s.localPushDevices: %w", err)
|
||||
}
|
||||
|
||||
n := &api.Notification{
|
||||
@ -527,18 +527,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
||||
RoomID: event.RoomID(),
|
||||
TS: gomatrixserverlib.AsTimestamp(time.Now()),
|
||||
}
|
||||
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil {
|
||||
return err
|
||||
if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil {
|
||||
return fmt.Errorf("s.db.InsertNotification: %w", err)
|
||||
}
|
||||
|
||||
if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("s.syncProducer.GetAndSendNotificationData: %w", err)
|
||||
}
|
||||
|
||||
// We do this after InsertNotification. Thus, this should always return >=1.
|
||||
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications)
|
||||
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, mem.Domain, tables.AllNotifications)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("s.db.GetNotificationCount: %w", err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
@ -589,7 +589,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
||||
}
|
||||
|
||||
if len(rejected) > 0 {
|
||||
s.deleteRejectedPushers(ctx, rejected, mem.Localpart)
|
||||
s.deleteRejectedPushers(ctx, rejected, mem.Localpart, mem.Domain)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -606,7 +606,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
|
||||
}
|
||||
|
||||
// Get accountdata to check if the event.Sender() is ignored by mem.LocalPart
|
||||
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list")
|
||||
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, mem.Domain, "", "m.ignored_user_list")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -621,7 +621,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
|
||||
return nil, fmt.Errorf("user %s is ignored", sender)
|
||||
}
|
||||
}
|
||||
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart)
|
||||
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -693,10 +693,10 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
|
||||
|
||||
// localPushDevices pushes to the configured devices of a local
|
||||
// user. The map keys are [url][format].
|
||||
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
|
||||
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
|
||||
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
|
||||
pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
return nil, "", fmt.Errorf("util.GetPushDevices: %w", err)
|
||||
}
|
||||
|
||||
var profileTag string
|
||||
@ -791,7 +791,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri
|
||||
}
|
||||
|
||||
// deleteRejectedPushers deletes the pushers associated with the given devices.
|
||||
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
|
||||
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) {
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
"app_id0": devices[0].AppID,
|
||||
@ -799,7 +799,7 @@ func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, dev
|
||||
}).Warnf("Deleting pushers rejected by the HTTP push gateway")
|
||||
|
||||
for _, d := range devices {
|
||||
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil {
|
||||
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart, serverName); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"localpart": localpart,
|
||||
}).WithError(err).Errorf("Unable to delete rejected pusher")
|
||||
|
@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
|
||||
if req.DataType == "" {
|
||||
return fmt.Errorf("data type must not be empty")
|
||||
}
|
||||
if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil {
|
||||
if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed")
|
||||
return fmt.Errorf("failed to save account data: %w", err)
|
||||
}
|
||||
@ -108,7 +108,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
|
||||
return nil
|
||||
}
|
||||
|
||||
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
|
||||
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
|
||||
return err
|
||||
@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
|
||||
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil {
|
||||
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
|
||||
return err
|
||||
}
|
||||
@ -175,8 +175,10 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||
if serverName == "" {
|
||||
serverName = a.Config.Matrix.ServerName
|
||||
}
|
||||
// XXXX: Use the server name here
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
||||
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
||||
return fmt.Errorf("server name %s is not local", serverName)
|
||||
}
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
|
||||
if err != nil {
|
||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||
switch req.OnConflict {
|
||||
@ -215,8 +217,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
|
||||
return err
|
||||
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil {
|
||||
return fmt.Errorf("a.DB.SetDisplayName: %w", err)
|
||||
}
|
||||
|
||||
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
|
||||
@ -227,11 +229,14 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
|
||||
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
|
||||
if !a.Config.Matrix.IsLocalServerName(req.ServerName) {
|
||||
return fmt.Errorf("server name %s is not local", req.ServerName)
|
||||
}
|
||||
if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
if req.LogoutDevices {
|
||||
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil {
|
||||
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -244,14 +249,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
||||
if serverName == "" {
|
||||
serverName = a.Config.Matrix.ServerName
|
||||
}
|
||||
_ = serverName
|
||||
// XXXX: Use the server name here
|
||||
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
||||
return fmt.Errorf("server name %s is not local", serverName)
|
||||
}
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"localpart": req.Localpart,
|
||||
"device_id": req.DeviceID,
|
||||
"display_name": req.DeviceDisplayName,
|
||||
}).Info("PerformDeviceCreation")
|
||||
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||
dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -276,12 +282,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
||||
deletedDeviceIDs := req.DeviceIDs
|
||||
if len(req.DeviceIDs) == 0 {
|
||||
var devices []api.Device
|
||||
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
|
||||
devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID)
|
||||
for _, d := range devices {
|
||||
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
|
||||
}
|
||||
} else {
|
||||
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
|
||||
err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
@ -335,23 +341,29 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
|
||||
req *api.PerformLastSeenUpdateRequest,
|
||||
res *api.PerformLastSeenUpdateResponse,
|
||||
) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||
}
|
||||
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
|
||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("server name %s is not local", domain)
|
||||
}
|
||||
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
|
||||
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return err
|
||||
}
|
||||
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
|
||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("server name %s is not local", domain)
|
||||
}
|
||||
dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID)
|
||||
if err == sql.ErrNoRows {
|
||||
res.DeviceExists = false
|
||||
return nil
|
||||
@ -366,7 +378,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
||||
return nil
|
||||
}
|
||||
|
||||
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
|
||||
err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
||||
return err
|
||||
@ -406,7 +418,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
|
||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
|
||||
}
|
||||
prof, err := a.DB.GetProfileByLocalpart(ctx, local)
|
||||
prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
@ -457,7 +469,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
|
||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
|
||||
}
|
||||
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
|
||||
devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -476,7 +488,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
||||
}
|
||||
if req.DataType != "" {
|
||||
var data json.RawMessage
|
||||
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
||||
data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -494,7 +506,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
||||
}
|
||||
return nil
|
||||
}
|
||||
global, rooms, err := a.DB.GetAccountData(ctx, local)
|
||||
global, rooms, err := a.DB.GetAccountData(ctx, local, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -527,7 +539,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return nil
|
||||
}
|
||||
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
|
||||
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -561,14 +573,14 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
|
||||
AccountType: api.AccountTypeAppService,
|
||||
}
|
||||
|
||||
localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
|
||||
localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if localpart != "" { // AS is masquerading as another user
|
||||
// Verify that the user is registered
|
||||
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
|
||||
account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain)
|
||||
// Verify that the account exists and either appServiceID matches or
|
||||
// it belongs to the appservice user namespaces
|
||||
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
|
||||
@ -620,7 +632,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
|
||||
return err
|
||||
}
|
||||
|
||||
err := a.DB.DeactivateAccount(ctx, req.Localpart)
|
||||
err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
|
||||
res.AccountDeactivated = err == nil
|
||||
return err
|
||||
}
|
||||
@ -783,7 +795,7 @@ func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.Query
|
||||
if req.Only == "highlight" {
|
||||
filter = tables.HighlightNotifications
|
||||
}
|
||||
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
|
||||
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -811,23 +823,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
|
||||
}
|
||||
}
|
||||
if req.Pusher.Kind == "" {
|
||||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
|
||||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName)
|
||||
}
|
||||
if req.Pusher.PushKeyTS == 0 {
|
||||
req.Pusher.PushKeyTS = int64(time.Now().Unix())
|
||||
}
|
||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
|
||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
|
||||
pushers, err := a.DB.GetPushers(ctx, req.Localpart)
|
||||
pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range pushers {
|
||||
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
|
||||
if pushers[i].SessionID != req.SessionID {
|
||||
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
|
||||
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -838,7 +850,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe
|
||||
|
||||
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
|
||||
var err error
|
||||
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
|
||||
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -864,11 +876,11 @@ func (a *UserInternalAPI) PerformPushRulesPut(
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
|
||||
}
|
||||
pushRules, err := a.DB.QueryPushRules(ctx, localpart)
|
||||
pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query push rules: %w", err)
|
||||
}
|
||||
@ -877,14 +889,14 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
|
||||
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
|
||||
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL)
|
||||
res.Profile = profile
|
||||
res.Changed = changed
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
|
||||
id, err := a.DB.GetNewNumericLocalpart(ctx)
|
||||
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error {
|
||||
id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -894,12 +906,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu
|
||||
|
||||
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
|
||||
var err error
|
||||
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart)
|
||||
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
|
||||
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword)
|
||||
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword)
|
||||
switch err {
|
||||
case sql.ErrNoRows: // user does not exist
|
||||
return nil
|
||||
@ -915,23 +927,24 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
|
||||
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
|
||||
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName)
|
||||
res.Profile = profile
|
||||
res.Changed = changed
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
|
||||
localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
|
||||
localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.Localpart = localpart
|
||||
res.ServerName = domain
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
|
||||
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart)
|
||||
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -944,7 +957,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
|
||||
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium)
|
||||
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
|
||||
}
|
||||
|
||||
const pushRulesAccountDataType = "m.push_rules"
|
||||
|
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
|
||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
|
||||
}
|
||||
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
|
||||
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil {
|
||||
res.Data = nil
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
|
@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL(
|
||||
|
||||
func (h *httpUserInternalAPI) QueryNumericLocalpart(
|
||||
ctx context.Context,
|
||||
request *api.QueryNumericLocalpartRequest,
|
||||
response *api.QueryNumericLocalpartResponse,
|
||||
) error {
|
||||
return httputil.CallInternalRPCAPI(
|
||||
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
|
||||
h.httpClient, ctx, &struct{}{}, response,
|
||||
h.httpClient, ctx, request, response,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -15,12 +15,9 @@
|
||||
package inthttp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// nolint: gocyclo
|
||||
@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
|
||||
)
|
||||
|
||||
// TODO: Look at the shape of this
|
||||
internalAPIMux.Handle(QueryNumericLocalpartPath,
|
||||
httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse {
|
||||
response := api.QueryNumericLocalpartResponse{}
|
||||
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
internalAPIMux.Handle(
|
||||
QueryNumericLocalpartPath,
|
||||
httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart),
|
||||
)
|
||||
|
||||
internalAPIMux.Handle(
|
||||
|
@ -61,12 +61,12 @@ func (p *SyncAPI) SendAccountData(userID string, data eventutil.AccountData) err
|
||||
// GetAndSendNotificationData reads the database and sends data about unread
|
||||
// notifications to the Sync API server.
|
||||
func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID)
|
||||
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, domain, roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -29,40 +29,40 @@ import (
|
||||
)
|
||||
|
||||
type Profile interface {
|
||||
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||
GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
|
||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
|
||||
SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
|
||||
SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
||||
SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
|
||||
}
|
||||
|
||||
type Account interface {
|
||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
|
||||
CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error)
|
||||
GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||
GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
|
||||
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||
SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error
|
||||
}
|
||||
|
||||
type AccountData interface {
|
||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
||||
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||
SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
|
||||
GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
// localpart, room ID and type.
|
||||
// If no account data could be found, returns nil
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||
QueryPushRules(ctx context.Context, localpart string) (*pushrules.AccountRuleSets, error)
|
||||
GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
|
||||
QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error)
|
||||
}
|
||||
|
||||
type Device interface {
|
||||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
||||
GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
|
||||
GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error)
|
||||
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||
@ -70,12 +70,12 @@ type Device interface {
|
||||
// an error will be returned.
|
||||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
||||
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error
|
||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||
CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||
UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
|
||||
UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
|
||||
RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
|
||||
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
||||
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
||||
RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error)
|
||||
}
|
||||
|
||||
type KeyBackup interface {
|
||||
@ -107,26 +107,26 @@ type OpenID interface {
|
||||
}
|
||||
|
||||
type Pusher interface {
|
||||
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
|
||||
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
|
||||
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
|
||||
UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
|
||||
RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
RemovePushers(ctx context.Context, appid, pushkey string) error
|
||||
}
|
||||
|
||||
type ThreePID interface {
|
||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error)
|
||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
|
||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
|
||||
}
|
||||
|
||||
type Notification interface {
|
||||
InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
|
||||
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error)
|
||||
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error)
|
||||
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
||||
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
|
||||
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
|
||||
InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
|
||||
DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error)
|
||||
SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, read bool) (affected bool, err error)
|
||||
GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
||||
GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error)
|
||||
GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
|
||||
DeleteOldNotifications(ctx context.Context) error
|
||||
}
|
||||
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
@ -29,27 +30,28 @@ const accountDataSchema = `
|
||||
CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- The room ID for this data (empty string if not specific to a room)
|
||||
room_id TEXT,
|
||||
-- The account data type
|
||||
type TEXT NOT NULL,
|
||||
-- The account data content
|
||||
content TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(localpart, room_id, type)
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
|
||||
`
|
||||
|
||||
const insertAccountDataSQL = `
|
||||
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
|
||||
INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = EXCLUDED.content
|
||||
`
|
||||
|
||||
const selectAccountDataSQL = "" +
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountDataByTypeSQL = "" +
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
|
||||
|
||||
type accountDataStatements struct {
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
@ -71,21 +73,24 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) InsertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string, content json.RawMessage,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
||||
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountData(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (
|
||||
/* global */ map[string]json.RawMessage,
|
||||
/* rooms */ map[string]map[string]json.RawMessage,
|
||||
error,
|
||||
) {
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
stmt := s.selectAccountDataByTypeStmt
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
@ -34,7 +35,8 @@ const accountsSchema = `
|
||||
-- Stores data about accounts.
|
||||
CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- When this account was first created, as a unix timestamp (ms resolution).
|
||||
created_ts BIGINT NOT NULL,
|
||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||
@ -48,25 +50,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||
-- TODO:
|
||||
-- upgraded_ts, devices, any email reset stuff?
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
|
||||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
|
||||
|
||||
const deactivateAccountSQL = "" +
|
||||
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1"
|
||||
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
||||
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'"
|
||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1"
|
||||
|
||||
type accountsStatements struct {
|
||||
insertAccountStmt *sql.Stmt
|
||||
@ -117,59 +121,62 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
|
||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) InsertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
hash, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||
|
||||
var err error
|
||||
if accountType != api.AccountTypeAppService {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
|
||||
} else {
|
||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("insertAccountStmt: %w", err)
|
||||
}
|
||||
|
||||
return &api.Account{
|
||||
Localpart: localpart,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
ServerName: s.serverName,
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
ServerName: serverName,
|
||||
AppServiceID: appserviceID,
|
||||
AccountType: accountType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) UpdatePassword(
|
||||
ctx context.Context, localpart, passwordHash string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
passwordHash string,
|
||||
) (err error) {
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) DeactivateAccount(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectPasswordHash(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (hash string, err error) {
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*api.Account, error) {
|
||||
var appserviceIDPtr sql.NullString
|
||||
var acc api.Account
|
||||
|
||||
stmt := s.selectAccountByLocalpartStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||
@ -180,19 +187,17 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
||||
acc.AppServiceID = appserviceIDPtr.String
|
||||
}
|
||||
|
||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
acc.ServerName = s.serverName
|
||||
|
||||
acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
|
||||
return &acc, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
|
||||
return id + 1, err
|
||||
}
|
||||
|
@ -0,0 +1,81 @@
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var serverNamesTables = []string{
|
||||
"userapi_accounts",
|
||||
"userapi_account_datas",
|
||||
"userapi_devices",
|
||||
"userapi_notifications",
|
||||
"userapi_openid_tokens",
|
||||
"userapi_profiles",
|
||||
"userapi_pushers",
|
||||
"userapi_threepids",
|
||||
}
|
||||
|
||||
// These tables have a PRIMARY KEY constraint which we need to drop so
|
||||
// that we can recreate a new unique index that contains the server name.
|
||||
// If the new key doesn't exist (i.e. the database was created before the
|
||||
// table rename migration) we'll try to drop the old one instead.
|
||||
var serverNamesDropPK = map[string]string{
|
||||
"userapi_accounts": "account_accounts",
|
||||
"userapi_account_datas": "account_data",
|
||||
"userapi_profiles": "account_profiles",
|
||||
}
|
||||
|
||||
// These indices are out of date so let's drop them. They will get recreated
|
||||
// automatically.
|
||||
var serverNamesDropIndex = []string{
|
||||
"userapi_pusher_localpart_idx",
|
||||
"userapi_pusher_app_id_pushkey_localpart_idx",
|
||||
}
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("add server name to %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
for newTable, oldTable := range serverNamesDropPK {
|
||||
q := fmt.Sprintf(
|
||||
"ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;",
|
||||
pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(newTable+"_pkey"),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("drop new PK from %q error: %w", newTable, err)
|
||||
}
|
||||
q = fmt.Sprintf(
|
||||
"ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;",
|
||||
pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(oldTable+"_pkey"),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("drop old PK from %q error: %w", newTable, err)
|
||||
}
|
||||
}
|
||||
for _, index := range serverNamesDropIndex {
|
||||
q := fmt.Sprintf(
|
||||
"DROP INDEX IF EXISTS %s;",
|
||||
pq.QuoteIdentifier(index),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("drop index %q error: %w", index, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"UPDATE %s SET server_name = %s WHERE server_name = '';",
|
||||
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("write server names to %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@ -17,6 +17,7 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
@ -50,6 +51,7 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
|
||||
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
|
||||
-- migration to different domain names easier.
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
|
||||
created_ts BIGINT NOT NULL,
|
||||
-- The display name, human friendlier than device_id and updatable
|
||||
@ -65,39 +67,39 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
|
||||
);
|
||||
|
||||
-- Device IDs must be unique for a given user.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, server_name, device_id);
|
||||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
|
||||
"INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
|
||||
" RETURNING session_id"
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
|
||||
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
|
||||
|
||||
const selectDeviceByIDSQL = "" +
|
||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
|
||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
|
||||
|
||||
const deleteDeviceSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
|
||||
|
||||
const deleteDevicesByLocalpartSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
|
||||
|
||||
const deleteDevicesSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
||||
|
||||
type devicesStatements struct {
|
||||
insertDeviceStmt *sql.Stmt
|
||||
@ -148,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
|
||||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) InsertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
||||
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
||||
return nil, err
|
||||
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
||||
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
|
||||
}
|
||||
return &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
SessionID: sessionID,
|
||||
LastSeenTS: createdTimeMS,
|
||||
@ -170,38 +173,45 @@ func (s *devicesStatements) InsertDevice(
|
||||
|
||||
// deleteDevice removes a single device by id and user localpart.
|
||||
func (s *devicesStatements) DeleteDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
// deleteDevices removes a single or multiple devices by ids and user localpart.
|
||||
// Returns an error if the execution failed.
|
||||
func (s *devicesStatements) DeleteDevices(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
devices []string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
|
||||
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
|
||||
_, err := stmt.ExecContext(ctx, localpart, serverName, pq.Array(devices))
|
||||
return err
|
||||
}
|
||||
|
||||
// deleteDevicesByLocalpart removes all devices for the
|
||||
// given user localpart.
|
||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
||||
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) UpdateDeviceName(
|
||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string, displayName *string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -210,10 +220,11 @@ func (s *devicesStatements) SelectDeviceByToken(
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
stmt := s.selectDeviceByTokenStmt
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
|
||||
if err == nil {
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
dev.AccessToken = accessToken
|
||||
}
|
||||
return &dev, err
|
||||
@ -222,16 +233,18 @@ func (s *devicesStatements) SelectDeviceByToken(
|
||||
// selectDeviceByID retrieves a device from the database with the given user
|
||||
// localpart and deviceID
|
||||
func (s *devicesStatements) SelectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var displayName, ip sql.NullString
|
||||
var lastseenTS sql.NullInt64
|
||||
stmt := s.selectDeviceByIDStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||
if err == nil {
|
||||
dev.ID = deviceID
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
@ -254,10 +267,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||
var devices []api.Device
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
var lastseents sql.NullInt64
|
||||
var displayName sql.NullString
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
@ -266,17 +280,19 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||
if lastseents.Valid {
|
||||
dev.LastSeenTS = lastseents.Int64
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
|
||||
|
||||
if err != nil {
|
||||
return devices, err
|
||||
@ -307,16 +323,16 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||
dev.UserAgent = useragent.String
|
||||
}
|
||||
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
|
||||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
|
||||
return err
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ const notificationSchema = `
|
||||
CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
stream_pos BIGINT NOT NULL,
|
||||
@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||
read BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
|
||||
`
|
||||
|
||||
const insertNotificationSQL = "" +
|
||||
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
"INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||
|
||||
const deleteNotificationsUpToSQL = "" +
|
||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
|
||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
|
||||
|
||||
const updateNotificationReadSQL = "" +
|
||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
|
||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
|
||||
|
||||
const selectNotificationSQL = "" +
|
||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
|
||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read ORDER BY localpart, id LIMIT $4"
|
||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
|
||||
"(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read ORDER BY localpart, id LIMIT $5"
|
||||
|
||||
const selectNotificationCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
|
||||
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
|
||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
|
||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read"
|
||||
|
||||
const selectRoomNotificationCountsSQL = "" +
|
||||
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
||||
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
|
||||
"WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
|
||||
|
||||
const cleanNotificationsSQL = "" +
|
||||
"DELETE FROM userapi_notifications WHERE" +
|
||||
@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
|
||||
}
|
||||
|
||||
// Insert inserts a notification into the database.
|
||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||
roomID, tsMS := n.RoomID, n.TS
|
||||
nn := *n
|
||||
// Clears out fields that have their own columns to (1) shrink the
|
||||
@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
|
||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
|
||||
}
|
||||
|
||||
// UpdateRead updates the "read" value for an event.
|
||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
|
||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
|
||||
return nrows > 0, nil
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
|
||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
|
||||
return notifs, maxID, rows.Err()
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
|
||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
|
||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
|
||||
return
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
-- The Matrix user ID for this account
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- When the token expires, as a unix timestamp (ms resolution).
|
||||
token_expires_at_ms BIGINT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertOpenIDTokenSQL = "" +
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectOpenIDTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
|
||||
type openIDTokenStatements struct {
|
||||
insertTokenStmt *sql.Stmt
|
||||
@ -54,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
token, localpart string,
|
||||
token, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
expiresAtMS int64,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
|
||||
return
|
||||
}
|
||||
|
||||
@ -69,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
||||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||
&openIDTokenAttrs.UserID,
|
||||
&localpart, &serverName,
|
||||
&openIDTokenAttrs.ExpiresAtMS,
|
||||
)
|
||||
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||
|
@ -23,42 +23,46 @@ import (
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const profilesSchema = `
|
||||
-- Stores data about accounts profiles.
|
||||
CREATE TABLE IF NOT EXISTS userapi_profiles (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- The display name for this account
|
||||
display_name TEXT,
|
||||
-- The URL of the avatar for this account
|
||||
avatar_url TEXT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
|
||||
`
|
||||
|
||||
const insertProfileSQL = "" +
|
||||
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectProfileByLocalpartSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
||||
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const setAvatarURLSQL = "" +
|
||||
"UPDATE userapi_profiles AS new" +
|
||||
" SET avatar_url = $1" +
|
||||
" FROM userapi_profiles AS old" +
|
||||
" WHERE new.localpart = $2" +
|
||||
" WHERE new.localpart = $2 AND new.server_name = $3" +
|
||||
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
|
||||
|
||||
const setDisplayNameSQL = "" +
|
||||
"UPDATE userapi_profiles AS new" +
|
||||
" SET display_name = $1" +
|
||||
" FROM userapi_profiles AS old" +
|
||||
" WHERE new.localpart = $2" +
|
||||
" WHERE new.localpart = $2 AND new.server_name = $3" +
|
||||
" RETURNING new.avatar_url, old.display_name <> new.display_name"
|
||||
|
||||
const selectProfilesBySearchSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
|
||||
type profilesStatements struct {
|
||||
serverNoticesLocalpart string
|
||||
@ -87,18 +91,20 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables
|
||||
}
|
||||
|
||||
func (s *profilesStatements) InsertProfile(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
|
||||
return
|
||||
}
|
||||
|
||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*authtypes.Profile, error) {
|
||||
var profile authtypes.Profile
|
||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
|
||||
&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -107,28 +113,34 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
||||
}
|
||||
|
||||
func (s *profilesStatements) SetAvatarURL(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
avatarURL string,
|
||||
) (*authtypes.Profile, bool, error) {
|
||||
profile := &authtypes.Profile{
|
||||
Localpart: localpart,
|
||||
AvatarURL: avatarURL,
|
||||
Localpart: localpart,
|
||||
ServerName: string(serverName),
|
||||
AvatarURL: avatarURL,
|
||||
}
|
||||
var changed bool
|
||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
|
||||
err := stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName, &changed)
|
||||
return profile, changed, err
|
||||
}
|
||||
|
||||
func (s *profilesStatements) SetDisplayName(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
displayName string,
|
||||
) (*authtypes.Profile, bool, error) {
|
||||
profile := &authtypes.Profile{
|
||||
Localpart: localpart,
|
||||
ServerName: string(serverName),
|
||||
DisplayName: displayName,
|
||||
}
|
||||
var changed bool
|
||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
|
||||
err := stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL, &changed)
|
||||
return profile, changed, err
|
||||
}
|
||||
|
||||
@ -146,7 +158,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var profile authtypes.Profile
|
||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||
if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if profile.Localpart != s.serverNoticesLocalpart {
|
||||
|
@ -25,6 +25,7 @@ import (
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
||||
@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
-- The Matrix user ID localpart for this pusher
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
session_id BIGINT DEFAULT NULL,
|
||||
profile_tag TEXT,
|
||||
kind TEXT NOT NULL,
|
||||
@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
||||
|
||||
-- For faster retrieving by localpart.
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
|
||||
|
||||
-- Pushkey must be unique for a given user and app.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
|
||||
`
|
||||
|
||||
const insertPusherSQL = "" +
|
||||
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
|
||||
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
|
||||
|
||||
const selectPushersSQL = "" +
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const deletePusherSQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
|
||||
|
||||
const deletePushersByAppIdAndPushKeySQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
||||
@ -95,18 +97,19 @@ type pushersStatements struct {
|
||||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *pushersStatements) SelectPushers(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Pusher, error) {
|
||||
pushers := []api.Pusher{}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName)
|
||||
|
||||
if err != nil {
|
||||
return pushers, err
|
||||
@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
|
||||
|
||||
// deletePusher removes a single pusher by pushkey and user localpart.
|
||||
func (s *pushersStatements) DeletePusher(
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,8 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@ -43,18 +45,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||
Up: deltas.UpRenameTables,
|
||||
Down: deltas.DownRenameTables,
|
||||
})
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNames(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountDataTable, err := NewPostgresAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
|
||||
}
|
||||
accountsTable, err := NewPostgresAccountsTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
|
||||
}
|
||||
accountDataTable, err := NewPostgresAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
|
||||
}
|
||||
devicesTable, err := NewPostgresDevicesTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
|
||||
@ -95,6 +103,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
|
||||
}
|
||||
|
||||
m = sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names populate",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.Database{
|
||||
AccountDatas: accountDataTable,
|
||||
Accounts: accountsTable,
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
)
|
||||
@ -33,21 +34,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
|
||||
medium TEXT NOT NULL DEFAULT 'email',
|
||||
-- The localpart of the Matrix user ID associated to this 3PID
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(threepid, medium)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart);
|
||||
CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart, server_name);
|
||||
`
|
||||
|
||||
const selectLocalpartForThreePIDSQL = "" +
|
||||
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
"SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
const selectThreePIDsForLocalpartSQL = "" +
|
||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
|
||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const insertThreePIDSQL = "" +
|
||||
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const deleteThreePIDSQL = "" +
|
||||
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
@ -75,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
||||
|
||||
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
return "", "", nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -109,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||
}
|
||||
|
||||
func (s *threepidStatements) InsertThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -68,9 +68,10 @@ const (
|
||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword string,
|
||||
) (*api.Account, error) {
|
||||
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
|
||||
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -80,24 +81,27 @@ func (d *Database) GetAccountByPassword(
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||
return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||
func (d *Database) GetProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*authtypes.Profile, error) {
|
||||
return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
|
||||
return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||
// localpart. Returns an error if something went wrong with the SQL query
|
||||
func (d *Database) SetAvatarURL(
|
||||
ctx context.Context, localpart string, avatarURL string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
avatarURL string,
|
||||
) (profile *authtypes.Profile, changed bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
|
||||
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -106,10 +110,12 @@ func (d *Database) SetAvatarURL(
|
||||
// SetDisplayName updates the display name of the profile associated with the given
|
||||
// localpart. Returns an error if something went wrong with the SQL query
|
||||
func (d *Database) SetDisplayName(
|
||||
ctx context.Context, localpart string, displayName string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
displayName string,
|
||||
) (profile *authtypes.Profile, changed bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
|
||||
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -117,14 +123,15 @@ func (d *Database) SetDisplayName(
|
||||
|
||||
// SetPassword sets the account password to the given hash.
|
||||
func (d *Database) SetPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword string,
|
||||
) error {
|
||||
hash, err := d.hashPassword(plaintextPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.UpdatePassword(ctx, localpart, hash)
|
||||
return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash)
|
||||
})
|
||||
}
|
||||
|
||||
@ -132,21 +139,22 @@ func (d *Database) SetPassword(
|
||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
func (d *Database) CreateAccount(
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
) (acc *api.Account, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
// For guest accounts, we create a new numeric local part
|
||||
if accountType == api.AccountTypeGuest {
|
||||
var numLocalpart int64
|
||||
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
|
||||
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("d.Accounts.SelectNewNumericLocalpart: %w", err)
|
||||
}
|
||||
localpart = strconv.FormatInt(numLocalpart, 10)
|
||||
plaintextPassword = ""
|
||||
appserviceID = ""
|
||||
}
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, serverName, plaintextPassword, appserviceID, accountType)
|
||||
return err
|
||||
})
|
||||
return
|
||||
@ -155,7 +163,9 @@ func (d *Database) CreateAccount(
|
||||
// WARNING! This function assumes that the relevant mutexes have already
|
||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||
func (d *Database) createAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
var err error
|
||||
var account *api.Account
|
||||
@ -167,28 +177,28 @@ func (d *Database) createAccount(
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
||||
return nil, err
|
||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil {
|
||||
return nil, fmt.Errorf("d.Profiles.InsertProfile: %w", err)
|
||||
}
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
||||
prbs, err := json.Marshal(pushRuleSets)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("json.Marshal: %w", err)
|
||||
}
|
||||
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||
return nil, err
|
||||
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||
return nil, fmt.Errorf("d.AccountDatas.InsertAccountData: %w", err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (d *Database) QueryPushRules(
|
||||
ctx context.Context,
|
||||
localpart string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*pushrules.AccountRuleSets, error) {
|
||||
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, "", "m.push_rules")
|
||||
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -196,13 +206,13 @@ func (d *Database) QueryPushRules(
|
||||
// If we didn't find any default push rules then we should just generate some
|
||||
// fresh ones.
|
||||
if len(data) == 0 {
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
||||
prbs, err := json.Marshal(pushRuleSets)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal default push rules: %w", err)
|
||||
}
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", prbs); dbErr != nil {
|
||||
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", prbs); dbErr != nil {
|
||||
return fmt.Errorf("failed to save default push rules: %w", dbErr)
|
||||
}
|
||||
return nil
|
||||
@ -225,22 +235,23 @@ func (d *Database) QueryPushRules(
|
||||
// update the corresponding row with the new content
|
||||
// Returns a SQL error if there was an issue with the insertion/update
|
||||
func (d *Database) SaveAccountData(
|
||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, roomID, dataType, content)
|
||||
})
|
||||
}
|
||||
|
||||
// GetAccountData returns account data related to a given localpart
|
||||
// If no account data could be found, returns an empty arrays
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||
func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (
|
||||
global map[string]json.RawMessage,
|
||||
rooms map[string]map[string]json.RawMessage,
|
||||
err error,
|
||||
) {
|
||||
return d.AccountDatas.SelectAccountData(ctx, localpart)
|
||||
return d.AccountDatas.SelectAccountData(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
@ -248,18 +259,19 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||
// If no account data could be found, returns nil
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
return d.AccountDatas.SelectAccountDataByType(
|
||||
ctx, localpart, roomID, dataType,
|
||||
ctx, localpart, serverName, roomID, dataType,
|
||||
)
|
||||
}
|
||||
|
||||
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
||||
func (d *Database) GetNewNumericLocalpart(
|
||||
ctx context.Context,
|
||||
ctx context.Context, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
|
||||
return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName)
|
||||
}
|
||||
|
||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||
@ -276,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
||||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) SaveThreePIDAssociation(
|
||||
ctx context.Context, threepid, localpart, medium string,
|
||||
ctx context.Context, threepid string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
medium string,
|
||||
) (err error) {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
user, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
||||
user, _, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
||||
ctx, txn, threepid, medium,
|
||||
)
|
||||
if err != nil {
|
||||
@ -290,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation(
|
||||
return Err3PIDInUse
|
||||
}
|
||||
|
||||
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
|
||||
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
@ -313,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation(
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) GetLocalpartForThreePID(
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||
}
|
||||
|
||||
@ -322,16 +336,17 @@ func (d *Database) GetLocalpartForThreePID(
|
||||
// If no association is known for this user, returns an empty slice.
|
||||
// Returns an error if there was an issue talking to the database.
|
||||
func (d *Database) GetThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
|
||||
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// CheckAccountAvailability checks if the username/localpart is already present
|
||||
// in the database.
|
||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return true, nil
|
||||
}
|
||||
@ -341,12 +356,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
|
||||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||
// This function assumes the request is authenticated or the account data is used only internally.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*api.Account, error) {
|
||||
// try to get the account with lowercase localpart (majority)
|
||||
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart))
|
||||
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request
|
||||
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request
|
||||
}
|
||||
return acc, err
|
||||
}
|
||||
@ -359,20 +374,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
|
||||
}
|
||||
|
||||
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) {
|
||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.DeactivateAccount(ctx, localpart)
|
||||
return d.Accounts.DeactivateAccount(ctx, localpart, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
||||
func (d *Database) CreateOpenIDToken(
|
||||
ctx context.Context,
|
||||
token, localpart string,
|
||||
token, userID string,
|
||||
) (int64, error) {
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS)
|
||||
})
|
||||
return expiresAtMS, err
|
||||
}
|
||||
@ -539,16 +558,19 @@ func (d *Database) GetDeviceByAccessToken(
|
||||
// GetDeviceByID returns the device matching the given ID.
|
||||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string,
|
||||
) (*api.Device, error) {
|
||||
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
|
||||
return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID)
|
||||
}
|
||||
|
||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||
func (d *Database) GetDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Device, error) {
|
||||
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "")
|
||||
}
|
||||
|
||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
@ -562,18 +584,18 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
|
||||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
func (d *Database) CreateDevice(
|
||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
|
||||
) (dev *api.Device, returnErr error) {
|
||||
if deviceID != nil {
|
||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
// Revoke existing tokens for this device
|
||||
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
@ -588,7 +610,7 @@ func (d *Database) CreateDevice(
|
||||
|
||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
if returnErr == nil {
|
||||
@ -614,10 +636,12 @@ func generateDeviceID() (string, error) {
|
||||
// UpdateDevice updates the given device with the display name.
|
||||
// Returns SQL error if there are problems and nil on success.
|
||||
func (d *Database) UpdateDevice(
|
||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string, displayName *string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||
return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName)
|
||||
})
|
||||
}
|
||||
|
||||
@ -626,10 +650,12 @@ func (d *Database) UpdateDevice(
|
||||
// If the devices don't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveDevices(
|
||||
ctx context.Context, localpart string, devices []string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
devices []string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||
if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@ -640,14 +666,16 @@ func (d *Database) RemoveDevices(
|
||||
// database matching the given user ID localpart.
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveAllDevices(
|
||||
ctx context.Context, localpart, exceptDeviceID string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) (devices []api.Device, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@ -656,9 +684,9 @@ func (d *Database) RemoveAllDevices(
|
||||
}
|
||||
|
||||
// UpdateDeviceLastSeen updates a last seen timestamp and the ip address.
|
||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error {
|
||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent)
|
||||
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent)
|
||||
})
|
||||
}
|
||||
|
||||
@ -706,38 +734,38 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
|
||||
return d.LoginTokens.SelectLoginToken(ctx, token)
|
||||
}
|
||||
|
||||
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
|
||||
func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
|
||||
return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) {
|
||||
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
|
||||
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) {
|
||||
func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
|
||||
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
|
||||
func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter)
|
||||
}
|
||||
|
||||
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
|
||||
return d.Notifications.SelectCount(ctx, nil, localpart, filter)
|
||||
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) {
|
||||
return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter)
|
||||
}
|
||||
|
||||
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
|
||||
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
|
||||
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) {
|
||||
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID)
|
||||
}
|
||||
|
||||
func (d *Database) DeleteOldNotifications(ctx context.Context) error {
|
||||
@ -747,7 +775,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (d *Database) UpsertPusher(
|
||||
ctx context.Context, p api.Pusher, localpart string,
|
||||
ctx context.Context, p api.Pusher,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
data, err := json.Marshal(p.Data)
|
||||
if err != nil {
|
||||
@ -766,25 +795,26 @@ func (d *Database) UpsertPusher(
|
||||
p.ProfileTag,
|
||||
p.Language,
|
||||
string(data),
|
||||
localpart)
|
||||
localpart,
|
||||
serverName)
|
||||
})
|
||||
}
|
||||
|
||||
// GetPushers returns the pushers matching the given localpart.
|
||||
func (d *Database) GetPushers(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Pusher, error) {
|
||||
return d.Pushers.SelectPushers(ctx, nil, localpart)
|
||||
return d.Pushers.SelectPushers(ctx, nil, localpart, serverName)
|
||||
}
|
||||
|
||||
// RemovePusher deletes one pusher
|
||||
// Invoked when `append` is true and `kind` is null in
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
|
||||
func (d *Database) RemovePusher(
|
||||
ctx context.Context, appid, pushkey, localpart string,
|
||||
ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
|
||||
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
@ -28,27 +29,28 @@ const accountDataSchema = `
|
||||
CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- The room ID for this data (empty string if not specific to a room)
|
||||
room_id TEXT,
|
||||
-- The account data type
|
||||
type TEXT NOT NULL,
|
||||
-- The account data content
|
||||
content TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(localpart, room_id, type)
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
|
||||
`
|
||||
|
||||
const insertAccountDataSQL = `
|
||||
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
||||
INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $5
|
||||
`
|
||||
|
||||
const selectAccountDataSQL = "" +
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountDataByTypeSQL = "" +
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
@ -73,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) InsertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountData(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (
|
||||
/* global */ map[string]json.RawMessage,
|
||||
/* rooms */ map[string]map[string]json.RawMessage,
|
||||
error,
|
||||
) {
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
stmt := s.selectAccountDataByTypeStmt
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -34,7 +34,8 @@ const accountsSchema = `
|
||||
-- Stores data about accounts.
|
||||
CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- When this account was first created, as a unix timestamp (ms resolution).
|
||||
created_ts BIGINT NOT NULL,
|
||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||
@ -48,25 +49,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||
-- TODO:
|
||||
-- upgraded_ts, devices, any email reset stuff?
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
|
||||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
|
||||
|
||||
const deactivateAccountSQL = "" +
|
||||
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
||||
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
|
||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1"
|
||||
|
||||
type accountsStatements struct {
|
||||
db *sql.DB
|
||||
@ -119,16 +122,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) InsertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
hash, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := s.insertAccountStmt
|
||||
|
||||
var err error
|
||||
if accountType != api.AccountTypeAppService {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
|
||||
} else {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -136,42 +140,43 @@ func (s *accountsStatements) InsertAccount(
|
||||
|
||||
return &api.Account{
|
||||
Localpart: localpart,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
ServerName: s.serverName,
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
ServerName: serverName,
|
||||
AppServiceID: appserviceID,
|
||||
AccountType: accountType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) UpdatePassword(
|
||||
ctx context.Context, localpart, passwordHash string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
passwordHash string,
|
||||
) (err error) {
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) DeactivateAccount(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectPasswordHash(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (hash string, err error) {
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*api.Account, error) {
|
||||
var appserviceIDPtr sql.NullString
|
||||
var acc api.Account
|
||||
|
||||
stmt := s.selectAccountByLocalpartStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||
@ -182,20 +187,18 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
||||
acc.AppServiceID = appserviceIDPtr.String
|
||||
}
|
||||
|
||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
acc.ServerName = s.serverName
|
||||
|
||||
acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
|
||||
return &acc, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
|
||||
if err == sql.ErrNoRows {
|
||||
return 1, nil
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
|
||||
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||
CREATE TABLE userapi_accounts (
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
appservice_id TEXT,
|
||||
|
@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
|
||||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
localpart TEXT ,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT,
|
||||
display_name TEXT,
|
||||
last_seen_ts BIGINT,
|
||||
|
@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||
CREATE TABLE userapi_accounts (
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
appservice_id TEXT,
|
||||
|
108
userapi/storage/sqlite3/deltas/2022110411000000_server_names.go
Normal file
108
userapi/storage/sqlite3/deltas/2022110411000000_server_names.go
Normal file
@ -0,0 +1,108 @@
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var serverNamesTables = []string{
|
||||
"userapi_accounts",
|
||||
"userapi_account_datas",
|
||||
"userapi_devices",
|
||||
"userapi_notifications",
|
||||
"userapi_openid_tokens",
|
||||
"userapi_profiles",
|
||||
"userapi_pushers",
|
||||
"userapi_threepids",
|
||||
}
|
||||
|
||||
// These tables have a PRIMARY KEY constraint which we need to drop so
|
||||
// that we can recreate a new unique index that contains the server name.
|
||||
var serverNamesDropPK = []string{
|
||||
"userapi_accounts",
|
||||
"userapi_account_datas",
|
||||
"userapi_profiles",
|
||||
}
|
||||
|
||||
// These indices are out of date so let's drop them. They will get recreated
|
||||
// automatically.
|
||||
var serverNamesDropIndex = []string{
|
||||
"userapi_pusher_localpart_idx",
|
||||
"userapi_pusher_app_id_pushkey_localpart_idx",
|
||||
}
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
var c int
|
||||
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 {
|
||||
continue
|
||||
}
|
||||
q = fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 {
|
||||
logrus.Infof("Table %s already has column, skipping", table)
|
||||
continue
|
||||
}
|
||||
if c == 0 {
|
||||
q = fmt.Sprintf(
|
||||
"ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("add server name to %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, table := range serverNamesDropPK {
|
||||
q := fmt.Sprintf(
|
||||
"SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
var c int
|
||||
var sql string
|
||||
if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 {
|
||||
continue
|
||||
}
|
||||
q = fmt.Sprintf(`
|
||||
%s; -- create temporary table
|
||||
INSERT INTO %s SELECT * FROM %s; -- copy data
|
||||
DROP TABLE %s; -- drop original table
|
||||
ALTER TABLE %s RENAME TO %s; -- rename new table
|
||||
`,
|
||||
strings.Replace(sql, table, table+"_tmp", 1), // create temporary table
|
||||
table+"_tmp", table, // copy data
|
||||
table, // drop original table
|
||||
table+"_tmp", table, // rename new table
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("drop PK from %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
for _, index := range serverNamesDropIndex {
|
||||
q := fmt.Sprintf(
|
||||
"DROP INDEX IF EXISTS %s;",
|
||||
pq.QuoteIdentifier(index),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("drop index %q error: %w", index, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"UPDATE %s SET server_name = %s WHERE server_name = '';",
|
||||
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("write server names to %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@ -40,49 +40,50 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
|
||||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
localpart TEXT ,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT,
|
||||
display_name TEXT,
|
||||
last_seen_ts BIGINT,
|
||||
ip TEXT,
|
||||
user_agent TEXT,
|
||||
|
||||
UNIQUE (localpart, device_id)
|
||||
UNIQUE (localpart, server_name, device_id)
|
||||
);
|
||||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
"INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
|
||||
|
||||
const selectDevicesCountSQL = "" +
|
||||
"SELECT COUNT(access_token) FROM userapi_devices"
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
|
||||
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
|
||||
|
||||
const selectDeviceByIDSQL = "" +
|
||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
|
||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
|
||||
|
||||
const deleteDeviceSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
|
||||
|
||||
const deleteDevicesByLocalpartSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
|
||||
|
||||
const deleteDevicesSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
||||
|
||||
type devicesStatements struct {
|
||||
db *sql.DB
|
||||
@ -135,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) InsertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
@ -146,12 +148,12 @@ func (s *devicesStatements) InsertDevice(
|
||||
return nil, err
|
||||
}
|
||||
sessionID++
|
||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
SessionID: sessionID,
|
||||
LastSeenTS: createdTimeMS,
|
||||
@ -161,44 +163,52 @@ func (s *devicesStatements) InsertDevice(
|
||||
}
|
||||
|
||||
func (s *devicesStatements) DeleteDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) DeleteDevices(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
devices []string,
|
||||
) error {
|
||||
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
||||
orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1)
|
||||
prep, err := s.db.Prepare(orig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, prep)
|
||||
params := make([]interface{}, len(devices)+1)
|
||||
params := make([]interface{}, len(devices)+2)
|
||||
params[0] = localpart
|
||||
params[1] = serverName
|
||||
for i, v := range devices {
|
||||
params[i+1] = v
|
||||
params[i+2] = v
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
||||
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) UpdateDeviceName(
|
||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string, displayName *string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -207,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken(
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
stmt := s.selectDeviceByTokenStmt
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
|
||||
if err == nil {
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
dev.AccessToken = accessToken
|
||||
}
|
||||
return &dev, err
|
||||
@ -219,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken(
|
||||
// selectDeviceByID retrieves a device from the database with the given user
|
||||
// localpart and deviceID
|
||||
func (s *devicesStatements) SelectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var displayName, ip sql.NullString
|
||||
stmt := s.selectDeviceByIDStmt
|
||||
var lastseenTS sql.NullInt64
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||
if err == nil {
|
||||
dev.ID = deviceID
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
@ -243,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID(
|
||||
}
|
||||
|
||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
|
||||
|
||||
if err != nil {
|
||||
return devices, err
|
||||
@ -276,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||
dev.UserAgent = useragent.String
|
||||
}
|
||||
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
|
||||
@ -298,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||
var devices []api.Device
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
var displayName sql.NullString
|
||||
var lastseents sql.NullInt64
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
@ -310,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||
if lastseents.Valid {
|
||||
dev.LastSeenTS = lastseents.Int64
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
|
||||
return err
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ const notificationSchema = `
|
||||
CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
stream_pos BIGINT NOT NULL,
|
||||
@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||
read BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
|
||||
`
|
||||
|
||||
const insertNotificationSQL = "" +
|
||||
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
"INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||
|
||||
const deleteNotificationsUpToSQL = "" +
|
||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
|
||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
|
||||
|
||||
const updateNotificationReadSQL = "" +
|
||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
|
||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
|
||||
|
||||
const selectNotificationSQL = "" +
|
||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
|
||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read ORDER BY localpart, id LIMIT $4"
|
||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
|
||||
"(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read ORDER BY localpart, id LIMIT $5"
|
||||
|
||||
const selectNotificationCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
|
||||
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
|
||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
|
||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read"
|
||||
|
||||
const selectRoomNotificationCountsSQL = "" +
|
||||
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
||||
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
|
||||
"WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
|
||||
|
||||
const cleanNotificationsSQL = "" +
|
||||
"DELETE FROM userapi_notifications WHERE" +
|
||||
@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
|
||||
}
|
||||
|
||||
// Insert inserts a notification into the database.
|
||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||
roomID, tsMS := n.RoomID, n.TS
|
||||
nn := *n
|
||||
// Clears out fields that have their own columns to (1) shrink the
|
||||
@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
|
||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
|
||||
}
|
||||
|
||||
// UpdateRead updates the "read" value for an event.
|
||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
|
||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
|
||||
return nrows > 0, nil
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
|
||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
|
||||
return notifs, maxID, rows.Err()
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
|
||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
|
||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
|
||||
return
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package sqlite3
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
-- The Matrix user ID for this account
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- When the token expires, as a unix timestamp (ms resolution).
|
||||
token_expires_at_ms BIGINT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertOpenIDTokenSQL = "" +
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectOpenIDTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
|
||||
type openIDTokenStatements struct {
|
||||
db *sql.DB
|
||||
@ -56,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (
|
||||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
token, localpart string,
|
||||
token, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
expiresAtMS int64,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
|
||||
return
|
||||
}
|
||||
|
||||
@ -71,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
||||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||
&openIDTokenAttrs.UserID,
|
||||
&localpart, &serverName,
|
||||
&openIDTokenAttrs.ExpiresAtMS,
|
||||
)
|
||||
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||
|
@ -23,36 +23,40 @@ import (
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const profilesSchema = `
|
||||
-- Stores data about accounts profiles.
|
||||
CREATE TABLE IF NOT EXISTS userapi_profiles (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- The display name for this account
|
||||
display_name TEXT,
|
||||
-- The URL of the avatar for this account
|
||||
avatar_url TEXT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
|
||||
`
|
||||
|
||||
const insertProfileSQL = "" +
|
||||
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectProfileByLocalpartSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
||||
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const setAvatarURLSQL = "" +
|
||||
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
|
||||
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" +
|
||||
" RETURNING display_name"
|
||||
|
||||
const setDisplayNameSQL = "" +
|
||||
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
|
||||
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" +
|
||||
" RETURNING avatar_url"
|
||||
|
||||
const selectProfilesBySearchSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
|
||||
type profilesStatements struct {
|
||||
db *sql.DB
|
||||
@ -83,18 +87,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P
|
||||
}
|
||||
|
||||
func (s *profilesStatements) InsertProfile(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*authtypes.Profile, error) {
|
||||
var profile authtypes.Profile
|
||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
|
||||
&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -103,13 +109,16 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
||||
}
|
||||
|
||||
func (s *profilesStatements) SetAvatarURL(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
avatarURL string,
|
||||
) (*authtypes.Profile, bool, error) {
|
||||
profile := &authtypes.Profile{
|
||||
Localpart: localpart,
|
||||
AvatarURL: avatarURL,
|
||||
Localpart: localpart,
|
||||
ServerName: string(serverName),
|
||||
AvatarURL: avatarURL,
|
||||
}
|
||||
old, err := s.SelectProfileByLocalpart(ctx, localpart)
|
||||
old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return old, false, err
|
||||
}
|
||||
@ -117,18 +126,21 @@ func (s *profilesStatements) SetAvatarURL(
|
||||
return old, false, nil
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||
err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
|
||||
err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName)
|
||||
return profile, true, err
|
||||
}
|
||||
|
||||
func (s *profilesStatements) SetDisplayName(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
displayName string,
|
||||
) (*authtypes.Profile, bool, error) {
|
||||
profile := &authtypes.Profile{
|
||||
Localpart: localpart,
|
||||
ServerName: string(serverName),
|
||||
DisplayName: displayName,
|
||||
}
|
||||
old, err := s.SelectProfileByLocalpart(ctx, localpart)
|
||||
old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return old, false, err
|
||||
}
|
||||
@ -136,7 +148,7 @@ func (s *profilesStatements) SetDisplayName(
|
||||
return old, false, nil
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||
err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
|
||||
err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL)
|
||||
return profile, true, err
|
||||
}
|
||||
|
||||
@ -154,7 +166,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var profile authtypes.Profile
|
||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||
if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if profile.Localpart != s.serverNoticesLocalpart {
|
||||
|
@ -25,6 +25,7 @@ import (
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
||||
@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The Matrix user ID localpart for this pusher
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
session_id BIGINT DEFAULT NULL,
|
||||
profile_tag TEXT,
|
||||
kind TEXT NOT NULL,
|
||||
@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
||||
|
||||
-- For faster retrieving by localpart.
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
|
||||
|
||||
-- Pushkey must be unique for a given user and app.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
|
||||
`
|
||||
|
||||
const insertPusherSQL = "" +
|
||||
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
|
||||
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
|
||||
|
||||
const selectPushersSQL = "" +
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const deletePusherSQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
|
||||
|
||||
const deletePushersByAppIdAndPushKeySQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
||||
@ -95,18 +97,19 @@ type pushersStatements struct {
|
||||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *pushersStatements) SelectPushers(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Pusher, error) {
|
||||
pushers := []api.Pusher{}
|
||||
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName)
|
||||
|
||||
if err != nil {
|
||||
return pushers, err
|
||||
@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
|
||||
|
||||
// deletePusher removes a single pusher by pushkey and user localpart.
|
||||
func (s *pushersStatements) DeletePusher(
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,8 @@
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@ -41,18 +43,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||
Up: deltas.UpRenameTables,
|
||||
Down: deltas.DownRenameTables,
|
||||
})
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNames(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountDataTable, err := NewSQLiteAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
|
||||
}
|
||||
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
|
||||
}
|
||||
accountDataTable, err := NewSQLiteAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
|
||||
}
|
||||
devicesTable, err := NewSQLiteDevicesTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
|
||||
@ -93,6 +101,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
|
||||
}
|
||||
|
||||
m = sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names populate",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.Database{
|
||||
AccountDatas: accountDataTable,
|
||||
Accounts: accountsTable,
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
)
|
||||
@ -34,21 +35,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
|
||||
medium TEXT NOT NULL DEFAULT 'email',
|
||||
-- The localpart of the Matrix user ID associated to this 3PID
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(threepid, medium)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart);
|
||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name);
|
||||
`
|
||||
|
||||
const selectLocalpartForThreePIDSQL = "" +
|
||||
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
"SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
const selectThreePIDsForLocalpartSQL = "" +
|
||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
|
||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const insertThreePIDSQL = "" +
|
||||
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const deleteThreePIDSQL = "" +
|
||||
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
@ -79,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
||||
|
||||
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
return "", "", nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -113,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||
}
|
||||
|
||||
func (s *threepidStatements) InsertThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -50,25 +50,25 @@ func Test_AccountData(t *testing.T) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser(t)
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
room := test.NewRoom(t, alice)
|
||||
events := room.Events()
|
||||
|
||||
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
|
||||
err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom)
|
||||
err = db.SaveAccountData(ctx, localpart, domain, room.ID, "m.fully_read", contentRoom)
|
||||
assert.NoError(t, err, "unable to save account data")
|
||||
|
||||
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
|
||||
err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal)
|
||||
err = db.SaveAccountData(ctx, localpart, domain, "", "im.vector.setting.breadcrumbs", contentGlobal)
|
||||
assert.NoError(t, err, "unable to save account data")
|
||||
|
||||
accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read")
|
||||
accountData, err := db.GetAccountDataByType(ctx, localpart, domain, room.ID, "m.fully_read")
|
||||
assert.NoError(t, err, "unable to get account data by type")
|
||||
assert.Equal(t, contentRoom, accountData)
|
||||
|
||||
globalData, roomData, err := db.GetAccountData(ctx, localpart)
|
||||
globalData, roomData, err := db.GetAccountData(ctx, localpart, domain)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
|
||||
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
|
||||
@ -81,78 +81,78 @@ func Test_Accounts(t *testing.T) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
// verify the newly create account is the same as returned by CreateAccount
|
||||
var accGet *api.Account
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing")
|
||||
assert.NoError(t, err, "failed to get account by password")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
|
||||
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "failed to get account by localpart")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
|
||||
// check account availability
|
||||
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
|
||||
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "failed to checkout account availability")
|
||||
assert.Equal(t, false, available)
|
||||
|
||||
available, err = db.CheckAccountAvailability(ctx, "unusedname")
|
||||
available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain)
|
||||
assert.NoError(t, err, "failed to checkout account availability")
|
||||
assert.Equal(t, true, available)
|
||||
|
||||
// get guest account numeric aliceLocalpart
|
||||
first, err := db.GetNewNumericLocalpart(ctx)
|
||||
first, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
|
||||
assert.NoError(t, err, "failed to get new numeric localpart")
|
||||
// Create a new account to verify the numeric localpart is updated
|
||||
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
second, err := db.GetNewNumericLocalpart(ctx)
|
||||
second, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, second, first)
|
||||
|
||||
// update password for alice
|
||||
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
|
||||
err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||
assert.NoError(t, err, "failed to update password")
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||
assert.NoError(t, err, "failed to get account by new password")
|
||||
assert.Equal(t, accAlice, accGet)
|
||||
|
||||
// deactivate account
|
||||
err = db.DeactivateAccount(ctx, aliceLocalpart)
|
||||
err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "failed to deactivate account")
|
||||
// This should fail now, as the account is deactivated
|
||||
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
||||
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||
assert.Error(t, err, "expected an error, got none")
|
||||
|
||||
_, err = db.GetAccountByLocalpart(ctx, "unusename")
|
||||
_, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain)
|
||||
assert.Error(t, err, "expected an error for non existent localpart")
|
||||
|
||||
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart
|
||||
// if there's already a user without a localpart in the database
|
||||
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser)
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeUser)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// test getting a numeric localpart, with an existing user without a localpart
|
||||
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type
|
||||
_, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser)
|
||||
_, err = db.CreateAccount(ctx, "2147483650", aliceDomain, "", "", api.AccountTypeUser)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Now try to create a new guest user
|
||||
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
|
||||
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Devices(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
deviceID := util.RandomString(8)
|
||||
accessToken := util.RandomString(16)
|
||||
@ -161,10 +161,10 @@ func Test_Devices(t *testing.T) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
|
||||
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||
|
||||
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
|
||||
gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
|
||||
|
||||
@ -174,14 +174,14 @@ func Test_Devices(t *testing.T) {
|
||||
|
||||
// create a device without existing device ID
|
||||
accessToken = util.RandomString(16)
|
||||
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
|
||||
deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
|
||||
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
|
||||
|
||||
// Get devices
|
||||
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
|
||||
devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain)
|
||||
assert.NoError(t, err, "unable to get devices by localpart")
|
||||
assert.Equal(t, 2, len(devices))
|
||||
deviceIDs := make([]string, 0, len(devices))
|
||||
@ -195,15 +195,15 @@ func Test_Devices(t *testing.T) {
|
||||
|
||||
// Update device
|
||||
newName := "new display name"
|
||||
err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName)
|
||||
err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName)
|
||||
assert.NoError(t, err, "unable to update device displayname")
|
||||
updatedAfterTimestamp := time.Now().Unix()
|
||||
err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web")
|
||||
err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web")
|
||||
assert.NoError(t, err, "unable to update device last seen")
|
||||
|
||||
deviceWithID.DisplayName = newName
|
||||
deviceWithID.LastSeenIP = "127.0.0.1"
|
||||
gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID)
|
||||
gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, 2, len(devices))
|
||||
assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
|
||||
@ -213,20 +213,20 @@ func Test_Devices(t *testing.T) {
|
||||
// create one more device and remove the devices step by step
|
||||
newDeviceID := util.RandomString(16)
|
||||
accessToken = util.RandomString(16)
|
||||
_, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
|
||||
_, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "")
|
||||
assert.NoError(t, err, "unable to create new device")
|
||||
|
||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, 3, len(devices))
|
||||
|
||||
err = db.RemoveDevices(ctx, localpart, deviceIDs)
|
||||
err = db.RemoveDevices(ctx, localpart, domain, deviceIDs)
|
||||
assert.NoError(t, err, "unable to remove devices")
|
||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
|
||||
assert.NoError(t, err, "unable to get device by id")
|
||||
assert.Equal(t, 1, len(devices))
|
||||
|
||||
deleted, err := db.RemoveAllDevices(ctx, localpart, "")
|
||||
deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "")
|
||||
assert.NoError(t, err, "unable to remove all devices")
|
||||
assert.Equal(t, 1, len(deleted))
|
||||
assert.Equal(t, newDeviceID, deleted[0].ID)
|
||||
@ -364,7 +364,7 @@ func Test_OpenID(t *testing.T) {
|
||||
|
||||
func Test_Profile(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
@ -372,30 +372,33 @@ func Test_Profile(t *testing.T) {
|
||||
defer close()
|
||||
|
||||
// create account, which also creates a profile
|
||||
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
||||
_, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
|
||||
assert.NoError(t, err, "failed to create account")
|
||||
|
||||
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
||||
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get profile by localpart")
|
||||
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart}
|
||||
wantProfile := &authtypes.Profile{
|
||||
Localpart: aliceLocalpart,
|
||||
ServerName: string(aliceDomain),
|
||||
}
|
||||
assert.Equal(t, wantProfile, gotProfile)
|
||||
|
||||
// set avatar & displayname
|
||||
wantProfile.DisplayName = "Alice"
|
||||
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice")
|
||||
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, aliceDomain, "Alice")
|
||||
assert.Equal(t, wantProfile, gotProfile)
|
||||
assert.NoError(t, err, "unable to set displayname")
|
||||
assert.True(t, changed)
|
||||
|
||||
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
||||
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
||||
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
|
||||
assert.NoError(t, err, "unable to set avatar url")
|
||||
assert.Equal(t, wantProfile, gotProfile)
|
||||
assert.True(t, changed)
|
||||
|
||||
// Setting the same avatar again doesn't change anything
|
||||
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
||||
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
||||
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
|
||||
assert.NoError(t, err, "unable to set avatar url")
|
||||
assert.Equal(t, wantProfile, gotProfile)
|
||||
assert.False(t, changed)
|
||||
@ -410,7 +413,7 @@ func Test_Profile(t *testing.T) {
|
||||
|
||||
func Test_Pusher(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
@ -432,11 +435,11 @@ func Test_Pusher(t *testing.T) {
|
||||
ProfileTag: util.RandomString(8),
|
||||
Language: util.RandomString(2),
|
||||
}
|
||||
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart)
|
||||
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to upsert pusher")
|
||||
|
||||
// check it was actually persisted
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get pushers")
|
||||
assert.Equal(t, i+1, len(gotPushers))
|
||||
assert.Equal(t, wantPusher, gotPushers[i])
|
||||
@ -444,16 +447,16 @@ func Test_Pusher(t *testing.T) {
|
||||
}
|
||||
|
||||
// remove single pusher
|
||||
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart)
|
||||
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to remove pusher")
|
||||
gotPushers, err := db.GetPushers(ctx, aliceLocalpart)
|
||||
gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get pushers")
|
||||
assert.Equal(t, 1, len(gotPushers))
|
||||
|
||||
// remove last pusher
|
||||
err = db.RemovePushers(ctx, appID, pushKeys[1])
|
||||
assert.NoError(t, err, "unable to remove pusher")
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
|
||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get pushers")
|
||||
assert.Equal(t, 0, len(gotPushers))
|
||||
})
|
||||
@ -461,7 +464,7 @@ func Test_Pusher(t *testing.T) {
|
||||
|
||||
func Test_ThreePID(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
@ -469,15 +472,16 @@ func Test_ThreePID(t *testing.T) {
|
||||
defer close()
|
||||
threePID := util.RandomString(8)
|
||||
medium := util.RandomString(8)
|
||||
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium)
|
||||
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium)
|
||||
assert.NoError(t, err, "unable to save threepid association")
|
||||
|
||||
// get the stored threepid
|
||||
gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
|
||||
gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
|
||||
assert.NoError(t, err, "unable to get localpart for threepid")
|
||||
assert.Equal(t, aliceLocalpart, gotLocalpart)
|
||||
assert.Equal(t, aliceDomain, gotDomain)
|
||||
|
||||
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
|
||||
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get threepids for localpart")
|
||||
assert.Equal(t, 1, len(threepids))
|
||||
assert.Equal(t, authtypes.ThreePID{
|
||||
@ -490,7 +494,7 @@ func Test_ThreePID(t *testing.T) {
|
||||
assert.NoError(t, err, "unexpected error")
|
||||
|
||||
// verify it was deleted
|
||||
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
|
||||
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||
assert.NoError(t, err, "unable to get threepids for localpart")
|
||||
assert.Equal(t, 0, len(threepids))
|
||||
})
|
||||
@ -498,7 +502,7 @@ func Test_ThreePID(t *testing.T) {
|
||||
|
||||
func Test_Notification(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
assert.NoError(t, err)
|
||||
room := test.NewRoom(t, alice)
|
||||
room2 := test.NewRoom(t, alice)
|
||||
@ -526,34 +530,34 @@ func Test_Notification(t *testing.T) {
|
||||
RoomID: roomID,
|
||||
TS: gomatrixserverlib.AsTimestamp(ts),
|
||||
}
|
||||
err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification)
|
||||
err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification)
|
||||
assert.NoError(t, err, "unable to insert notification")
|
||||
}
|
||||
|
||||
// get notifications
|
||||
count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications)
|
||||
count, err := db.GetNotificationCount(ctx, aliceLocalpart, aliceDomain, tables.AllNotifications)
|
||||
assert.NoError(t, err, "unable to get notification count")
|
||||
assert.Equal(t, int64(10), count)
|
||||
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications)
|
||||
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, aliceDomain, 0, 15, tables.AllNotifications)
|
||||
assert.NoError(t, err, "unable to get notifications")
|
||||
assert.Equal(t, int64(10), count)
|
||||
assert.Equal(t, 10, len(notifs))
|
||||
// ... for a specific room
|
||||
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
||||
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
|
||||
assert.NoError(t, err, "unable to get notifications for room")
|
||||
assert.Equal(t, int64(4), total)
|
||||
|
||||
// mark notification as read
|
||||
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true)
|
||||
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, aliceDomain, room2.ID, 7, true)
|
||||
assert.NoError(t, err, "unable to set notifications read")
|
||||
assert.True(t, affected)
|
||||
|
||||
// this should delete 2 notifications
|
||||
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8)
|
||||
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, aliceDomain, room2.ID, 8)
|
||||
assert.NoError(t, err, "unable to set notifications read")
|
||||
assert.True(t, affected)
|
||||
|
||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
|
||||
assert.NoError(t, err, "unable to get notifications for room")
|
||||
assert.Equal(t, int64(2), total)
|
||||
|
||||
@ -562,7 +566,7 @@ func Test_Notification(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// this should now return 0 notifications
|
||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
|
||||
assert.NoError(t, err, "unable to get notifications for room")
|
||||
assert.Equal(t, int64(0), total)
|
||||
})
|
||||
|
@ -28,31 +28,31 @@ import (
|
||||
)
|
||||
|
||||
type AccountDataTable interface {
|
||||
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error
|
||||
SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
||||
SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
|
||||
SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
||||
SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
|
||||
}
|
||||
|
||||
type AccountsTable interface {
|
||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
||||
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||
UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error)
|
||||
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||
SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error)
|
||||
SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
|
||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error)
|
||||
}
|
||||
|
||||
type DevicesTable interface {
|
||||
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
||||
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
|
||||
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
|
||||
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
|
||||
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error
|
||||
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
||||
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
|
||||
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error
|
||||
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
|
||||
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
|
||||
SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error)
|
||||
SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
|
||||
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error)
|
||||
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error
|
||||
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
|
||||
}
|
||||
|
||||
type KeyBackupTable interface {
|
||||
@ -79,40 +79,40 @@ type LoginTokenTable interface {
|
||||
}
|
||||
|
||||
type OpenIDTable interface {
|
||||
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error)
|
||||
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error)
|
||||
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||
}
|
||||
|
||||
type ProfileTable interface {
|
||||
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
|
||||
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
|
||||
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
|
||||
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
|
||||
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
||||
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
|
||||
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||
}
|
||||
|
||||
type ThreePIDTable interface {
|
||||
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error)
|
||||
SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
|
||||
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
|
||||
SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
|
||||
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
|
||||
}
|
||||
|
||||
type PusherTable interface {
|
||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
|
||||
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
|
||||
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
|
||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
|
||||
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
|
||||
}
|
||||
|
||||
type NotificationTable interface {
|
||||
Clean(ctx context.Context, txn *sql.Tx) error
|
||||
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error
|
||||
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error)
|
||||
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error)
|
||||
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
|
||||
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
|
||||
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
|
||||
Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error
|
||||
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error)
|
||||
UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error)
|
||||
Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
|
||||
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter NotificationFilter) (int64, error)
|
||||
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
|
||||
}
|
||||
|
||||
type StatsTable interface {
|
||||
|
@ -79,6 +79,7 @@ func mustMakeAccountAndDevice(
|
||||
accDB tables.AccountsTable,
|
||||
devDB tables.DevicesTable,
|
||||
localpart string,
|
||||
serverName gomatrixserverlib.ServerName, // nolint:unparam
|
||||
accType api.AccountType,
|
||||
userAgent string,
|
||||
) {
|
||||
@ -89,11 +90,11 @@ func mustMakeAccountAndDevice(
|
||||
appServiceID = util.RandomString(16)
|
||||
}
|
||||
|
||||
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType)
|
||||
_, err := accDB.InsertAccount(ctx, nil, localpart, serverName, "", appServiceID, accType)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create account: %v", err)
|
||||
}
|
||||
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent)
|
||||
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create device: %v", err)
|
||||
}
|
||||
@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Want Users", func(t *testing.T) {
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko")
|
||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko")
|
||||
gotStats, _, err := statsDB.UserStatistics(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
|
@ -80,14 +80,14 @@ func TestQueryProfile(t *testing.T) {
|
||||
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
|
||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
|
||||
defer close()
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser)
|
||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
|
||||
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil {
|
||||
t.Fatalf("failed to set avatar url: %s", err)
|
||||
}
|
||||
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
|
||||
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil {
|
||||
t.Fatalf("failed to set display name: %s", err)
|
||||
}
|
||||
|
||||
@ -164,7 +164,7 @@ func TestPasswordlessLoginFails(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||
defer close()
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", "", "", api.AccountTypeAppService)
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
@ -190,7 +190,7 @@ func TestLoginToken(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||
defer close()
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser)
|
||||
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to make account: %s", err)
|
||||
}
|
||||
|
@ -2,10 +2,12 @@ package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@ -17,10 +19,10 @@ type PusherDevice struct {
|
||||
}
|
||||
|
||||
// GetPushDevices pushes to the configured devices of a local user.
|
||||
func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
|
||||
pushers, err := db.GetPushers(ctx, localpart)
|
||||
func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
|
||||
pushers, err := db.GetPushers(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("db.GetPushers: %w", err)
|
||||
}
|
||||
|
||||
devices := make([]*PusherDevice, 0, len(pushers))
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@ -16,8 +17,8 @@ import (
|
||||
// a single goroutine is started when talking to the Push
|
||||
// gateways. There is no way to know when the background goroutine has
|
||||
// finished.
|
||||
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error {
|
||||
pusherDevices, err := GetPushDevices(ctx, localpart, nil, db)
|
||||
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
|
||||
pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -26,7 +27,7 @@ func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, loc
|
||||
return nil
|
||||
}
|
||||
|
||||
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications)
|
||||
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, serverName, tables.AllNotifications)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user