diff --git a/clientapi/auth/storage/accounts/filter_table.go b/clientapi/auth/storage/accounts/filter_table.go index 81bae454..2b07ef17 100644 --- a/clientapi/auth/storage/accounts/filter_table.go +++ b/clientapi/auth/storage/accounts/filter_table.go @@ -17,6 +17,7 @@ package accounts import ( "context" "database/sql" + "encoding/json" "github.com/matrix-org/gomatrixserverlib" ) @@ -71,25 +72,44 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) { func (s *filterStatements) selectFilter( ctx context.Context, localpart string, filterID string, -) (filter []byte, err error) { - err = s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filter) - return +) (*gomatrixserverlib.Filter, error) { + // Retrieve filter from database (stored as canonical JSON) + var filterData []byte + err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + if err != nil { + return nil, err + } + + // Unmarshal JSON into Filter struct + var filter gomatrixserverlib.Filter + if err = json.Unmarshal(filterData, &filter); err != nil { + return nil, err + } + return &filter, nil } func (s *filterStatements) insertFilter( - ctx context.Context, filter []byte, localpart string, + ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, ) (filterID string, err error) { var existingFilterID string - // This can result in a race condition when two clients try to insert the - // same filter and localpart at the same time, however this is not a - // problem as both calls will result in the same filterID - filterJSON, err := gomatrixserverlib.CanonicalJSON(filter) + // Serialise json + filterJSON, err := json.Marshal(filter) + if err != nil { + return "", err + } + // Remove whitespaces and sort JSON data + // needed to prevent from inserting the same filter multiple times + filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON) if err != nil { return "", err } - // Check if filter already exists in the database + // Check if filter already exists in the database using its localpart and content + // + // This can result in a race condition when two clients try to insert the + // same filter and localpart at the same time, however this is not a + // problem as both calls will result in the same filterID err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, localpart, filterJSON).Scan(&existingFilterID) if err != nil && err != sql.ErrNoRows { diff --git a/clientapi/auth/storage/accounts/storage.go b/clientapi/auth/storage/accounts/storage.go index 27c0a176..5c8ffffe 100644 --- a/clientapi/auth/storage/accounts/storage.go +++ b/clientapi/auth/storage/accounts/storage.go @@ -344,11 +344,11 @@ func (d *Database) GetThreePIDsForLocalpart( } // GetFilter looks up the filter associated with a given local user and filter ID. -// Returns a filter represented as a byte slice. Otherwise returns an error if -// no such filter exists or if there was an error talking to the database. +// Returns a filter structure. Otherwise returns an error if no such filter exists +// or if there was an error talking to the database. func (d *Database) GetFilter( ctx context.Context, localpart string, filterID string, -) ([]byte, error) { +) (*gomatrixserverlib.Filter, error) { return d.filter.selectFilter(ctx, localpart, filterID) } @@ -356,7 +356,7 @@ func (d *Database) GetFilter( // Returns the filterID as a string. Otherwise returns an error if something // goes wrong. func (d *Database) PutFilter( - ctx context.Context, localpart string, filter []byte, + ctx context.Context, localpart string, filter *gomatrixserverlib.Filter, ) (string, error) { return d.filter.insertFilter(ctx, filter, localpart) } diff --git a/clientapi/routing/filter.go b/clientapi/routing/filter.go index 291a165b..eec501ff 100644 --- a/clientapi/routing/filter.go +++ b/clientapi/routing/filter.go @@ -17,13 +17,10 @@ package routing import ( "net/http" - "encoding/json" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" - "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -43,7 +40,7 @@ func GetFilter( return httputil.LogThenError(req, err) } - res, err := accountDB.GetFilter(req.Context(), localpart, filterID) + filter, err := accountDB.GetFilter(req.Context(), localpart, filterID) if err != nil { //TODO better error handling. This error message is *probably* right, // but if there are obscure db errors, this will also be returned, @@ -53,11 +50,6 @@ func GetFilter( JSON: jsonerror.NotFound("No such filter"), } } - filter := gomatrix.Filter{} - err = json.Unmarshal(res, &filter) - if err != nil { - return httputil.LogThenError(req, err) - } return util.JSONResponse{ Code: http.StatusOK, @@ -85,21 +77,21 @@ func PutFilter( return httputil.LogThenError(req, err) } - var filter gomatrix.Filter + var filter gomatrixserverlib.Filter if reqErr := httputil.UnmarshalJSONRequest(req, &filter); reqErr != nil { return *reqErr } - filterArray, err := json.Marshal(filter) - if err != nil { + // Validate generates a user-friendly error + if err = filter.Validate(); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Filter is malformed"), + JSON: jsonerror.BadJSON("Invalid filter: " + err.Error()), } } - filterID, err := accountDB.PutFilter(req.Context(), localpart, filterArray) + filterID, err := accountDB.PutFilter(req.Context(), localpart, &filter) if err != nil { return httputil.LogThenError(req, err) }