mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-10 06:53:00 +00:00
5caae6f3a0
* bugfix: fix panic on new invite events from sytest I'm unsure why the previous code didn't work, but it's clearer, quicker and easier to read the `LastInsertID()` way. Previously, the code would panic as the SELECT would fail to find the last inserted row ID. * sqlite: Fix UNIQUE violations and close more cursors - Add missing `defer rows.Close()` - Do not have the state block NID as a PRIMARY KEY else it breaks for blocks with >1 state event in them. Instead, rejig the queries so we can still have monotonically increasing integers without using AUTOINCREMENT (which mandates PRIMARY KEY). * sqlite: Add missing variadic function * Use LastInsertId because empirically it works over the SELECT form (though I don't know why that is) * sqlite: Fix invite table by using the global stream pos rather than one specific to invites If we don't use the global, clients don't get notified about any invites because the position is too low. * linting: shadowing * sqlite: do not use last rowid, we already know the stream pos! * sqlite: Fix account data table in syncapi by commiting insert txns! * sqlite: Fix failing federation invite Was failing with 'database is locked' due to multiple write txns being taken out. * sqlite: Ensure we return exactly the number of events found in the database Previously we would return exactly the number of *requested* events, which meant that several zero-initialised events would bubble through the system, failing at JSON serialisation time. * sqlite: let's just ignore the problem for now.... * linting
882 lines
26 KiB
Go
882 lines
26 KiB
Go
// Copyright 2017-2018 New Vector Ltd
|
|
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package sqlite3
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"net/url"
|
|
|
|
"github.com/matrix-org/dendrite/common"
|
|
"github.com/matrix-org/dendrite/roomserver/api"
|
|
"github.com/matrix-org/dendrite/roomserver/types"
|
|
"github.com/matrix-org/gomatrixserverlib"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
// A Database is used to store room events and stream offsets.
|
|
type Database struct {
|
|
statements statements
|
|
db *sql.DB
|
|
}
|
|
|
|
// Open a postgres database.
|
|
func Open(dataSourceName string) (*Database, error) {
|
|
var d Database
|
|
uri, err := url.Parse(dataSourceName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var cs string
|
|
if uri.Opaque != "" { // file:filename.db
|
|
cs = uri.Opaque
|
|
} else if uri.Path != "" { // file:///path/to/filename.db
|
|
cs = uri.Path
|
|
} else {
|
|
return nil, errors.New("no filename or path in connect string")
|
|
}
|
|
if d.db, err = sql.Open("sqlite3", cs); err != nil {
|
|
return nil, err
|
|
}
|
|
//d.db.Exec("PRAGMA journal_mode=WAL;")
|
|
//d.db.Exec("PRAGMA read_uncommitted = true;")
|
|
|
|
// FIXME: We are leaking connections somewhere. Setting this to 2 will eventually
|
|
// cause the roomserver to be unresponsive to new events because something will
|
|
// acquire the global mutex and never unlock it because it is waiting for a connection
|
|
// which it will never obtain.
|
|
d.db.SetMaxOpenConns(20)
|
|
if err = d.statements.prepare(d.db); err != nil {
|
|
return nil, err
|
|
}
|
|
return &d, nil
|
|
}
|
|
|
|
// StoreEvent implements input.EventDatabase
|
|
func (d *Database) StoreEvent(
|
|
ctx context.Context, event gomatrixserverlib.Event,
|
|
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
|
|
) (types.RoomNID, types.StateAtEvent, error) {
|
|
var (
|
|
roomNID types.RoomNID
|
|
eventTypeNID types.EventTypeNID
|
|
eventStateKeyNID types.EventStateKeyNID
|
|
eventNID types.EventNID
|
|
stateNID types.StateSnapshotNID
|
|
err error
|
|
)
|
|
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
if txnAndSessionID != nil {
|
|
if err = d.statements.insertTransaction(
|
|
ctx, txn, txnAndSessionID.TransactionID,
|
|
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
|
|
); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID()); err != nil {
|
|
return err
|
|
}
|
|
|
|
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
|
|
return err
|
|
}
|
|
|
|
eventStateKey := event.StateKey()
|
|
// Assigned a numeric ID for the state_key if there is one present.
|
|
// Otherwise set the numeric ID for the state_key to 0.
|
|
if eventStateKey != nil {
|
|
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if eventNID, stateNID, err = d.statements.insertEvent(
|
|
ctx,
|
|
txn,
|
|
roomNID,
|
|
eventTypeNID,
|
|
eventStateKeyNID,
|
|
event.EventID(),
|
|
event.EventReference().EventSHA256,
|
|
authEventNIDs,
|
|
event.Depth(),
|
|
); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
// We've already inserted the event so select the numeric event ID
|
|
eventNID, stateNID, err = d.statements.selectEvent(ctx, txn, event.EventID())
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err = d.statements.insertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return 0, types.StateAtEvent{}, err
|
|
}
|
|
|
|
return roomNID, types.StateAtEvent{
|
|
BeforeStateSnapshotNID: stateNID,
|
|
StateEntry: types.StateEntry{
|
|
StateKeyTuple: types.StateKeyTuple{
|
|
EventTypeNID: eventTypeNID,
|
|
EventStateKeyNID: eventStateKeyNID,
|
|
},
|
|
EventNID: eventNID,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func (d *Database) assignRoomNID(
|
|
ctx context.Context, txn *sql.Tx, roomID string,
|
|
) (roomNID types.RoomNID, err error) {
|
|
// Check if we already have a numeric ID in the database.
|
|
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
|
|
if err == sql.ErrNoRows {
|
|
// We don't have a numeric ID so insert one into the database.
|
|
roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID)
|
|
if err == nil {
|
|
// Now get the numeric ID back out of the database
|
|
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (d *Database) assignEventTypeNID(
|
|
ctx context.Context, txn *sql.Tx, eventType string,
|
|
) (eventTypeNID types.EventTypeNID, err error) {
|
|
// Check if we already have a numeric ID in the database.
|
|
eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType)
|
|
if err == sql.ErrNoRows {
|
|
// We don't have a numeric ID so insert one into the database.
|
|
eventTypeNID, err = d.statements.insertEventTypeNID(ctx, txn, eventType)
|
|
if err == sql.ErrNoRows {
|
|
// We raced with another insert so run the select again.
|
|
eventTypeNID, err = d.statements.selectEventTypeNID(ctx, txn, eventType)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (d *Database) assignStateKeyNID(
|
|
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
|
) (eventStateKeyNID types.EventStateKeyNID, err error) {
|
|
// Check if we already have a numeric ID in the database.
|
|
eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
|
|
if err == sql.ErrNoRows {
|
|
// We don't have a numeric ID so insert one into the database.
|
|
eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey)
|
|
if err == sql.ErrNoRows {
|
|
// We raced with another insert so run the select again.
|
|
eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// StateEntriesForEventIDs implements input.EventDatabase
|
|
func (d *Database) StateEntriesForEventIDs(
|
|
ctx context.Context, eventIDs []string,
|
|
) (se []types.StateEntry, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
se, err = d.statements.bulkSelectStateEventByID(ctx, txn, eventIDs)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// EventTypeNIDs implements state.RoomStateDatabase
|
|
func (d *Database) EventTypeNIDs(
|
|
ctx context.Context, eventTypes []string,
|
|
) (etnids map[string]types.EventTypeNID, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
etnids, err = d.statements.bulkSelectEventTypeNID(ctx, txn, eventTypes)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// EventStateKeyNIDs implements state.RoomStateDatabase
|
|
func (d *Database) EventStateKeyNIDs(
|
|
ctx context.Context, eventStateKeys []string,
|
|
) (esknids map[string]types.EventStateKeyNID, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
esknids, err = d.statements.bulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// EventStateKeys implements query.RoomserverQueryAPIDatabase
|
|
func (d *Database) EventStateKeys(
|
|
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
|
) (out map[types.EventStateKeyNID]string, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
out, err = d.statements.bulkSelectEventStateKey(ctx, txn, eventStateKeyNIDs)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// EventNIDs implements query.RoomserverQueryAPIDatabase
|
|
func (d *Database) EventNIDs(
|
|
ctx context.Context, eventIDs []string,
|
|
) (out map[string]types.EventNID, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
out, err = d.statements.bulkSelectEventNID(ctx, txn, eventIDs)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// Events implements input.EventDatabase
|
|
func (d *Database) Events(
|
|
ctx context.Context, eventNIDs []types.EventNID,
|
|
) ([]types.Event, error) {
|
|
var eventJSONs []eventJSONPair
|
|
var err error
|
|
var results []types.Event
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
eventJSONs, err = d.statements.bulkSelectEventJSON(ctx, txn, eventNIDs)
|
|
if err != nil || len(eventJSONs) == 0 {
|
|
return nil
|
|
}
|
|
results = make([]types.Event, len(eventJSONs))
|
|
for i, eventJSON := range eventJSONs {
|
|
result := &results[i]
|
|
result.EventNID = eventJSON.EventNID
|
|
// TODO: Use NewEventFromTrustedJSON for efficiency
|
|
result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return []types.Event{}, err
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
// AddState implements input.EventDatabase
|
|
func (d *Database) AddState(
|
|
ctx context.Context,
|
|
roomNID types.RoomNID,
|
|
stateBlockNIDs []types.StateBlockNID,
|
|
state []types.StateEntry,
|
|
) (stateNID types.StateSnapshotNID, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
if len(state) > 0 {
|
|
var stateBlockNID types.StateBlockNID
|
|
stateBlockNID, err = d.statements.bulkInsertStateData(ctx, txn, state)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
|
|
}
|
|
stateNID, err = d.statements.insertState(ctx, txn, roomNID, stateBlockNIDs)
|
|
return err
|
|
})
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return
|
|
}
|
|
|
|
// SetState implements input.EventDatabase
|
|
func (d *Database) SetState(
|
|
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
|
) error {
|
|
e := common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
return d.statements.updateEventState(ctx, txn, eventNID, stateNID)
|
|
})
|
|
return e
|
|
}
|
|
|
|
// StateAtEventIDs implements input.EventDatabase
|
|
func (d *Database) StateAtEventIDs(
|
|
ctx context.Context, eventIDs []string,
|
|
) (se []types.StateAtEvent, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
se, err = d.statements.bulkSelectStateAtEventByID(ctx, txn, eventIDs)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// StateBlockNIDs implements state.RoomStateDatabase
|
|
func (d *Database) StateBlockNIDs(
|
|
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
|
) (sl []types.StateBlockNIDList, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
sl, err = d.statements.bulkSelectStateBlockNIDs(ctx, txn, stateNIDs)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// StateEntries implements state.RoomStateDatabase
|
|
func (d *Database) StateEntries(
|
|
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
|
) (sel []types.StateEntryList, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
sel, err = d.statements.bulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// SnapshotNIDFromEventID implements state.RoomStateDatabase
|
|
func (d *Database) SnapshotNIDFromEventID(
|
|
ctx context.Context, eventID string,
|
|
) (stateNID types.StateSnapshotNID, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
_, stateNID, err = d.statements.selectEvent(ctx, txn, eventID)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// EventIDs implements input.RoomEventDatabase
|
|
func (d *Database) EventIDs(
|
|
ctx context.Context, eventNIDs []types.EventNID,
|
|
) (out map[types.EventNID]string, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
out, err = d.statements.bulkSelectEventID(ctx, txn, eventNIDs)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// GetLatestEventsForUpdate implements input.EventDatabase
|
|
func (d *Database) GetLatestEventsForUpdate(
|
|
ctx context.Context, roomNID types.RoomNID,
|
|
) (types.RoomRecentEventsUpdater, error) {
|
|
txn, err := d.db.Begin()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
|
|
d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
|
|
if err != nil {
|
|
txn.Rollback() // nolint: errcheck
|
|
return nil, err
|
|
}
|
|
stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
|
|
if err != nil {
|
|
txn.Rollback() // nolint: errcheck
|
|
return nil, err
|
|
}
|
|
var lastEventIDSent string
|
|
if lastEventNIDSent != 0 {
|
|
lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent)
|
|
if err != nil {
|
|
txn.Rollback() // nolint: errcheck
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// FIXME: we probably want to support long-lived txns in sqlite somehow, but we don't because we get
|
|
// 'database is locked' errors caused by multiple write txns (one being the long-lived txn created here)
|
|
// so for now let's not use a long-lived txn at all, and just commit it here and set the txn to nil so
|
|
// we fail fast if someone tries to use the underlying txn object.
|
|
err = txn.Commit()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &roomRecentEventsUpdater{
|
|
transaction{ctx, nil}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
|
}, nil
|
|
}
|
|
|
|
// GetTransactionEventID implements input.EventDatabase
|
|
func (d *Database) GetTransactionEventID(
|
|
ctx context.Context, transactionID string,
|
|
sessionID int64, userID string,
|
|
) (string, error) {
|
|
eventID, err := d.statements.selectTransactionEventID(ctx, nil, transactionID, sessionID, userID)
|
|
if err == sql.ErrNoRows {
|
|
return "", nil
|
|
}
|
|
return eventID, err
|
|
}
|
|
|
|
type roomRecentEventsUpdater struct {
|
|
transaction
|
|
d *Database
|
|
roomNID types.RoomNID
|
|
latestEvents []types.StateAtEventAndReference
|
|
lastEventIDSent string
|
|
currentStateSnapshotNID types.StateSnapshotNID
|
|
}
|
|
|
|
// LatestEvents implements types.RoomRecentEventsUpdater
|
|
func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
|
|
return u.latestEvents
|
|
}
|
|
|
|
// LastEventIDSent implements types.RoomRecentEventsUpdater
|
|
func (u *roomRecentEventsUpdater) LastEventIDSent() string {
|
|
return u.lastEventIDSent
|
|
}
|
|
|
|
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
|
|
func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
|
|
return u.currentStateSnapshotNID
|
|
}
|
|
|
|
// StorePreviousEvents implements types.RoomRecentEventsUpdater
|
|
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
|
err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
|
for _, ref := range previousEventReferences {
|
|
if err := u.d.statements.insertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
return err
|
|
}
|
|
|
|
// IsReferenced implements types.RoomRecentEventsUpdater
|
|
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (res bool, err error) {
|
|
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
|
err := u.d.statements.selectPreviousEventExists(u.ctx, txn, eventReference.EventID, eventReference.EventSHA256)
|
|
if err == nil {
|
|
res = true
|
|
err = nil
|
|
}
|
|
if err == sql.ErrNoRows {
|
|
res = false
|
|
err = nil
|
|
}
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// SetLatestEvents implements types.RoomRecentEventsUpdater
|
|
func (u *roomRecentEventsUpdater) SetLatestEvents(
|
|
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
|
|
currentStateSnapshotNID types.StateSnapshotNID,
|
|
) error {
|
|
err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
|
eventNIDs := make([]types.EventNID, len(latest))
|
|
for i := range latest {
|
|
eventNIDs[i] = latest[i].EventNID
|
|
}
|
|
return u.d.statements.updateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
|
|
})
|
|
return err
|
|
}
|
|
|
|
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
|
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (res bool, err error) {
|
|
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
|
res, err = u.d.statements.selectEventSentToOutput(u.ctx, txn, eventNID)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
|
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
|
err := common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
|
return u.d.statements.updateEventSentToOutput(u.ctx, txn, eventNID)
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) {
|
|
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
|
mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID)
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// RoomNID implements query.RoomserverQueryAPIDB
|
|
func (d *Database) RoomNID(ctx context.Context, roomID string) (roomNID types.RoomNID, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
|
|
if err == sql.ErrNoRows {
|
|
roomNID = 0
|
|
err = nil
|
|
}
|
|
return err
|
|
})
|
|
return
|
|
}
|
|
|
|
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
|
|
func (d *Database) LatestEventIDs(
|
|
ctx context.Context, roomNID types.RoomNID,
|
|
) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
var eventNIDs []types.EventNID
|
|
eventNIDs, currentStateSnapshotNID, err = d.statements.selectLatestEventNIDs(ctx, txn, roomNID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
references, err = d.statements.bulkSelectEventReference(ctx, txn, eventNIDs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
depth, err = d.statements.selectMaxEventDepth(ctx, txn, eventNIDs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
return
|
|
}
|
|
|
|
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
|
|
func (d *Database) GetInvitesForUser(
|
|
ctx context.Context,
|
|
roomNID types.RoomNID,
|
|
targetUserNID types.EventStateKeyNID,
|
|
) (senderUserIDs []types.EventStateKeyNID, err error) {
|
|
return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
|
|
}
|
|
|
|
// SetRoomAlias implements alias.RoomserverAliasAPIDB
|
|
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
|
|
return d.statements.insertRoomAlias(ctx, nil, alias, roomID, creatorUserID)
|
|
}
|
|
|
|
// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB
|
|
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
|
|
return d.statements.selectRoomIDFromAlias(ctx, nil, alias)
|
|
}
|
|
|
|
// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB
|
|
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
|
|
return d.statements.selectAliasesFromRoomID(ctx, nil, roomID)
|
|
}
|
|
|
|
// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB
|
|
func (d *Database) GetCreatorIDForAlias(
|
|
ctx context.Context, alias string,
|
|
) (string, error) {
|
|
return d.statements.selectCreatorIDFromAlias(ctx, nil, alias)
|
|
}
|
|
|
|
// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
|
|
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
|
|
return d.statements.deleteRoomAlias(ctx, nil, alias)
|
|
}
|
|
|
|
// StateEntriesForTuples implements state.RoomStateDatabase
|
|
func (d *Database) StateEntriesForTuples(
|
|
ctx context.Context,
|
|
stateBlockNIDs []types.StateBlockNID,
|
|
stateKeyTuples []types.StateKeyTuple,
|
|
) ([]types.StateEntryList, error) {
|
|
return d.statements.bulkSelectFilteredStateBlockEntries(
|
|
ctx, nil, stateBlockNIDs, stateKeyTuples,
|
|
)
|
|
}
|
|
|
|
// MembershipUpdater implements input.RoomEventDatabase
|
|
func (d *Database) MembershipUpdater(
|
|
ctx context.Context, roomID, targetUserID string,
|
|
) (updater types.MembershipUpdater, err error) {
|
|
var txn *sql.Tx
|
|
txn, err = d.db.Begin()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
succeeded := false
|
|
defer func() {
|
|
if !succeeded {
|
|
txn.Rollback() // nolint: errcheck
|
|
} else {
|
|
// TODO: We should be holding open this transaction but we cannot have
|
|
// multiple write transactions on sqlite. The code will perform additional
|
|
// write transactions independent of this one which will consistently cause
|
|
// 'database is locked' errors. For now, we'll break up the transaction and
|
|
// hope we don't race too catastrophically. Long term, we should be able to
|
|
// thread in txn objects where appropriate (either at the interface level or
|
|
// bring matrix business logic into the storage layer).
|
|
txerr := txn.Commit()
|
|
if err == nil && txerr != nil {
|
|
err = txerr
|
|
}
|
|
}
|
|
}()
|
|
|
|
roomNID, err := d.assignRoomNID(ctx, txn, roomID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
succeeded = true
|
|
return updater, nil
|
|
}
|
|
|
|
type membershipUpdater struct {
|
|
transaction
|
|
d *Database
|
|
roomNID types.RoomNID
|
|
targetUserNID types.EventStateKeyNID
|
|
membership membershipState
|
|
}
|
|
|
|
func (d *Database) membershipUpdaterTxn(
|
|
ctx context.Context,
|
|
txn *sql.Tx,
|
|
roomNID types.RoomNID,
|
|
targetUserNID types.EventStateKeyNID,
|
|
) (types.MembershipUpdater, error) {
|
|
|
|
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &membershipUpdater{
|
|
// purposefully set the txn to nil so if we try to use it we panic and fail fast
|
|
transaction{ctx, nil}, d, roomNID, targetUserNID, membership,
|
|
}, nil
|
|
}
|
|
|
|
// IsInvite implements types.MembershipUpdater
|
|
func (u *membershipUpdater) IsInvite() bool {
|
|
return u.membership == membershipStateInvite
|
|
}
|
|
|
|
// IsJoin implements types.MembershipUpdater
|
|
func (u *membershipUpdater) IsJoin() bool {
|
|
return u.membership == membershipStateJoin
|
|
}
|
|
|
|
// IsLeave implements types.MembershipUpdater
|
|
func (u *membershipUpdater) IsLeave() bool {
|
|
return u.membership == membershipStateLeaveOrBan
|
|
}
|
|
|
|
// SetToInvite implements types.MembershipUpdater
|
|
func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (inserted bool, err error) {
|
|
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
|
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, event.Sender())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
inserted, err = u.d.statements.insertInviteEvent(
|
|
u.ctx, txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if u.membership != membershipStateInvite {
|
|
if err = u.d.statements.updateMembership(
|
|
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
return
|
|
}
|
|
|
|
// SetToJoin implements types.MembershipUpdater
|
|
func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) (inviteEventIDs []string, err error) {
|
|
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
|
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// If this is a join event update, there is no invite to update
|
|
if !isUpdate {
|
|
inviteEventIDs, err = u.d.statements.updateInviteRetired(
|
|
u.ctx, txn, u.roomNID, u.targetUserNID,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Look up the NID of the new join event
|
|
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if u.membership != membershipStateJoin || isUpdate {
|
|
if err = u.d.statements.updateMembership(
|
|
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
|
|
membershipStateJoin, nIDs[eventID],
|
|
); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
|
|
return
|
|
}
|
|
|
|
// SetToLeave implements types.MembershipUpdater
|
|
func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) (inviteEventIDs []string, err error) {
|
|
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
|
|
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, txn, senderUserID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
inviteEventIDs, err = u.d.statements.updateInviteRetired(
|
|
u.ctx, txn, u.roomNID, u.targetUserNID,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Look up the NID of the new leave event
|
|
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if u.membership != membershipStateLeaveOrBan {
|
|
if err = u.d.statements.updateMembership(
|
|
u.ctx, txn, u.roomNID, u.targetUserNID, senderUserNID,
|
|
membershipStateLeaveOrBan, nIDs[eventID],
|
|
); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
return
|
|
}
|
|
|
|
// GetMembership implements query.RoomserverQueryAPIDB
|
|
func (d *Database) GetMembership(
|
|
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
|
|
) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
requestSenderUserNID, err := d.assignStateKeyNID(ctx, txn, requestSenderUserID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
membershipEventNID, _, err =
|
|
d.statements.selectMembershipFromRoomAndTarget(
|
|
ctx, txn, roomNID, requestSenderUserNID,
|
|
)
|
|
if err == sql.ErrNoRows {
|
|
// The user has never been a member of that room
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
stillInRoom = true
|
|
return nil
|
|
})
|
|
|
|
return
|
|
}
|
|
|
|
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
|
|
func (d *Database) GetMembershipEventNIDsForRoom(
|
|
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
|
|
) (eventNIDs []types.EventNID, err error) {
|
|
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
if joinOnly {
|
|
eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(
|
|
ctx, txn, roomNID, membershipStateJoin,
|
|
)
|
|
return nil
|
|
}
|
|
|
|
eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID)
|
|
return nil
|
|
})
|
|
return
|
|
}
|
|
|
|
// EventsFromIDs implements query.RoomserverQueryAPIEventDB
|
|
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
|
|
nidMap, err := d.EventNIDs(ctx, eventIDs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var nids []types.EventNID
|
|
for _, nid := range nidMap {
|
|
nids = append(nids, nid)
|
|
}
|
|
|
|
return d.Events(ctx, nids)
|
|
}
|
|
|
|
func (d *Database) GetRoomVersionForRoom(
|
|
ctx context.Context, roomNID types.RoomNID,
|
|
) (int64, error) {
|
|
return d.statements.selectRoomVersionForRoomNID(
|
|
ctx, nil, roomNID,
|
|
)
|
|
}
|
|
|
|
type transaction struct {
|
|
ctx context.Context
|
|
txn *sql.Tx
|
|
}
|
|
|
|
// Commit implements types.Transaction
|
|
func (t *transaction) Commit() error {
|
|
if t.txn == nil {
|
|
return nil
|
|
}
|
|
return t.txn.Commit()
|
|
}
|
|
|
|
// Rollback implements types.Transaction
|
|
func (t *transaction) Rollback() error {
|
|
if t.txn == nil {
|
|
return nil
|
|
}
|
|
return t.txn.Rollback()
|
|
}
|