User directory for nearby Pinecone peers (P2P demo) (#2311)

* User directory for nearby Pinecone peers

* Fix mux routing

* Use config to determine which server notices user to exclude
This commit is contained in:
Neil Alexander 2022-03-28 16:25:26 +01:00 committed by GitHub
parent 0692be44d9
commit 7972915806
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 226 additions and 37 deletions

View File

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver"
"github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/eduserver/cache"
@ -326,6 +327,9 @@ func (m *DendriteMonolith) Start() {
// This is different to rsAPI which can be the http client which doesn't need this dependency // This is different to rsAPI which can be the http client which doesn't need this dependency
rsAPI.SetFederationAPI(fsAPI, keyRing) rsAPI.SetFederationAPI(fsAPI, keyRing)
userProvider := users.NewPineconeUserProvider(m.PineconeRouter, m.PineconeQUIC, m.userAPI, federation)
roomProvider := rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, fsAPI, federation)
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Cfg, Config: base.Cfg,
AccountDB: accountDB, AccountDB: accountDB,
@ -339,7 +343,8 @@ func (m *DendriteMonolith) Start() {
RoomserverAPI: rsAPI, RoomserverAPI: rsAPI,
UserAPI: m.userAPI, UserAPI: m.userAPI,
KeyAPI: keyAPI, KeyAPI: keyAPI,
ExtPublicRoomsProvider: rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, fsAPI, federation), ExtPublicRoomsProvider: roomProvider,
ExtUserDirectoryProvider: userProvider,
} }
monolith.AddAllPublicRoutes( monolith.AddAllPublicRoutes(
base.ProcessContext, base.ProcessContext,
@ -357,10 +362,12 @@ func (m *DendriteMonolith) Start() {
httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux)
pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() pMux := mux.NewRouter().SkipClean(true).UseEncodedPath()
pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles)
pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux)
pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux)
pHTTP := m.PineconeQUIC.HTTP() pHTTP := m.PineconeQUIC.HTTP()
pHTTP.Mux().Handle(users.PublicURL, pMux)
pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, pMux) pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, pMux)
pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, pMux) pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, pMux)

View File

@ -17,6 +17,7 @@ package authtypes
// Profile represents the profile for a Matrix account. // Profile represents the profile for a Matrix account.
type Profile struct { type Profile struct {
Localpart string `json:"local_part"` Localpart string `json:"local_part"`
ServerName string `json:"server_name,omitempty"` // NOTSPEC: only set by Pinecone user provider
DisplayName string `json:"display_name"` DisplayName string `json:"display_name"`
AvatarURL string `json:"avatar_url"` AvatarURL string `json:"avatar_url"`
} }

View File

@ -45,6 +45,7 @@ func AddPublicRoutes(
transactionsCache *transactions.Cache, transactionsCache *transactions.Cache,
fsAPI federationAPI.FederationInternalAPI, fsAPI federationAPI.FederationInternalAPI,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
userDirectoryProvider userapi.UserDirectoryProvider,
keyAPI keyserverAPI.KeyInternalAPI, keyAPI keyserverAPI.KeyInternalAPI,
extRoomsProvider api.ExtraPublicRoomsProvider, extRoomsProvider api.ExtraPublicRoomsProvider,
mscCfg *config.MSCs, mscCfg *config.MSCs,
@ -58,7 +59,7 @@ func AddPublicRoutes(
routing.Setup( routing.Setup(
router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI, router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI,
userAPI, federation, userAPI, userDirectoryProvider, federation,
syncProducer, transactionsCache, fsAPI, keyAPI, syncProducer, transactionsCache, fsAPI, keyAPI,
extRoomsProvider, mscCfg, extRoomsProvider, mscCfg,
) )

View File

@ -51,6 +51,7 @@ func Setup(
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
asAPI appserviceAPI.AppServiceQueryAPI, asAPI appserviceAPI.AppServiceQueryAPI,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
userDirectoryProvider userapi.UserDirectoryProvider,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
syncProducer *producers.SyncAPIProducer, syncProducer *producers.SyncAPIProducer,
transactionsCache *transactions.Cache, transactionsCache *transactions.Cache,
@ -904,6 +905,7 @@ func Setup(
device, device,
userAPI, userAPI,
rsAPI, rsAPI,
userDirectoryProvider,
cfg.Matrix.ServerName, cfg.Matrix.ServerName,
postContent.SearchString, postContent.SearchString,
postContent.Limit, postContent.Limit,

View File

@ -16,6 +16,7 @@ package routing
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -35,6 +36,7 @@ func SearchUserDirectory(
device *userapi.Device, device *userapi.Device,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
provider userapi.UserDirectoryProvider,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
searchString string, searchString string,
limit int, limit int,
@ -50,13 +52,12 @@ func SearchUserDirectory(
} }
// First start searching local users. // First start searching local users.
userReq := &userapi.QuerySearchProfilesRequest{ userReq := &userapi.QuerySearchProfilesRequest{
SearchString: searchString, SearchString: searchString,
Limit: limit, Limit: limit,
} }
userRes := &userapi.QuerySearchProfilesResponse{} userRes := &userapi.QuerySearchProfilesResponse{}
if err := userAPI.QuerySearchProfiles(ctx, userReq, userRes); err != nil { if err := provider.QuerySearchProfiles(ctx, userReq, userRes); err != nil {
errRes := util.ErrorResponse(fmt.Errorf("userAPI.QuerySearchProfiles: %w", err)) errRes := util.ErrorResponse(fmt.Errorf("userAPI.QuerySearchProfiles: %w", err))
return &errRes return &errRes
} }
@ -67,7 +68,12 @@ func SearchUserDirectory(
break break
} }
userID := fmt.Sprintf("@%s:%s", user.Localpart, serverName) var userID string
if user.ServerName != "" {
userID = fmt.Sprintf("@%s:%s", user.Localpart, user.ServerName)
} else {
userID = fmt.Sprintf("@%s:%s", user.Localpart, serverName)
}
if _, ok := results[userID]; !ok { if _, ok := results[userID]; !ok {
results[userID] = authtypes.FullyQualifiedProfile{ results[userID] = authtypes.FullyQualifiedProfile{
UserID: userID, UserID: userID,
@ -87,7 +93,7 @@ func SearchUserDirectory(
Limit: limit - len(results), Limit: limit - len(results),
} }
stateRes := &api.QueryKnownUsersResponse{} stateRes := &api.QueryKnownUsersResponse{}
if err := rsAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil { if err := rsAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil && err != sql.ErrNoRows {
errRes := util.ErrorResponse(fmt.Errorf("rsAPI.QueryKnownUsers: %w", err)) errRes := util.ErrorResponse(fmt.Errorf("rsAPI.QueryKnownUsers: %w", err))
return &errRes return &errRes
} }

View File

@ -35,6 +35,7 @@ import (
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/embed" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/embed"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
"github.com/matrix-org/dendrite/eduserver" "github.com/matrix-org/dendrite/eduserver"
"github.com/matrix-org/dendrite/eduserver/cache" "github.com/matrix-org/dendrite/eduserver/cache"
@ -198,6 +199,9 @@ func main() {
rsComponent.SetFederationAPI(fsAPI, keyRing) rsComponent.SetFederationAPI(fsAPI, keyRing)
userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation)
roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation)
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Cfg, Config: base.Cfg,
AccountDB: accountDB, AccountDB: accountDB,
@ -211,7 +215,8 @@ func main() {
RoomserverAPI: rsAPI, RoomserverAPI: rsAPI,
UserAPI: userAPI, UserAPI: userAPI,
KeyAPI: keyAPI, KeyAPI: keyAPI,
ExtPublicRoomsProvider: rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation), ExtPublicRoomsProvider: roomProvider,
ExtUserDirectoryProvider: userProvider,
} }
monolith.AddAllPublicRoutes( monolith.AddAllPublicRoutes(
base.ProcessContext, base.ProcessContext,
@ -250,10 +255,12 @@ func main() {
embed.Embed(httpRouter, *instancePort, "Pinecone Demo") embed.Embed(httpRouter, *instancePort, "Pinecone Demo")
pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() pMux := mux.NewRouter().SkipClean(true).UseEncodedPath()
pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles)
pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux)
pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux)
pHTTP := pQUIC.HTTP() pHTTP := pQUIC.HTTP()
pHTTP.Mux().Handle(users.PublicURL, pMux)
pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, pMux) pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, pMux)
pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, pMux) pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, pMux)

View File

@ -0,0 +1,145 @@
package users
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
clienthttputil "github.com/matrix-org/dendrite/clientapi/httputil"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeSessions "github.com/matrix-org/pinecone/sessions"
)
type PineconeUserProvider struct {
r *pineconeRouter.Router
s *pineconeSessions.Sessions
userAPI userapi.UserProfileAPI
fedClient *gomatrixserverlib.FederationClient
}
const PublicURL = "/_matrix/p2p/profiles"
func NewPineconeUserProvider(
r *pineconeRouter.Router,
s *pineconeSessions.Sessions,
userAPI userapi.UserProfileAPI,
fedClient *gomatrixserverlib.FederationClient,
) *PineconeUserProvider {
p := &PineconeUserProvider{
r: r,
s: s,
userAPI: userAPI,
fedClient: fedClient,
}
return p
}
func (p *PineconeUserProvider) FederatedUserProfiles(w http.ResponseWriter, r *http.Request) {
req := &userapi.QuerySearchProfilesRequest{Limit: 25}
res := &userapi.QuerySearchProfilesResponse{}
if err := clienthttputil.UnmarshalJSONRequest(r, &req); err != nil {
w.WriteHeader(400)
return
}
if err := p.userAPI.QuerySearchProfiles(r.Context(), req, res); err != nil {
w.WriteHeader(400)
return
}
j, err := json.Marshal(res)
if err != nil {
w.WriteHeader(400)
return
}
w.WriteHeader(200)
_, _ = w.Write(j)
}
func (p *PineconeUserProvider) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error {
list := map[string]struct{}{}
for _, k := range p.r.Peers() {
list[k.PublicKey] = struct{}{}
}
res.Profiles = bulkFetchUserDirectoriesFromServers(context.Background(), req, p.fedClient, list)
return nil
}
// bulkFetchUserDirectoriesFromServers fetches users from the list of homeservers.
// Returns a list of user profiles.
func bulkFetchUserDirectoriesFromServers(
ctx context.Context, req *userapi.QuerySearchProfilesRequest,
fedClient *gomatrixserverlib.FederationClient,
homeservers map[string]struct{},
) (profiles []authtypes.Profile) {
jsonBody, err := json.Marshal(req)
if err != nil {
return nil
}
limit := 200
// follow pipeline semantics, see https://blog.golang.org/pipelines for more info.
// goroutines send rooms to this channel
profileCh := make(chan authtypes.Profile, int(limit))
// signalling channel to tell goroutines to stop sending rooms and quit
done := make(chan bool)
// signalling to say when we can close the room channel
var wg sync.WaitGroup
wg.Add(len(homeservers))
// concurrently query for public rooms
reqctx, reqcancel := context.WithTimeout(ctx, time.Second*5)
for hs := range homeservers {
go func(homeserverDomain string) {
defer wg.Done()
util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for users")
jsonBodyReader := bytes.NewBuffer(jsonBody)
httpReq, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("matrix://%s%s", homeserverDomain, PublicURL), jsonBodyReader)
if err != nil {
util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn(
"bulkFetchUserDirectoriesFromServers: failed to create request",
)
}
res := &userapi.QuerySearchProfilesResponse{}
if err = fedClient.DoRequestAndParseResponse(reqctx, httpReq, res); err != nil {
util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn(
"bulkFetchUserDirectoriesFromServers: failed to query hs",
)
return
}
for _, profile := range res.Profiles {
profile.ServerName = homeserverDomain
// atomically send a room or stop
select {
case profileCh <- profile:
case <-done:
case <-reqctx.Done():
util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Info("Interrupted whilst sending profiles")
return
}
}
}(hs)
}
select {
case <-time.After(5 * time.Second):
default:
wg.Wait()
}
reqcancel()
close(done)
close(profileCh)
for profile := range profileCh {
profiles = append(profiles, profile)
}
return profiles
}

View File

@ -33,7 +33,7 @@ func ClientAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
clientapi.AddPublicRoutes( clientapi.AddPublicRoutes(
base.ProcessContext, base.PublicClientAPIMux, base.SynapseAdminMux, &base.Cfg.ClientAPI, base.ProcessContext, base.PublicClientAPIMux, base.SynapseAdminMux, &base.Cfg.ClientAPI,
federation, rsAPI, eduInputAPI, asQuery, transactions.New(), fsAPI, userAPI, federation, rsAPI, eduInputAPI, asQuery, transactions.New(), fsAPI, userAPI, userAPI,
keyAPI, nil, &cfg.MSCs, keyAPI, nil, &cfg.MSCs,
) )

View File

@ -289,6 +289,7 @@ func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.BCryptCost,
b.Cfg.UserAPI.OpenIDTokenLifetimeMS, b.Cfg.UserAPI.OpenIDTokenLifetimeMS,
userapi.DefaultLoginTokenLifetime, userapi.DefaultLoginTokenLifetime,
b.Cfg.Global.ServerNotices.LocalPart,
) )
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db") logrus.WithError(err).Panicf("failed to connect to accounts db")

View File

@ -52,15 +52,20 @@ type Monolith struct {
// Optional // Optional
ExtPublicRoomsProvider api.ExtraPublicRoomsProvider ExtPublicRoomsProvider api.ExtraPublicRoomsProvider
ExtUserDirectoryProvider userapi.UserDirectoryProvider
} }
// AddAllPublicRoutes attaches all public paths to the given router // AddAllPublicRoutes attaches all public paths to the given router
func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux *mux.Router) { func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ssMux, keyMux, wkMux, mediaMux, synapseMux *mux.Router) {
userDirectoryProvider := m.ExtUserDirectoryProvider
if userDirectoryProvider == nil {
userDirectoryProvider = m.UserAPI
}
clientapi.AddPublicRoutes( clientapi.AddPublicRoutes(
process, csMux, synapseMux, &m.Config.ClientAPI, process, csMux, synapseMux, &m.Config.ClientAPI,
m.FedClient, m.RoomserverAPI, m.FedClient, m.RoomserverAPI,
m.EDUInternalAPI, m.AppserviceAPI, transactions.New(), m.EDUInternalAPI, m.AppserviceAPI, transactions.New(),
m.FederationAPI, m.UserAPI, m.KeyAPI, m.FederationAPI, m.UserAPI, userDirectoryProvider, m.KeyAPI,
m.ExtPublicRoomsProvider, &m.Config.MSCs, m.ExtPublicRoomsProvider, &m.Config.MSCs,
) )
federationapi.AddPublicRoutes( federationapi.AddPublicRoutes(

View File

@ -54,6 +54,10 @@ type UserInternalAPI interface {
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
} }
type UserDirectoryProvider interface {
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
}
// UserProfileAPI provides functions for getting user profiles // UserProfileAPI provides functions for getting user profiles
type UserProfileAPI interface { type UserProfileAPI interface {
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error

View File

@ -53,6 +53,7 @@ const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" "SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct { type profilesStatements struct {
serverNoticesLocalpart string
insertProfileStmt *sql.Stmt insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt setAvatarURLStmt *sql.Stmt
@ -60,8 +61,10 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt selectProfilesBySearchStmt *sql.Stmt
} }
func NewPostgresProfilesTable(db *sql.DB) (tables.ProfileTable, error) { func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.ProfileTable, error) {
s := &profilesStatements{} s := &profilesStatements{
serverNoticesLocalpart: serverNoticesLocalpart,
}
_, err := db.Exec(profilesSchema) _, err := db.Exec(profilesSchema)
if err != nil { if err != nil {
return nil, err return nil, err
@ -126,7 +129,9 @@ func (s *profilesStatements) SelectProfilesBySearch(
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err return nil, err
} }
if profile.Localpart != s.serverNoticesLocalpart {
profiles = append(profiles, profile) profiles = append(profiles, profile)
} }
}
return profiles, nil return profiles, nil
} }

View File

@ -30,7 +30,7 @@ import (
) )
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties) db, err := sqlutil.Open(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err
@ -77,7 +77,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil { if err != nil {
return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err) return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err)
} }
profilesTable, err := NewPostgresProfilesTable(db) profilesTable, err := NewPostgresProfilesTable(db, serverNoticesLocalpart)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err) return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err)
} }

View File

@ -54,6 +54,7 @@ const selectProfilesBySearchSQL = "" +
type profilesStatements struct { type profilesStatements struct {
db *sql.DB db *sql.DB
serverNoticesLocalpart string
insertProfileStmt *sql.Stmt insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt setAvatarURLStmt *sql.Stmt
@ -61,9 +62,10 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt selectProfilesBySearchStmt *sql.Stmt
} }
func NewSQLiteProfilesTable(db *sql.DB) (tables.ProfileTable, error) { func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.ProfileTable, error) {
s := &profilesStatements{ s := &profilesStatements{
db: db, db: db,
serverNoticesLocalpart: serverNoticesLocalpart,
} }
_, err := db.Exec(profilesSchema) _, err := db.Exec(profilesSchema)
if err != nil { if err != nil {
@ -131,7 +133,9 @@ func (s *profilesStatements) SelectProfilesBySearch(
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err return nil, err
} }
if profile.Localpart != s.serverNoticesLocalpart {
profiles = append(profiles, profile) profiles = append(profiles, profile)
} }
}
return profiles, nil return profiles, nil
} }

View File

@ -31,7 +31,7 @@ import (
) )
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties) db, err := sqlutil.Open(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err
@ -78,7 +78,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil { if err != nil {
return nil, fmt.Errorf("NewSQLiteOpenIDTable: %w", err) return nil, fmt.Errorf("NewSQLiteOpenIDTable: %w", err)
} }
profilesTable, err := NewSQLiteProfilesTable(db) profilesTable, err := NewSQLiteProfilesTable(db, serverNoticesLocalpart)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewSQLiteProfilesTable: %w", err) return nil, fmt.Errorf("NewSQLiteProfilesTable: %w", err)
} }

View File

@ -30,12 +30,12 @@ import (
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters // and sets postgres connection parameters
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
default: default:
return nil, fmt.Errorf("unexpected database type") return nil, fmt.Errorf("unexpected database type")
} }

View File

@ -29,10 +29,11 @@ func NewDatabase(
bcryptCost int, bcryptCost int,
openIDTokenLifetimeMS int64, openIDTokenLifetimeMS int64,
loginTokenLifetime time.Duration, loginTokenLifetime time.Duration,
serverNoticesLocalpart string,
) (Database, error) { ) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime) return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
case dbProperties.ConnectionString.IsPostgres(): case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation") return nil, fmt.Errorf("can't use Postgres implementation")
default: default:

View File

@ -52,7 +52,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s
MaxOpenConnections: 1, MaxOpenConnections: 1,
MaxIdleConnections: 1, MaxIdleConnections: 1,
} }
accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime) accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
if err != nil { if err != nil {
t.Fatalf("failed to create account DB: %s", err) t.Fatalf("failed to create account DB: %s", err)
} }