mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-22 19:51:39 +00:00
Add device display names (#319)
This commit is contained in:
parent
8720570bb0
commit
bad701c703
@ -40,7 +40,9 @@ CREATE TABLE IF NOT EXISTS device_devices (
|
|||||||
-- migration to different domain names easier.
|
-- migration to different domain names easier.
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
|
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
|
||||||
created_ts BIGINT NOT NULL
|
created_ts BIGINT NOT NULL,
|
||||||
|
-- The display name, human friendlier than device_id and updatable
|
||||||
|
display_name TEXT
|
||||||
-- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app)
|
-- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app)
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -49,16 +51,19 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca
|
|||||||
`
|
`
|
||||||
|
|
||||||
const insertDeviceSQL = "" +
|
const insertDeviceSQL = "" +
|
||||||
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts) VALUES ($1, $2, $3, $4)"
|
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)"
|
||||||
|
|
||||||
const selectDeviceByTokenSQL = "" +
|
const selectDeviceByTokenSQL = "" +
|
||||||
"SELECT device_id, localpart FROM device_devices WHERE access_token = $1"
|
"SELECT device_id, localpart, display_name FROM device_devices WHERE access_token = $1"
|
||||||
|
|
||||||
const selectDeviceByIDSQL = "" +
|
const selectDeviceByIDSQL = "" +
|
||||||
"SELECT created_ts FROM device_devices WHERE localpart = $1 and device_id = $2"
|
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||||
|
|
||||||
const selectDevicesByLocalpartSQL = "" +
|
const selectDevicesByLocalpartSQL = "" +
|
||||||
"SELECT device_id FROM device_devices WHERE localpart = $1"
|
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
|
||||||
|
|
||||||
|
const updateDeviceNameSQL = "" +
|
||||||
|
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||||
|
|
||||||
const deleteDeviceSQL = "" +
|
const deleteDeviceSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
|
||||||
@ -66,13 +71,12 @@ const deleteDeviceSQL = "" +
|
|||||||
const deleteDevicesByLocalpartSQL = "" +
|
const deleteDevicesByLocalpartSQL = "" +
|
||||||
"DELETE FROM device_devices WHERE localpart = $1"
|
"DELETE FROM device_devices WHERE localpart = $1"
|
||||||
|
|
||||||
// TODO: List devices?
|
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
insertDeviceStmt *sql.Stmt
|
insertDeviceStmt *sql.Stmt
|
||||||
selectDeviceByTokenStmt *sql.Stmt
|
selectDeviceByTokenStmt *sql.Stmt
|
||||||
selectDeviceByIDStmt *sql.Stmt
|
selectDeviceByIDStmt *sql.Stmt
|
||||||
selectDevicesByLocalpartStmt *sql.Stmt
|
selectDevicesByLocalpartStmt *sql.Stmt
|
||||||
|
updateDeviceNameStmt *sql.Stmt
|
||||||
deleteDeviceStmt *sql.Stmt
|
deleteDeviceStmt *sql.Stmt
|
||||||
deleteDevicesByLocalpartStmt *sql.Stmt
|
deleteDevicesByLocalpartStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
@ -95,6 +99,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
|||||||
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -110,10 +117,11 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
|||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (s *devicesStatements) insertDevice(
|
func (s *devicesStatements) insertDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||||
|
displayName *string,
|
||||||
) (*authtypes.Device, error) {
|
) (*authtypes.Device, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := common.TxStmt(txn, s.insertDeviceStmt)
|
stmt := common.TxStmt(txn, s.insertDeviceStmt)
|
||||||
if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS); err != nil {
|
if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &authtypes.Device{
|
return &authtypes.Device{
|
||||||
@ -139,6 +147,14 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *devicesStatements) updateDeviceName(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||||
|
) error {
|
||||||
|
stmt := common.TxStmt(txn, s.updateDeviceNameStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDeviceByToken(
|
func (s *devicesStatements) selectDeviceByToken(
|
||||||
ctx context.Context, accessToken string,
|
ctx context.Context, accessToken string,
|
||||||
) (*authtypes.Device, error) {
|
) (*authtypes.Device, error) {
|
||||||
|
@ -75,6 +75,7 @@ func (d *Database) GetDevicesByLocalpart(
|
|||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (d *Database) CreateDevice(
|
func (d *Database) CreateDevice(
|
||||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||||
|
displayName *string,
|
||||||
) (dev *authtypes.Device, returnErr error) {
|
) (dev *authtypes.Device, returnErr error) {
|
||||||
if deviceID != nil {
|
if deviceID != nil {
|
||||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
@ -84,7 +85,7 @@ func (d *Database) CreateDevice(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken)
|
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
@ -99,7 +100,7 @@ func (d *Database) CreateDevice(
|
|||||||
|
|
||||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken)
|
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if returnErr == nil {
|
if returnErr == nil {
|
||||||
@ -110,6 +111,16 @@ func (d *Database) CreateDevice(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateDevice updates the given device with the display name.
|
||||||
|
// Returns SQL error if there are problems and nil on success.
|
||||||
|
func (d *Database) UpdateDevice(
|
||||||
|
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||||
|
) error {
|
||||||
|
return common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// RemoveDevice revokes a device by deleting the entry in the database
|
// RemoveDevice revokes a device by deleting the entry in the database
|
||||||
// matching with the given device ID and user ID localpart
|
// matching with the given device ID and user ID localpart
|
||||||
// If the device doesn't exist, it will not return an error
|
// If the device doesn't exist, it will not return an error
|
||||||
|
@ -16,6 +16,7 @@ package routing
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
@ -35,6 +36,10 @@ type devicesJSON struct {
|
|||||||
Devices []deviceJSON `json:"devices"`
|
Devices []deviceJSON `json:"devices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type deviceUpdateJSON struct {
|
||||||
|
DisplayName *string `json:"display_name"`
|
||||||
|
}
|
||||||
|
|
||||||
// GetDeviceByID handles /device/{deviceID}
|
// GetDeviceByID handles /device/{deviceID}
|
||||||
func GetDeviceByID(
|
func GetDeviceByID(
|
||||||
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
|
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
|
||||||
@ -95,3 +100,56 @@ func GetDevicesByLocalpart(
|
|||||||
JSON: res,
|
JSON: res,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateDeviceByID handles PUT on /devices/{deviceID}
|
||||||
|
func UpdateDeviceByID(
|
||||||
|
req *http.Request, deviceDB *devices.Database, device *authtypes.Device,
|
||||||
|
deviceID string,
|
||||||
|
) util.JSONResponse {
|
||||||
|
if req.Method != "PUT" {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 405,
|
||||||
|
JSON: jsonerror.NotFound("Bad Method"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := req.Context()
|
||||||
|
dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 404,
|
||||||
|
JSON: jsonerror.NotFound("Unknown device"),
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dev.UserID != device.UserID {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 403,
|
||||||
|
JSON: jsonerror.Forbidden("device not owned by current user"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
defer req.Body.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
payload := deviceUpdateJSON{}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&payload); err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := deviceDB.UpdateDevice(ctx, localpart, deviceID, payload.DisplayName); err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: struct{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -40,6 +40,7 @@ type flow struct {
|
|||||||
type passwordRequest struct {
|
type passwordRequest struct {
|
||||||
User string `json:"user"`
|
User string `json:"user"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
|
InitialDisplayName *string `json:"initial_device_display_name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type loginResponse struct {
|
type loginResponse struct {
|
||||||
@ -119,7 +120,7 @@ func Login(
|
|||||||
|
|
||||||
// TODO: Use the device ID in the request
|
// TODO: Use the device ID in the request
|
||||||
dev, err := deviceDB.CreateDevice(
|
dev, err := deviceDB.CreateDevice(
|
||||||
req.Context(), acc.Localpart, nil, token,
|
req.Context(), acc.Localpart, nil, token, r.InitialDisplayName,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
@ -60,6 +60,8 @@ type registerRequest struct {
|
|||||||
Admin bool `json:"admin"`
|
Admin bool `json:"admin"`
|
||||||
// user-interactive auth params
|
// user-interactive auth params
|
||||||
Auth authDict `json:"auth"`
|
Auth authDict `json:"auth"`
|
||||||
|
|
||||||
|
InitialDisplayName *string `json:"initial_device_display_name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type authDict struct {
|
type authDict struct {
|
||||||
@ -210,10 +212,10 @@ func Register(
|
|||||||
return util.MessageResponse(403, "HMAC incorrect")
|
return util.MessageResponse(403, "HMAC incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password)
|
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, r.InitialDisplayName)
|
||||||
case authtypes.LoginTypeDummy:
|
case authtypes.LoginTypeDummy:
|
||||||
// there is nothing to do
|
// there is nothing to do
|
||||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password)
|
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, r.InitialDisplayName)
|
||||||
default:
|
default:
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 501,
|
Code: 501,
|
||||||
@ -270,10 +272,10 @@ func LegacyRegister(
|
|||||||
return util.MessageResponse(403, "HMAC incorrect")
|
return util.MessageResponse(403, "HMAC incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password)
|
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, nil)
|
||||||
case authtypes.LoginTypeDummy:
|
case authtypes.LoginTypeDummy:
|
||||||
// there is nothing to do
|
// there is nothing to do
|
||||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password)
|
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, nil)
|
||||||
default:
|
default:
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 501,
|
Code: 501,
|
||||||
@ -287,6 +289,7 @@ func completeRegistration(
|
|||||||
accountDB *accounts.Database,
|
accountDB *accounts.Database,
|
||||||
deviceDB *devices.Database,
|
deviceDB *devices.Database,
|
||||||
username, password string,
|
username, password string,
|
||||||
|
displayName *string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if username == "" {
|
if username == "" {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
@ -318,7 +321,7 @@ func completeRegistration(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// // TODO: Use the device ID in the request.
|
// // TODO: Use the device ID in the request.
|
||||||
dev, err := deviceDB.CreateDevice(ctx, username, nil, token)
|
dev, err := deviceDB.CreateDevice(ctx, username, nil, token, displayName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 500,
|
Code: 500,
|
||||||
|
@ -364,6 +364,13 @@ func Setup(
|
|||||||
}),
|
}),
|
||||||
).Methods("GET")
|
).Methods("GET")
|
||||||
|
|
||||||
|
r0mux.Handle("/devices/{deviceID}",
|
||||||
|
common.MakeAuthAPI("device_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
|
||||||
|
vars := mux.Vars(req)
|
||||||
|
return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"])
|
||||||
|
}),
|
||||||
|
).Methods("PUT", "OPTIONS")
|
||||||
|
|
||||||
// Stub implementations for sytest
|
// Stub implementations for sytest
|
||||||
r0mux.Handle("/events",
|
r0mux.Handle("/events",
|
||||||
common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
|
common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
|
||||||
|
@ -87,7 +87,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
device, err := deviceDB.CreateDevice(
|
device, err := deviceDB.CreateDevice(
|
||||||
context.Background(), *username, nil, *accessToken,
|
context.Background(), *username, nil, *accessToken, nil,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err.Error())
|
fmt.Println(err.Error())
|
||||||
|
Loading…
Reference in New Issue
Block a user