diff --git a/clientapi/auth/storage/accounts/interface.go b/clientapi/auth/storage/accounts/interface.go index 9f6e3e1e..a5052b04 100644 --- a/clientapi/auth/storage/accounts/interface.go +++ b/clientapi/auth/storage/accounts/interface.go @@ -33,6 +33,7 @@ type Database interface { CreateGuestAccount(ctx context.Context) (*authtypes.Account, error) UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) + GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error) diff --git a/clientapi/auth/storage/accounts/postgres/membership_table.go b/clientapi/auth/storage/accounts/postgres/membership_table.go index 27570b67..04e9095e 100644 --- a/clientapi/auth/storage/accounts/postgres/membership_table.go +++ b/clientapi/auth/storage/accounts/postgres/membership_table.go @@ -53,6 +53,9 @@ const selectMembershipsByLocalpartSQL = "" + const selectMembershipInRoomByLocalpartSQL = "" + "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" +const selectRoomIDsByLocalPartSQL = "" + + "SELECT room_id FROM account_memberships WHERE localpart = $1" + const deleteMembershipsByEventIDsSQL = "" + "DELETE FROM account_memberships WHERE event_id = ANY($1)" @@ -61,6 +64,7 @@ type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipInRoomByLocalpartStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt + selectRoomIDsByLocalPartStmt *sql.Stmt } func (s *membershipStatements) prepare(db *sql.DB) (err error) { @@ -80,6 +84,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { return } + if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil { + return + } return } @@ -131,3 +138,23 @@ func (s *membershipStatements) selectMembershipsByLocalpart( } return memberships, rows.Err() } + +func (s *membershipStatements) selectRoomIDsByLocalPart( + ctx context.Context, localPart string, +) ([]string, error) { + stmt := s.selectRoomIDsByLocalPartStmt + rows, err := stmt.QueryContext(ctx, localPart) + if err != nil { + return nil, err + } + roomIDs := []string{} + defer rows.Close() // nolint: errcheck + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, rows.Err() +} diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/clientapi/auth/storage/accounts/postgres/storage.go index 8115dca4..4a0a2060 100644 --- a/clientapi/auth/storage/accounts/postgres/storage.go +++ b/clientapi/auth/storage/accounts/postgres/storage.go @@ -234,6 +234,16 @@ func (d *Database) GetMembershipInRoomByLocalpart( return d.memberships.selectMembershipInRoomByLocalpart(ctx, localpart, roomID) } +// GetRoomIDsByLocalPart returns an array containing the room ids of all +// the rooms a user matching a given localpart is a member of +// If no membership match the given localpart, returns an empty array +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetRoomIDsByLocalPart( + ctx context.Context, localpart string, +) ([]string, error) { + return d.memberships.selectRoomIDsByLocalPart(ctx, localpart) +} + // GetMembershipsByLocalpart returns an array containing the memberships for all // the rooms a user matching a given localpart is a member of // If no membership match the given localpart, returns an empty array diff --git a/clientapi/auth/storage/accounts/sqlite3/membership_table.go b/clientapi/auth/storage/accounts/sqlite3/membership_table.go index b4bff633..bd9838b6 100644 --- a/clientapi/auth/storage/accounts/sqlite3/membership_table.go +++ b/clientapi/auth/storage/accounts/sqlite3/membership_table.go @@ -51,6 +51,9 @@ const selectMembershipsByLocalpartSQL = "" + const selectMembershipInRoomByLocalpartSQL = "" + "SELECT event_id FROM account_memberships WHERE localpart = $1 AND room_id = $2" +const selectRoomIDsByLocalPartSQL = "" + + "SELECT room_id FROM account_memberships WHERE localpart = $1" + const deleteMembershipsByEventIDsSQL = "" + "DELETE FROM account_memberships WHERE event_id IN ($1)" @@ -58,6 +61,7 @@ type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipInRoomByLocalpartStmt *sql.Stmt selectMembershipsByLocalpartStmt *sql.Stmt + selectRoomIDsByLocalPartStmt *sql.Stmt } func (s *membershipStatements) prepare(db *sql.DB) (err error) { @@ -74,6 +78,9 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) { if s.selectMembershipsByLocalpartStmt, err = db.Prepare(selectMembershipsByLocalpartSQL); err != nil { return } + if s.selectRoomIDsByLocalPartStmt, err = db.Prepare(selectRoomIDsByLocalPartSQL); err != nil { + return + } return } @@ -130,3 +137,22 @@ func (s *membershipStatements) selectMembershipsByLocalpart( return } +func (s *membershipStatements) selectRoomIDsByLocalPart( + ctx context.Context, localPart string, +) ([]string, error) { + stmt := s.selectRoomIDsByLocalPartStmt + rows, err := stmt.QueryContext(ctx, localPart) + if err != nil { + return nil, err + } + roomIDs := []string{} + defer rows.Close() // nolint: errcheck + for rows.Next() { + var roomID string + if err = rows.Scan(&roomID); err != nil { + return nil, err + } + roomIDs = append(roomIDs, roomID) + } + return roomIDs, rows.Err() +} diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/clientapi/auth/storage/accounts/sqlite3/storage.go index 9124640c..bfb7b4ea 100644 --- a/clientapi/auth/storage/accounts/sqlite3/storage.go +++ b/clientapi/auth/storage/accounts/sqlite3/storage.go @@ -253,6 +253,16 @@ func (d *Database) GetMembershipsByLocalpart( return d.memberships.selectMembershipsByLocalpart(ctx, localpart) } +// GetRoomIDsByLocalPart returns an array containing the room ids of all +// the rooms a user matching a given localpart is a member of +// If no membership match the given localpart, returns an empty array +// If there was an issue during the retrieval, returns the SQL error +func (d *Database) GetRoomIDsByLocalPart( + ctx context.Context, localpart string, +) ([]string, error) { + return d.memberships.selectRoomIDsByLocalPart(ctx, localpart) +} + // newMembership saves a new membership in the database. // If the event isn't a valid m.room.member event with type `join`, does nothing. // If an error occurred, returns the SQL error diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go index a6899eeb..0b846e5e 100644 --- a/clientapi/routing/memberships.go +++ b/clientapi/routing/memberships.go @@ -17,6 +17,8 @@ package routing import ( "net/http" + "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/common/config" @@ -25,10 +27,14 @@ import ( "github.com/matrix-org/util" ) -type response struct { +type getMembershipResponse struct { Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` } +type getJoinedRoomsResponse struct { + JoinedRooms []string `json:"joined_rooms"` +} + // GetMemberships implements GET /rooms/{roomId}/members func GetMemberships( req *http.Request, device *authtypes.Device, roomID string, joinedOnly bool, @@ -55,6 +61,27 @@ func GetMemberships( return util.JSONResponse{ Code: http.StatusOK, - JSON: response{queryRes.JoinEvents}, + JSON: getMembershipResponse{queryRes.JoinEvents}, + } +} + +func GetJoinedRooms( + req *http.Request, + device *authtypes.Device, + accountsDB accounts.Database, +) util.JSONResponse { + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError() + } + joinedRooms, err := accountsDB.GetRoomIDsByLocalPart(req.Context(), localpart) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("accountsDB.GetRoomIDsByLocalPart failed") + return jsonerror.InternalServerError() + } + return util.JSONResponse{ + Code: http.StatusOK, + JSON: getJoinedRoomsResponse{joinedRooms}, } } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index f0841b79..47b7b267 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -105,6 +105,12 @@ func Setup( ) }), ).Methods(http.MethodPost, http.MethodOptions) + r0mux.Handle("/joined_rooms", + common.MakeAuthAPI("joined_rooms", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + return GetJoinedRooms(req, device, accountDB) + }), + ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|leave|invite)}", common.MakeAuthAPI("membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse { vars, err := common.URLDecodeMapValues(mux.Vars(req)) diff --git a/go.mod b/go.mod index 82474944..25debd74 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/tidwall/pretty v1.0.1 // indirect github.com/uber/jaeger-client-go v2.22.1+incompatible github.com/uber/jaeger-lib v2.2.0+incompatible + go.uber.org/atomic v1.6.0 // indirect golang.org/x/crypto v0.0.0-20200115085410-6d4e4cb37c7d gopkg.in/Shopify/sarama.v1 v1.20.1 gopkg.in/h2non/bimg.v1 v1.0.18 diff --git a/sytest-whitelist b/sytest-whitelist index cac71828..03e11b83 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -218,3 +218,5 @@ Push rules come down in an initial /sync Regular users can add and delete aliases in the default room configuration Regular users can add and delete aliases when m.room.aliases is restricted GET /r0/capabilities is not public +GET /joined_rooms lists newly-created room +/joined_rooms returns only joined rooms