Associate transactions with session IDs instead of device IDs (#789)

This commit is contained in:
Alex Chen 2019-08-24 00:55:40 +08:00 committed by GitHub
parent 5eb63f1d1e
commit 43308d2f3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 55 additions and 39 deletions

View File

@ -21,5 +21,9 @@ type Device struct {
// The access_token granted to this device.
// This uniquely identifies the device from all other devices and clients.
AccessToken string
// The unique ID of the session identified by the access token.
// Can be used as a secure substitution in places where data needs to be
// associated with access tokens.
SessionID int64
// TODO: display name, last used timestamp, keys, etc
}

View File

@ -27,11 +27,19 @@ import (
)
const devicesSchema = `
-- This sequence is used for automatic allocation of session_id.
CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
-- The access token granted to this device. This has to be the primary key
-- so we can distinguish which device is making a given request.
access_token TEXT NOT NULL PRIMARY KEY,
-- The auto-allocated unique ID of the session identified by the access token.
-- This can be used as a secure substitution of the access token in situations
-- where data is associated with access tokens (e.g. transaction storage),
-- so we don't have to store users' access tokens everywhere.
session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'),
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
-- access_tokens will be clobbered based on the device ID for a user.
device_id TEXT NOT NULL,
@ -51,10 +59,11 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" +
" RETURNING session_id"
const selectDeviceByTokenSQL = "" +
"SELECT device_id, localpart FROM device_devices WHERE access_token = $1"
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
@ -120,14 +129,16 @@ func (s *devicesStatements) insertDevice(
displayName *string,
) (*authtypes.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
stmt := common.TxStmt(txn, s.insertDeviceStmt)
if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName); err != nil {
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
return nil, err
}
return &authtypes.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken,
SessionID: sessionID,
}, nil
}
@ -161,7 +172,7 @@ func (s *devicesStatements) selectDeviceByToken(
var dev authtypes.Device
var localpart string
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.ID, &localpart)
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.AccessToken = accessToken

View File

@ -60,18 +60,18 @@ func SendEvent(
return *resErr
}
var txnAndDeviceID *api.TransactionID
var txnAndSessionID *api.TransactionID
if txnID != nil {
txnAndDeviceID = &api.TransactionID{
txnAndSessionID = &api.TransactionID{
TransactionID: *txnID,
DeviceID: device.ID,
SessionID: device.SessionID,
}
}
// pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded
eventID, err := producer.SendEvents(
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndDeviceID,
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndSessionID,
)
if err != nil {
return httputil.LogThenError(req, err)

View File

@ -75,9 +75,9 @@ type InputRoomEvent struct {
}
// TransactionID contains the transaction ID sent by a client when sending an
// event, along with the ID of that device.
// event, along with the ID of the client session.
type TransactionID struct {
DeviceID string `json:"device_id"`
SessionID int64 `json:"session_id"`
TransactionID string `json:"id"`
}

View File

@ -32,7 +32,7 @@ type RoomEventDatabase interface {
StoreEvent(
ctx context.Context,
event gomatrixserverlib.Event,
txnAndDeviceID *api.TransactionID,
txnAndSessionID *api.TransactionID,
authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error)
// Look up the state entries for a list of string event IDs
@ -67,7 +67,7 @@ type RoomEventDatabase interface {
// Returns an empty string if no such event exists.
GetTransactionEventID(
ctx context.Context, transactionID string,
deviceID string, userID string,
sessionID int64, userID string,
) (string, error)
}
@ -100,7 +100,7 @@ func processRoomEvent(
if input.TransactionID != nil {
tdID := input.TransactionID
eventID, err = db.GetTransactionEventID(
ctx, tdID.TransactionID, tdID.DeviceID, input.Event.Sender(),
ctx, tdID.TransactionID, tdID.SessionID, input.Event.Sender(),
)
// On error OR event with the transaction already processed/processesing
if err != nil || eventID != "" {

View File

@ -47,7 +47,7 @@ func Open(dataSourceName string) (*Database, error) {
// StoreEvent implements input.EventDatabase
func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event,
txnAndDeviceID *api.TransactionID, authEventNIDs []types.EventNID,
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
) (types.RoomNID, types.StateAtEvent, error) {
var (
roomNID types.RoomNID
@ -58,10 +58,10 @@ func (d *Database) StoreEvent(
err error
)
if txnAndDeviceID != nil {
if txnAndSessionID != nil {
if err = d.statements.insertTransaction(
ctx, txnAndDeviceID.TransactionID,
txnAndDeviceID.DeviceID, event.Sender(), event.EventID(),
ctx, txnAndSessionID.TransactionID,
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil {
return 0, types.StateAtEvent{}, err
}
@ -322,9 +322,9 @@ func (d *Database) GetLatestEventsForUpdate(
// GetTransactionEventID implements input.EventDatabase
func (d *Database) GetTransactionEventID(
ctx context.Context, transactionID string,
deviceID string, userID string,
sessionID int64, userID string,
) (string, error) {
eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, deviceID, userID)
eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID)
if err == sql.ErrNoRows {
return "", nil
}

View File

@ -23,8 +23,8 @@ const transactionsSchema = `
CREATE TABLE IF NOT EXISTS roomserver_transactions (
-- The transaction ID of the event.
transaction_id TEXT NOT NULL,
-- The device ID of the originating transaction.
device_id TEXT NOT NULL,
-- The session ID of the originating transaction.
session_id BIGINT NOT NULL,
-- User ID of the sender who authored the event
user_id TEXT NOT NULL,
-- Event ID corresponding to the transaction
@ -32,16 +32,16 @@ CREATE TABLE IF NOT EXISTS roomserver_transactions (
event_id TEXT NOT NULL,
-- A transaction ID is unique for a user and device
-- This automatically creates an index.
PRIMARY KEY (transaction_id, device_id, user_id)
PRIMARY KEY (transaction_id, session_id, user_id)
);
`
const insertTransactionSQL = "" +
"INSERT INTO roomserver_transactions (transaction_id, device_id, user_id, event_id)" +
"INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)" +
" VALUES ($1, $2, $3, $4)"
const selectTransactionEventIDSQL = "" +
"SELECT event_id FROM roomserver_transactions" +
" WHERE transaction_id = $1 AND device_id = $2 AND user_id = $3"
" WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3"
type transactionStatements struct {
insertTransactionStmt *sql.Stmt
@ -63,12 +63,12 @@ func (s *transactionStatements) prepare(db *sql.DB) (err error) {
func (s *transactionStatements) insertTransaction(
ctx context.Context,
transactionID string,
deviceID string,
sessionID int64,
userID string,
eventID string,
) (err error) {
_, err = s.insertTransactionStmt.ExecContext(
ctx, transactionID, deviceID, userID, eventID,
ctx, transactionID, sessionID, userID, eventID,
)
return
}
@ -76,11 +76,11 @@ func (s *transactionStatements) insertTransaction(
func (s *transactionStatements) selectTransactionEventID(
ctx context.Context,
transactionID string,
deviceID string,
sessionID int64,
userID string,
) (eventID string, err error) {
err = s.selectTransactionEventIDStmt.QueryRowContext(
ctx, transactionID, deviceID, userID,
ctx, transactionID, sessionID, userID,
).Scan(&eventID)
return
}

View File

@ -54,7 +54,7 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
-- if there is no delta.
add_state_ids TEXT[],
remove_state_ids TEXT[],
device_id TEXT, -- The local device that sent the event, if any
session_id BIGINT, -- The client session that sent the event, if any
transaction_id TEXT -- The transaction id used to send the event, if any
);
-- for event selection
@ -63,14 +63,14 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_ev
const insertEventSQL = "" +
"INSERT INTO syncapi_output_room_events (" +
"room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, device_id, transaction_id" +
"room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id" +
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id"
const selectEventsSQL = "" +
"SELECT id, event_json FROM syncapi_output_room_events WHERE event_id = ANY($1)"
const selectRecentEventsSQL = "" +
"SELECT id, event_json, device_id, transaction_id FROM syncapi_output_room_events" +
"SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC LIMIT $4"
@ -221,9 +221,10 @@ func (s *outputRoomEventsStatements) insertEvent(
event *gomatrixserverlib.Event, addState, removeState []string,
transactionID *api.TransactionID,
) (streamPos int64, err error) {
var deviceID, txnID *string
var txnID *string
var sessionID *int64
if transactionID != nil {
deviceID = &transactionID.DeviceID
sessionID = &transactionID.SessionID
txnID = &transactionID.TransactionID
}
@ -246,7 +247,7 @@ func (s *outputRoomEventsStatements) insertEvent(
containsURL,
pq.StringArray(addState),
pq.StringArray(removeState),
deviceID,
sessionID,
txnID,
).Scan(&streamPos)
return
@ -296,11 +297,11 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
var (
streamPos int64
eventBytes []byte
deviceID *string
sessionID *int64
txnID *string
transactionID *api.TransactionID
)
if err := rows.Scan(&streamPos, &eventBytes, &deviceID, &txnID); err != nil {
if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &txnID); err != nil {
return nil, err
}
// TODO: Handle redacted events
@ -309,9 +310,9 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
return nil, err
}
if deviceID != nil && txnID != nil {
if sessionID != nil && txnID != nil {
transactionID = &api.TransactionID{
DeviceID: *deviceID,
SessionID: *sessionID,
TransactionID: *txnID,
}
}

View File

@ -893,7 +893,7 @@ func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrix
for i := 0; i < len(in); i++ {
out[i] = in[i].Event
if device != nil && in[i].transactionID != nil {
if device.UserID == in[i].Sender() && device.ID == in[i].transactionID.DeviceID {
if device.UserID == in[i].Sender() && device.SessionID == in[i].transactionID.SessionID {
err := out[i].SetUnsignedField(
"transaction_id", in[i].transactionID.TransactionID,
)