Virtual hosting schema and logic changes (#2876)

Note that virtual users cannot federate correctly yet.
This commit is contained in:
Neil Alexander 2022-11-11 16:41:37 +00:00 committed by GitHub
parent e177e0ae73
commit 529df30b56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
62 changed files with 1250 additions and 732 deletions

View File

@ -32,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
// AddInternalRoutes registers HTTP handlers for internal API calls
@ -74,7 +75,7 @@ func NewInternalAPI(
// events to be sent out.
for _, appservice := range base.Cfg.Derived.ApplicationServices {
// Create bot account for this AS if it doesn't already exist
if err := generateAppServiceAccount(userAPI, appservice); err != nil {
if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil {
logrus.WithFields(logrus.Fields{
"appservice": appservice.ID,
}).WithError(err).Panicf("failed to generate bot account for appservice")
@ -101,11 +102,13 @@ func NewInternalAPI(
func generateAppServiceAccount(
userAPI userapi.AppserviceUserAPI,
as config.ApplicationService,
serverName gomatrixserverlib.ServerName,
) error {
var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeAppService,
Localpart: as.SenderLocalpart,
ServerName: serverName,
AppServiceID: as.ID,
OnConflict: userapi.ConflictUpdate,
}, &accRes)
@ -115,6 +118,7 @@ func generateAppServiceAccount(
var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
Localpart: as.SenderLocalpart,
ServerName: serverName,
AccessToken: as.ASToken,
DeviceID: &as.SenderLocalpart,
DeviceDisplayName: &as.SenderLocalpart,

View File

@ -61,7 +61,7 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte)
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
r := req.(*PasswordRequest)
username := strings.ToLower(r.Username())
username := r.Username()
if username == "" {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
@ -74,32 +74,43 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
JSON: jsonerror.BadJSON("A password must be supplied."),
}
}
localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.InvalidUsername(err.Error()),
}
}
if !t.Config.Matrix.IsLocalServerName(domain) {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.InvalidUsername("The server name is not known."),
}
}
// Squash username to all lowercase letters
res := &api.QueryAccountByPasswordResponse{}
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res)
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: strings.ToLower(localpart),
ServerName: domain,
PlaintextPassword: r.Password,
}, res)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to fetch account by password"),
JSON: jsonerror.Unknown("Unable to fetch account by password."),
}
}
if !res.Exists {
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: localpart,
ServerName: domain,
PlaintextPassword: r.Password,
}, res)
if err != nil {
return nil, &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to fetch account by password"),
JSON: jsonerror.Unknown("Unable to fetch account by password."),
}
}
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows

View File

@ -102,6 +102,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
if err != nil {
return util.ErrorResponse(err)
}
serverName := cfg.Matrix.ServerName
localpart, ok := vars["localpart"]
if !ok {
return util.JSONResponse{
@ -109,6 +110,9 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
JSON: jsonerror.MissingArgument("Expecting user localpart."),
}
}
if l, s, err := gomatrixserverlib.SplitID('@', localpart); err == nil {
localpart, serverName = l, s
}
request := struct {
Password string `json:"password"`
}{}
@ -126,6 +130,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
}
updateReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart,
ServerName: serverName,
Password: request.Password,
LogoutDevices: true,
}

View File

@ -100,6 +100,7 @@ func completeAuth(
DeviceID: login.DeviceID,
AccessToken: token,
Localpart: localpart,
ServerName: serverName,
IPAddr: ipAddr,
UserAgent: userAgent,
}, &performRes)

View File

@ -40,13 +40,14 @@ func GetNotifications(
}
var queryRes userapi.QueryNotificationsResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError()
}
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
Localpart: localpart,
ServerName: domain,
From: req.URL.Query().Get("from"),
Limit: int(limit),
Only: req.URL.Query().Get("only"),

View File

@ -86,7 +86,7 @@ func Password(
}
// Get the local part.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
@ -95,6 +95,7 @@ func Password(
// Ask the user API to perform the password change.
passwordReq := &api.PerformPasswordUpdateRequest{
Localpart: localpart,
ServerName: domain,
Password: r.NewPassword,
}
passwordRes := &api.PerformPasswordUpdateResponse{}
@ -123,6 +124,7 @@ func Password(
pushersReq := &api.PerformPusherDeletionRequest{
Localpart: localpart,
ServerName: domain,
SessionID: device.SessionID,
}
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {

View File

@ -31,13 +31,14 @@ func GetPushers(
userAPI userapi.ClientUserAPI,
) util.JSONResponse {
var queryRes userapi.QueryPushersResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError()
}
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
Localpart: localpart,
ServerName: domain,
}, &queryRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
@ -59,7 +60,7 @@ func SetPusher(
req *http.Request, device *userapi.Device,
userAPI userapi.ClientUserAPI,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError()
@ -93,6 +94,7 @@ func SetPusher(
}
body.Localpart = localpart
body.ServerName = domain
body.SessionID = device.SessionID
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
if err != nil {

View File

@ -588,12 +588,15 @@ func Register(
}
// Auto generate a numeric username if r.Username is empty
if r.Username == "" {
res := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil {
nreq := &userapi.QueryNumericLocalpartRequest{
ServerName: cfg.Matrix.ServerName, // TODO: might not be right
}
nres := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
return jsonerror.InternalServerError()
}
r.Username = strconv.FormatInt(res.ID, 10)
r.Username = strconv.FormatInt(nres.ID, 10)
}
// Is this an appservice registration? It will be if the access
@ -676,6 +679,7 @@ func handleGuestRegistration(
var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{
Localpart: res.Account.Localpart,
ServerName: res.Account.ServerName,
DeviceDisplayName: r.InitialDisplayName,
AccessToken: token,
IPAddr: req.RemoteAddr,

View File

@ -157,7 +157,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}",
dendriteAdminRouter.Handle("/admin/resetPassword/{userID}",
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminResetPassword(req, cfg, device, userAPI)
}),

View File

@ -286,6 +286,7 @@ func getSenderDevice(
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeUser,
Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
OnConflict: userapi.ConflictUpdate,
}, &accRes)
if err != nil {
@ -296,6 +297,7 @@ func getSenderDevice(
avatarRes := &userapi.PerformSetAvatarURLResponse{}
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
}, avatarRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
@ -308,6 +310,7 @@ func getSenderDevice(
displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
DisplayName: cfg.Matrix.ServerNotices.DisplayName,
}, displayNameRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
@ -353,6 +356,7 @@ func getSenderDevice(
var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart,
AccessToken: token,
NoDeviceListUpdate: true,

View File

@ -136,7 +136,7 @@ func CheckAndSave3PIDAssociation(
}
// Save the association in the database
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
@ -145,6 +145,7 @@ func CheckAndSave3PIDAssociation(
if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{
ThreePID: address,
Localpart: localpart,
ServerName: domain,
Medium: medium,
}, &struct{}{}); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed")
@ -161,7 +162,7 @@ func CheckAndSave3PIDAssociation(
func GetAssociated3PIDs(
req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
@ -170,6 +171,7 @@ func GetAssociated3PIDs(
res := &api.QueryThreePIDsForLocalpartResponse{}
err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{
Localpart: localpart,
ServerName: domain,
}, res)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed")

View File

@ -120,15 +120,23 @@ func NewInternalAPI(
js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
signingInfo := map[gomatrixserverlib.ServerName]*queue.SigningInfo{}
for _, serverName := range append(
[]gomatrixserverlib.ServerName{base.Cfg.Global.ServerName},
base.Cfg.Global.SecondaryServerNames...,
) {
signingInfo[serverName] = &queue.SigningInfo{
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
ServerName: serverName,
}
}
queues := queue.NewOutgoingQueues(
federationDB, base.ProcessContext,
cfg.Matrix.DisableFederation,
cfg.Matrix.ServerName, federation, rsAPI, &stats,
&queue.SigningInfo{
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
ServerName: cfg.Matrix.ServerName,
},
signingInfo,
)
rsConsumer := consumers.NewOutputRoomEventConsumer(

View File

@ -137,7 +137,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err
}
// Get the keys and JSON-ify them.
keys := routing.LocalKeys(s.config)
keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host))
body, err := json.MarshalIndent(keys.JSON, "", " ")
if err != nil {
return nil, err

View File

@ -50,7 +50,7 @@ type destinationQueue struct {
queues *OutgoingQueues
db storage.Database
process *process.ProcessContext
signing *SigningInfo
signing map[gomatrixserverlib.ServerName]*SigningInfo
rsAPI api.FederationRoomserverAPI
client fedapi.FederationClient // federation client
origin gomatrixserverlib.ServerName // origin of requests

View File

@ -46,7 +46,7 @@ type OutgoingQueues struct {
origin gomatrixserverlib.ServerName
client fedapi.FederationClient
statistics *statistics.Statistics
signing *SigningInfo
signing map[gomatrixserverlib.ServerName]*SigningInfo
queuesMutex sync.Mutex // protects the below
queues map[gomatrixserverlib.ServerName]*destinationQueue
}
@ -91,7 +91,7 @@ func NewOutgoingQueues(
client fedapi.FederationClient,
rsAPI api.FederationRoomserverAPI,
statistics *statistics.Statistics,
signing *SigningInfo,
signing map[gomatrixserverlib.ServerName]*SigningInfo,
) *OutgoingQueues {
queues := &OutgoingQueues{
disabled: disabled,
@ -199,11 +199,10 @@ func (oqs *OutgoingQueues) SendEvent(
log.Trace("Federation is disabled, not sending event")
return nil
}
if origin != oqs.origin {
// TODO: Support virtual hosting; gh issue #577.
if _, ok := oqs.signing[origin]; !ok {
return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q",
origin, oqs.origin,
"sendevent: unexpected server to send as %q",
origin,
)
}
@ -214,7 +213,9 @@ func (oqs *OutgoingQueues) SendEvent(
destmap[d] = struct{}{}
}
delete(destmap, oqs.origin)
delete(destmap, oqs.signing.ServerName)
for local := range oqs.signing {
delete(destmap, local)
}
// Check if any of the destinations are prohibited by server ACLs.
for destination := range destmap {
@ -288,11 +289,10 @@ func (oqs *OutgoingQueues) SendEDU(
log.Trace("Federation is disabled, not sending EDU")
return nil
}
if origin != oqs.origin {
// TODO: Support virtual hosting; gh issue #577.
if _, ok := oqs.signing[origin]; !ok {
return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q",
origin, oqs.origin,
"sendevent: unexpected server to send as %q",
origin,
)
}
@ -303,7 +303,9 @@ func (oqs *OutgoingQueues) SendEDU(
destmap[d] = struct{}{}
}
delete(destmap, oqs.origin)
delete(destmap, oqs.signing.ServerName)
for local := range oqs.signing {
delete(destmap, local)
}
// There is absolutely no guarantee that the EDU will have a room_id
// field, as it is not required by the spec. However, if it *does*

View File

@ -350,10 +350,12 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T
}
rs := &stubFederationRoomServerAPI{}
stats := statistics.NewStatistics(db, failuresUntilBlacklist)
signingInfo := &SigningInfo{
signingInfo := map[gomatrixserverlib.ServerName]*SigningInfo{
"localhost": {
KeyID: "ed21019:auto",
PrivateKey: test.PrivateKeyA,
ServerName: "localhost",
},
}
queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo)

View File

@ -16,6 +16,7 @@ package routing
import (
"encoding/json"
"fmt"
"net/http"
"time"
@ -134,18 +135,21 @@ func ClaimOneTimeKeys(
// LocalKeys returns the local keys for the server.
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
func LocalKeys(cfg *config.FederationAPI) util.JSONResponse {
keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse {
keys, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
if err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: keys}
}
func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
var keys gomatrixserverlib.ServerKeys
if !cfg.Matrix.IsLocalServerName(serverName) {
return nil, fmt.Errorf("server name not known")
}
keys.ServerName = cfg.Matrix.ServerName
keys.ServerName = serverName
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil)
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
@ -172,7 +176,7 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
}
keys.Raw, err = gomatrixserverlib.SignJSON(
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign,
string(serverName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign,
)
if err != nil {
return nil, err
@ -186,6 +190,14 @@ func NotaryKeys(
fsAPI federationAPI.FederationInternalAPI,
req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
) util.JSONResponse {
serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal
if !cfg.Matrix.IsLocalServerName(serverName) {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Server name not known"),
}
}
if req == nil {
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
@ -201,7 +213,7 @@ func NotaryKeys(
for serverName, kidToCriteria := range req.ServerKeys {
var keyList []gomatrixserverlib.ServerKeys
if serverName == cfg.Matrix.ServerName {
if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil {
if k, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil {
keyList = append(keyList, *k)
} else {
return util.ErrorResponse(err)

View File

@ -74,7 +74,7 @@ func Setup(
}
localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse {
return LocalKeys(cfg)
return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host))
})
notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse {

View File

@ -33,12 +33,13 @@ import (
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/producers"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
type KeyInternalAPI struct {
DB storage.Database
ThisServer gomatrixserverlib.ServerName
Cfg *config.KeyServer
FedClient fedsenderapi.KeyserverFederationAPI
UserAPI userapi.KeyserverUserAPI
Producer *producers.KeyChange
@ -95,8 +96,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
nested[userID] = val
domainToDeviceKeys[string(serverName)] = nested
}
for domain, local := range domainToDeviceKeys {
if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
// claim local keys
if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
keys, err := a.DB.ClaimKeys(ctx, local)
if err != nil {
res.Error = &api.KeyError{
@ -117,7 +121,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
}
}
delete(domainToDeviceKeys, string(a.ThisServer))
delete(domainToDeviceKeys, domain)
}
if len(domainToDeviceKeys) > 0 {
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
@ -258,7 +262,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
}
domain := string(serverName)
// query local devices
if serverName == a.ThisServer {
if a.Cfg.Matrix.IsLocalServerName(serverName) {
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil {
res.Error = &api.KeyError{
@ -437,13 +441,13 @@ func (a *KeyInternalAPI) queryRemoteKeys(
domains := map[string]struct{}{}
for domain := range domainToDeviceKeys {
if domain == string(a.ThisServer) {
if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
domains[domain] = struct{}{}
}
for domain := range domainToCrossSigningKeys {
if domain == string(a.ThisServer) {
if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
domains[domain] = struct{}{}
@ -689,7 +693,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
if err != nil {
continue // ignore invalid users
}
if serverName != a.ThisServer {
if !a.Cfg.Matrix.IsLocalServerName(serverName) {
continue // ignore remote users
}
if len(key.KeyJSON) == 0 {

View File

@ -54,7 +54,7 @@ func NewInternalAPI(
}
ap := &internal.KeyInternalAPI{
DB: db,
ThisServer: cfg.Matrix.ServerName,
Cfg: cfg,
FedClient: fedClient,
Producer: keyChangeProducer,
}

View File

@ -78,7 +78,7 @@ type ClientUserAPI interface {
QueryAcccessTokenAPI
LoginTokenInternalAPI
UserLoginAPI
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
@ -336,6 +336,7 @@ type PerformAccountCreationResponse struct {
// PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformPasswordUpdateRequest struct {
Localpart string // Required: The localpart for this account.
ServerName gomatrixserverlib.ServerName // Required: The domain for this account.
Password string // Required: The new password to set.
LogoutDevices bool // Optional: Whether to log out all user devices.
}
@ -519,6 +520,7 @@ const (
type QueryPushersRequest struct {
Localpart string
ServerName gomatrixserverlib.ServerName
}
type QueryPushersResponse struct {
@ -528,11 +530,13 @@ type QueryPushersResponse struct {
type PerformPusherSetRequest struct {
Pusher // Anonymous field because that's how clientapi unmarshals it.
Localpart string
ServerName gomatrixserverlib.ServerName
Append bool `json:"append"`
}
type PerformPusherDeletionRequest struct {
Localpart string
ServerName gomatrixserverlib.ServerName
SessionID int64
}
@ -572,6 +576,7 @@ type QueryPushRulesResponse struct {
type QueryNotificationsRequest struct {
Localpart string `json:"localpart"` // Required.
ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required.
From string `json:"from,omitempty"`
Limit int `json:"limit,omitempty"`
Only string `json:"only,omitempty"`
@ -601,12 +606,17 @@ type PerformSetAvatarURLResponse struct {
Changed bool `json:"changed"`
}
type QueryNumericLocalpartRequest struct {
ServerName gomatrixserverlib.ServerName
}
type QueryNumericLocalpartResponse struct {
ID int64
}
type QueryAccountAvailabilityRequest struct {
Localpart string
ServerName gomatrixserverlib.ServerName
}
type QueryAccountAvailabilityResponse struct {
@ -614,7 +624,9 @@ type QueryAccountAvailabilityResponse struct {
}
type QueryAccountByPasswordRequest struct {
Localpart, PlaintextPassword string
Localpart string
ServerName gomatrixserverlib.ServerName
PlaintextPassword string
}
type QueryAccountByPasswordResponse struct {
@ -639,10 +651,12 @@ type QueryLocalpartForThreePIDRequest struct {
type QueryLocalpartForThreePIDResponse struct {
Localpart string
ServerName gomatrixserverlib.ServerName
}
type QueryThreePIDsForLocalpartRequest struct {
Localpart string
ServerName gomatrixserverlib.ServerName
}
type QueryThreePIDsForLocalpartResponse struct {
@ -652,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct {
type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest
type PerformSaveThreePIDAssociationRequest struct {
ThreePID, Localpart, Medium string
ThreePID string
Localpart string
ServerName gomatrixserverlib.ServerName
Medium string
}

View File

@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet
return err
}
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error {
err := t.Impl.QueryNumericLocalpart(ctx, res)
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error {
err := t.Impl.QueryNumericLocalpart(ctx, req, res)
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
return err
}

View File

@ -104,7 +104,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
return false
}
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
if err != nil {
log.WithError(err).Error("userapi EDU consumer")
return false
@ -118,7 +118,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
if !updated {
return true
}
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, domain, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}

View File

@ -192,25 +192,25 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy
func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error {
for _, membership := range localMembers {
// Copy any existing push rules from old -> new room
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
return err
}
// preserve m.direct room state
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil {
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil {
return err
}
// copy existing m.tag entries, if any
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
return err
}
}
return nil
}
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error {
pushRules, err := s.db.QueryPushRules(ctx, localpart)
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error {
pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName)
if err != nil {
return fmt.Errorf("failed to query pushrules for user: %w", err)
}
@ -229,7 +229,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
if err != nil {
return err
}
if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil {
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil {
return fmt.Errorf("failed to update pushrules: %w", err)
}
}
@ -237,13 +237,13 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
}
// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error {
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error {
// this is most likely not a DM, so skip updating m.direct state
if roomSize > 2 {
return nil
}
// Get direct message state
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct")
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct")
if err != nil {
return fmt.Errorf("failed to get m.direct from database: %w", err)
}
@ -267,7 +267,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
if err != nil {
return true
}
if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil {
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil {
return true
}
}
@ -279,15 +279,15 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
return nil
}
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error {
tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag")
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error {
tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag")
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
if tag == nil {
return nil
}
return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag)
return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag)
}
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
@ -492,11 +492,11 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
if err != nil {
return err
return fmt.Errorf("s.evaluatePushRules: %w", err)
}
a, tweaks, err := pushrules.ActionsToTweaks(actions)
if err != nil {
return err
return fmt.Errorf("pushrules.ActionsToTweaks: %w", err)
}
// TODO: support coalescing.
if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
@ -508,9 +508,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
return nil
}
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks)
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks)
if err != nil {
return err
return fmt.Errorf("s.localPushDevices: %w", err)
}
n := &api.Notification{
@ -527,18 +527,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
RoomID: event.RoomID(),
TS: gomatrixserverlib.AsTimestamp(time.Now()),
}
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil {
return err
if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil {
return fmt.Errorf("s.db.InsertNotification: %w", err)
}
if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
return err
return fmt.Errorf("s.syncProducer.GetAndSendNotificationData: %w", err)
}
// We do this after InsertNotification. Thus, this should always return >=1.
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications)
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, mem.Domain, tables.AllNotifications)
if err != nil {
return err
return fmt.Errorf("s.db.GetNotificationCount: %w", err)
}
log.WithFields(log.Fields{
@ -589,7 +589,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
}
if len(rejected) > 0 {
s.deleteRejectedPushers(ctx, rejected, mem.Localpart)
s.deleteRejectedPushers(ctx, rejected, mem.Localpart, mem.Domain)
}
}()
@ -606,7 +606,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
}
// Get accountdata to check if the event.Sender() is ignored by mem.LocalPart
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list")
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, mem.Domain, "", "m.ignored_user_list")
if err != nil {
return nil, err
}
@ -621,7 +621,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
return nil, fmt.Errorf("user %s is ignored", sender)
}
}
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart)
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain)
if err != nil {
return nil, err
}
@ -693,10 +693,10 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
// localPushDevices pushes to the configured devices of a local
// user. The map keys are [url][format].
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db)
if err != nil {
return nil, "", err
return nil, "", fmt.Errorf("util.GetPushDevices: %w", err)
}
var profileTag string
@ -791,7 +791,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri
}
// deleteRejectedPushers deletes the pushers associated with the given devices.
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) {
log.WithFields(log.Fields{
"localpart": localpart,
"app_id0": devices[0].AppID,
@ -799,7 +799,7 @@ func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, dev
}).Warnf("Deleting pushers rejected by the HTTP push gateway")
for _, d := range devices {
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil {
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart, serverName); err != nil {
log.WithFields(log.Fields{
"localpart": localpart,
}).WithError(err).Errorf("Unable to delete rejected pusher")

View File

@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
}
if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil {
if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil {
util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed")
return fmt.Errorf("failed to save account data: %w", err)
}
@ -108,7 +108,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
return nil
}
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
if err != nil {
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
return err
@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
return nil
}
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
return err
}
@ -175,8 +175,10 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
if serverName == "" {
serverName = a.Config.Matrix.ServerName
}
// XXXX: Use the server name here
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if !a.Config.Matrix.IsLocalServerName(serverName) {
return fmt.Errorf("server name %s is not local", serverName)
}
acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict {
@ -215,8 +217,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil {
return fmt.Errorf("a.DB.SetDisplayName: %w", err)
}
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
@ -227,11 +229,14 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
if !a.Config.Matrix.IsLocalServerName(req.ServerName) {
return fmt.Errorf("server name %s is not local", req.ServerName)
}
if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
return err
}
if req.LogoutDevices {
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil {
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil {
return err
}
}
@ -244,14 +249,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
if serverName == "" {
serverName = a.Config.Matrix.ServerName
}
_ = serverName
// XXXX: Use the server name here
if !a.Config.Matrix.IsLocalServerName(serverName) {
return fmt.Errorf("server name %s is not local", serverName)
}
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
"device_id": req.DeviceID,
"display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation")
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil {
return err
}
@ -276,12 +282,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
} else {
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs)
}
if err != nil {
return err
@ -335,23 +341,29 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
req *api.PerformLastSeenUpdateRequest,
res *api.PerformLastSeenUpdateResponse,
) error {
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("server name %s is not local", domain)
}
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
}
return nil
}
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err
}
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("server name %s is not local", domain)
}
dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID)
if err == sql.ErrNoRows {
res.DeviceExists = false
return nil
@ -366,7 +378,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil
}
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
@ -406,7 +418,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
}
prof, err := a.DB.GetProfileByLocalpart(ctx, local)
prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain)
if err != nil {
if err == sql.ErrNoRows {
return nil
@ -457,7 +469,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
}
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain)
if err != nil {
return err
}
@ -476,7 +488,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
if req.DataType != "" {
var data json.RawMessage
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType)
if err != nil {
return err
}
@ -494,7 +506,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
return nil
}
global, rooms, err := a.DB.GetAccountData(ctx, local)
global, rooms, err := a.DB.GetAccountData(ctx, local, domain)
if err != nil {
return err
}
@ -527,7 +539,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
if !a.Config.Matrix.IsLocalServerName(domain) {
return nil
}
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain)
if err != nil {
return err
}
@ -561,14 +573,14 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
AccountType: api.AccountTypeAppService,
}
localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
if err != nil {
return nil, err
}
if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain)
// Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@ -620,7 +632,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
return err
}
err := a.DB.DeactivateAccount(ctx, req.Localpart)
err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
res.AccountDeactivated = err == nil
return err
}
@ -783,7 +795,7 @@ func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.Query
if req.Only == "highlight" {
filter = tables.HighlightNotifications
}
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter)
if err != nil {
return err
}
@ -811,23 +823,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
}
}
if req.Pusher.Kind == "" {
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName)
}
if req.Pusher.PushKeyTS == 0 {
req.Pusher.PushKeyTS = int64(time.Now().Unix())
}
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName)
}
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
pushers, err := a.DB.GetPushers(ctx, req.Localpart)
pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
if err != nil {
return err
}
for i := range pushers {
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
if pushers[i].SessionID != req.SessionID {
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName)
if err != nil {
return err
}
@ -838,7 +850,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
var err error
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
return err
}
@ -864,11 +876,11 @@ func (a *UserInternalAPI) PerformPushRulesPut(
}
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
}
pushRules, err := a.DB.QueryPushRules(ctx, localpart)
pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain)
if err != nil {
return fmt.Errorf("failed to query push rules: %w", err)
}
@ -877,14 +889,14 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
}
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL)
res.Profile = profile
res.Changed = changed
return err
}
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
id, err := a.DB.GetNewNumericLocalpart(ctx)
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error {
id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName)
if err != nil {
return err
}
@ -894,12 +906,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
var err error
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart)
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName)
return err
}
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword)
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword)
switch err {
case sql.ErrNoRows: // user does not exist
return nil
@ -915,23 +927,24 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
}
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName)
res.Profile = profile
res.Changed = changed
return err
}
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
if err != nil {
return err
}
res.Localpart = localpart
res.ServerName = domain
return nil
}
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart)
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName)
if err != nil {
return err
}
@ -944,7 +957,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe
}
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium)
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
}
const pushRulesAccountDataType = "m.push_rules"

View File

@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
}
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil {
res.Data = nil
if err == sql.ErrNoRows {
return nil

View File

@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL(
func (h *httpUserInternalAPI) QueryNumericLocalpart(
ctx context.Context,
request *api.QueryNumericLocalpartRequest,
response *api.QueryNumericLocalpartResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
h.httpClient, ctx, &struct{}{}, response,
h.httpClient, ctx, request, response,
)
}

View File

@ -15,12 +15,9 @@
package inthttp
import (
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
// nolint: gocyclo
@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
)
// TODO: Look at the shape of this
internalAPIMux.Handle(QueryNumericLocalpartPath,
httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse {
response := api.QueryNumericLocalpartResponse{}
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
internalAPIMux.Handle(
QueryNumericLocalpartPath,
httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart),
)
internalAPIMux.Handle(

View File

@ -61,12 +61,12 @@ func (p *SyncAPI) SendAccountData(userID string, data eventutil.AccountData) err
// GetAndSendNotificationData reads the database and sends data about unread
// notifications to the Sync API server.
func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return err
}
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID)
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, domain, roomID)
if err != nil {
return err
}

View File

@ -29,40 +29,40 @@ import (
)
type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
}
type Account interface {
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetNewNumericLocalpart(ctx context.Context) (int64, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error)
GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error
}
type AccountData interface {
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
QueryPushRules(ctx context.Context, localpart string) (*pushrules.AccountRuleSets, error)
GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error)
}
type Device interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
@ -70,12 +70,12 @@ type Device interface {
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error)
}
type KeyBackup interface {
@ -107,26 +107,26 @@ type OpenID interface {
}
type Pusher interface {
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error
GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
RemovePushers(ctx context.Context, appid, pushkey string) error
}
type ThreePID interface {
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
}
type Notification interface {
InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, read bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
DeleteOldNotifications(ctx context.Context) error
}

View File

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
const accountDataSchema = `
@ -29,27 +30,28 @@ const accountDataSchema = `
CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
room_id TEXT,
-- The account data type
type TEXT NOT NULL,
-- The account data content
content TEXT NOT NULL,
PRIMARY KEY(localpart, room_id, type)
content TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
`
const insertAccountDataSQL = `
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = EXCLUDED.content
`
const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
const selectAccountDataByTypeSQL = "" +
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
type accountDataStatements struct {
insertAccountDataStmt *sql.Stmt
@ -71,21 +73,24 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
}
func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string, content json.RawMessage,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
return
}
func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (
/* global */ map[string]json.RawMessage,
/* rooms */ map[string]map[string]json.RawMessage,
error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return nil, nil, err
}
@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
}
func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string,
) (data json.RawMessage, err error) {
var bytes []byte
stmt := s.selectAccountDataByTypeStmt
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}

View File

@ -17,6 +17,7 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
@ -34,7 +35,8 @@ const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- When this account was first created, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account.
@ -48,25 +50,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
-- TODO:
-- upgraded_ts, devices, any email reset stuff?
);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
`
const insertAccountSQL = "" +
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" +
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
const deactivateAccountSQL = "" +
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1"
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'"
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1"
type accountsStatements struct {
insertAccountStmt *sql.Stmt
@ -117,59 +121,62 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error
if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
} else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil {
return nil, err
return nil, fmt.Errorf("insertAccountStmt: %w", err)
}
return &api.Account{
Localpart: localpart,
UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName,
UserID: userutil.MakeUserID(localpart, serverName),
ServerName: serverName,
AppServiceID: appserviceID,
AccountType: accountType,
}, nil
}
func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
return
}
func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
return
}
func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
return
}
func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
var acc api.Account
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
@ -180,19 +187,17 @@ func (s *accountsStatements) SelectAccountByLocalpart(
acc.AppServiceID = appserviceIDPtr.String
}
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
acc.ServerName = s.serverName
acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
return &acc, nil
}
func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx,
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
return id + 1, err
}

View File

@ -0,0 +1,81 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
)
var serverNamesTables = []string{
"userapi_accounts",
"userapi_account_datas",
"userapi_devices",
"userapi_notifications",
"userapi_openid_tokens",
"userapi_profiles",
"userapi_pushers",
"userapi_threepids",
}
// These tables have a PRIMARY KEY constraint which we need to drop so
// that we can recreate a new unique index that contains the server name.
// If the new key doesn't exist (i.e. the database was created before the
// table rename migration) we'll try to drop the old one instead.
var serverNamesDropPK = map[string]string{
"userapi_accounts": "account_accounts",
"userapi_account_datas": "account_data",
"userapi_profiles": "account_profiles",
}
// These indices are out of date so let's drop them. They will get recreated
// automatically.
var serverNamesDropIndex = []string{
"userapi_pusher_localpart_idx",
"userapi_pusher_app_id_pushkey_localpart_idx",
}
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
for _, table := range serverNamesTables {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';",
pq.QuoteIdentifier(table),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("add server name to %q error: %w", table, err)
}
}
for newTable, oldTable := range serverNamesDropPK {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;",
pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(newTable+"_pkey"),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("drop new PK from %q error: %w", newTable, err)
}
q = fmt.Sprintf(
"ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;",
pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(oldTable+"_pkey"),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("drop old PK from %q error: %w", newTable, err)
}
}
for _, index := range serverNamesDropIndex {
q := fmt.Sprintf(
"DROP INDEX IF EXISTS %s;",
pq.QuoteIdentifier(index),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("drop index %q error: %w", index, err)
}
}
return nil
}

View File

@ -0,0 +1,28 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
)
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
for _, table := range serverNamesTables {
q := fmt.Sprintf(
"UPDATE %s SET server_name = %s WHERE server_name = '';",
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("write server names to %q error: %w", table, err)
}
}
return nil
}

View File

@ -17,6 +17,7 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/lib/pq"
@ -50,6 +51,7 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
-- migration to different domain names easier.
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The display name, human friendlier than device_id and updatable
@ -65,39 +67,39 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
);
-- Device IDs must be unique for a given user.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, server_name, device_id);
`
const insertDeviceSQL = "" +
"INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
"INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
" RETURNING session_id"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
const deleteDeviceSQL = "" +
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
const deleteDevicesSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)"
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
type devicesStatements struct {
insertDeviceStmt *sql.Stmt
@ -148,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string,
ctx context.Context, txn *sql.Tx, id string,
localpart string, serverName gomatrixserverlib.ServerName,
accessToken string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
return nil, err
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
}
return &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
SessionID: sessionID,
LastSeenTS: createdTimeMS,
@ -170,38 +173,45 @@ func (s *devicesStatements) InsertDevice(
// deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
ctx context.Context, txn *sql.Tx, id string,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
return err
}
// deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed.
func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
devices []string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
_, err := stmt.ExecContext(ctx, localpart, serverName, pq.Array(devices))
return err
}
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
return err
}
func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
return err
}
@ -210,10 +220,11 @@ func (s *devicesStatements) SelectDeviceByToken(
) (*api.Device, error) {
var dev api.Device
var localpart string
var serverName gomatrixserverlib.ServerName
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
dev.AccessToken = accessToken
}
return &dev, err
@ -222,16 +233,18 @@ func (s *devicesStatements) SelectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string,
) (*api.Device, error) {
var dev api.Device
var displayName, ip sql.NullString
var lastseenTS sql.NullInt64
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
@ -254,10 +267,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
var devices []api.Device
var dev api.Device
var localpart string
var serverName gomatrixserverlib.ServerName
var lastseents sql.NullInt64
var displayName sql.NullString
for rows.Next() {
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
return nil, err
}
if displayName.Valid {
@ -266,17 +280,19 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
if err != nil {
return devices, err
@ -307,16 +323,16 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
dev.UserAgent = useragent.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
return err
}

View File

@ -43,6 +43,7 @@ const notificationSchema = `
CREATE TABLE IF NOT EXISTS userapi_notifications (
id BIGSERIAL PRIMARY KEY,
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
stream_pos BIGINT NOT NULL,
@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
read BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
`
const insertNotificationSQL = "" +
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
"INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const deleteNotificationsUpToSQL = "" +
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
"DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
const updateNotificationReadSQL = "" +
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
const selectNotificationSQL = "" +
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read ORDER BY localpart, id LIMIT $4"
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
"(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
") AND NOT read ORDER BY localpart, id LIMIT $5"
const selectNotificationCountSQL = "" +
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read"
const selectRoomNotificationCountsSQL = "" +
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
"WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
const cleanNotificationsSQL = "" +
"DELETE FROM userapi_notifications WHERE" +
@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
}
// Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS
nn := *n
// Clears out fields that have their own columns to (1) shrink the
@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
if err != nil {
return err
}
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
return err
}
// DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
if err != nil {
return false, err
}
@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
}
// UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
if err != nil {
return false, err
}
@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
return nrows > 0, nil
}
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
if err != nil {
return nil, 0, err
@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err()
}
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
return
}
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
return
}

View File

@ -3,6 +3,7 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_at_ms BIGINT NOT NULL
);
`
const insertOpenIDTokenSQL = "" +
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt
@ -54,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context,
txn *sql.Tx,
token, localpart string,
token, localpart string, serverName gomatrixserverlib.ServerName,
expiresAtMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
_, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
return
}
@ -69,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
token string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
var localpart string
var serverName gomatrixserverlib.ServerName
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID,
&localpart, &serverName,
&openIDTokenAttrs.ExpiresAtMS,
)
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db")

View File

@ -23,42 +23,46 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- The display name for this account
display_name TEXT,
-- The URL of the avatar for this account
avatar_url TEXT
);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
`
const insertProfileSQL = "" +
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
"INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
const setAvatarURLSQL = "" +
"UPDATE userapi_profiles AS new" +
" SET avatar_url = $1" +
" FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" +
" WHERE new.localpart = $2 AND new.server_name = $3" +
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
const setDisplayNameSQL = "" +
"UPDATE userapi_profiles AS new" +
" SET display_name = $1" +
" FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" +
" WHERE new.localpart = $2 AND new.server_name = $3" +
" RETURNING new.avatar_url, old.display_name <> new.display_name"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct {
serverNoticesLocalpart string
@ -87,18 +91,20 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables
}
func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
return
}
func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
)
if err != nil {
return nil, err
@ -107,28 +113,34 @@ func (s *profilesStatements) SelectProfileByLocalpart(
}
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
avatarURL string,
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
ServerName: string(serverName),
AvatarURL: avatarURL,
}
var changed bool
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
err := stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName, &changed)
return profile, changed, err
}
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
displayName string,
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
ServerName: string(serverName),
DisplayName: displayName,
}
var changed bool
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
err := stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL, &changed)
return profile, changed, err
}
@ -146,7 +158,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
for rows.Next() {
var profile authtypes.Profile
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err
}
if profile.Localpart != s.serverNoticesLocalpart {

View File

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
id BIGSERIAL PRIMARY KEY,
-- The Matrix user ID localpart for this pusher
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
session_id BIGINT DEFAULT NULL,
profile_tag TEXT,
kind TEXT NOT NULL,
@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart.
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
-- Pushkey must be unique for a given user and app.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
`
const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
"ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
const deletePusherSQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
@ -95,18 +97,19 @@ type pushersStatements struct {
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
return err
}
func (s *pushersStatements) SelectPushers(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) {
pushers := []api.Pusher{}
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName)
if err != nil {
return pushers, err
@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
ctx context.Context, txn *sql.Tx, appid, pushkey,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
return err
}

View File

@ -15,6 +15,8 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"time"
@ -43,18 +45,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables,
})
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNames(ctx, txn, serverName)
},
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
}
accountsTable, err := NewPostgresAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
}
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
}
devicesTable, err := NewPostgresDevicesTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
@ -95,6 +103,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
}
m = sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names populate",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
},
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,

View File

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
@ -33,21 +34,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
medium TEXT NOT NULL DEFAULT 'email',
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
PRIMARY KEY(threepid, medium)
);
CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart);
CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart, server_name);
`
const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
"SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
const insertThreePIDSQL = "" +
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
"INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
const deleteThreePIDSQL = "" +
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
@ -75,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) {
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
if err == sql.ErrNoRows {
return "", nil
return "", "", nil
}
return
}
func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return
}
@ -109,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
}
func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
ctx context.Context, txn *sql.Tx, threepid, medium,
localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
return
}

View File

@ -68,9 +68,10 @@ const (
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword string,
) (*api.Account, error) {
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName)
if err != nil {
return nil, err
}
@ -80,24 +81,27 @@ func (d *Database) GetAccountByPassword(
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) {
return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
avatarURL string,
) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL)
return err
})
return
@ -106,10 +110,12 @@ func (d *Database) SetAvatarURL(
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
displayName string,
) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName)
return err
})
return
@ -117,14 +123,15 @@ func (d *Database) SetDisplayName(
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword string,
) error {
hash, err := d.hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdatePassword(ctx, localpart, hash)
return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash)
})
}
@ -132,21 +139,22 @@ func (d *Database) SetPassword(
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest {
var numLocalpart int64
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName)
if err != nil {
return err
return fmt.Errorf("d.Accounts.SelectNewNumericLocalpart: %w", err)
}
localpart = strconv.FormatInt(numLocalpart, 10)
plaintextPassword = ""
appserviceID = ""
}
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
acc, err = d.createAccount(ctx, txn, localpart, serverName, plaintextPassword, appserviceID, accountType)
return err
})
return
@ -155,7 +163,9 @@ func (d *Database) CreateAccount(
// WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
var err error
var account *api.Account
@ -167,28 +177,28 @@ func (d *Database) createAccount(
return nil, err
}
}
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
return nil, sqlutil.ErrUserExists
}
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
return nil, err
if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil {
return nil, fmt.Errorf("d.Profiles.InsertProfile: %w", err)
}
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
prbs, err := json.Marshal(pushRuleSets)
if err != nil {
return nil, err
return nil, fmt.Errorf("json.Marshal: %w", err)
}
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
return nil, err
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
return nil, fmt.Errorf("d.AccountDatas.InsertAccountData: %w", err)
}
return account, nil
}
func (d *Database) QueryPushRules(
ctx context.Context,
localpart string,
localpart string, serverName gomatrixserverlib.ServerName,
) (*pushrules.AccountRuleSets, error) {
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, "", "m.push_rules")
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules")
if err != nil {
return nil, err
}
@ -196,13 +206,13 @@ func (d *Database) QueryPushRules(
// If we didn't find any default push rules then we should just generate some
// fresh ones.
if len(data) == 0 {
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
prbs, err := json.Marshal(pushRuleSets)
if err != nil {
return nil, fmt.Errorf("failed to marshal default push rules: %w", err)
}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", prbs); dbErr != nil {
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", prbs); dbErr != nil {
return fmt.Errorf("failed to save default push rules: %w", dbErr)
}
return nil
@ -225,22 +235,23 @@ func (d *Database) QueryPushRules(
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string, content json.RawMessage,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content)
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, roomID, dataType, content)
})
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.AccountDatas.SelectAccountData(ctx, localpart)
return d.AccountDatas.SelectAccountData(ctx, localpart, serverName)
}
// GetAccountDataByType returns account data matching a given
@ -248,18 +259,19 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string,
) (data json.RawMessage, err error) {
return d.AccountDatas.SelectAccountDataByType(
ctx, localpart, roomID, dataType,
ctx, localpart, serverName, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
ctx context.Context, serverName gomatrixserverlib.ServerName,
) (int64, error) {
return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName)
}
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
@ -276,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
ctx context.Context, threepid string,
localpart string, serverName gomatrixserverlib.ServerName,
medium string,
) (err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
user, err := d.ThreePIDs.SelectLocalpartForThreePID(
user, _, err := d.ThreePIDs.SelectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
@ -290,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation(
return Err3PIDInUse
}
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, serverName)
})
}
@ -313,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation(
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
}
@ -322,16 +336,17 @@ func (d *Database) GetLocalpartForThreePID(
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) {
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) {
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
if err == sql.ErrNoRows {
return true, nil
}
@ -341,12 +356,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) {
// try to get the account with lowercase localpart (majority)
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart))
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName)
if err == sql.ErrNoRows {
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request
}
return acc, err
}
@ -359,20 +374,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
}
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.DeactivateAccount(ctx, localpart)
return d.Accounts.DeactivateAccount(ctx, localpart, serverName)
})
}
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
token, localpart string,
token, userID string,
) (int64, error) {
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return 0, nil
}
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS)
})
return expiresAtMS, err
}
@ -539,16 +558,19 @@ func (d *Database) GetDeviceByAccessToken(
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string,
) (*api.Device, error) {
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Device, error) {
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -562,18 +584,18 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
return err
}
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
@ -588,7 +610,7 @@ func (d *Database) CreateDevice(
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
@ -614,10 +636,12 @@ func generateDeviceID() (string, error) {
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string, displayName *string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName)
})
}
@ -626,10 +650,12 @@ func (d *Database) UpdateDevice(
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
devices []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows {
return err
}
return nil
@ -640,14 +666,16 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID)
if err != nil {
return err
}
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
@ -656,9 +684,9 @@ func (d *Database) RemoveAllDevices(
}
// UpdateDeviceLastSeen updates a last seen timestamp and the ip address.
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error {
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent)
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent)
})
}
@ -706,38 +734,38 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
return d.LoginTokens.SelectLoginToken(ctx, token)
}
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
})
}
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) {
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos)
return err
})
return
}
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) {
func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b)
return err
})
return
}
func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter)
}
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
return d.Notifications.SelectCount(ctx, nil, localpart, filter)
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) {
return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter)
}
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) {
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID)
}
func (d *Database) DeleteOldNotifications(ctx context.Context) error {
@ -747,7 +775,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error {
}
func (d *Database) UpsertPusher(
ctx context.Context, p api.Pusher, localpart string,
ctx context.Context, p api.Pusher,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
data, err := json.Marshal(p.Data)
if err != nil {
@ -766,25 +795,26 @@ func (d *Database) UpsertPusher(
p.ProfileTag,
p.Language,
string(data),
localpart)
localpart,
serverName)
})
}
// GetPushers returns the pushers matching the given localpart.
func (d *Database) GetPushers(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) {
return d.Pushers.SelectPushers(ctx, nil, localpart)
return d.Pushers.SelectPushers(ctx, nil, localpart, serverName)
}
// RemovePusher deletes one pusher
// Invoked when `append` is true and `kind` is null in
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
func (d *Database) RemovePusher(
ctx context.Context, appid, pushkey, localpart string,
ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName)
if err == sql.ErrNoRows {
return nil
}

View File

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
const accountDataSchema = `
@ -28,27 +29,28 @@ const accountDataSchema = `
CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room)
room_id TEXT,
-- The account data type
type TEXT NOT NULL,
-- The account data content
content TEXT NOT NULL,
PRIMARY KEY(localpart, room_id, type)
content TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
`
const insertAccountDataSQL = `
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $5
`
const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
const selectAccountDataByTypeSQL = "" +
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
type accountDataStatements struct {
db *sql.DB
@ -73,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
}
func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string, content json.RawMessage,
) error {
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content)
return err
}
func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (
/* global */ map[string]json.RawMessage,
/* rooms */ map[string]map[string]json.RawMessage,
error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return nil, nil, err
}
@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
}
func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string,
) (data json.RawMessage, err error) {
var bytes []byte
stmt := s.selectAccountDataByTypeStmt
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}

View File

@ -34,7 +34,8 @@ const accountsSchema = `
-- Stores data about accounts.
CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- When this account was first created, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account.
@ -48,25 +49,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
-- TODO:
-- upgraded_ts, devices, any email reset stuff?
);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
`
const insertAccountSQL = "" +
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" +
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
const deactivateAccountSQL = "" +
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1"
type accountsStatements struct {
db *sql.DB
@ -119,16 +122,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
var err error
if accountType != api.AccountTypeAppService {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
} else {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil {
return nil, err
@ -136,42 +140,43 @@ func (s *accountsStatements) InsertAccount(
return &api.Account{
Localpart: localpart,
UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName,
UserID: userutil.MakeUserID(localpart, serverName),
ServerName: serverName,
AppServiceID: appserviceID,
AccountType: accountType,
}, nil
}
func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
return
}
func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
return
}
func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
return
}
func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string,
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
var acc api.Account
stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")
@ -182,20 +187,18 @@ func (s *accountsStatements) SelectAccountByLocalpart(
acc.AppServiceID = appserviceIDPtr.String
}
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
acc.ServerName = s.serverName
acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
return &acc, nil
}
func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx,
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt
if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt)
}
err = stmt.QueryRowContext(ctx).Scan(&id)
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
if err == sql.ErrNoRows {
return 1, nil
}

View File

@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
server_name TEXT NOT NULL,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT,

View File

@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
server_name TEXT NOT NULL,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,

View File

@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY,
server_name TEXT NOT NULL,
created_ts BIGINT NOT NULL,
password_hash TEXT,
appservice_id TEXT,

View File

@ -0,0 +1,108 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
var serverNamesTables = []string{
"userapi_accounts",
"userapi_account_datas",
"userapi_devices",
"userapi_notifications",
"userapi_openid_tokens",
"userapi_profiles",
"userapi_pushers",
"userapi_threepids",
}
// These tables have a PRIMARY KEY constraint which we need to drop so
// that we can recreate a new unique index that contains the server name.
var serverNamesDropPK = []string{
"userapi_accounts",
"userapi_account_datas",
"userapi_profiles",
}
// These indices are out of date so let's drop them. They will get recreated
// automatically.
var serverNamesDropIndex = []string{
"userapi_pusher_localpart_idx",
"userapi_pusher_app_id_pushkey_localpart_idx",
}
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
for _, table := range serverNamesTables {
q := fmt.Sprintf(
"SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;",
pq.QuoteIdentifier(table),
)
var c int
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 {
continue
}
q = fmt.Sprintf(
"SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'",
pq.QuoteIdentifier(table),
)
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 {
logrus.Infof("Table %s already has column, skipping", table)
continue
}
if c == 0 {
q = fmt.Sprintf(
"ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';",
pq.QuoteIdentifier(table),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("add server name to %q error: %w", table, err)
}
}
}
for _, table := range serverNamesDropPK {
q := fmt.Sprintf(
"SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;",
pq.QuoteIdentifier(table),
)
var c int
var sql string
if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 {
continue
}
q = fmt.Sprintf(`
%s; -- create temporary table
INSERT INTO %s SELECT * FROM %s; -- copy data
DROP TABLE %s; -- drop original table
ALTER TABLE %s RENAME TO %s; -- rename new table
`,
strings.Replace(sql, table, table+"_tmp", 1), // create temporary table
table+"_tmp", table, // copy data
table, // drop original table
table+"_tmp", table, // rename new table
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("drop PK from %q error: %w", table, err)
}
}
for _, index := range serverNamesDropIndex {
q := fmt.Sprintf(
"DROP INDEX IF EXISTS %s;",
pq.QuoteIdentifier(index),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("drop index %q error: %w", index, err)
}
}
return nil
}

View File

@ -0,0 +1,28 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
)
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
for _, table := range serverNamesTables {
q := fmt.Sprintf(
"UPDATE %s SET server_name = %s WHERE server_name = '';",
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("write server names to %q error: %w", table, err)
}
}
return nil
}

View File

@ -40,49 +40,50 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
server_name TEXT NOT NULL,
created_ts BIGINT,
display_name TEXT,
last_seen_ts BIGINT,
ip TEXT,
user_agent TEXT,
UNIQUE (localpart, device_id)
UNIQUE (localpart, server_name, device_id)
);
`
const insertDeviceSQL = "" +
"INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
"INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM userapi_devices"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
const deleteDeviceSQL = "" +
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
const deleteDevicesSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
type devicesStatements struct {
db *sql.DB
@ -135,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string,
ctx context.Context, txn *sql.Tx, id string,
localpart string, serverName gomatrixserverlib.ServerName,
accessToken string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
@ -146,12 +148,12 @@ func (s *devicesStatements) InsertDevice(
return nil, err
}
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
return &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
SessionID: sessionID,
LastSeenTS: createdTimeMS,
@ -161,44 +163,52 @@ func (s *devicesStatements) InsertDevice(
}
func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
ctx context.Context, txn *sql.Tx, id string,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
return err
}
func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
devices []string,
) error {
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1)
prep, err := s.db.Prepare(orig)
if err != nil {
return err
}
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params := make([]interface{}, len(devices)+2)
params[0] = localpart
params[1] = serverName
for i, v := range devices {
params[i+1] = v
params[i+2] = v
}
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
return err
}
func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
return err
}
@ -207,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken(
) (*api.Device, error) {
var dev api.Device
var localpart string
var serverName gomatrixserverlib.ServerName
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
dev.AccessToken = accessToken
}
return &dev, err
@ -219,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string,
) (*api.Device, error) {
var dev api.Device
var displayName, ip sql.NullString
stmt := s.selectDeviceByIDStmt
var lastseenTS sql.NullInt64
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
@ -243,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID(
}
func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
if err != nil {
return devices, err
@ -276,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
dev.UserAgent = useragent.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
@ -298,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
var devices []api.Device
var dev api.Device
var localpart string
var serverName gomatrixserverlib.ServerName
var displayName sql.NullString
var lastseents sql.NullInt64
for rows.Next() {
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
return nil, err
}
if displayName.Valid {
@ -310,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
return err
}

View File

@ -43,6 +43,7 @@ const notificationSchema = `
CREATE TABLE IF NOT EXISTS userapi_notifications (
id INTEGER PRIMARY KEY AUTOINCREMENT,
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
stream_pos BIGINT NOT NULL,
@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
read BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
`
const insertNotificationSQL = "" +
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
"INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const deleteNotificationsUpToSQL = "" +
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
"DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
const updateNotificationReadSQL = "" +
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
const selectNotificationSQL = "" +
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read ORDER BY localpart, id LIMIT $4"
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
"(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
") AND NOT read ORDER BY localpart, id LIMIT $5"
const selectNotificationCountSQL = "" +
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read"
const selectRoomNotificationCountsSQL = "" +
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
"WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
const cleanNotificationsSQL = "" +
"DELETE FROM userapi_notifications WHERE" +
@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
}
// Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS
nn := *n
// Clears out fields that have their own columns to (1) shrink the
@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
if err != nil {
return err
}
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
return err
}
// DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
if err != nil {
return false, err
}
@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
}
// UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
if err != nil {
return false, err
}
@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
return nrows > 0, nil
}
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
if err != nil {
return nil, 0, err
@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err()
}
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
return
}
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
return
}

View File

@ -3,6 +3,7 @@ package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution).
token_expires_at_ms BIGINT NOT NULL
);
`
const insertOpenIDTokenSQL = "" +
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct {
db *sql.DB
@ -56,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (
func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context,
txn *sql.Tx,
token, localpart string,
token, localpart string, serverName gomatrixserverlib.ServerName,
expiresAtMS int64,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
_, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
return
}
@ -71,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
token string,
) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes
var localpart string
var serverName gomatrixserverlib.ServerName
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID,
&localpart, &serverName,
&openIDTokenAttrs.ExpiresAtMS,
)
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db")

View File

@ -23,36 +23,40 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
const profilesSchema = `
-- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY,
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- The display name for this account
display_name TEXT,
-- The URL of the avatar for this account
avatar_url TEXT
);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
`
const insertProfileSQL = "" +
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
"INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
const setAvatarURLSQL = "" +
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" +
" RETURNING display_name"
const setDisplayNameSQL = "" +
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" +
" RETURNING avatar_url"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct {
db *sql.DB
@ -83,18 +87,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P
}
func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
return err
}
func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) {
var profile authtypes.Profile
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
)
if err != nil {
return nil, err
@ -103,13 +109,16 @@ func (s *profilesStatements) SelectProfileByLocalpart(
}
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
avatarURL string,
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
ServerName: string(serverName),
AvatarURL: avatarURL,
}
old, err := s.SelectProfileByLocalpart(ctx, localpart)
old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
if err != nil {
return old, false, err
}
@ -117,18 +126,21 @@ func (s *profilesStatements) SetAvatarURL(
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName)
return profile, true, err
}
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
displayName string,
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
ServerName: string(serverName),
DisplayName: displayName,
}
old, err := s.SelectProfileByLocalpart(ctx, localpart)
old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
if err != nil {
return old, false, err
}
@ -136,7 +148,7 @@ func (s *profilesStatements) SetDisplayName(
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL)
return profile, true, err
}
@ -154,7 +166,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
for rows.Next() {
var profile authtypes.Profile
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err
}
if profile.Localpart != s.serverNoticesLocalpart {

View File

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The Matrix user ID localpart for this pusher
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
session_id BIGINT DEFAULT NULL,
profile_tag TEXT,
kind TEXT NOT NULL,
@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart.
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
-- Pushkey must be unique for a given user and app.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
`
const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
"ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
const deletePusherSQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
@ -95,18 +97,19 @@ type pushersStatements struct {
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
return err
}
func (s *pushersStatements) SelectPushers(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) {
pushers := []api.Pusher{}
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return pushers, err
@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
ctx context.Context, txn *sql.Tx, appid, pushkey,
localpart string, serverName gomatrixserverlib.ServerName,
) error {
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
return err
}

View File

@ -15,6 +15,8 @@
package sqlite3
import (
"context"
"database/sql"
"fmt"
"time"
@ -41,18 +43,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables,
})
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNames(ctx, txn, serverName)
},
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
}
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
}
accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
}
devicesTable, err := NewSQLiteDevicesTable(db, serverName)
if err != nil {
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
@ -93,6 +101,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
}
m = sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names populate",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
},
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,

View File

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
@ -34,21 +35,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
medium TEXT NOT NULL DEFAULT 'email',
-- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
PRIMARY KEY(threepid, medium)
);
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart);
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name);
`
const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
"SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
const insertThreePIDSQL = "" +
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
"INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
const deleteThreePIDSQL = "" +
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
@ -79,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) {
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
if err == sql.ErrNoRows {
return "", nil
return "", "", nil
}
return
}
func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string,
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
if err != nil {
return
}
@ -113,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
}
func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
ctx context.Context, txn *sql.Tx, threepid, medium,
localpart string, serverName gomatrixserverlib.ServerName,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
return err
}

View File

@ -50,25 +50,25 @@ func Test_AccountData(t *testing.T) {
db, close := mustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
room := test.NewRoom(t, alice)
events := room.Events()
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom)
err = db.SaveAccountData(ctx, localpart, domain, room.ID, "m.fully_read", contentRoom)
assert.NoError(t, err, "unable to save account data")
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal)
err = db.SaveAccountData(ctx, localpart, domain, "", "im.vector.setting.breadcrumbs", contentGlobal)
assert.NoError(t, err, "unable to save account data")
accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read")
accountData, err := db.GetAccountDataByType(ctx, localpart, domain, room.ID, "m.fully_read")
assert.NoError(t, err, "unable to get account data by type")
assert.Equal(t, contentRoom, accountData)
globalData, roomData, err := db.GetAccountData(ctx, localpart)
globalData, roomData, err := db.GetAccountData(ctx, localpart, domain)
assert.NoError(t, err)
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
@ -81,78 +81,78 @@ func Test_Accounts(t *testing.T) {
db, close := mustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account")
// verify the newly create account is the same as returned by CreateAccount
var accGet *api.Account
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing")
assert.NoError(t, err, "failed to get account by password")
assert.Equal(t, accAlice, accGet)
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to get account by localpart")
assert.Equal(t, accAlice, accGet)
// check account availability
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to checkout account availability")
assert.Equal(t, false, available)
available, err = db.CheckAccountAvailability(ctx, "unusedname")
available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain)
assert.NoError(t, err, "failed to checkout account availability")
assert.Equal(t, true, available)
// get guest account numeric aliceLocalpart
first, err := db.GetNewNumericLocalpart(ctx)
first, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
assert.NoError(t, err, "failed to get new numeric localpart")
// Create a new account to verify the numeric localpart is updated
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
_, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
assert.NoError(t, err, "failed to create account")
second, err := db.GetNewNumericLocalpart(ctx)
second, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
assert.NoError(t, err)
assert.Greater(t, second, first)
// update password for alice
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.NoError(t, err, "failed to update password")
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.NoError(t, err, "failed to get account by new password")
assert.Equal(t, accAlice, accGet)
// deactivate account
err = db.DeactivateAccount(ctx, aliceLocalpart)
err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to deactivate account")
// This should fail now, as the account is deactivated
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.Error(t, err, "expected an error, got none")
_, err = db.GetAccountByLocalpart(ctx, "unusename")
_, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain)
assert.Error(t, err, "expected an error for non existent localpart")
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart
// if there's already a user without a localpart in the database
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser)
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeUser)
assert.NoError(t, err)
// test getting a numeric localpart, with an existing user without a localpart
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
assert.NoError(t, err)
// Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type
_, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser)
_, err = db.CreateAccount(ctx, "2147483650", aliceDomain, "", "", api.AccountTypeUser)
assert.NoError(t, err)
// Now try to create a new guest user
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
assert.NoError(t, err)
})
}
func Test_Devices(t *testing.T) {
alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
deviceID := util.RandomString(8)
accessToken := util.RandomString(16)
@ -161,10 +161,10 @@ func Test_Devices(t *testing.T) {
db, close := mustCreateDatabase(t, dbType)
defer close()
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create deviceWithoutID")
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
@ -174,14 +174,14 @@ func Test_Devices(t *testing.T) {
// create a device without existing device ID
accessToken = util.RandomString(16)
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create deviceWithoutID")
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
// Get devices
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain)
assert.NoError(t, err, "unable to get devices by localpart")
assert.Equal(t, 2, len(devices))
deviceIDs := make([]string, 0, len(devices))
@ -195,15 +195,15 @@ func Test_Devices(t *testing.T) {
// Update device
newName := "new display name"
err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName)
err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName)
assert.NoError(t, err, "unable to update device displayname")
updatedAfterTimestamp := time.Now().Unix()
err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web")
err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web")
assert.NoError(t, err, "unable to update device last seen")
deviceWithID.DisplayName = newName
deviceWithID.LastSeenIP = "127.0.0.1"
gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID)
gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 2, len(devices))
assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
@ -213,20 +213,20 @@ func Test_Devices(t *testing.T) {
// create one more device and remove the devices step by step
newDeviceID := util.RandomString(16)
accessToken = util.RandomString(16)
_, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
_, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create new device")
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 3, len(devices))
err = db.RemoveDevices(ctx, localpart, deviceIDs)
err = db.RemoveDevices(ctx, localpart, domain, deviceIDs)
assert.NoError(t, err, "unable to remove devices")
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 1, len(devices))
deleted, err := db.RemoveAllDevices(ctx, localpart, "")
deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "")
assert.NoError(t, err, "unable to remove all devices")
assert.Equal(t, 1, len(deleted))
assert.Equal(t, newDeviceID, deleted[0].ID)
@ -364,7 +364,7 @@ func Test_OpenID(t *testing.T) {
func Test_Profile(t *testing.T) {
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -372,30 +372,33 @@ func Test_Profile(t *testing.T) {
defer close()
// create account, which also creates a profile
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
_, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account")
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get profile by localpart")
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart}
wantProfile := &authtypes.Profile{
Localpart: aliceLocalpart,
ServerName: string(aliceDomain),
}
assert.Equal(t, wantProfile, gotProfile)
// set avatar & displayname
wantProfile.DisplayName = "Alice"
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice")
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, aliceDomain, "Alice")
assert.Equal(t, wantProfile, gotProfile)
assert.NoError(t, err, "unable to set displayname")
assert.True(t, changed)
wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.True(t, changed)
// Setting the same avatar again doesn't change anything
wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.False(t, changed)
@ -410,7 +413,7 @@ func Test_Profile(t *testing.T) {
func Test_Pusher(t *testing.T) {
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -432,11 +435,11 @@ func Test_Pusher(t *testing.T) {
ProfileTag: util.RandomString(8),
Language: util.RandomString(2),
}
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart)
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to upsert pusher")
// check it was actually persisted
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, i+1, len(gotPushers))
assert.Equal(t, wantPusher, gotPushers[i])
@ -444,16 +447,16 @@ func Test_Pusher(t *testing.T) {
}
// remove single pusher
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart)
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to remove pusher")
gotPushers, err := db.GetPushers(ctx, aliceLocalpart)
gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, 1, len(gotPushers))
// remove last pusher
err = db.RemovePushers(ctx, appID, pushKeys[1])
assert.NoError(t, err, "unable to remove pusher")
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, 0, len(gotPushers))
})
@ -461,7 +464,7 @@ func Test_Pusher(t *testing.T) {
func Test_ThreePID(t *testing.T) {
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -469,15 +472,16 @@ func Test_ThreePID(t *testing.T) {
defer close()
threePID := util.RandomString(8)
medium := util.RandomString(8)
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium)
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium)
assert.NoError(t, err, "unable to save threepid association")
// get the stored threepid
gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
assert.NoError(t, err, "unable to get localpart for threepid")
assert.Equal(t, aliceLocalpart, gotLocalpart)
assert.Equal(t, aliceDomain, gotDomain)
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get threepids for localpart")
assert.Equal(t, 1, len(threepids))
assert.Equal(t, authtypes.ThreePID{
@ -490,7 +494,7 @@ func Test_ThreePID(t *testing.T) {
assert.NoError(t, err, "unexpected error")
// verify it was deleted
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get threepids for localpart")
assert.Equal(t, 0, len(threepids))
})
@ -498,7 +502,7 @@ func Test_ThreePID(t *testing.T) {
func Test_Notification(t *testing.T) {
alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
room := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice)
@ -526,34 +530,34 @@ func Test_Notification(t *testing.T) {
RoomID: roomID,
TS: gomatrixserverlib.AsTimestamp(ts),
}
err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification)
err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification)
assert.NoError(t, err, "unable to insert notification")
}
// get notifications
count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications)
count, err := db.GetNotificationCount(ctx, aliceLocalpart, aliceDomain, tables.AllNotifications)
assert.NoError(t, err, "unable to get notification count")
assert.Equal(t, int64(10), count)
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications)
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, aliceDomain, 0, 15, tables.AllNotifications)
assert.NoError(t, err, "unable to get notifications")
assert.Equal(t, int64(10), count)
assert.Equal(t, 10, len(notifs))
// ... for a specific room
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
assert.NoError(t, err, "unable to get notifications for room")
assert.Equal(t, int64(4), total)
// mark notification as read
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true)
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, aliceDomain, room2.ID, 7, true)
assert.NoError(t, err, "unable to set notifications read")
assert.True(t, affected)
// this should delete 2 notifications
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8)
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, aliceDomain, room2.ID, 8)
assert.NoError(t, err, "unable to set notifications read")
assert.True(t, affected)
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
assert.NoError(t, err, "unable to get notifications for room")
assert.Equal(t, int64(2), total)
@ -562,7 +566,7 @@ func Test_Notification(t *testing.T) {
assert.NoError(t, err)
// this should now return 0 notifications
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
assert.NoError(t, err, "unable to get notifications for room")
assert.Equal(t, int64(0), total)
})

View File

@ -28,31 +28,31 @@ import (
)
type AccountDataTable interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error
SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
}
type AccountsTable interface {
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error)
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error)
SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error)
}
type DevicesTable interface {
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error)
SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error)
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
}
type KeyBackupTable interface {
@ -79,40 +79,40 @@ type LoginTokenTable interface {
}
type OpenIDTable interface {
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error)
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error)
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
}
type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error
SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
}
type ThreePIDTable interface {
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error)
SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error)
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
}
type PusherTable interface {
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
}
type NotificationTable interface {
Clean(ctx context.Context, txn *sql.Tx) error
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error)
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error)
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error)
UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error)
Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter NotificationFilter) (int64, error)
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
}
type StatsTable interface {

View File

@ -79,6 +79,7 @@ func mustMakeAccountAndDevice(
accDB tables.AccountsTable,
devDB tables.DevicesTable,
localpart string,
serverName gomatrixserverlib.ServerName, // nolint:unparam
accType api.AccountType,
userAgent string,
) {
@ -89,11 +90,11 @@ func mustMakeAccountAndDevice(
appServiceID = util.RandomString(16)
}
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType)
_, err := accDB.InsertAccount(ctx, nil, localpart, serverName, "", appServiceID, accType)
if err != nil {
t.Fatalf("unable to create account: %v", err)
}
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent)
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent)
if err != nil {
t.Fatalf("unable to create device: %v", err)
}
@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) {
})
t.Run("Want Users", func(t *testing.T) {
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko")
gotStats, _, err := statsDB.UserStatistics(ctx, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)

View File

@ -80,14 +80,14 @@ func TestQueryProfile(t *testing.T) {
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
defer close()
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser)
_, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser)
if err != nil {
t.Fatalf("failed to make account: %s", err)
}
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil {
t.Fatalf("failed to set avatar url: %s", err)
}
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil {
t.Fatalf("failed to set display name: %s", err)
}
@ -164,7 +164,7 @@ func TestPasswordlessLoginFails(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close()
_, err := accountDB.CreateAccount(ctx, "auser", "", "", api.AccountTypeAppService)
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService)
if err != nil {
t.Fatalf("failed to make account: %s", err)
}
@ -190,7 +190,7 @@ func TestLoginToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close()
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser)
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser)
if err != nil {
t.Fatalf("failed to make account: %s", err)
}

View File

@ -2,10 +2,12 @@ package util
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
@ -17,10 +19,10 @@ type PusherDevice struct {
}
// GetPushDevices pushes to the configured devices of a local user.
func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
pushers, err := db.GetPushers(ctx, localpart)
func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
pushers, err := db.GetPushers(ctx, localpart, serverName)
if err != nil {
return nil, err
return nil, fmt.Errorf("db.GetPushers: %w", err)
}
devices := make([]*PusherDevice, 0, len(pushers))

View File

@ -8,6 +8,7 @@ import (
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
@ -16,8 +17,8 @@ import (
// a single goroutine is started when talking to the Push
// gateways. There is no way to know when the background goroutine has
// finished.
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error {
pusherDevices, err := GetPushDevices(ctx, localpart, nil, db)
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
if err != nil {
return err
}
@ -26,7 +27,7 @@ func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, loc
return nil
}
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications)
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, serverName, tables.AllNotifications)
if err != nil {
return err
}