diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go new file mode 100644 index 00000000..63e84a66 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/account_data_table.go @@ -0,0 +1,109 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package accounts + +import ( + "database/sql" + + "github.com/matrix-org/gomatrixserverlib" +) + +const accountDataSchema = ` +-- Stores data about accounts data. +CREATE TABLE IF NOT EXISTS account_data ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL, + -- The room ID for this data (empty string if not specific to a room) + room_id TEXT, + -- The account data type + type TEXT NOT NULL, + -- The account data content + content TEXT NOT NULL, + + PRIMARY KEY(localpart, room_id, type) +); +` + +const insertAccountDataSQL = ` + INSERT INTO account_data(localpart, room_id, type, content) VALUES($1, $2, $3, $4) + ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content +` + +const selectAccountDataSQL = "" + + "SELECT room_id, type, content FROM account_data WHERE localpart = $1" + +const deleteAccountDataSQL = "" + + "DELETE FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3" + +type accountDataStatements struct { + insertAccountDataStmt *sql.Stmt + selectAccountDataStmt *sql.Stmt +} + +func (s *accountDataStatements) prepare(db *sql.DB) (err error) { + _, err = db.Exec(accountDataSchema) + if err != nil { + return + } + if s.insertAccountDataStmt, err = db.Prepare(insertAccountDataSQL); err != nil { + return + } + if s.selectAccountDataStmt, err = db.Prepare(selectAccountDataSQL); err != nil { + return + } + return +} + +func (s *accountDataStatements) insertAccountData(localpart string, roomID string, dataType string, content string) (err error) { + _, err = s.insertAccountDataStmt.Exec(localpart, roomID, dataType, content) + return +} + +func (s *accountDataStatements) selectAccountData(localpart string) ( + global []gomatrixserverlib.ClientEvent, + rooms map[string][]gomatrixserverlib.ClientEvent, + err error, +) { + rows, err := s.selectAccountDataStmt.Query(localpart) + if err != nil { + return + } + + global = []gomatrixserverlib.ClientEvent{} + rooms = make(map[string][]gomatrixserverlib.ClientEvent) + + for rows.Next() { + var roomID string + var dataType string + var content []byte + + if err = rows.Scan(&roomID, &dataType, &content); err != nil && err != sql.ErrNoRows { + return + } + + ac := gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: content, + } + + if len(roomID) > 0 { + rooms[roomID] = append(rooms[roomID], ac) + } else { + global = append(global, ac) + } + } + + return +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go index fcada6d8..a7b2c786 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/storage.go @@ -27,12 +27,13 @@ import ( // Database represents an account database type Database struct { - db *sql.DB - partitions common.PartitionOffsetStatements - accounts accountsStatements - profiles profilesStatements - memberships membershipStatements - serverName gomatrixserverlib.ServerName + db *sql.DB + partitions common.PartitionOffsetStatements + accounts accountsStatements + profiles profilesStatements + memberships membershipStatements + accountDatas accountDataStatements + serverName gomatrixserverlib.ServerName } // NewDatabase creates a new accounts and profiles database @@ -58,7 +59,11 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) if err = m.prepare(db); err != nil { return nil, err } - return &Database{db, partitions, a, p, m, serverName}, nil + ac := accountDataStatements{} + if err = ac.prepare(db); err != nil { + return nil, err + } + return &Database{db, partitions, a, p, m, ac, serverName}, nil } // GetAccountByPassword returns the account associated with the given localpart and password. @@ -199,6 +204,26 @@ func (d *Database) newMembership(ev gomatrixserverlib.Event, txn *sql.Tx) error return nil } +// SaveAccountData saves new account data for a given user and a given room. +// If the account data is not specific to a room, the room ID should be an empty string +// If an account data already exists for a given set (user, room, data type), it will +// update the corresponding row with the new content +// Returns a SQL error if there was an issue with the insertion/update +func (d *Database) SaveAccountData(localpart string, roomID string, dataType string, content string) error { + return d.accountDatas.insertAccountData(localpart, roomID, dataType, content) +} + +// GetAccountData returns account data related to a given localpart +// If no account data could be found, returns an empty arrays +// Returns an error if there was an issue with the retrieval +func (d *Database) GetAccountData(localpart string) ( + global []gomatrixserverlib.ClientEvent, + rooms map[string][]gomatrixserverlib.ClientEvent, + err error, +) { + return d.accountDatas.selectAccountData(localpart) +} + func hashPassword(plaintext string) (hash string, err error) { hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), bcrypt.DefaultCost) return string(hashBytes), err diff --git a/src/github.com/matrix-org/dendrite/clientapi/readers/account_data.go b/src/github.com/matrix-org/dendrite/clientapi/readers/account_data.go new file mode 100644 index 00000000..ca2c2232 --- /dev/null +++ b/src/github.com/matrix-org/dendrite/clientapi/readers/account_data.go @@ -0,0 +1,69 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package readers + +import ( + "io/ioutil" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/util" +) + +// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} +func SaveAccountData( + req *http.Request, accountDB *accounts.Database, device *authtypes.Device, + userID string, roomID string, dataType string, +) util.JSONResponse { + if req.Method != "PUT" { + return util.JSONResponse{ + Code: 405, + JSON: jsonerror.NotFound("Bad method"), + } + } + + if userID != device.UserID { + return util.JSONResponse{ + Code: 403, + JSON: jsonerror.Forbidden("userID does not match the current user"), + } + } + + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return httputil.LogThenError(req, err) + } + + defer req.Body.Close() + + body, err := ioutil.ReadAll(req.Body) + if err != nil { + return httputil.LogThenError(req, err) + } + + if err := accountDB.SaveAccountData(localpart, roomID, dataType, string(body)); err != nil { + return httputil.LogThenError(req, err) + } + + return util.JSONResponse{ + Code: 200, + JSON: struct{}{}, + } +} diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go index 93ce9e1e..be587f1d 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/clientapi/routing/routing.go @@ -274,9 +274,16 @@ func Setup( ) r0mux.Handle("/user/{userID}/account_data/{type}", - common.MakeAPI("user_account_data", func(req *http.Request) util.JSONResponse { - // TODO: Set and get the account_data - return util.JSONResponse{Code: 200, JSON: struct{}{}} + common.MakeAuthAPI("user_account_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars := mux.Vars(req) + return readers.SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"]) + }), + ) + + r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}", + common.MakeAuthAPI("user_account_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse { + vars := mux.Vars(req) + return readers.SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"]) }), ) diff --git a/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-api-server/main.go b/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-api-server/main.go index a8984c1d..c7870684 100644 --- a/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-api-server/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/dendrite-sync-api-server/main.go @@ -19,6 +19,7 @@ import ( "net/http" "os" + "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/common/config" @@ -58,6 +59,11 @@ func main() { log.Panicf("startup: failed to create device database with data source %s : %s", cfg.Database.Device, err) } + adb, err := accounts.NewDatabase(string(cfg.Database.Account), cfg.Matrix.ServerName) + if err != nil { + log.Panicf("startup: failed to create account database with data source %s : %s", cfg.Database.Account, err) + } + pos, err := db.SyncStreamPosition() if err != nil { log.Panicf("startup: failed to get latest sync stream position : %s", err) @@ -76,6 +82,6 @@ func main() { } log.Info("Starting sync server on ", cfg.Listen.SyncAPI) - routing.SetupSyncServerListeners(http.DefaultServeMux, http.DefaultClient, sync.NewRequestPool(db, n), deviceDB) + routing.SetupSyncServerListeners(http.DefaultServeMux, http.DefaultClient, sync.NewRequestPool(db, n, adb), deviceDB) log.Fatal(http.ListenAndServe(string(cfg.Listen.SyncAPI), nil)) } diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go b/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go index 08bad334..953e5f4f 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go @@ -20,22 +20,25 @@ import ( log "github.com/Sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) // RequestPool manages HTTP long-poll connections for /sync type RequestPool struct { - db *storage.SyncServerDatabase - notifier *Notifier + db *storage.SyncServerDatabase + accountDB *accounts.Database + notifier *Notifier } // NewRequestPool makes a new RequestPool -func NewRequestPool(db *storage.SyncServerDatabase, n *Notifier) *RequestPool { - return &RequestPool{db, n} +func NewRequestPool(db *storage.SyncServerDatabase, n *Notifier, adb *accounts.Database) *RequestPool { + return &RequestPool{db, adb, n} } // OnIncomingSyncRequest is called when a client makes a /sync request. This function MUST be @@ -77,9 +80,14 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype if err != nil { res = httputil.LogThenError(req, err) } else { - res = util.JSONResponse{ - Code: 200, - JSON: syncData, + syncData, err = rp.appendAccountData(syncData, device.UserID) + if err != nil { + res = httputil.LogThenError(req, err) + } else { + res = util.JSONResponse{ + Code: 200, + JSON: syncData, + } } } done <- res @@ -104,3 +112,27 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.Stre } return rp.db.IncrementalSync(req.userID, req.since, currentPos, req.limit) } + +func (rp *RequestPool) appendAccountData(data *types.Response, userID string) (*types.Response, error) { + // TODO: We currently send all account data on every sync response, we should instead send data + // that has changed on incremental sync responses + localpart, _, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return nil, err + } + + global, rooms, err := rp.accountDB.GetAccountData(localpart) + if err != nil { + return nil, err + } + data.AccountData.Events = global + + for r, j := range data.Rooms.Join { + if len(rooms[r]) > 0 { + j.AccountData.Events = rooms[r] + data.Rooms.Join[r] = j + } + } + + return data, nil +}