Event relations (#2790)

This adds support for tracking `m.relates_to`, as well as adding support
for the various `/room/{roomID}/relations/...` endpoints to the CS API.
This commit is contained in:
Neil Alexander 2022-10-13 14:50:52 +01:00 committed by GitHub
parent 3c1474f68f
commit 23a3e04579
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 943 additions and 51 deletions

View File

@ -1,66 +1,85 @@
# Sample Caddyfile for using Caddy in front of Dendrite. # Sample Caddyfile for using Caddy in front of Dendrite
#
# Customize email address and domain names.
# Optional settings commented out.
#
# BE SURE YOUR DOMAINS ARE POINTED AT YOUR SERVER FIRST.
# Documentation: https://caddyserver.com/docs/
#
# Bonus tip: If your IP address changes, use Caddy's
# dynamic DNS plugin to update your DNS records to
# point to your new IP automatically:
# https://github.com/mholt/caddy-dynamicdns
# #
# Customize email address and domain names
# Optional settings commented out
#
# BE SURE YOUR DOMAINS ARE POINTED AT YOUR SERVER FIRST
# Documentation: <https://caddyserver.com/docs/>
#
# Bonus tip: If your IP address changes, use Caddy's
# dynamic DNS plugin to update your DNS records to
# point to your new IP automatically
# <https://github.com/mholt/caddy-dynamicdns>
#
# Global options block # Global options block
{ {
# In case there is a problem with your certificates. # In case there is a problem with your certificates.
# email example@example.com # email example@example.com
# Turn off the admin endpoint if you don't need graceful config # Turn off the admin endpoint if you don't need graceful config
# changes and/or are running untrusted code on your machine. # changes and/or are running untrusted code on your machine.
# admin off # admin off
# Enable this if your clients don't send ServerName in TLS handshakes. # Enable this if your clients don't send ServerName in TLS handshakes.
# default_sni example.com # default_sni example.com
# Enable debug mode for verbose logging. # Enable debug mode for verbose logging.
# debug # debug
# Use Let's Encrypt's staging endpoint for testing. # Use Let's Encrypt's staging endpoint for testing.
# acme_ca https://acme-staging-v02.api.letsencrypt.org/directory # acme_ca https://acme-staging-v02.api.letsencrypt.org/directory
# If you're port-forwarding HTTP/HTTPS ports from 80/443 to something # If you're port-forwarding HTTP/HTTPS ports from 80/443 to something
# else, enable these and put the alternate port numbers here. # else, enable these and put the alternate port numbers here.
# http_port 8080 # http_port 8080
# https_port 8443 # https_port 8443
} }
# The server name of your matrix homeserver. This example shows # The server name of your matrix homeserver. This example shows
# "well-known delegation" from the registered domain to a subdomain,
# "well-known delegation" from the registered domain to a subdomain
# which is only needed if your server_name doesn't match your Matrix # which is only needed if your server_name doesn't match your Matrix
# homeserver URL (i.e. you can show users a vanity domain that looks # homeserver URL (i.e. you can show users a vanity domain that looks
# nice and is easy to remember but still have your Matrix server on # nice and is easy to remember but still have your Matrix server on
# its own subdomain or hosted service).
# its own subdomain or hosted service)
example.com { example.com {
header /.well-known/matrix/* Content-Type application/json header /.well-known/matrix/*Content-Type application/json
header /.well-known/matrix/* Access-Control-Allow-Origin * header /.well-known/matrix/* Access-Control-Allow-Origin *
respond /.well-known/matrix/server `{"m.server": "matrix.example.com:443"}` respond /.well-known/matrix/server `{"m.server": "matrix.example.com:443"}`
respond /.well-known/matrix/client `{"m.homeserver": {"base_url": "https://matrix.example.com"}}` respond /.well-known/matrix/client `{"m.homeserver": {"base_url": "https://matrix.example.com"}}`
} }
# The actual domain name whereby your Matrix server is accessed. # The actual domain name whereby your Matrix server is accessed
matrix.example.com { matrix.example.com {
# Change the end of each reverse_proxy line to the correct # Change the end of each reverse_proxy line to the correct
# address for your various services. # address for your various services.
@sync_api { @sync_api {
path_regexp /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|context/.*?|event/.*?))$ path_regexp /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|context/.*?|relations/.*?|event/.*?))$
} }
reverse_proxy @sync_api sync_api:8073 reverse_proxy @sync_api sync_api:8073
reverse_proxy /_matrix/client* client_api:8071 reverse_proxy /_matrix/client* client_api:8071
reverse_proxy /_matrix/federation* federation_api:8071 reverse_proxy /_matrix/federation* federation_api:8071
reverse_proxy /_matrix/key* federation_api:8071 reverse_proxy /_matrix/key* federation_api:8071
reverse_proxy /_matrix/media* media_api:8071 reverse_proxy /_matrix/media* media_api:8071
} }

View File

@ -20,8 +20,11 @@ VirtualHost {
# /_matrix/client/.*/rooms/{roomId}/messages # /_matrix/client/.*/rooms/{roomId}/messages
# /_matrix/client/.*/rooms/{roomId}/context/{eventID} # /_matrix/client/.*/rooms/{roomId}/context/{eventID}
# /_matrix/client/.*/rooms/{roomId}/event/{eventID} # /_matrix/client/.*/rooms/{roomId}/event/{eventID}
# /_matrix/client/.*/rooms/{roomId}/relations/{eventID}
# /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType}
# /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType}/{eventType}
# to sync_api # to sync_api
ReverseProxy = /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|context/.*?|event/.*?))$ http://localhost:8073 600 ReverseProxy = /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|context/.*?|relations/.*?|event/.*?))$ http://localhost:8073 600
ReverseProxy = /_matrix/client http://localhost:8071 600 ReverseProxy = /_matrix/client http://localhost:8071 600
ReverseProxy = /_matrix/federation http://localhost:8072 600 ReverseProxy = /_matrix/federation http://localhost:8072 600
ReverseProxy = /_matrix/key http://localhost:8072 600 ReverseProxy = /_matrix/key http://localhost:8072 600

View File

@ -30,8 +30,11 @@ server {
# /_matrix/client/.*/rooms/{roomId}/messages # /_matrix/client/.*/rooms/{roomId}/messages
# /_matrix/client/.*/rooms/{roomId}/context/{eventID} # /_matrix/client/.*/rooms/{roomId}/context/{eventID}
# /_matrix/client/.*/rooms/{roomId}/event/{eventID} # /_matrix/client/.*/rooms/{roomId}/event/{eventID}
# /_matrix/client/.*/rooms/{roomId}/relations/{eventID}
# /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType}
# /_matrix/client/.*/rooms/{roomId}/relations/{eventID}/{relType}/{eventType}
# to sync_api # to sync_api
location ~ /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|context/.*?|event/.*?))$ { location ~ /_matrix/client/.*?/(sync|user/.*?/filter/?.*|keys/changes|rooms/.*?/(messages|context/.*?|relations/.*?|event/.*?))$ {
proxy_pass http://sync_api:8073; proxy_pass http://sync_api:8073;
} }

2
go.mod
View File

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
github.com/matrix-org/gomatrixserverlib v0.0.0-20220929190355-91d455cd3621 github.com/matrix-org/gomatrixserverlib v0.0.0-20221011115330-49fa704b9a64
github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.15 github.com/mattn/go-sqlite3 v1.14.15

4
go.sum
View File

@ -384,8 +384,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220929190355-91d455cd3621 h1:a8IaoSPDxevkgXnOUrtIW9AqVNvXBJAG0gtnX687S7g= github.com/matrix-org/gomatrixserverlib v0.0.0-20221011115330-49fa704b9a64 h1:QJmfAPC3P0ZHJzYD/QtbNc5EztKlK1ipRWP5SO/m4jw=
github.com/matrix-org/gomatrixserverlib v0.0.0-20220929190355-91d455cd3621/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/gomatrixserverlib v0.0.0-20221011115330-49fa704b9a64/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4=
github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c h1:iCHLYwwlPsf4TYFrvhKdhQoAM2lXzcmDZYqwBNWcnVk= github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c h1:iCHLYwwlPsf4TYFrvhKdhQoAM2lXzcmDZYqwBNWcnVk=
github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= github.com/matrix-org/pinecone v0.0.0-20220929155234-2ce51dd4a42c/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=

View File

@ -148,6 +148,16 @@ func (s *OutputRoomEventConsumer) onRedactEvent(
log.WithError(err).Error("RedactEvent error'd") log.WithError(err).Error("RedactEvent error'd")
return err return err
} }
if err = s.db.RedactRelations(ctx, msg.RedactedBecause.RoomID(), msg.RedactedEventID); err != nil {
log.WithFields(log.Fields{
"room_id": msg.RedactedBecause.RoomID(),
"event_id": msg.RedactedBecause.EventID(),
"redacted_event_id": msg.RedactedEventID,
}).WithError(err).Warn("Failed to redact relations")
return err
}
// fake a room event so we notify clients about the redaction, as if it were // fake a room event so we notify clients about the redaction, as if it were
// a normal event. // a normal event.
return s.onNewRoomEvent(ctx, api.OutputNewRoomEvent{ return s.onNewRoomEvent(ctx, api.OutputNewRoomEvent{
@ -271,6 +281,14 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
return err return err
} }
if err = s.db.UpdateRelations(ctx, ev); err != nil {
log.WithFields(log.Fields{
"event_id": ev.EventID(),
"type": ev.Type(),
}).WithError(err).Warn("Failed to update relations")
return err
}
s.pduStream.Advance(pduPos) s.pduStream.Advance(pduPos)
s.notifier.OnNewEvent(ev, ev.RoomID(), nil, types.StreamingToken{PDUPosition: pduPos}) s.notifier.OnNewEvent(ev, ev.RoomID(), nil, types.StreamingToken{PDUPosition: pduPos})
@ -315,6 +333,15 @@ func (s *OutputRoomEventConsumer) onOldRoomEvent(
}).WithError(err).Warn("failed to index fulltext element") }).WithError(err).Warn("failed to index fulltext element")
} }
if err = s.db.UpdateRelations(ctx, ev); err != nil {
log.WithFields(log.Fields{
"room_id": ev.RoomID(),
"event_id": ev.EventID(),
"type": ev.Type(),
}).WithError(err).Warn("Failed to update relations")
return err
}
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
return err return err

View File

@ -0,0 +1,124 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"strconv"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/internal"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
)
type RelationsResponse struct {
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
NextBatch string `json:"next_batch,omitempty"`
PrevBatch string `json:"prev_batch,omitempty"`
}
// nolint:gocyclo
func Relations(
req *http.Request, device *userapi.Device,
syncDB storage.Database,
rsAPI api.SyncRoomserverAPI,
roomID, eventID, relType, eventType string,
) util.JSONResponse {
var err error
var from, to types.StreamPosition
var limit int
dir := req.URL.Query().Get("dir")
if f := req.URL.Query().Get("from"); f != "" {
if from, err = types.NewStreamPositionFromString(f); err != nil {
return util.ErrorResponse(err)
}
}
if t := req.URL.Query().Get("to"); t != "" {
if to, err = types.NewStreamPositionFromString(t); err != nil {
return util.ErrorResponse(err)
}
}
if l := req.URL.Query().Get("limit"); l != "" {
if limit, err = strconv.Atoi(l); err != nil {
return util.ErrorResponse(err)
}
}
if limit == 0 || limit > 50 {
limit = 50
}
if dir == "" {
dir = "b"
}
if dir != "b" && dir != "f" {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.MissingArgument("Bad or missing dir query parameter (should be either 'b' or 'f')"),
}
}
snapshot, err := syncDB.NewDatabaseSnapshot(req.Context())
if err != nil {
logrus.WithError(err).Error("Failed to get snapshot for relations")
return jsonerror.InternalServerError()
}
var succeeded bool
defer sqlutil.EndTransactionWithCheck(snapshot, &succeeded, &err)
res := &RelationsResponse{
Chunk: []gomatrixserverlib.ClientEvent{},
}
var events []types.StreamEvent
events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor(
req.Context(), roomID, eventID, relType, eventType, from, to, dir == "b", limit,
)
if err != nil {
return util.ErrorResponse(err)
}
headeredEvents := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events))
for _, event := range events {
headeredEvents = append(headeredEvents, event.HeaderedEvent)
}
// Apply history visibility to the result events.
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(req.Context(), snapshot, rsAPI, headeredEvents, nil, device.UserID, "relations")
if err != nil {
return util.ErrorResponse(err)
}
// Convert the events into client events, and optionally filter based on the event
// type if it was specified.
res.Chunk = make([]gomatrixserverlib.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents {
res.Chunk = append(
res.Chunk,
gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll),
)
}
succeeded = true
return util.JSONResponse{
Code: http.StatusOK,
JSON: res,
}
}

View File

@ -45,6 +45,7 @@ func Setup(
lazyLoadCache caching.LazyLoadCache, lazyLoadCache caching.LazyLoadCache,
fts *fulltext.Search, fts *fulltext.Search,
) { ) {
v1unstablemux := csMux.PathPrefix("/{apiversion:(?:v1|unstable)}/").Subrouter()
v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter()
// TODO: Add AS support for all handlers below. // TODO: Add AS support for all handlers below.
@ -110,6 +111,48 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return Relations(
req, device, syncDB, rsAPI,
vars["roomId"], vars["eventId"], "", "",
)
}),
).Methods(http.MethodGet, http.MethodOptions)
v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return Relations(
req, device, syncDB, rsAPI,
vars["roomId"], vars["eventId"], vars["relType"], "",
)
}),
).Methods(http.MethodGet, http.MethodOptions)
v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}",
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return Relations(
req, device, syncDB, rsAPI,
vars["roomId"], vars["eventId"], vars["relType"], vars["eventType"],
)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/search", v3mux.Handle("/search",
httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if !cfg.Fulltext.Enabled { if !cfg.Fulltext.Enabled {

View File

@ -38,6 +38,7 @@ type DatabaseTransaction interface {
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
@ -107,6 +108,7 @@ type DatabaseTransaction interface {
GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error)
GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error)
PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error)
RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error)
} }
type Database interface { type Database interface {
@ -174,6 +176,8 @@ type Database interface {
StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error UpdateIgnoresForUser(ctx context.Context, userID string, ignores *types.IgnoredUsers) error
ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error)
UpdateRelations(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
RedactRelations(ctx context.Context, roomID, redactedEventID string) error
} }
type Presence interface { type Presence interface {

View File

@ -0,0 +1,158 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
const relationsSchema = `
CREATE SEQUENCE IF NOT EXISTS syncapi_relation_id;
CREATE TABLE IF NOT EXISTS syncapi_relations (
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_relation_id'),
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
child_event_id TEXT NOT NULL,
child_event_type TEXT NOT NULL,
rel_type TEXT NOT NULL,
CONSTRAINT syncapi_relations_unique UNIQUE (room_id, event_id, child_event_id, rel_type)
);
`
const insertRelationSQL = "" +
"INSERT INTO syncapi_relations (" +
" room_id, event_id, child_event_id, child_event_type, rel_type" +
") VALUES ($1, $2, $3, $4, $5) " +
" ON CONFLICT DO NOTHING"
const deleteRelationSQL = "" +
"DELETE FROM syncapi_relations WHERE room_id = $1 AND child_event_id = $2"
const selectRelationsInRangeAscSQL = "" +
"SELECT id, child_event_id, rel_type FROM syncapi_relations" +
" WHERE room_id = $1 AND event_id = $2" +
" AND ( $3 = '' OR rel_type = $3 )" +
" AND ( $4 = '' OR child_event_type = $4 )" +
" AND id > $5 AND id <= $6" +
" ORDER BY id ASC LIMIT $7"
const selectRelationsInRangeDescSQL = "" +
"SELECT id, child_event_id, rel_type FROM syncapi_relations" +
" WHERE room_id = $1 AND event_id = $2" +
" AND ( $3 = '' OR rel_type = $3 )" +
" AND ( $4 = '' OR child_event_type = $4 )" +
" AND id >= $5 AND id < $6" +
" ORDER BY id DESC LIMIT $7"
const selectMaxRelationIDSQL = "" +
"SELECT COALESCE(MAX(id), 0) FROM syncapi_relations"
type relationsStatements struct {
insertRelationStmt *sql.Stmt
selectRelationsInRangeAscStmt *sql.Stmt
selectRelationsInRangeDescStmt *sql.Stmt
deleteRelationStmt *sql.Stmt
selectMaxRelationIDStmt *sql.Stmt
}
func NewPostgresRelationsTable(db *sql.DB) (tables.Relations, error) {
s := &relationsStatements{}
_, err := db.Exec(relationsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertRelationStmt, insertRelationSQL},
{&s.selectRelationsInRangeAscStmt, selectRelationsInRangeAscSQL},
{&s.selectRelationsInRangeDescStmt, selectRelationsInRangeDescSQL},
{&s.deleteRelationStmt, deleteRelationSQL},
{&s.selectMaxRelationIDStmt, selectMaxRelationIDSQL},
}.Prepare(db)
}
func (s *relationsStatements) InsertRelation(
ctx context.Context, txn *sql.Tx, roomID, eventID, childEventID, childEventType, relType string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertRelationStmt).ExecContext(
ctx, roomID, eventID, childEventID, childEventType, relType,
)
return
}
func (s *relationsStatements) DeleteRelation(
ctx context.Context, txn *sql.Tx, roomID, childEventID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteRelationStmt)
_, err := stmt.ExecContext(
ctx, roomID, childEventID,
)
return err
}
// SelectRelationsInRange returns a map rel_type -> []child_event_id
func (s *relationsStatements) SelectRelationsInRange(
ctx context.Context, txn *sql.Tx, roomID, eventID, relType, eventType string,
r types.Range, limit int,
) (map[string][]types.RelationEntry, types.StreamPosition, error) {
var lastPos types.StreamPosition
var stmt *sql.Stmt
if r.Backwards {
stmt = sqlutil.TxStmt(txn, s.selectRelationsInRangeDescStmt)
} else {
stmt = sqlutil.TxStmt(txn, s.selectRelationsInRangeAscStmt)
}
rows, err := stmt.QueryContext(ctx, roomID, eventID, relType, eventType, r.Low(), r.High(), limit)
if err != nil {
return nil, lastPos, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRelationsInRange: rows.close() failed")
result := map[string][]types.RelationEntry{}
var (
id types.StreamPosition
childEventID string
relationType string
)
for rows.Next() {
if err = rows.Scan(&id, &childEventID, &relationType); err != nil {
return nil, lastPos, err
}
if id > lastPos {
lastPos = id
}
result[relationType] = append(result[relationType], types.RelationEntry{
Position: id,
EventID: childEventID,
})
}
if lastPos == 0 {
lastPos = r.To
}
return result, lastPos, rows.Err()
}
func (s *relationsStatements) SelectMaxRelationID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMaxRelationIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&id)
return
}

View File

@ -98,6 +98,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
relations, err := NewPostgresRelationsTable(d.db)
if err != nil {
return nil, err
}
// apply migrations which need multiple tables // apply migrations which need multiple tables
m := sqlutil.NewMigrator(d.db) m := sqlutil.NewMigrator(d.db)
@ -129,6 +133,7 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
NotificationData: notificationData, NotificationData: notificationData,
Ignores: ignores, Ignores: ignores,
Presence: presence, Presence: presence,
Relations: relations,
} }
return &d, nil return &d, nil
} }

View File

@ -53,6 +53,7 @@ type Database struct {
NotificationData tables.NotificationData NotificationData tables.NotificationData
Ignores tables.Ignores Ignores tables.Ignores
Presence tables.Presence Presence tables.Presence
Relations tables.Relations
} }
func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) { func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) {
@ -579,10 +580,40 @@ func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID s
return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos) return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos)
} }
func (s *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) { func (d *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{ return d.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{
gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomName,
gomatrixserverlib.MRoomTopic, gomatrixserverlib.MRoomTopic,
"m.room.message", "m.room.message",
}) })
} }
func (d *Database) UpdateRelations(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error {
var content gomatrixserverlib.RelationContent
if err := json.Unmarshal(event.Content(), &content); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
switch {
case content.Relations == nil:
return nil
case content.Relations.EventID == "":
return nil
case content.Relations.RelationType == "":
return nil
case event.Type() == gomatrixserverlib.MRoomRedaction:
return nil
default:
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Relations.InsertRelation(
ctx, txn, event.RoomID(), content.Relations.EventID,
event.EventID(), event.Type(), content.Relations.RelationType,
)
})
}
}
func (d *Database) RedactRelations(ctx context.Context, roomID, redactedEventID string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Relations.DeleteRelation(ctx, txn, roomID, redactedEventID)
})
}

View File

@ -589,3 +589,84 @@ func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.Str
func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) { func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
return d.Presence.GetMaxPresenceID(ctx, d.txn) return d.Presence.GetMaxPresenceID(ctx, d.txn)
} }
func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) {
id, err := d.Relations.SelectMaxRelationID(ctx, d.txn)
return types.StreamPosition(id), err
}
func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (
events []types.StreamEvent, prevBatch, nextBatch string, err error,
) {
r := types.Range{
From: from,
To: to,
Backwards: backwards,
}
if r.Backwards && r.From == 0 {
// If we're working backwards (dir=b) and there's no ?from= specified then
// we will automatically want to work backwards from the current position,
// so find out what that is.
if r.From, err = d.MaxStreamPositionForRelations(ctx); err != nil {
return nil, "", "", fmt.Errorf("d.MaxStreamPositionForRelations: %w", err)
}
// The result normally isn't inclusive of the event *at* the ?from=
// position, so add 1 here so that we include the most recent relation.
r.From++
} else if !r.Backwards && r.To == 0 {
// If we're working forwards (dir=f) and there's no ?to= specified then
// we will automatically want to work forwards towards the current position,
// so find out what that is.
if r.To, err = d.MaxStreamPositionForRelations(ctx); err != nil {
return nil, "", "", fmt.Errorf("d.MaxStreamPositionForRelations: %w", err)
}
}
// First look up any relations from the database. We add one to the limit here
// so that we can tell if we're overflowing, as we will only set the "next_batch"
// in the response if we are.
relations, _, err := d.Relations.SelectRelationsInRange(ctx, d.txn, roomID, eventID, relType, eventType, r, limit+1)
if err != nil {
return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err)
}
// If we specified a relation type then just get those results, otherwise collate
// them from all of the returned relation types.
entries := []types.RelationEntry{}
if relType != "" {
entries = relations[relType]
} else {
for _, e := range relations {
entries = append(entries, e...)
}
}
// If there were no entries returned, there were no relations, so stop at this point.
if len(entries) == 0 {
return nil, "", "", nil
}
// Otherwise, let's try and work out what sensible prev_batch and next_batch values
// could be. We've requested an extra event by adding one to the limit already so
// that we can determine whether or not to provide a "next_batch", so trim off that
// event off the end if needs be.
if len(entries) > limit {
entries = entries[:len(entries)-1]
nextBatch = fmt.Sprintf("%d", entries[len(entries)-1].Position)
}
// TODO: set prevBatch? doesn't seem to affect the tests...
// Extract all of the event IDs from the relation entries so that we can pull the
// events out of the database. Then go and fetch the events.
eventIDs := make([]string, 0, len(entries))
for _, entry := range entries {
eventIDs = append(eventIDs, entry.EventID)
}
events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, true)
if err != nil {
return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err)
}
return events, prevBatch, nextBatch, nil
}

View File

@ -0,0 +1,163 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
const relationsSchema = `
CREATE TABLE IF NOT EXISTS syncapi_relations (
id BIGINT PRIMARY KEY,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
child_event_id TEXT NOT NULL,
child_event_type TEXT NOT NULL,
rel_type TEXT NOT NULL,
UNIQUE (room_id, event_id, child_event_id, rel_type)
);
`
const insertRelationSQL = "" +
"INSERT INTO syncapi_relations (" +
" id, room_id, event_id, child_event_id, child_event_type, rel_type" +
") VALUES ($1, $2, $3, $4, $5, $6) " +
" ON CONFLICT DO NOTHING"
const deleteRelationSQL = "" +
"DELETE FROM syncapi_relations WHERE room_id = $1 AND child_event_id = $2"
const selectRelationsInRangeAscSQL = "" +
"SELECT id, child_event_id, rel_type FROM syncapi_relations" +
" WHERE room_id = $1 AND event_id = $2" +
" AND ( $3 = '' OR rel_type = $3 )" +
" AND ( $4 = '' OR child_event_type = $4 )" +
" AND id > $5 AND id <= $6" +
" ORDER BY id ASC LIMIT $7"
const selectRelationsInRangeDescSQL = "" +
"SELECT id, child_event_id, rel_type FROM syncapi_relations" +
" WHERE room_id = $1 AND event_id = $2" +
" AND ( $3 = '' OR rel_type = $3 )" +
" AND ( $4 = '' OR child_event_type = $4 )" +
" AND id >= $5 AND id < $6" +
" ORDER BY id DESC LIMIT $7"
const selectMaxRelationIDSQL = "" +
"SELECT COALESCE(MAX(id), 0) FROM syncapi_relations"
type relationsStatements struct {
streamIDStatements *StreamIDStatements
insertRelationStmt *sql.Stmt
selectRelationsInRangeAscStmt *sql.Stmt
selectRelationsInRangeDescStmt *sql.Stmt
deleteRelationStmt *sql.Stmt
selectMaxRelationIDStmt *sql.Stmt
}
func NewSqliteRelationsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Relations, error) {
s := &relationsStatements{
streamIDStatements: streamID,
}
_, err := db.Exec(relationsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertRelationStmt, insertRelationSQL},
{&s.selectRelationsInRangeAscStmt, selectRelationsInRangeAscSQL},
{&s.selectRelationsInRangeDescStmt, selectRelationsInRangeDescSQL},
{&s.deleteRelationStmt, deleteRelationSQL},
{&s.selectMaxRelationIDStmt, selectMaxRelationIDSQL},
}.Prepare(db)
}
func (s *relationsStatements) InsertRelation(
ctx context.Context, txn *sql.Tx, roomID, eventID, childEventID, childEventType, relType string,
) (err error) {
var streamPos types.StreamPosition
if streamPos, err = s.streamIDStatements.nextRelationID(ctx, txn); err != nil {
return
}
_, err = sqlutil.TxStmt(txn, s.insertRelationStmt).ExecContext(
ctx, streamPos, roomID, eventID, childEventID, childEventType, relType,
)
return
}
func (s *relationsStatements) DeleteRelation(
ctx context.Context, txn *sql.Tx, roomID, childEventID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteRelationStmt)
_, err := stmt.ExecContext(
ctx, roomID, childEventID,
)
return err
}
// SelectRelationsInRange returns a map rel_type -> []child_event_id
func (s *relationsStatements) SelectRelationsInRange(
ctx context.Context, txn *sql.Tx, roomID, eventID, relType, eventType string,
r types.Range, limit int,
) (map[string][]types.RelationEntry, types.StreamPosition, error) {
var lastPos types.StreamPosition
var stmt *sql.Stmt
if r.Backwards {
stmt = sqlutil.TxStmt(txn, s.selectRelationsInRangeDescStmt)
} else {
stmt = sqlutil.TxStmt(txn, s.selectRelationsInRangeAscStmt)
}
rows, err := stmt.QueryContext(ctx, roomID, eventID, relType, eventType, r.Low(), r.High(), limit)
if err != nil {
return nil, lastPos, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRelationsInRange: rows.close() failed")
result := map[string][]types.RelationEntry{}
var (
id types.StreamPosition
childEventID string
relationType string
)
for rows.Next() {
if err = rows.Scan(&id, &childEventID, &relationType); err != nil {
return nil, lastPos, err
}
if id > lastPos {
lastPos = id
}
result[relationType] = append(result[relationType], types.RelationEntry{
Position: id,
EventID: childEventID,
})
}
if lastPos == 0 {
lastPos = r.To
}
return result, lastPos, rows.Err()
}
func (s *relationsStatements) SelectMaxRelationID(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
stmt := sqlutil.TxStmt(txn, s.selectMaxRelationIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&id)
return
}

View File

@ -28,6 +28,8 @@ INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("presence", 0)
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("notification", 0) INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("notification", 0)
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("relation", 0)
ON CONFLICT DO NOTHING;
` `
const increaseStreamIDStmt = "" + const increaseStreamIDStmt = "" +
@ -86,3 +88,9 @@ func (s *StreamIDStatements) nextNotificationID(ctx context.Context, txn *sql.Tx
err = increaseStmt.QueryRowContext(ctx, "notification").Scan(&pos) err = increaseStmt.QueryRowContext(ctx, "notification").Scan(&pos)
return return
} }
func (s *StreamIDStatements) nextRelationID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "relation").Scan(&pos)
return
}

View File

@ -123,6 +123,10 @@ func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) {
if err != nil { if err != nil {
return err return err
} }
relations, err := NewSqliteRelationsTable(d.db, &d.streamID)
if err != nil {
return err
}
// apply migrations which need multiple tables // apply migrations which need multiple tables
m := sqlutil.NewMigrator(d.db) m := sqlutil.NewMigrator(d.db)
@ -153,6 +157,7 @@ func (d *SyncServerDatasource) prepare(ctx context.Context) (err error) {
NotificationData: notificationData, NotificationData: notificationData,
Ignores: ignores, Ignores: ignores,
Presence: presence, Presence: presence,
Relations: relations,
} }
return nil return nil
} }

View File

@ -206,3 +206,22 @@ type Presence interface {
GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error)
GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error)
} }
type Relations interface {
// Inserts a relation which refers from the child event ID to the event ID in the given room.
// If the relation already exists then this function will do nothing and return no error.
InsertRelation(ctx context.Context, txn *sql.Tx, roomID, eventID, childEventID, childEventType, relType string) (err error)
// Deletes a relation which already exists as the result of an event redaction. If the relation
// does not exist then this function will do nothing and return no error.
DeleteRelation(ctx context.Context, txn *sql.Tx, roomID, childEventID string) error
// SelectRelationsInRange will return relations grouped by relation type within the given range.
// The map is relType -> []entry. If a relType parameter is specified then the results will only
// contain relations of that type, otherwise if "" is specified then all relations in the range
// will be returned, inclusive of the "to" position but excluding the "from" position. The stream
// position returned is the maximum position of the returned results.
SelectRelationsInRange(ctx context.Context, txn *sql.Tx, roomID, eventID, relType, eventType string, r types.Range, limit int) (map[string][]types.RelationEntry, types.StreamPosition, error)
// SelectMaxRelationID returns the maximum ID of all relations, used to determine what the boundaries
// should be if there are no boundaries supplied (i.e. we want to work backwards but don't have a
// "from" or want to work forwards and don't have a "to").
SelectMaxRelationID(ctx context.Context, txn *sql.Tx) (id int64, err error)
}

View File

@ -0,0 +1,186 @@
package tables_test
import (
"context"
"database/sql"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/test"
)
func newRelationsTable(t *testing.T, dbType test.DBType) (tables.Relations, *sql.DB, func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
if err != nil {
t.Fatalf("failed to open db: %s", err)
}
var tab tables.Relations
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresRelationsTable(db)
case test.DBTypeSQLite:
var stream sqlite3.StreamIDStatements
if err = stream.Prepare(db); err != nil {
t.Fatalf("failed to prepare stream stmts: %s", err)
}
tab, err = sqlite3.NewSqliteRelationsTable(db, &stream)
}
if err != nil {
t.Fatalf("failed to make new table: %s", err)
}
return tab, db, close
}
func compareRelationsToExpected(t *testing.T, tab tables.Relations, r types.Range, expected []types.RelationEntry) {
ctx := context.Background()
relations, _, err := tab.SelectRelationsInRange(ctx, nil, roomID, "a", "", "", r, 50)
if err != nil {
t.Fatal(err)
}
if len(relations[relType]) != len(expected) {
t.Fatalf("incorrect number of values returned for range %v (got %d, want %d)", r, len(relations[relType]), len(expected))
}
for i := 0; i < len(relations[relType]); i++ {
got := relations[relType][i]
want := expected[i]
if got != want {
t.Fatalf("range %v position %d should have been %q but got %q", r, i, got, want)
}
}
}
const roomID = "!roomid:server"
const childType = "m.room.something"
const relType = "m.reaction"
func TestRelationsTable(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, _, close := newRelationsTable(t, dbType)
defer close()
// Insert some relations
for _, child := range []string{"b", "c", "d"} {
if err := tab.InsertRelation(ctx, nil, roomID, "a", child, childType, relType); err != nil {
t.Fatal(err)
}
}
// Check the max position, we've inserted three things so it
// should be 3
if max, err := tab.SelectMaxRelationID(ctx, nil); err != nil {
t.Fatal(err)
} else if max != 3 {
t.Fatalf("max position should have been 3 but got %d", max)
}
// Query some ranges for "a"
for r, expected := range map[types.Range][]types.RelationEntry{
{From: 0, To: 10, Backwards: false}: {
{Position: 1, EventID: "b"},
{Position: 2, EventID: "c"},
{Position: 3, EventID: "d"},
},
{From: 1, To: 2, Backwards: false}: {
{Position: 2, EventID: "c"},
},
{From: 1, To: 3, Backwards: false}: {
{Position: 2, EventID: "c"},
{Position: 3, EventID: "d"},
},
{From: 10, To: 0, Backwards: true}: {
{Position: 3, EventID: "d"},
{Position: 2, EventID: "c"},
{Position: 1, EventID: "b"},
},
{From: 3, To: 1, Backwards: true}: {
{Position: 2, EventID: "c"},
{Position: 1, EventID: "b"},
},
} {
compareRelationsToExpected(t, tab, r, expected)
}
// Now delete one of the relations
if err := tab.DeleteRelation(ctx, nil, roomID, "c"); err != nil {
t.Fatal(err)
}
// Query some more ranges for "a"
for r, expected := range map[types.Range][]types.RelationEntry{
{From: 0, To: 10, Backwards: false}: {
{Position: 1, EventID: "b"},
{Position: 3, EventID: "d"},
},
{From: 1, To: 2, Backwards: false}: {},
{From: 1, To: 3, Backwards: false}: {
{Position: 3, EventID: "d"},
},
{From: 10, To: 0, Backwards: true}: {
{Position: 3, EventID: "d"},
{Position: 1, EventID: "b"},
},
{From: 3, To: 1, Backwards: true}: {
{Position: 1, EventID: "b"},
},
} {
compareRelationsToExpected(t, tab, r, expected)
}
// Insert some new relations
for _, child := range []string{"e", "f", "g", "h"} {
if err := tab.InsertRelation(ctx, nil, roomID, "a", child, childType, relType); err != nil {
t.Fatal(err)
}
}
// Check the max position, we've inserted four things so it
// should now be 7
if max, err := tab.SelectMaxRelationID(ctx, nil); err != nil {
t.Fatal(err)
} else if max != 7 {
t.Fatalf("max position should have been 3 but got %d", max)
}
// Query last set of ranges for "a"
for r, expected := range map[types.Range][]types.RelationEntry{
{From: 0, To: 10, Backwards: false}: {
{Position: 1, EventID: "b"},
{Position: 3, EventID: "d"},
{Position: 4, EventID: "e"},
{Position: 5, EventID: "f"},
{Position: 6, EventID: "g"},
{Position: 7, EventID: "h"},
},
{From: 1, To: 2, Backwards: false}: {},
{From: 1, To: 3, Backwards: false}: {
{Position: 3, EventID: "d"},
},
{From: 10, To: 0, Backwards: true}: {
{Position: 7, EventID: "h"},
{Position: 6, EventID: "g"},
{Position: 5, EventID: "f"},
{Position: 4, EventID: "e"},
{Position: 3, EventID: "d"},
{Position: 1, EventID: "b"},
},
{From: 6, To: 3, Backwards: true}: {
{Position: 5, EventID: "f"},
{Position: 4, EventID: "e"},
{Position: 3, EventID: "d"},
},
} {
compareRelationsToExpected(t, tab, r, expected)
}
})
}

View File

@ -47,6 +47,14 @@ type StateDelta struct {
// StreamPosition represents the offset in the sync stream a client is at. // StreamPosition represents the offset in the sync stream a client is at.
type StreamPosition int64 type StreamPosition int64
func NewStreamPositionFromString(s string) (StreamPosition, error) {
n, err := strconv.Atoi(s)
if err != nil {
return 0, err
}
return StreamPosition(n), nil
}
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. // StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamEvent struct { type StreamEvent struct {
*gomatrixserverlib.HeaderedEvent *gomatrixserverlib.HeaderedEvent
@ -599,3 +607,8 @@ type OutputSendToDeviceEvent struct {
type IgnoredUsers struct { type IgnoredUsers struct {
List map[string]interface{} `json:"ignored_users"` List map[string]interface{} `json:"ignored_users"`
} }
type RelationEntry struct {
Position StreamPosition
EventID string
}