Peeking via MSC2753 (#1370)

Initial implementation of MSC2753, as tested by https://github.com/matrix-org/sytest/pull/944.
Doesn't yet handle unpeeks, peeked EDUs, or history viz changing during a peek - these will follow.
https://github.com/matrix-org/dendrite/pull/1370 has full details.
This commit is contained in:
Matthew Hodgson 2020-09-10 14:39:18 +01:00 committed by GitHub
parent 35564dd73c
commit 39507bacc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 1209 additions and 59 deletions

View File

@ -81,6 +81,18 @@ matrixdotorg/sytest-dendrite:latest tests/50federation/40devicelists.pl
``` ```
See [sytest.md](docs/sytest.md) for the full description of these flags. See [sytest.md](docs/sytest.md) for the full description of these flags.
You can try running sytest outside of docker for faster runs, but the dependencies can be temperamental
and we recommend using docker where possible.
```
cd sytest
export PERL5LIB=$HOME/lib/perl5
export PERL_MB_OPT=--install_base=$HOME
export PERL_MM_OPT=INSTALL_BASE=$HOME
./install-deps.pl
./run-tests.pl -I Dendrite::Monolith -d $PATH_TO_DENDRITE_BINARIES
```
Sometimes Sytest is testing the wrong thing or is flakey, so it will need to be patched. Sometimes Sytest is testing the wrong thing or is flakey, so it will need to be patched.
Ask on `#dendrite-dev:matrix.org` if you think this is the case for you and we'll be happy to help. Ask on `#dendrite-dev:matrix.org` if you think this is the case for you and we'll be happy to help.

View File

@ -52,7 +52,7 @@ func JoinRoomByIDOrAlias(
} }
} }
// If content was provided in the request then incude that // If content was provided in the request then include that
// in the request. It'll get used as a part of the membership // in the request. It'll get used as a part of the membership
// event content. // event content.
_ = httputil.UnmarshalJSONRequest(req, &joinReq.Content) _ = httputil.UnmarshalJSONRequest(req, &joinReq.Content)

View File

@ -0,0 +1,79 @@
// Copyright 2020 New Vector Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
func PeekRoomByIDOrAlias(
req *http.Request,
device *api.Device,
rsAPI roomserverAPI.RoomserverInternalAPI,
accountDB accounts.Database,
roomIDOrAlias string,
) util.JSONResponse {
// if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to
// to call /peek and /state on the remote server.
// TODO: in future we could skip this if we know we're already participating in the room,
// but this is fiddly in case we stop participating in the room.
// then we create a local peek.
peekReq := roomserverAPI.PerformPeekRequest{
RoomIDOrAlias: roomIDOrAlias,
UserID: device.UserID,
DeviceID: device.ID,
}
peekRes := roomserverAPI.PerformPeekResponse{}
// Check to see if any ?server_name= query parameters were
// given in the request.
if serverNames, ok := req.URL.Query()["server_name"]; ok {
for _, serverName := range serverNames {
peekReq.ServerNames = append(
peekReq.ServerNames,
gomatrixserverlib.ServerName(serverName),
)
}
}
// Ask the roomserver to perform the peek.
rsAPI.PerformPeek(req.Context(), &peekReq, &peekRes)
if peekRes.Error != nil {
return peekRes.Error.JSONResponse()
}
// if this user is already joined to the room, we let them peek anyway
// (given they might be about to part the room, and it makes things less fiddly)
// Peeking stops if none of the devices who started peeking have been
// /syncing for a while, or if everyone who was peeking calls /leave
// (or /unpeek with a server_name param? or DELETE /peek?)
// on the peeked room.
return util.JSONResponse{
Code: http.StatusOK,
// TODO: Put the response struct somewhere internal.
JSON: struct {
RoomID string `json:"room_id"`
}{peekRes.RoomID},
}
}

View File

@ -103,6 +103,17 @@ func Setup(
) )
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/peek/{roomIDOrAlias}",
httputil.MakeAuthAPI(gomatrixserverlib.Peek, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return PeekRoomByIDOrAlias(
req, device, rsAPI, accountDB, vars["roomIDOrAlias"],
)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/joined_rooms", r0mux.Handle("/joined_rooms",
httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetJoinedRooms(req, device, rsAPI) return GetJoinedRooms(req, device, rsAPI)

19
docs/peeking.md Normal file
View File

@ -0,0 +1,19 @@
## Peeking
Peeking is implemented as per [MSC2753](https://github.com/matrix-org/matrix-doc/pull/2753).
Implementationwise, this means:
* Users call `/peek` and `/unpeek` on the clientapi from a given device.
* The clientapi delegates these via HTTP to the roomserver, which coordinates peeking in general for a given room
* The roomserver writes an NewPeek event into the kafka log headed to the syncserver
* The syncserver tracks the existence of the local peek in its DB, and then starts waking up the peeking devices for the room in question, putting it in the `peek` section of the /sync response.
Questions (given this is [my](https://github.com/ara4n) first time hacking on Dendrite):
* The whole clientapi -> roomserver -> syncapi flow to initiate a peek seems very indirect. Is there a reason not to just let syncapi itself host the implementation of `/peek`?
In future, peeking over federation will be added as per [MSC2444](https://github.com/matrix-org/matrix-doc/pull/2444).
* The `roomserver` will kick the `federationsender` much as it does for a federated `/join` in order to trigger a federated `/peek`
* The `federationsender` tracks the existence of the remote peek in question
* The `federationsender` regularly renews the remote peek as long as there are still peeking devices syncing for it.
* TBD: how do we tell if there are no devices currently syncing for a given peeked room? The syncserver needs to tell the roomserver
somehow who then needs to warn the federationsender.

View File

@ -112,6 +112,13 @@ func (t *testRoomserverAPI) PerformJoin(
) { ) {
} }
func (t *testRoomserverAPI) PerformPeek(
ctx context.Context,
req *api.PerformPeekRequest,
res *api.PerformPeekResponse,
) {
}
func (t *testRoomserverAPI) PerformPublish( func (t *testRoomserverAPI) PerformPublish(
ctx context.Context, ctx context.Context,
req *api.PerformPublishRequest, req *api.PerformPublishRequest,

View File

@ -138,7 +138,7 @@ func (d *Database) StoreJSON(
var err error var err error
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nid, err = d.FederationSenderQueueJSON.InsertQueueJSON(ctx, txn, js) nid, err = d.FederationSenderQueueJSON.InsertQueueJSON(ctx, txn, js)
return nil return err
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("d.insertQueueJSON: %w", err) return nil, fmt.Errorf("d.insertQueueJSON: %w", err)

View File

@ -74,7 +74,7 @@ func (in *traceInterceptor) RowsNext(c context.Context, rows driver.Rows, dest [
b := strings.Builder{} b := strings.Builder{}
for i, val := range dest { for i, val := range dest {
b.WriteString(fmt.Sprintf("%v", val)) b.WriteString(fmt.Sprintf("%q", val))
if i+1 <= len(dest)-1 { if i+1 <= len(dest)-1 {
b.WriteString(" | ") b.WriteString(" | ")
} }

View File

@ -41,7 +41,7 @@ func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID str
func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
return nil return err
}) })
return return
} }

View File

@ -36,6 +36,12 @@ type RoomserverInternalAPI interface {
res *PerformLeaveResponse, res *PerformLeaveResponse,
) error ) error
PerformPeek(
ctx context.Context,
req *PerformPeekRequest,
res *PerformPeekResponse,
)
PerformPublish( PerformPublish(
ctx context.Context, ctx context.Context,
req *PerformPublishRequest, req *PerformPublishRequest,

View File

@ -38,6 +38,15 @@ func (t *RoomserverInternalAPITrace) PerformInvite(
return t.Impl.PerformInvite(ctx, req, res) return t.Impl.PerformInvite(ctx, req, res)
} }
func (t *RoomserverInternalAPITrace) PerformPeek(
ctx context.Context,
req *PerformPeekRequest,
res *PerformPeekResponse,
) {
t.Impl.PerformPeek(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPeek req=%+v res=%+v", js(req), js(res))
}
func (t *RoomserverInternalAPITrace) PerformJoin( func (t *RoomserverInternalAPITrace) PerformJoin(
ctx context.Context, ctx context.Context,
req *PerformJoinRequest, req *PerformJoinRequest,

View File

@ -46,6 +46,9 @@ const (
// - Redact the event and set the corresponding `unsigned` fields to indicate it as redacted. // - Redact the event and set the corresponding `unsigned` fields to indicate it as redacted.
// - Replace the event in the database. // - Replace the event in the database.
OutputTypeRedactedEvent OutputType = "redacted_event" OutputTypeRedactedEvent OutputType = "redacted_event"
// OutputTypeNewPeek indicates that the kafka event is an OutputNewPeek
OutputTypeNewPeek OutputType = "new_peek"
) )
// An OutputEvent is an entry in the roomserver output kafka log. // An OutputEvent is an entry in the roomserver output kafka log.
@ -59,8 +62,10 @@ type OutputEvent struct {
NewInviteEvent *OutputNewInviteEvent `json:"new_invite_event,omitempty"` NewInviteEvent *OutputNewInviteEvent `json:"new_invite_event,omitempty"`
// The content of event with type OutputTypeRetireInviteEvent // The content of event with type OutputTypeRetireInviteEvent
RetireInviteEvent *OutputRetireInviteEvent `json:"retire_invite_event,omitempty"` RetireInviteEvent *OutputRetireInviteEvent `json:"retire_invite_event,omitempty"`
// The content of event with type OutputTypeRedactedEvent // The content of event with type OutputTypeRedactedEvent
RedactedEvent *OutputRedactedEvent `json:"redacted_event,omitempty"` RedactedEvent *OutputRedactedEvent `json:"redacted_event,omitempty"`
// The content of event with type OutputTypeNewPeek
NewPeek *OutputNewPeek `json:"new_peek,omitempty"`
} }
// An OutputNewRoomEvent is written when the roomserver receives a new event. // An OutputNewRoomEvent is written when the roomserver receives a new event.
@ -195,3 +200,11 @@ type OutputRedactedEvent struct {
// The value of `unsigned.redacted_because` - the redaction event itself // The value of `unsigned.redacted_because` - the redaction event itself
RedactedBecause gomatrixserverlib.HeaderedEvent RedactedBecause gomatrixserverlib.HeaderedEvent
} }
// An OutputNewPeek is written whenever a user starts peeking into a room
// using a given device.
type OutputNewPeek struct {
RoomID string
UserID string
DeviceID string
}

View File

@ -108,6 +108,20 @@ type PerformInviteResponse struct {
Error *PerformError Error *PerformError
} }
type PerformPeekRequest struct {
RoomIDOrAlias string `json:"room_id_or_alias"`
UserID string `json:"user_id"`
DeviceID string `json:"device_id"`
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
}
type PerformPeekResponse struct {
// The room ID, populated on success.
RoomID string `json:"room_id"`
// If non-nil, the join request failed. Contains more information why it failed.
Error *PerformError
}
// PerformBackfillRequest is a request to PerformBackfill. // PerformBackfillRequest is a request to PerformBackfill.
type PerformBackfillRequest struct { type PerformBackfillRequest struct {
// The room to backfill // The room to backfill

View File

@ -22,6 +22,7 @@ type RoomserverInternalAPI struct {
*query.Queryer *query.Queryer
*perform.Inviter *perform.Inviter
*perform.Joiner *perform.Joiner
*perform.Peeker
*perform.Leaver *perform.Leaver
*perform.Publisher *perform.Publisher
*perform.Backfiller *perform.Backfiller
@ -83,6 +84,13 @@ func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSen
FSAPI: r.fsAPI, FSAPI: r.fsAPI,
Inputer: r.Inputer, Inputer: r.Inputer,
} }
r.Peeker = &perform.Peeker{
ServerName: r.Cfg.Matrix.ServerName,
Cfg: r.Cfg,
DB: r.DB,
FSAPI: r.fsAPI,
Inputer: r.Inputer,
}
r.Leaver = &perform.Leaver{ r.Leaver = &perform.Leaver{
Cfg: r.Cfg, Cfg: r.Cfg,
DB: r.DB, DB: r.DB,

View File

@ -0,0 +1,206 @@
// Copyright 2020 New Vector Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package perform
import (
"context"
"encoding/json"
"fmt"
"strings"
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/input"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
)
type Peeker struct {
ServerName gomatrixserverlib.ServerName
Cfg *config.RoomServer
FSAPI fsAPI.FederationSenderInternalAPI
DB storage.Database
Inputer *input.Inputer
}
// PerformPeek handles peeking into matrix rooms, including over federation by talking to the federationsender.
func (r *Peeker) PerformPeek(
ctx context.Context,
req *api.PerformPeekRequest,
res *api.PerformPeekResponse,
) {
roomID, err := r.performPeek(ctx, req)
if err != nil {
perr, ok := err.(*api.PerformError)
if ok {
res.Error = perr
} else {
res.Error = &api.PerformError{
Msg: err.Error(),
}
}
}
res.RoomID = roomID
}
func (r *Peeker) performPeek(
ctx context.Context,
req *api.PerformPeekRequest,
) (string, error) {
// FIXME: there's way too much duplication with performJoin
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return "", &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
}
}
if domain != r.Cfg.Matrix.ServerName {
return "", &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
}
}
if strings.HasPrefix(req.RoomIDOrAlias, "!") {
return r.performPeekRoomByID(ctx, req)
}
if strings.HasPrefix(req.RoomIDOrAlias, "#") {
return r.performPeekRoomByAlias(ctx, req)
}
return "", &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias),
}
}
func (r *Peeker) performPeekRoomByAlias(
ctx context.Context,
req *api.PerformPeekRequest,
) (string, error) {
// Get the domain part of the room alias.
_, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias)
if err != nil {
return "", fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias)
}
req.ServerNames = append(req.ServerNames, domain)
// Check if this alias matches our own server configuration. If it
// doesn't then we'll need to try a federated peek.
var roomID string
if domain != r.Cfg.Matrix.ServerName {
// The alias isn't owned by us, so we will need to try peeking using
// a remote server.
dirReq := fsAPI.PerformDirectoryLookupRequest{
RoomAlias: req.RoomIDOrAlias, // the room alias to lookup
ServerName: domain, // the server to ask
}
dirRes := fsAPI.PerformDirectoryLookupResponse{}
err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes)
if err != nil {
logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias)
return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err)
}
roomID = dirRes.RoomID
req.ServerNames = append(req.ServerNames, dirRes.ServerNames...)
} else {
// Otherwise, look up if we know this room alias locally.
roomID, err = r.DB.GetRoomIDForAlias(ctx, req.RoomIDOrAlias)
if err != nil {
return "", fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err)
}
}
// If the room ID is empty then we failed to look up the alias.
if roomID == "" {
return "", fmt.Errorf("Alias %q not found", req.RoomIDOrAlias)
}
// If we do, then pluck out the room ID and continue the peek.
req.RoomIDOrAlias = roomID
return r.performPeekRoomByID(ctx, req)
}
func (r *Peeker) performPeekRoomByID(
ctx context.Context,
req *api.PerformPeekRequest,
) (roomID string, err error) {
roomID = req.RoomIDOrAlias
// Get the domain part of the room ID.
_, domain, err := gomatrixserverlib.SplitID('!', roomID)
if err != nil {
return "", &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: fmt.Sprintf("Room ID %q is invalid: %s", roomID, err),
}
}
// If the server name in the room ID isn't ours then it's a
// possible candidate for finding the room via federation. Add
// it to the list of servers to try.
if domain != r.Cfg.Matrix.ServerName {
req.ServerNames = append(req.ServerNames, domain)
}
// If this room isn't world_readable, we reject.
// XXX: would be nicer to call this with NIDs
// XXX: we should probably factor out history_visibility checks into a common utility method somewhere
// which handles the default value etc.
var worldReadable = false
ev, _ := r.DB.GetStateEvent(ctx, roomID, "m.room.history_visibility", "")
if ev != nil {
content := map[string]string{}
if err = json.Unmarshal(ev.Content(), &content); err != nil {
util.GetLogger(ctx).WithError(err).Error("json.Unmarshal for history visibility failed")
return
}
if visibility, ok := content["history_visibility"]; ok {
worldReadable = visibility == "world_readable"
}
}
if !worldReadable {
return "", &api.PerformError{
Code: api.PerformErrorNotAllowed,
Msg: "Room is not world-readable",
}
}
// TODO: handle federated peeks
err = r.Inputer.WriteOutputEvents(roomID, []api.OutputEvent{
{
Type: api.OutputTypeNewPeek,
NewPeek: &api.OutputNewPeek{
RoomID: roomID,
UserID: req.UserID,
DeviceID: req.DeviceID,
},
},
})
if err != nil {
return
}
// By this point, if req.RoomIDOrAlias contained an alias, then
// it will have been overwritten with a room ID by performPeekRoomByAlias.
// We should now include this in the response so that the CS API can
// return the right room ID.
return roomID, nil
}

View File

@ -26,6 +26,7 @@ const (
// Perform operations // Perform operations
RoomserverPerformInvitePath = "/roomserver/performInvite" RoomserverPerformInvitePath = "/roomserver/performInvite"
RoomserverPerformPeekPath = "/roomserver/performPeek"
RoomserverPerformJoinPath = "/roomserver/performJoin" RoomserverPerformJoinPath = "/roomserver/performJoin"
RoomserverPerformLeavePath = "/roomserver/performLeave" RoomserverPerformLeavePath = "/roomserver/performLeave"
RoomserverPerformBackfillPath = "/roomserver/performBackfill" RoomserverPerformBackfillPath = "/roomserver/performBackfill"
@ -185,6 +186,23 @@ func (h *httpRoomserverInternalAPI) PerformJoin(
} }
} }
func (h *httpRoomserverInternalAPI) PerformPeek(
ctx context.Context,
request *api.PerformPeekRequest,
response *api.PerformPeekResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPeek")
defer span.Finish()
apiURL := h.roomserverURL + RoomserverPerformPeekPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.PerformError{
Msg: fmt.Sprintf("failed to communicate with roomserver: %s", err),
}
}
}
func (h *httpRoomserverInternalAPI) PerformLeave( func (h *httpRoomserverInternalAPI) PerformLeave(
ctx context.Context, ctx context.Context,
request *api.PerformLeaveRequest, request *api.PerformLeaveRequest,

View File

@ -63,6 +63,17 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(RoomserverPerformPeekPath,
httputil.MakeInternalAPI("performPeek", func(req *http.Request) util.JSONResponse {
var request api.PerformPeekRequest
var response api.PerformPeekResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
r.PerformPeek(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(RoomserverPerformPublishPath, internalAPIMux.Handle(RoomserverPerformPublishPath,
httputil.MakeInternalAPI("performPublish", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("performPublish", func(req *http.Request) util.JSONResponse {
var request api.PerformPublishRequest var request api.PerformPublishRequest

View File

@ -359,7 +359,7 @@ func (d *Database) MembershipUpdater(
var updater *MembershipUpdater var updater *MembershipUpdater
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
updater, err = NewMembershipUpdater(ctx, d, txn, roomID, targetUserID, targetLocal, roomVersion) updater, err = NewMembershipUpdater(ctx, d, txn, roomID, targetUserID, targetLocal, roomVersion)
return nil return err
}) })
return updater, err return updater, err
} }
@ -374,7 +374,7 @@ func (d *Database) GetLatestEventsForUpdate(
var updater *LatestEventsUpdater var updater *LatestEventsUpdater
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo) updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo)
return nil return err
}) })
return updater, err return updater, err
} }

View File

@ -17,6 +17,7 @@ package consumers
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
@ -26,11 +27,13 @@ import (
"github.com/matrix-org/dendrite/syncapi/sync" "github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
cfg *config.SyncAPI
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
rsConsumer *internal.ContinualConsumer rsConsumer *internal.ContinualConsumer
db storage.Database db storage.Database
@ -55,6 +58,7 @@ func NewOutputRoomEventConsumer(
PartitionStore: store, PartitionStore: store,
} }
s := &OutputRoomEventConsumer{ s := &OutputRoomEventConsumer{
cfg: cfg,
rsConsumer: &consumer, rsConsumer: &consumer,
db: store, db: store,
notifier: n, notifier: n,
@ -100,6 +104,8 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
return s.onNewInviteEvent(context.TODO(), *output.NewInviteEvent) return s.onNewInviteEvent(context.TODO(), *output.NewInviteEvent)
case api.OutputTypeRetireInviteEvent: case api.OutputTypeRetireInviteEvent:
return s.onRetireInviteEvent(context.TODO(), *output.RetireInviteEvent) return s.onRetireInviteEvent(context.TODO(), *output.RetireInviteEvent)
case api.OutputTypeNewPeek:
return s.onNewPeek(context.TODO(), *output.NewPeek)
case api.OutputTypeRedactedEvent: case api.OutputTypeRedactedEvent:
return s.onRedactEvent(context.TODO(), *output.RedactedEvent) return s.onRedactEvent(context.TODO(), *output.RedactedEvent)
default: default:
@ -162,6 +168,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
}).Panicf("roomserver output log: write event failure") }).Panicf("roomserver output log: write event failure")
return nil return nil
} }
if pduPos, err = s.notifyJoinedPeeks(ctx, &ev, pduPos); err != nil {
logrus.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
return err
}
s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil)) s.notifier.OnNewEvent(&ev, "", nil, types.NewStreamToken(pduPos, 0, nil))
s.notifyKeyChanges(&ev) s.notifyKeyChanges(&ev)
@ -184,6 +196,37 @@ func (s *OutputRoomEventConsumer) notifyKeyChanges(ev *gomatrixserverlib.Headere
} }
} }
func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, sp types.StreamPosition) (types.StreamPosition, error) {
if ev.Type() != gomatrixserverlib.MRoomMember {
return sp, nil
}
membership, err := ev.Membership()
if err != nil {
return sp, fmt.Errorf("ev.Membership: %w", err)
}
// TODO: check that it's a join and not a profile change (means unmarshalling prev_content)
if membership == gomatrixserverlib.Join {
// check it's a local join
_, domain, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return sp, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if domain != s.cfg.Matrix.ServerName {
return sp, nil
}
// cancel any peeks for it
peekSP, peekErr := s.db.DeletePeeks(ctx, ev.RoomID(), *ev.StateKey())
if peekErr != nil {
return sp, fmt.Errorf("s.db.DeletePeeks: %w", peekErr)
}
if peekSP > 0 {
sp = peekSP
}
}
return sp, nil
}
func (s *OutputRoomEventConsumer) onNewInviteEvent( func (s *OutputRoomEventConsumer) onNewInviteEvent(
ctx context.Context, msg api.OutputNewInviteEvent, ctx context.Context, msg api.OutputNewInviteEvent,
) error { ) error {
@ -219,6 +262,26 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
return nil return nil
} }
func (s *OutputRoomEventConsumer) onNewPeek(
ctx context.Context, msg api.OutputNewPeek,
) error {
sp, err := s.db.AddPeek(ctx, msg.RoomID, msg.UserID, msg.DeviceID)
if err != nil {
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{
log.ErrorKey: err,
}).Panicf("roomserver output log: write peek failure")
return nil
}
// tell the notifier about the new peek so it knows to wake up new devices
s.notifier.OnNewPeek(msg.RoomID, msg.UserID, msg.DeviceID)
// we need to wake up the users who might need to now be peeking into this room,
// so we send in a dummy event to trigger a wakeup
s.notifier.OnNewEvent(nil, msg.RoomID, nil, types.NewStreamToken(sp, 0, nil))
return nil
}
func (s *OutputRoomEventConsumer) updateStateEvent(event gomatrixserverlib.HeaderedEvent) (gomatrixserverlib.HeaderedEvent, error) { func (s *OutputRoomEventConsumer) updateStateEvent(event gomatrixserverlib.HeaderedEvent) (gomatrixserverlib.HeaderedEvent, error) {
if event.StateKey() == nil { if event.StateKey() == nil {
return event, nil return event, nil

View File

@ -30,6 +30,8 @@ type Database interface {
internal.PartitionStorer internal.PartitionStorer
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs. // AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
// AllPeekingDevicesInRooms returns a map of room ID to a list of all peeking devices.
AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error)
// Events lookups a list of event by their event ID. // Events lookups a list of event by their event ID.
// Returns a list of events matching the requested IDs found in the database. // Returns a list of events matching the requested IDs found in the database.
// If an event is not found in the database then it will be omitted from the list. // If an event is not found in the database then it will be omitted from the list.
@ -81,6 +83,12 @@ type Database interface {
// RetireInviteEvent removes an old invite event from the database. Returns the new position of the retired invite. // RetireInviteEvent removes an old invite event from the database. Returns the new position of the retired invite.
// Returns an error if there was a problem communicating with the database. // Returns an error if there was a problem communicating with the database.
RetireInviteEvent(ctx context.Context, inviteEventID string) (types.StreamPosition, error) RetireInviteEvent(ctx context.Context, inviteEventID string) (types.StreamPosition, error)
// AddPeek adds a new peek to our DB for a given room by a given user's device.
// Returns an error if there was a problem communicating with the database.
AddPeek(ctx context.Context, RoomID, UserID, DeviceID string) (types.StreamPosition, error)
// DeletePeek deletes all peeks for a given room by a given user
// Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// SetTypingTimeoutCallback sets a callback function that is called right after // SetTypingTimeoutCallback sets a callback function that is called right after
// a user is removed from the typing user list due to timeout. // a user is removed from the typing user list due to timeout.
SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn)

View File

@ -0,0 +1,186 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
const peeksSchema = `
CREATE TABLE IF NOT EXISTS syncapi_peeks (
id BIGINT DEFAULT nextval('syncapi_stream_id'),
room_id TEXT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
deleted BOOL NOT NULL DEFAULT false,
-- When the peek was created in UNIX epoch ms.
creation_ts BIGINT NOT NULL,
UNIQUE(room_id, user_id, device_id)
);
CREATE INDEX IF NOT EXISTS syncapi_peeks_room_id_idx ON syncapi_peeks(room_id);
CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks(user_id, device_id);
`
const insertPeekSQL = "" +
"INSERT INTO syncapi_peeks" +
" (room_id, user_id, device_id, creation_ts)" +
" VALUES ($1, $2, $3, $4)" +
" ON CONFLICT (room_id, user_id, device_id) DO UPDATE SET deleted=false, creation_ts=$4" +
" RETURNING id"
const deletePeekSQL = "" +
"UPDATE syncapi_peeks SET deleted=true, id=nextval('syncapi_stream_id') WHERE room_id = $1 AND user_id = $2 AND device_id = $3 RETURNING id"
const deletePeeksSQL = "" +
"UPDATE syncapi_peeks SET deleted=true, id=nextval('syncapi_stream_id') WHERE room_id = $1 AND user_id = $2 RETURNING id"
// we care about all the peeks which were created in this range, deleted in this range,
// or were created before this range but haven't been deleted yet.
const selectPeeksInRangeSQL = "" +
"SELECT room_id, deleted, (id > $3 AND id <= $4) AS changed FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted) OR (id > $3 AND id <= $4))"
const selectPeekingDevicesSQL = "" +
"SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false"
const selectMaxPeekIDSQL = "" +
"SELECT MAX(id) FROM syncapi_peeks"
type peekStatements struct {
db *sql.DB
insertPeekStmt *sql.Stmt
deletePeekStmt *sql.Stmt
deletePeeksStmt *sql.Stmt
selectPeeksInRangeStmt *sql.Stmt
selectPeekingDevicesStmt *sql.Stmt
selectMaxPeekIDStmt *sql.Stmt
}
func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) {
_, err := db.Exec(peeksSchema)
if err != nil {
return nil, err
}
s := &peekStatements{
db: db,
}
if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil {
return nil, err
}
if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil {
return nil, err
}
if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil {
return nil, err
}
if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil {
return nil, err
}
if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil {
return nil, err
}
if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *peekStatements) InsertPeek(
ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string,
) (streamPos types.StreamPosition, err error) {
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
stmt := sqlutil.TxStmt(txn, s.insertPeekStmt)
err = stmt.QueryRowContext(ctx, roomID, userID, deviceID, nowMilli).Scan(&streamPos)
return
}
func (s *peekStatements) DeletePeek(
ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string,
) (streamPos types.StreamPosition, err error) {
stmt := sqlutil.TxStmt(txn, s.deletePeekStmt)
err = stmt.QueryRowContext(ctx, roomID, userID, deviceID).Scan(&streamPos)
return
}
func (s *peekStatements) DeletePeeks(
ctx context.Context, txn *sql.Tx, roomID, userID string,
) (streamPos types.StreamPosition, err error) {
stmt := sqlutil.TxStmt(txn, s.deletePeeksStmt)
err = stmt.QueryRowContext(ctx, roomID, userID).Scan(&streamPos)
return
}
func (s *peekStatements) SelectPeeksInRange(
ctx context.Context, txn *sql.Tx, userID, deviceID string, r types.Range,
) (peeks []types.Peek, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectPeeksInRangeStmt).QueryContext(ctx, userID, deviceID, r.Low(), r.High())
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectPeeksInRange: rows.close() failed")
for rows.Next() {
peek := types.Peek{}
var changed bool
if err = rows.Scan(&peek.RoomID, &peek.Deleted, &changed); err != nil {
return
}
peek.New = changed && !peek.Deleted
peeks = append(peeks, peek)
}
return peeks, rows.Err()
}
func (s *peekStatements) SelectPeekingDevices(
ctx context.Context,
) (peekingDevices map[string][]types.PeekingDevice, err error) {
rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectPeekingDevices: rows.close() failed")
result := make(map[string][]types.PeekingDevice)
for rows.Next() {
var roomID, userID, deviceID string
if err := rows.Scan(&roomID, &userID, &deviceID); err != nil {
return nil, err
}
devices := result[roomID]
devices = append(devices, types.PeekingDevice{UserID: userID, DeviceID: deviceID})
result[roomID] = devices
}
return result, nil
}
func (s *peekStatements) SelectMaxPeekID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxPeekIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return
}

View File

@ -62,6 +62,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if err != nil { if err != nil {
return nil, err return nil, err
} }
peeks, err := NewPostgresPeeksTable(d.db)
if err != nil {
return nil, err
}
topology, err := NewPostgresTopologyTable(d.db) topology, err := NewPostgresTopologyTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -82,6 +86,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,
Invites: invites, Invites: invites,
Peeks: peeks,
AccountData: accountData, AccountData: accountData,
OutputEvents: events, OutputEvents: events,
Topology: topology, Topology: topology,

View File

@ -30,7 +30,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite // Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite
@ -39,6 +39,7 @@ type Database struct {
DB *sql.DB DB *sql.DB
Writer sqlutil.Writer Writer sqlutil.Writer
Invites tables.Invites Invites tables.Invites
Peeks tables.Peeks
AccountData tables.AccountData AccountData tables.AccountData
OutputEvents tables.Events OutputEvents tables.Events
Topology tables.Topology Topology tables.Topology
@ -120,6 +121,10 @@ func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]stri
return d.CurrentRoomState.SelectJoinedUsers(ctx) return d.CurrentRoomState.SelectJoinedUsers(ctx)
} }
func (d *Database) AllPeekingDevicesInRooms(ctx context.Context) (map[string][]types.PeekingDevice, error) {
return d.Peeks.SelectPeekingDevices(ctx)
}
func (d *Database) GetStateEvent( func (d *Database) GetStateEvent(
ctx context.Context, roomID, evType, stateKey string, ctx context.Context, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) { ) (*gomatrixserverlib.HeaderedEvent, error) {
@ -141,7 +146,7 @@ func (d *Database) AddInviteEvent(
) (sp types.StreamPosition, err error) { ) (sp types.StreamPosition, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent)
return nil return err
}) })
return return
} }
@ -153,11 +158,41 @@ func (d *Database) RetireInviteEvent(
) (sp types.StreamPosition, err error) { ) (sp types.StreamPosition, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID) sp, err = d.Invites.DeleteInviteEvent(ctx, txn, inviteEventID)
return nil return err
}) })
return return
} }
// AddPeek tracks the fact that a user has started peeking.
// If the peek was successfully stored this returns the stream ID it was stored at.
// Returns an error if there was a problem communicating with the database.
func (d *Database) AddPeek(
ctx context.Context, roomID, userID, deviceID string,
) (sp types.StreamPosition, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.Peeks.InsertPeek(ctx, txn, roomID, userID, deviceID)
return err
})
return
}
// DeletePeeks tracks the fact that a user has stopped peeking from all devices
// If the peeks was successfully deleted this returns the stream ID it was stored at.
// Returns an error if there was a problem communicating with the database.
func (d *Database) DeletePeeks(
ctx context.Context, roomID, userID string,
) (sp types.StreamPosition, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID)
return err
})
if err == sql.ErrNoRows {
sp = 0
err = nil
}
return
}
// GetAccountDataInRange returns all account data for a given user inserted or // GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions // updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes // Returns a map following the format data[roomID] = []dataTypes
@ -196,7 +231,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea
"transaction_id", in[i].TransactionID.TransactionID, "transaction_id", in[i].TransactionID.TransactionID,
) )
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ log.WithFields(log.Fields{
"event_id": out[i].EventID(), "event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event") }).WithError(err).Warnf("Failed to add transaction ID to event")
} }
@ -389,7 +424,6 @@ func (d *Database) EventPositionInTopology(
func (d *Database) syncPositionTx( func (d *Database) syncPositionTx(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (sp types.StreamingToken, err error) { ) (sp types.StreamingToken, err error) {
maxEventID, err := d.OutputEvents.SelectMaxEventID(ctx, txn) maxEventID, err := d.OutputEvents.SelectMaxEventID(ctx, txn)
if err != nil { if err != nil {
return sp, err return sp, err
@ -408,6 +442,13 @@ func (d *Database) syncPositionTx(
if maxInviteID > maxEventID { if maxInviteID > maxEventID {
maxEventID = maxInviteID maxEventID = maxInviteID
} }
maxPeekID, err := d.Peeks.SelectMaxPeekID(ctx, txn)
if err != nil {
return sp, err
}
if maxPeekID > maxEventID {
maxEventID = maxPeekID
}
sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition()), nil) sp = types.NewStreamToken(types.StreamPosition(maxEventID), types.StreamPosition(d.EDUCache.GetLatestSyncPosition()), nil)
return return
} }
@ -566,6 +607,8 @@ func (d *Database) IncrementalSync(
} }
} }
// TODO: handle EDUs in peeked rooms
err = d.addEDUDeltaToResponse( err = d.addEDUDeltaToResponse(
fromPos, toPos, joinedRoomIDs, res, fromPos, toPos, joinedRoomIDs, res,
) )
@ -582,7 +625,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
return err return err
} }
if len(redactedEvents) == 0 { if len(redactedEvents) == 0 {
logrus.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction") log.WithField("event_id", redactedEventID).WithField("redaction_event", redactedBecause.EventID()).Warnf("missing redacted event for redaction")
return nil return nil
} }
eventToRedact := redactedEvents[0].Unwrap() eventToRedact := redactedEvents[0].Unwrap()
@ -604,7 +647,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
// nolint:nakedret // nolint:nakedret
func (d *Database) getResponseWithPDUsForCompleteSync( func (d *Database) getResponseWithPDUsForCompleteSync(
ctx context.Context, res *types.Response, ctx context.Context, res *types.Response,
userID string, userID string, deviceID string,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
) ( ) (
toPos types.StreamingToken, toPos types.StreamingToken,
@ -644,46 +687,32 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
// Build up a /sync response. Add joined rooms. // Build up a /sync response. Add joined rooms.
for _, roomID := range joinedRoomIDs { for _, roomID := range joinedRoomIDs {
var stateEvents []gomatrixserverlib.HeaderedEvent var jr *types.JoinResponse
stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, &stateFilter) jr, err = d.getJoinResponseForCompleteSync(
if err != nil { ctx, txn, roomID, r, &stateFilter, numRecentEventsPerRoom,
return
}
// TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
var recentStreamEvents []types.StreamEvent
var limited bool
recentStreamEvents, limited, err = d.OutputEvents.SelectRecentEvents(
ctx, txn, roomID, r, numRecentEventsPerRoom, true, true,
) )
if err != nil { if err != nil {
return return
} }
res.Rooms.Join[roomID] = *jr
}
// Retrieve the backward topology position, i.e. the position of the // Add peeked rooms.
// oldest event in the room's topology. peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, deviceID, r)
var prevBatchStr string if err != nil {
if len(recentStreamEvents) > 0 { return
var backwardTopologyPos, backwardStreamPos types.StreamPosition }
backwardTopologyPos, backwardStreamPos, err = d.Topology.SelectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) for _, peek := range peeks {
if !peek.Deleted {
var jr *types.JoinResponse
jr, err = d.getJoinResponseForCompleteSync(
ctx, txn, peek.RoomID, r, &stateFilter, numRecentEventsPerRoom,
)
if err != nil { if err != nil {
return return
} }
prevBatch := types.NewTopologyToken(backwardTopologyPos, backwardStreamPos) res.Rooms.Peek[peek.RoomID] = *jr
prevBatch.Decrement()
prevBatchStr = prevBatch.String()
} }
// We don't include a device here as we don't need to send down
// transaction IDs for complete syncs
recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatchStr
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[roomID] = *jr
} }
if err = d.addInvitesToResponse(ctx, txn, userID, r, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, userID, r, res); err != nil {
@ -694,17 +723,68 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
return //res, toPos, joinedRoomIDs, err return //res, toPos, joinedRoomIDs, err
} }
func (d *Database) getJoinResponseForCompleteSync(
ctx context.Context, txn *sql.Tx,
roomID string,
r types.Range,
stateFilter *gomatrixserverlib.StateFilter,
numRecentEventsPerRoom int,
) (jr *types.JoinResponse, err error) {
var stateEvents []gomatrixserverlib.HeaderedEvent
stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter)
if err != nil {
return
}
// TODO: When filters are added, we may need to call this multiple times to get enough events.
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
var recentStreamEvents []types.StreamEvent
var limited bool
recentStreamEvents, limited, err = d.OutputEvents.SelectRecentEvents(
ctx, txn, roomID, r, numRecentEventsPerRoom, true, true,
)
if err != nil {
return
}
// Retrieve the backward topology position, i.e. the position of the
// oldest event in the room's topology.
var prevBatchStr string
if len(recentStreamEvents) > 0 {
var backwardTopologyPos, backwardStreamPos types.StreamPosition
backwardTopologyPos, backwardStreamPos, err = d.Topology.SelectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID())
if err != nil {
return
}
prevBatch := types.NewTopologyToken(backwardTopologyPos, backwardStreamPos)
prevBatch.Decrement()
prevBatchStr = prevBatch.String()
}
// We don't include a device here as we don't need to send down
// transaction IDs for complete syncs
recentEvents := d.StreamEventsToEvents(nil, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents)
jr = types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatchStr
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
return jr, nil
}
func (d *Database) CompleteSync( func (d *Database) CompleteSync(
ctx context.Context, res *types.Response, ctx context.Context, res *types.Response,
device userapi.Device, numRecentEventsPerRoom int, device userapi.Device, numRecentEventsPerRoom int,
) (*types.Response, error) { ) (*types.Response, error) {
toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync( toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
ctx, res, device.UserID, numRecentEventsPerRoom, ctx, res, device.UserID, device.ID, numRecentEventsPerRoom,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("d.getResponseWithPDUsForCompleteSync: %w", err) return nil, fmt.Errorf("d.getResponseWithPDUsForCompleteSync: %w", err)
} }
// TODO: handle EDUs in peeked rooms
// Use a zero value SyncPosition for fromPos so all EDU states are added. // Use a zero value SyncPosition for fromPos so all EDU states are added.
err = d.addEDUDeltaToResponse( err = d.addEDUDeltaToResponse(
types.NewStreamToken(0, 0, nil), toPos, joinedRoomIDs, res, types.NewStreamToken(0, 0, nil), toPos, joinedRoomIDs, res,
@ -803,6 +883,12 @@ func (d *Database) addRoomDeltaToResponse(
return err return err
} }
// XXX: should we ever get this far if we have no recent events or state in this room?
// in practice we do for peeks, but possibly not joins?
if len(recentEvents) == 0 && len(delta.stateEvents) == 0 {
return nil
}
switch delta.membership { switch delta.membership {
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
@ -812,6 +898,14 @@ func (d *Database) addRoomDeltaToResponse(
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Join[delta.roomID] = *jr res.Rooms.Join[delta.roomID] = *jr
case gomatrixserverlib.Peek:
jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatch.String()
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
res.Rooms.Peek[delta.roomID] = *jr
case gomatrixserverlib.Leave: case gomatrixserverlib.Leave:
fallthrough // transitions to leave are the same as ban fallthrough // transitions to leave are the same as ban
case gomatrixserverlib.Ban: case gomatrixserverlib.Ban:
@ -918,6 +1012,7 @@ func (d *Database) fetchMissingStateEvents(
// exclusive of oldPos, inclusive of newPos, for the rooms in which // exclusive of oldPos, inclusive of newPos, for the rooms in which
// the user has new membership events. // the user has new membership events.
// A list of joined room IDs is also returned in case the caller needs it. // A list of joined room IDs is also returned in case the caller needs it.
// nolint:gocyclo
func (d *Database) getStateDeltas( func (d *Database) getStateDeltas(
ctx context.Context, device *userapi.Device, txn *sql.Tx, ctx context.Context, device *userapi.Device, txn *sql.Tx,
r types.Range, userID string, r types.Range, userID string,
@ -933,7 +1028,7 @@ func (d *Database) getStateDeltas(
// - Get all CURRENTLY joined rooms, and add them to 'joined' block. // - Get all CURRENTLY joined rooms, and add them to 'joined' block.
var deltas []stateDelta var deltas []stateDelta
// get all the state events ever between these two positions // get all the state events ever (i.e. for all available rooms) between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -943,6 +1038,34 @@ func (d *Database) getStateDeltas(
return nil, nil, err return nil, nil, err
} }
// find out which rooms this user is peeking, if any.
// We do this before joins so any peeks get overwritten
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
if err != nil {
return nil, nil, err
}
// add peek blocks
for _, peek := range peeks {
if peek.New {
// send full room state down instead of a delta
var s []types.StreamEvent
s, err = d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
if err != nil {
return nil, nil, err
}
state[peek.RoomID] = s
}
if !peek.Deleted {
deltas = append(deltas, stateDelta{
membership: gomatrixserverlib.Peek,
stateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]),
roomID: peek.RoomID,
})
}
}
// handle newly joined rooms and non-joined rooms
for roomID, stateStreamEvents := range state { for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents { for _, ev := range stateStreamEvents {
// TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event. // TODO: Currently this will incorrectly add rooms which were ALREADY joined but they sent another no-op join event.
@ -993,6 +1116,7 @@ func (d *Database) getStateDeltas(
// requests with full_state=true. // requests with full_state=true.
// Fetches full state for all joined rooms and uses selectStateInRange to get // Fetches full state for all joined rooms and uses selectStateInRange to get
// updates for other rooms. // updates for other rooms.
// nolint:gocyclo
func (d *Database) getStateDeltasForFullStateSync( func (d *Database) getStateDeltasForFullStateSync(
ctx context.Context, device *userapi.Device, txn *sql.Tx, ctx context.Context, device *userapi.Device, txn *sql.Tx,
r types.Range, userID string, r types.Range, userID string,
@ -1001,6 +1125,26 @@ func (d *Database) getStateDeltasForFullStateSync(
// Use a reasonable initial capacity // Use a reasonable initial capacity
deltas := make(map[string]stateDelta) deltas := make(map[string]stateDelta)
peeks, err := d.Peeks.SelectPeeksInRange(ctx, txn, userID, device.ID, r)
if err != nil {
return nil, nil, err
}
// Add full states for all peeking rooms
for _, peek := range peeks {
if !peek.Deleted {
s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, peek.RoomID, stateFilter)
if stateErr != nil {
return nil, nil, stateErr
}
deltas[peek.RoomID] = stateDelta{
membership: gomatrixserverlib.Peek,
stateEvents: d.StreamEventsToEvents(device, s),
roomID: peek.RoomID,
}
}
}
// Get all the state events ever between these two positions // Get all the state events ever between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter)
if err != nil { if err != nil {

View File

@ -0,0 +1,206 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
const peeksSchema = `
CREATE TABLE IF NOT EXISTS syncapi_peeks (
id INTEGER,
room_id TEXT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
deleted BOOL NOT NULL DEFAULT false,
-- When the peek was created in UNIX epoch ms.
creation_ts INTEGER NOT NULL,
UNIQUE(room_id, user_id, device_id)
);
CREATE INDEX IF NOT EXISTS syncapi_peeks_room_id_idx ON syncapi_peeks(room_id);
CREATE INDEX IF NOT EXISTS syncapi_peeks_user_id_device_id_idx ON syncapi_peeks(user_id, device_id);
`
const insertPeekSQL = "" +
"INSERT OR REPLACE INTO syncapi_peeks" +
" (id, room_id, user_id, device_id, creation_ts, deleted)" +
" VALUES ($1, $2, $3, $4, $5, false)"
const deletePeekSQL = "" +
"UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3 AND device_id = $4"
const deletePeeksSQL = "" +
"UPDATE syncapi_peeks SET deleted=true, id=$1 WHERE room_id = $2 AND user_id = $3"
// we care about all the peeks which were created in this range, deleted in this range,
// or were created before this range but haven't been deleted yet.
// BEWARE: sqlite chokes on out of order substitution strings.
const selectPeeksInRangeSQL = "" +
"SELECT id, room_id, deleted FROM syncapi_peeks WHERE user_id = $1 AND device_id = $2 AND ((id <= $3 AND NOT deleted=true) OR (id > $3 AND id <= $4))"
const selectPeekingDevicesSQL = "" +
"SELECT room_id, user_id, device_id FROM syncapi_peeks WHERE deleted=false"
const selectMaxPeekIDSQL = "" +
"SELECT MAX(id) FROM syncapi_peeks"
type peekStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
insertPeekStmt *sql.Stmt
deletePeekStmt *sql.Stmt
deletePeeksStmt *sql.Stmt
selectPeeksInRangeStmt *sql.Stmt
selectPeekingDevicesStmt *sql.Stmt
selectMaxPeekIDStmt *sql.Stmt
}
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) {
_, err := db.Exec(peeksSchema)
if err != nil {
return nil, err
}
s := &peekStatements{
db: db,
streamIDStatements: streamID,
}
if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil {
return nil, err
}
if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil {
return nil, err
}
if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil {
return nil, err
}
if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil {
return nil, err
}
if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil {
return nil, err
}
if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *peekStatements) InsertPeek(
ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string,
) (streamPos types.StreamPosition, err error) {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return
}
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
_, err = sqlutil.TxStmt(txn, s.insertPeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID, nowMilli)
return
}
func (s *peekStatements) DeletePeek(
ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string,
) (streamPos types.StreamPosition, err error) {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return
}
_, err = sqlutil.TxStmt(txn, s.deletePeekStmt).ExecContext(ctx, streamPos, roomID, userID, deviceID)
return
}
func (s *peekStatements) DeletePeeks(
ctx context.Context, txn *sql.Tx, roomID, userID string,
) (types.StreamPosition, error) {
streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return 0, err
}
result, err := sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, streamPos, roomID, userID)
if err != nil {
return 0, err
}
numAffected, err := result.RowsAffected()
if err != nil {
return 0, err
}
if numAffected == 0 {
return 0, sql.ErrNoRows
}
return streamPos, nil
}
func (s *peekStatements) SelectPeeksInRange(
ctx context.Context, txn *sql.Tx, userID, deviceID string, r types.Range,
) (peeks []types.Peek, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectPeeksInRangeStmt).QueryContext(ctx, userID, deviceID, r.Low(), r.High())
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectPeeksInRange: rows.close() failed")
for rows.Next() {
peek := types.Peek{}
var id types.StreamPosition
if err = rows.Scan(&id, &peek.RoomID, &peek.Deleted); err != nil {
return
}
peek.New = (id > r.Low() && id <= r.High()) && !peek.Deleted
peeks = append(peeks, peek)
}
return peeks, rows.Err()
}
func (s *peekStatements) SelectPeekingDevices(
ctx context.Context,
) (peekingDevices map[string][]types.PeekingDevice, err error) {
rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectPeekingDevices: rows.close() failed")
result := make(map[string][]types.PeekingDevice)
for rows.Next() {
var roomID, userID, deviceID string
if err := rows.Scan(&roomID, &userID, &deviceID); err != nil {
return nil, err
}
devices := result[roomID]
devices = append(devices, types.PeekingDevice{UserID: userID, DeviceID: deviceID})
result[roomID] = devices
}
return result, nil
}
func (s *peekStatements) SelectMaxPeekID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxPeekIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return
}

View File

@ -75,6 +75,10 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil { if err != nil {
return err return err
} }
peeks, err := NewSqlitePeeksTable(d.db, &d.streamID)
if err != nil {
return err
}
topology, err := NewSqliteTopologyTable(d.db) topology, err := NewSqliteTopologyTable(d.db)
if err != nil { if err != nil {
return err return err
@ -95,6 +99,7 @@ func (d *SyncServerDatasource) prepare() (err error) {
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,
Invites: invites, Invites: invites,
Peeks: peeks,
AccountData: accountData, AccountData: accountData,
OutputEvents: events, OutputEvents: events,
BackwardExtremities: bwExtrem, BackwardExtremities: bwExtrem,

View File

@ -39,6 +39,15 @@ type Invites interface {
SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error)
} }
type Peeks interface {
InsertPeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error)
DeletePeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error)
DeletePeeks(ctx context.Context, txn *sql.Tx, roomID, userID string) (streamPos types.StreamPosition, err error)
SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error)
SelectPeekingDevices(ctxt context.Context) (peekingDevices map[string][]types.PeekingDevice, err error)
SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error)
}
type Events interface { type Events interface {
SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter) (map[string]map[string]bool, map[string]types.StreamEvent, error) SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter) (map[string]map[string]bool, map[string]types.StreamEvent, error)
SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error)

View File

@ -33,6 +33,8 @@ import (
type Notifier struct { type Notifier struct {
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine // A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToJoinedUsers map[string]userIDSet roomIDToJoinedUsers map[string]userIDSet
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToPeekingDevices map[string]peekingDeviceSet
// Protects currPos and userStreams. // Protects currPos and userStreams.
streamLock *sync.Mutex streamLock *sync.Mutex
// The latest sync position // The latest sync position
@ -48,11 +50,12 @@ type Notifier struct {
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase). // the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier(pos types.StreamingToken) *Notifier { func NewNotifier(pos types.StreamingToken) *Notifier {
return &Notifier{ return &Notifier{
currPos: pos, currPos: pos,
roomIDToJoinedUsers: make(map[string]userIDSet), roomIDToJoinedUsers: make(map[string]userIDSet),
userDeviceStreams: make(map[string]map[string]*UserDeviceStream), roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
streamLock: &sync.Mutex{}, userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
lastCleanUpTime: time.Now(), streamLock: &sync.Mutex{},
lastCleanUpTime: time.Now(),
} }
} }
@ -82,6 +85,8 @@ func (n *Notifier) OnNewEvent(
if ev != nil { if ev != nil {
// Map this event's room_id to a list of joined users, and wake them up. // Map this event's room_id to a list of joined users, and wake them up.
usersToNotify := n.joinedUsers(ev.RoomID()) usersToNotify := n.joinedUsers(ev.RoomID())
// Map this event's room_id to a list of peeking devices, and wake them up.
peekingDevicesToNotify := n.PeekingDevices(ev.RoomID())
// If this is an invite, also add in the invitee to this list. // If this is an invite, also add in the invitee to this list.
if ev.Type() == "m.room.member" && ev.StateKey() != nil { if ev.Type() == "m.room.member" && ev.StateKey() != nil {
targetUserID := *ev.StateKey() targetUserID := *ev.StateKey()
@ -108,11 +113,11 @@ func (n *Notifier) OnNewEvent(
} }
} }
n.wakeupUsers(usersToNotify, latestPos) n.wakeupUsers(usersToNotify, peekingDevicesToNotify, latestPos)
} else if roomID != "" { } else if roomID != "" {
n.wakeupUsers(n.joinedUsers(roomID), latestPos) n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), latestPos)
} else if len(userIDs) > 0 { } else if len(userIDs) > 0 {
n.wakeupUsers(userIDs, latestPos) n.wakeupUsers(userIDs, nil, latestPos)
} else { } else {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"posUpdate": posUpdate.String, "posUpdate": posUpdate.String,
@ -120,6 +125,18 @@ func (n *Notifier) OnNewEvent(
} }
} }
func (n *Notifier) OnNewPeek(
roomID, userID, deviceID string,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
n.addPeekingDevice(roomID, userID, deviceID)
// we don't wake up devices here given the roomserver consumer will do this shortly afterwards
// by calling OnNewEvent.
}
func (n *Notifier) OnNewSendToDevice( func (n *Notifier) OnNewSendToDevice(
userID string, deviceIDs []string, userID string, deviceIDs []string,
posUpdate types.StreamingToken, posUpdate types.StreamingToken,
@ -139,7 +156,7 @@ func (n *Notifier) OnNewKeyChange(
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate) latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos n.currPos = latestPos
n.wakeupUsers([]string{wakeUserID}, latestPos) n.wakeupUsers([]string{wakeUserID}, nil, latestPos)
} }
// GetListener returns a UserStreamListener that can be used to wait for // GetListener returns a UserStreamListener that can be used to wait for
@ -169,6 +186,13 @@ func (n *Notifier) Load(ctx context.Context, db storage.Database) error {
return err return err
} }
n.setUsersJoinedToRooms(roomToUsers) n.setUsersJoinedToRooms(roomToUsers)
roomToPeekingDevices, err := db.AllPeekingDevicesInRooms(ctx)
if err != nil {
return err
}
n.setPeekingDevices(roomToPeekingDevices)
return nil return nil
} }
@ -195,9 +219,24 @@ func (n *Notifier) setUsersJoinedToRooms(roomIDToUserIDs map[string][]string) {
} }
} }
// setPeekingDevices marks the given devices as peeking in the given rooms, such that new events from
// these rooms will wake the given devices' /sync requests. This should be called prior to ANY calls to
// OnNewEvent (eg on startup) to prevent racing.
func (n *Notifier) setPeekingDevices(roomIDToPeekingDevices map[string][]types.PeekingDevice) {
// This is just the bulk form of addPeekingDevice
for roomID, peekingDevices := range roomIDToPeekingDevices {
if _, ok := n.roomIDToPeekingDevices[roomID]; !ok {
n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet)
}
for _, peekingDevice := range peekingDevices {
n.roomIDToPeekingDevices[roomID].add(peekingDevice)
}
}
}
// wakeupUsers will wake up the sync strems for all of the devices for all of the // wakeupUsers will wake up the sync strems for all of the devices for all of the
// specified user IDs. // specified user IDs, and also the specified peekingDevices
func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) { func (n *Notifier) wakeupUsers(userIDs []string, peekingDevices []types.PeekingDevice, newPos types.StreamingToken) {
for _, userID := range userIDs { for _, userID := range userIDs {
for _, stream := range n.fetchUserStreams(userID) { for _, stream := range n.fetchUserStreams(userID) {
if stream == nil { if stream == nil {
@ -206,6 +245,13 @@ func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) {
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
} }
} }
for _, peekingDevice := range peekingDevices {
// TODO: don't bother waking up for devices whose users we already woke up
if stream := n.fetchUserDeviceStream(peekingDevice.UserID, peekingDevice.DeviceID, false); stream != nil {
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
}
}
} }
// wakeupUserDevice will wake up the sync stream for a specific user device. Other // wakeupUserDevice will wake up the sync stream for a specific user device. Other
@ -284,6 +330,32 @@ func (n *Notifier) joinedUsers(roomID string) (userIDs []string) {
return n.roomIDToJoinedUsers[roomID].values() return n.roomIDToJoinedUsers[roomID].values()
} }
// Not thread-safe: must be called on the OnNewEvent goroutine only
func (n *Notifier) addPeekingDevice(roomID, userID, deviceID string) {
if _, ok := n.roomIDToPeekingDevices[roomID]; !ok {
n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet)
}
n.roomIDToPeekingDevices[roomID].add(types.PeekingDevice{UserID: userID, DeviceID: deviceID})
}
// Not thread-safe: must be called on the OnNewEvent goroutine only
// nolint:unused
func (n *Notifier) removePeekingDevice(roomID, userID, deviceID string) {
if _, ok := n.roomIDToPeekingDevices[roomID]; !ok {
n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet)
}
// XXX: is this going to work as a key?
n.roomIDToPeekingDevices[roomID].remove(types.PeekingDevice{UserID: userID, DeviceID: deviceID})
}
// Not thread-safe: must be called on the OnNewEvent goroutine only
func (n *Notifier) PeekingDevices(roomID string) (peekingDevices []types.PeekingDevice) {
if _, ok := n.roomIDToPeekingDevices[roomID]; !ok {
return
}
return n.roomIDToPeekingDevices[roomID].values()
}
// removeEmptyUserStreams iterates through the user stream map and removes any // removeEmptyUserStreams iterates through the user stream map and removes any
// that have been empty for a certain amount of time. This is a crude way of // that have been empty for a certain amount of time. This is a crude way of
// ensuring that the userStreams map doesn't grow forver. // ensuring that the userStreams map doesn't grow forver.
@ -329,3 +401,23 @@ func (s userIDSet) values() (vals []string) {
} }
return return
} }
// A set of PeekingDevices, similar to userIDSet
type peekingDeviceSet map[types.PeekingDevice]bool
func (s peekingDeviceSet) add(d types.PeekingDevice) {
s[d] = true
}
// nolint:unused
func (s peekingDeviceSet) remove(d types.PeekingDevice) {
delete(s, d)
}
func (s peekingDeviceSet) values() (vals []types.PeekingDevice) {
for d := range s {
vals = append(vals, d)
}
return
}

View File

@ -388,6 +388,7 @@ type Response struct {
} `json:"presence,omitempty"` } `json:"presence,omitempty"`
Rooms struct { Rooms struct {
Join map[string]JoinResponse `json:"join"` Join map[string]JoinResponse `json:"join"`
Peek map[string]JoinResponse `json:"peek"`
Invite map[string]InviteResponse `json:"invite"` Invite map[string]InviteResponse `json:"invite"`
Leave map[string]LeaveResponse `json:"leave"` Leave map[string]LeaveResponse `json:"leave"`
} `json:"rooms"` } `json:"rooms"`
@ -407,6 +408,7 @@ func NewResponse() *Response {
// Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section, // Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section,
// so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors. // so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors.
res.Rooms.Join = make(map[string]JoinResponse) res.Rooms.Join = make(map[string]JoinResponse)
res.Rooms.Peek = make(map[string]JoinResponse)
res.Rooms.Invite = make(map[string]InviteResponse) res.Rooms.Invite = make(map[string]InviteResponse)
res.Rooms.Leave = make(map[string]LeaveResponse) res.Rooms.Leave = make(map[string]LeaveResponse)
@ -433,7 +435,7 @@ func (r *Response) IsEmpty() bool {
len(r.ToDevice.Events) == 0 len(r.ToDevice.Events) == 0
} }
// JoinResponse represents a /sync response for a room which is under the 'join' key. // JoinResponse represents a /sync response for a room which is under the 'join' or 'peek' key.
type JoinResponse struct { type JoinResponse struct {
State struct { State struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
@ -507,3 +509,14 @@ type SendToDeviceEvent struct {
DeviceID string DeviceID string
SentByToken *StreamingToken SentByToken *StreamingToken
} }
type PeekingDevice struct {
UserID string
DeviceID string
}
type Peek struct {
RoomID string
New bool
Deleted bool
}

View File

@ -465,3 +465,9 @@ After changing password, can log in with new password
After changing password, existing session still works After changing password, existing session still works
After changing password, different sessions can optionally be kept After changing password, different sessions can optionally be kept
After changing password, a different session no longer works by default After changing password, a different session no longer works by default
Local users can peek into world_readable rooms by room ID
We can't peek into rooms with shared history_visibility
We can't peek into rooms with invited history_visibility
We can't peek into rooms with joined history_visibility
Local users can peek by room alias
Peeked rooms only turn up in the sync for the device who peeked them