mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-09 22:42:58 +00:00
Make /messages
filterable (#2347)
* Make /messages filterable Fix bug when determining if an event contains an URL * Add newly passing test * Fix test
This commit is contained in:
parent
ea92f80c12
commit
29f2168789
@ -262,12 +262,8 @@ func (r *messagesReq) retrieveEvents() (
|
||||
clientEvents []gomatrixserverlib.ClientEvent, start,
|
||||
end types.TopologyToken, err error,
|
||||
) {
|
||||
eventFilter := r.filter
|
||||
|
||||
// Retrieve the events from the local database.
|
||||
streamEvents, err := r.db.GetEventsInTopologicalRange(
|
||||
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
|
||||
)
|
||||
streamEvents, err := r.db.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("GetEventsInRange: %w", err)
|
||||
return
|
||||
|
@ -105,7 +105,7 @@ type Database interface {
|
||||
// Returns an error if there was a problem communicating with the database.
|
||||
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
|
||||
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last.
|
||||
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
|
||||
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
|
||||
// EventPositionInTopology returns the depth and stream position of the given event.
|
||||
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
||||
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
|
||||
|
@ -81,6 +81,15 @@ const insertEventSQL = "" +
|
||||
const selectEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)"
|
||||
|
||||
const selectEventsWithFilterSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = ANY($1)" +
|
||||
" AND ( $2::text[] IS NULL OR sender = ANY($2) )" +
|
||||
" AND ( $3::text[] IS NULL OR NOT(sender = ANY($3)) )" +
|
||||
" AND ( $4::text[] IS NULL OR type LIKE ANY($4) )" +
|
||||
" AND ( $5::text[] IS NULL OR NOT(type LIKE ANY($5)) )" +
|
||||
" AND ( $6::bool IS NULL OR contains_url = $6 )" +
|
||||
" LIMIT $7"
|
||||
|
||||
const selectRecentEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
|
||||
@ -153,6 +162,7 @@ const selectContextAfterEventSQL = "" +
|
||||
type outputRoomEventsStatements struct {
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventsStmt *sql.Stmt
|
||||
selectEventsWitFilterStmt *sql.Stmt
|
||||
selectMaxEventIDStmt *sql.Stmt
|
||||
selectRecentEventsStmt *sql.Stmt
|
||||
selectRecentEventsForSyncStmt *sql.Stmt
|
||||
@ -174,6 +184,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertEventStmt, insertEventSQL},
|
||||
{&s.selectEventsStmt, selectEventsSQL},
|
||||
{&s.selectEventsWitFilterStmt, selectEventsWithFilterSQL},
|
||||
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
||||
{&s.selectRecentEventsStmt, selectRecentEventsSQL},
|
||||
{&s.selectRecentEventsForSyncStmt, selectRecentEventsForSyncSQL},
|
||||
@ -310,7 +321,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
||||
// Parse content as JSON and search for an "url" key
|
||||
containsURL := false
|
||||
var content map[string]interface{}
|
||||
if json.Unmarshal(event.Content(), &content) != nil {
|
||||
if json.Unmarshal(event.Content(), &content) == nil {
|
||||
// Set containsURL to true if url is present
|
||||
_, containsURL = content["url"]
|
||||
}
|
||||
@ -429,10 +440,29 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
||||
// selectEvents returns the events for the given event IDs. If an event is
|
||||
// missing from the database, it will be omitted.
|
||||
func (s *outputRoomEventsStatements) SelectEvents(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
|
||||
) ([]types.StreamEvent, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
var (
|
||||
stmt *sql.Stmt
|
||||
rows *sql.Rows
|
||||
err error
|
||||
)
|
||||
if filter == nil {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||
rows, err = stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
} else {
|
||||
senders, notSenders := getSendersRoomEventFilter(filter)
|
||||
stmt = sqlutil.TxStmt(txn, s.selectEventsWitFilterStmt)
|
||||
rows, err = stmt.QueryContext(ctx,
|
||||
pq.StringArray(eventIDs),
|
||||
pq.StringArray(senders),
|
||||
pq.StringArray(notSenders),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
|
||||
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
|
||||
filter.ContainsURL,
|
||||
filter.Limit,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -150,7 +150,7 @@ func (d *Database) RoomReceiptsAfter(ctx context.Context, roomIDs []string, stre
|
||||
// Returns an error if there was a problem talking with the database.
|
||||
// Does not include any transaction IDs in the returned events.
|
||||
func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, false)
|
||||
streamEvents, err := d.OutputEvents.SelectEvents(ctx, nil, eventIDs, nil, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -312,7 +312,7 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
|
||||
|
||||
// Check if we have all of the event's previous events. If an event is
|
||||
// missing, add it to the room's backward extremities.
|
||||
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), false)
|
||||
prevEvents, err := d.OutputEvents.SelectEvents(ctx, txn, ev.PrevEventIDs(), nil, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -429,7 +429,8 @@ func (d *Database) updateRoomState(
|
||||
func (d *Database) GetEventsInTopologicalRange(
|
||||
ctx context.Context,
|
||||
from, to *types.TopologyToken,
|
||||
roomID string, limit int,
|
||||
roomID string,
|
||||
filter *gomatrixserverlib.RoomEventFilter,
|
||||
backwardOrdering bool,
|
||||
) (events []types.StreamEvent, err error) {
|
||||
var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition
|
||||
@ -450,14 +451,14 @@ func (d *Database) GetEventsInTopologicalRange(
|
||||
// Select the event IDs from the defined range.
|
||||
var eIDs []string
|
||||
eIDs, err = d.Topology.SelectEventIDsInRange(
|
||||
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, limit, !backwardOrdering,
|
||||
ctx, nil, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering,
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Retrieve the events' contents using their IDs.
|
||||
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, true)
|
||||
events, err = d.OutputEvents.SelectEvents(ctx, nil, eIDs, filter, true)
|
||||
return
|
||||
}
|
||||
|
||||
@ -619,7 +620,7 @@ func (d *Database) fetchMissingStateEvents(
|
||||
) ([]types.StreamEvent, error) {
|
||||
// Fetch from the events table first so we pick up the stream ID for the
|
||||
// event.
|
||||
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, false)
|
||||
events, err := d.OutputEvents.SelectEvents(ctx, txn, eventIDs, nil, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -104,8 +104,7 @@ func (s *accountDataStatements) SelectAccountDataInRange(
|
||||
},
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
[]string{}, filter.Limit, FilterOrderAsc,
|
||||
)
|
||||
[]string{}, nil, filter.Limit, FilterOrderAsc)
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
|
@ -220,7 +220,7 @@ func (s *currentRoomStateStatements) SelectCurrentState(
|
||||
},
|
||||
stateFilter.Senders, stateFilter.NotSenders,
|
||||
stateFilter.Types, stateFilter.NotTypes,
|
||||
excludeEventIDs, stateFilter.Limit, FilterOrderNone,
|
||||
excludeEventIDs, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderNone,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
|
@ -26,7 +26,7 @@ const (
|
||||
func prepareWithFilters(
|
||||
db *sql.DB, txn *sql.Tx, query string, params []interface{},
|
||||
senders, notsenders, types, nottypes *[]string, excludeEventIDs []string,
|
||||
limit int, order FilterOrder,
|
||||
containsURL *bool, limit int, order FilterOrder,
|
||||
) (*sql.Stmt, []interface{}, error) {
|
||||
offset := len(params)
|
||||
if senders != nil {
|
||||
@ -69,6 +69,9 @@ func prepareWithFilters(
|
||||
query += ` AND type NOT = ""`
|
||||
}
|
||||
}
|
||||
if containsURL != nil {
|
||||
query += fmt.Sprintf(" AND contains_url = %v", *containsURL)
|
||||
}
|
||||
if count := len(excludeEventIDs); count > 0 {
|
||||
query += " AND event_id NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
|
||||
for _, v := range excludeEventIDs {
|
||||
|
@ -168,7 +168,7 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
|
||||
s.db, txn, stmtSQL, inputParams,
|
||||
stateFilter.Senders, stateFilter.NotSenders,
|
||||
stateFilter.Types, stateFilter.NotTypes,
|
||||
nil, stateFilter.Limit, FilterOrderAsc,
|
||||
nil, stateFilter.ContainsURL, stateFilter.Limit, FilterOrderAsc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
@ -277,7 +277,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
||||
// Parse content as JSON and search for an "url" key
|
||||
containsURL := false
|
||||
var content map[string]interface{}
|
||||
if json.Unmarshal(event.Content(), &content) != nil {
|
||||
if json.Unmarshal(event.Content(), &content) == nil {
|
||||
// Set containsURL to true if url is present
|
||||
_, containsURL = content["url"]
|
||||
}
|
||||
@ -345,7 +345,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
||||
},
|
||||
eventFilter.Senders, eventFilter.NotSenders,
|
||||
eventFilter.Types, eventFilter.NotTypes,
|
||||
nil, eventFilter.Limit+1, FilterOrderDesc,
|
||||
nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
@ -393,7 +393,7 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
||||
},
|
||||
eventFilter.Senders, eventFilter.NotSenders,
|
||||
eventFilter.Types, eventFilter.NotTypes,
|
||||
nil, eventFilter.Limit, FilterOrderAsc,
|
||||
nil, eventFilter.ContainsURL, eventFilter.Limit, FilterOrderAsc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
|
||||
@ -419,20 +419,27 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
||||
// selectEvents returns the events for the given event IDs. If an event is
|
||||
// missing from the database, it will be omitted.
|
||||
func (s *outputRoomEventsStatements) SelectEvents(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool,
|
||||
) ([]types.StreamEvent, error) {
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
for i := range eventIDs {
|
||||
iEventIDs[i] = eventIDs[i]
|
||||
}
|
||||
selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, selectSQL, iEventIDs...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, selectSQL, iEventIDs...)
|
||||
|
||||
if filter == nil {
|
||||
filter = &gomatrixserverlib.RoomEventFilter{Limit: 20}
|
||||
}
|
||||
stmt, params, err := prepareWithFilters(
|
||||
s.db, txn, selectSQL, iEventIDs,
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -527,7 +534,7 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
|
||||
},
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
nil, filter.Limit, FilterOrderDesc,
|
||||
nil, filter.ContainsURL, filter.Limit, FilterOrderDesc,
|
||||
)
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
@ -563,7 +570,7 @@ func (s *outputRoomEventsStatements) SelectContextAfterEvent(
|
||||
},
|
||||
filter.Senders, filter.NotSenders,
|
||||
filter.Types, filter.NotTypes,
|
||||
nil, filter.Limit, FilterOrderAsc,
|
||||
nil, filter.ContainsURL, filter.Limit, FilterOrderAsc,
|
||||
)
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
|
@ -180,7 +180,8 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
||||
to := types.TopologyToken{}
|
||||
|
||||
// backpaginate 5 messages starting at the latest position.
|
||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, 5, true)
|
||||
filter := &gomatrixserverlib.RoomEventFilter{Limit: 5}
|
||||
paginatedEvents, err := db.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err)
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ type Events interface {
|
||||
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
||||
// SelectEarlyEvents returns the earliest events in the given room.
|
||||
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
|
||||
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool) ([]types.StreamEvent, error)
|
||||
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error)
|
||||
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
|
||||
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
|
||||
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
func newOutputRoomEventsTable(t *testing.T, dbType test.DBType) (tables.Events, *sql.DB, func()) {
|
||||
@ -61,7 +62,7 @@ func TestOutputRoomEventsTable(t *testing.T) {
|
||||
wantEventIDs := []string{
|
||||
events[2].EventID(), events[0].EventID(), events[3].EventID(), events[1].EventID(),
|
||||
}
|
||||
gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, true)
|
||||
gotEvents, err := tab.SelectEvents(ctx, txn, wantEventIDs, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEvents: %s", err)
|
||||
}
|
||||
@ -73,6 +74,28 @@ func TestOutputRoomEventsTable(t *testing.T) {
|
||||
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventIDs)
|
||||
}
|
||||
|
||||
// Test that contains_url is correctly populated
|
||||
urlEv := room.CreateEvent(t, alice, "m.text", map[string]interface{}{
|
||||
"body": "test.txt",
|
||||
"url": "mxc://test.txt",
|
||||
})
|
||||
if _, err = tab.InsertEvent(ctx, txn, urlEv, nil, nil, nil, false); err != nil {
|
||||
return fmt.Errorf("failed to InsertEvent: %s", err)
|
||||
}
|
||||
wantEventID := []string{urlEv.EventID()}
|
||||
t := true
|
||||
gotEvents, err = tab.SelectEvents(ctx, txn, wantEventID, &gomatrixserverlib.RoomEventFilter{Limit: 1, ContainsURL: &t}, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to SelectEvents: %s", err)
|
||||
}
|
||||
gotEventIDs = make([]string, len(gotEvents))
|
||||
for i := range gotEvents {
|
||||
gotEventIDs[i] = gotEvents[i].EventID()
|
||||
}
|
||||
if !reflect.DeepEqual(gotEventIDs, wantEventID) {
|
||||
return fmt.Errorf("SelectEvents\ngot %v\n want %v", gotEventIDs, wantEventID)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -699,3 +699,4 @@ Ignore invite in full sync
|
||||
Ignore invite in incremental sync
|
||||
A filtered timeline reaches its limit
|
||||
A change to displayname should not result in a full state sync
|
||||
Can fetch images in room
|
Loading…
Reference in New Issue
Block a user