User directory (#1225)

* User directory

* Fix syncapi unit test

* Make user directory only show remote users you know about from your joined rooms

* Update sytest-whitelist

* Review comments
This commit is contained in:
Neil Alexander 2020-07-28 10:53:17 +01:00 committed by GitHub
parent c632867135
commit acc8e80a51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 402 additions and 8 deletions

View File

@ -16,7 +16,14 @@ package authtypes
// Profile represents the profile for a Matrix account.
type Profile struct {
Localpart string
DisplayName string
AvatarURL string
Localpart string `json:"local_part"`
DisplayName string `json:"display_name"`
AvatarURL string `json:"avatar_url"`
}
// FullyQualifiedProfile represents the profile for a Matrix account.
type FullyQualifiedProfile struct {
UserID string `json:"user_id"`
DisplayName string `json:"display_name"`
AvatarURL string `json:"avatar_url"`
}

View File

@ -574,6 +574,27 @@ func Setup(
}),
).Methods(http.MethodGet)
r0mux.Handle("/user_directory/search",
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
postContent := struct {
SearchString string `json:"search_term"`
Limit int `json:"limit"`
}{}
if err := json.NewDecoder(req.Body).Decode(&postContent); err != nil {
return util.ErrorResponse(err)
}
return *SearchUserDirectory(
req.Context(),
device,
userAPI,
stateAPI,
cfg.Matrix.ServerName,
postContent.SearchString,
postContent.Limit,
)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/members",
httputil.MakeAuthAPI("rooms_members", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))

View File

@ -0,0 +1,115 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
type UserDirectoryResponse struct {
Results []authtypes.FullyQualifiedProfile `json:"results"`
Limited bool `json:"limited"`
}
func SearchUserDirectory(
ctx context.Context,
device *userapi.Device,
userAPI userapi.UserInternalAPI,
stateAPI currentstateAPI.CurrentStateInternalAPI,
serverName gomatrixserverlib.ServerName,
searchString string,
limit int,
) *util.JSONResponse {
if limit < 10 {
limit = 10
}
results := map[string]authtypes.FullyQualifiedProfile{}
response := &UserDirectoryResponse{
Results: []authtypes.FullyQualifiedProfile{},
Limited: false,
}
// First start searching local users.
userReq := &userapi.QuerySearchProfilesRequest{
SearchString: searchString,
Limit: limit,
}
userRes := &userapi.QuerySearchProfilesResponse{}
if err := userAPI.QuerySearchProfiles(ctx, userReq, userRes); err != nil {
errRes := util.ErrorResponse(fmt.Errorf("userAPI.QuerySearchProfiles: %w", err))
return &errRes
}
for _, user := range userRes.Profiles {
if len(results) == limit {
response.Limited = true
break
}
userID := fmt.Sprintf("@%s:%s", user.Localpart, serverName)
if _, ok := results[userID]; !ok {
results[userID] = authtypes.FullyQualifiedProfile{
UserID: userID,
DisplayName: user.DisplayName,
AvatarURL: user.AvatarURL,
}
}
}
// Then, if we have enough room left in the response,
// start searching for known users from joined rooms.
if len(results) <= limit {
stateReq := &currentstateAPI.QueryKnownUsersRequest{
UserID: device.UserID,
SearchString: searchString,
Limit: limit - len(results),
}
stateRes := &currentstateAPI.QueryKnownUsersResponse{}
if err := stateAPI.QueryKnownUsers(ctx, stateReq, stateRes); err != nil {
errRes := util.ErrorResponse(fmt.Errorf("stateAPI.QueryKnownUsers: %w", err))
return &errRes
}
for _, user := range stateRes.Users {
if len(results) == limit {
response.Limited = true
break
}
if _, ok := results[user.UserID]; !ok {
results[user.UserID] = user
}
}
}
for _, result := range results {
response.Results = append(response.Results, result)
}
return &util.JSONResponse{
Code: 200,
JSON: response,
}
}

View File

@ -20,6 +20,7 @@ import (
"fmt"
"strings"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/gomatrixserverlib"
)
@ -33,6 +34,8 @@ type CurrentStateInternalAPI interface {
QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
// QueryKnownUsers returns a list of users that we know about from our joined rooms.
QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error
}
type QuerySharedUsersRequest struct {
@ -88,6 +91,16 @@ type QueryCurrentStateResponse struct {
StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent
}
type QueryKnownUsersRequest struct {
UserID string `json:"user_id"`
SearchString string `json:"search_string"`
Limit int `json:"limit"`
}
type QueryKnownUsersResponse struct {
Users []authtypes.FullyQualifiedProfile `json:"profiles"`
}
// MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode.
func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) {
se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents))

View File

@ -17,6 +17,7 @@ package internal
import (
"context"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/currentstateserver/storage"
"github.com/matrix-org/gomatrixserverlib"
@ -49,6 +50,19 @@ func (a *CurrentStateInternalAPI) QueryRoomsForUser(ctx context.Context, req *ap
return nil
}
func (a *CurrentStateInternalAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
users, err := a.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit)
if err != nil {
return err
}
for _, user := range users {
res.Users = append(res.Users, authtypes.FullyQualifiedProfile{
UserID: user,
})
}
return nil
}
func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error {
events, err := a.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards)
if err != nil {

View File

@ -30,6 +30,7 @@ const (
QueryRoomsForUserPath = "/currentstateserver/queryRoomsForUser"
QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent"
QuerySharedUsersPath = "/currentstateserver/querySharedUsers"
QueryKnownUsersPath = "/currentstateserver/queryKnownUsers"
)
// NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API.
@ -97,3 +98,13 @@ func (h *httpCurrentStateInternalAPI) QuerySharedUsers(
apiURL := h.apiURL + QuerySharedUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpCurrentStateInternalAPI) QueryKnownUsers(
ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers")
defer span.Finish()
apiURL := h.apiURL + QueryKnownUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View File

@ -77,4 +77,17 @@ func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QuerySharedUsersPath,
httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse {
request := api.QueryKnownUsersRequest{}
response := api.QueryKnownUsersResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.QueryKnownUsers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

View File

@ -39,4 +39,6 @@ type Database interface {
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
// GetKnownUsers searches all users that userID knows about.
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
}

View File

@ -18,6 +18,7 @@ import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/currentstateserver/storage/tables"
@ -81,6 +82,14 @@ const selectJoinedUsersSetForRoomsSQL = "" +
"SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id = ANY($1) AND" +
" type = 'm.room.member' and content_value = 'join' GROUP BY state_key"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will
// only return users that the user would ordinarily be able to see anyway.
const selectKnownUsersSQL = "" +
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY(" +
" SELECT DISTINCT room_id FROM currentstate_current_room_state WHERE state_key=$1 AND TYPE='m.room.member' AND content_value='join'" +
") AND TYPE='m.room.member' AND content_value='join' AND state_key LIKE $2 LIMIT $3"
type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
@ -90,6 +99,7 @@ type currentRoomStateStatements struct {
selectBulkStateContentStmt *sql.Stmt
selectBulkStateContentWildStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt
}
func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -122,6 +132,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
return nil, err
}
if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil {
return nil, err
}
return s, nil
}
@ -295,3 +308,20 @@ func (s *currentRoomStateStatements) SelectBulkStateContent(
}
return strippedEvents, rows.Err()
}
func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
result := []string{}
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID)
}
return result, rows.Err()
}

View File

@ -89,3 +89,7 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs)
}
func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) {
return d.CurrentRoomState.SelectKnownUsers(ctx, userID, searchString, limit)
}

View File

@ -18,6 +18,7 @@ import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/matrix-org/dendrite/currentstateserver/storage/tables"
@ -69,6 +70,14 @@ const selectBulkStateContentWildSQL = "" +
const selectJoinedUsersSetForRoomsSQL = "" +
"SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will
// only return users that the user would ordinarily be able to see anyway.
const selectKnownUsersSQL = "" +
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN (" +
" SELECT DISTINCT room_id FROM currentstate_current_room_state WHERE state_key=$1 AND TYPE='m.room.member' AND content_value='join'" +
") AND TYPE='m.room.member' AND content_value='join' AND state_key LIKE $2 LIMIT $3"
type currentRoomStateStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
@ -77,6 +86,7 @@ type currentRoomStateStatements struct {
selectRoomIDsWithMembershipStmt *sql.Stmt
selectStateEventStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt
}
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
@ -103,6 +113,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error)
if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
return nil, err
}
if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil {
return nil, err
}
return s, nil
}
@ -315,3 +328,20 @@ func (s *currentRoomStateStatements) SelectBulkStateContent(
}
return strippedEvents, rows.Err()
}
func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
result := []string{}
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID)
}
return result, rows.Err()
}

View File

@ -39,6 +39,8 @@ type CurrentRoomState interface {
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
// counts of how many rooms they are joined.
SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
// SelectKnownUsers searches all users that userID knows about.
SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
}
// StrippedEvent represents a stripped event for returning extracted content values.

View File

@ -23,6 +23,10 @@ func (s *mockCurrentStateAPI) QueryCurrentState(ctx context.Context, req *api.Qu
return nil
}
func (s *mockCurrentStateAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
return nil
}
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
func (s *mockCurrentStateAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
return nil

View File

@ -415,4 +415,8 @@ We don't send redundant membership state across incremental syncs by default
Typing notifications don't leak
Users cannot kick users from a room they are not in
Users cannot kick users who have already left a room
User appears in user directory
User directory correctly update on display name change
User in shared private room does appear in user directory
User in dir while user still shares private rooms
Can get 'm.room.name' state for a departed room (SPEC-216)

View File

@ -18,6 +18,7 @@ import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/gomatrixserverlib"
)
@ -31,6 +32,7 @@ type UserInternalAPI interface {
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
}
// InputAccountDataRequest is the request for InputAccountData
@ -112,6 +114,20 @@ type QueryProfileResponse struct {
AvatarURL string
}
// QuerySearchProfilesRequest is the request for QueryProfile
type QuerySearchProfilesRequest struct {
// The search string to match
SearchString string
// How many results to return
Limit int
}
// QuerySearchProfilesResponse is the response for QuerySearchProfilesRequest
type QuerySearchProfilesResponse struct {
// Profiles matching the search
Profiles []authtypes.Profile
}
// PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformAccountCreationRequest struct {
AccountType AccountType // Required: whether this is a guest or user account

View File

@ -125,6 +125,15 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
return nil
}
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit)
if err != nil {
return err
}
res.Profiles = profiles
return nil
}
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil {

View File

@ -31,11 +31,12 @@ const (
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation"
QueryProfilePath = "/userapi/queryProfile"
QueryAccessTokenPath = "/userapi/queryAccessToken"
QueryDevicesPath = "/userapi/queryDevices"
QueryAccountDataPath = "/userapi/queryAccountData"
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
QueryProfilePath = "/userapi/queryProfile"
QueryAccessTokenPath = "/userapi/queryAccessToken"
QueryDevicesPath = "/userapi/queryDevices"
QueryAccountDataPath = "/userapi/queryAccountData"
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
)
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -141,3 +142,11 @@ func (h *httpUserInternalAPI) QueryAccountData(ctx context.Context, req *api.Que
apiURL := h.apiURL + QueryAccountDataPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySearchProfiles")
defer span.Finish()
apiURL := h.apiURL + QuerySearchProfilesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View File

@ -117,4 +117,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryDeviceInfosPath,
httputil.MakeInternalAPI("querySearchProfiles", func(req *http.Request) util.JSONResponse {
request := api.QuerySearchProfilesRequest{}
response := api.QuerySearchProfilesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QuerySearchProfiles(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

View File

@ -49,6 +49,7 @@ type Database interface {
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
}
// Err3PIDInUse is the error returned when trying to save an association involving

View File

@ -17,8 +17,10 @@ package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal"
)
const profilesSchema = `
@ -45,11 +47,15 @@ const setAvatarURLSQL = "" +
const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct {
insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
setDisplayNameStmt *sql.Stmt
selectProfilesBySearchStmt *sql.Stmt
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
@ -69,6 +75,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
return
}
if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil {
return
}
return
}
@ -105,3 +114,25 @@ func (s *profilesStatements) setDisplayName(
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return
}
func (s *profilesStatements) selectProfilesBySearch(
ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile
// The fmt.Sprintf directive below is building a parameter for the
// "LIKE" condition in the SQL query. %% escapes the % char, so the
// statement in the end will look like "LIKE %searchString%".
rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
for rows.Next() {
var profile authtypes.Profile
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err
}
profiles = append(profiles, profile)
}
return profiles, nil
}

View File

@ -298,3 +298,10 @@ func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
}

View File

@ -17,8 +17,10 @@ package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
@ -46,6 +48,9 @@ const setAvatarURLSQL = "" +
const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM account_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
@ -53,6 +58,7 @@ type profilesStatements struct {
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
setDisplayNameStmt *sql.Stmt
selectProfilesBySearchStmt *sql.Stmt
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
@ -74,6 +80,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
if s.setDisplayNameStmt, err = db.Prepare(setDisplayNameSQL); err != nil {
return
}
if s.selectProfilesBySearchStmt, err = db.Prepare(selectProfilesBySearchSQL); err != nil {
return
}
return
}
@ -112,3 +121,25 @@ func (s *profilesStatements) setDisplayName(
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return
}
func (s *profilesStatements) selectProfilesBySearch(
ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile
// The fmt.Sprintf directive below is building a parameter for the
// "LIKE" condition in the SQL query. %% escapes the % char, so the
// statement in the end will look like "LIKE %searchString%".
rows, err := s.selectProfilesBySearchStmt.QueryContext(ctx, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
for rows.Next() {
var profile authtypes.Profile
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err
}
profiles = append(profiles, profile)
}
return profiles, nil
}

View File

@ -343,3 +343,10 @@ func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
}