Refactor media storage layer, add tests (#2352)

* Refactor mediaapi storage layer

* Verify filetype before trying to create thumbnails

* Add media api storage tests

* Fix returned values
This commit is contained in:
Till 2022-04-14 14:32:48 +02:00 committed by GitHub
parent 3a5e9a0f28
commit 3ddbffd59e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 417 additions and 416 deletions

View File

@ -32,7 +32,7 @@ func AddPublicRoutes(
userAPI userapi.UserInternalAPI,
client *gomatrixserverlib.Client,
) {
mediaDB, err := storage.Open(&cfg.Database)
mediaDB, err := storage.NewMediaAPIDatasource(&cfg.Database)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to media db")
}

View File

@ -22,6 +22,7 @@ import (
"io"
"net/http"
"net/url"
"os"
"path"
"strings"
@ -311,6 +312,26 @@ func (r *uploadRequest) storeFileAndMetadata(
}
go func() {
file, err := os.Open(string(finalPath))
if err != nil {
r.Logger.WithError(err).Error("unable to open file")
return
}
defer file.Close() // nolint: errcheck
// http.DetectContentType only needs 512 bytes
buf := make([]byte, 512)
_, err = file.Read(buf)
if err != nil {
r.Logger.WithError(err).Error("unable to read file")
return
}
// Check if we need to generate thumbnails
fileType := http.DetectContentType(buf)
if !strings.HasPrefix(fileType, "image") {
r.Logger.WithField("contentType", fileType).Debugf("uploaded file is not an image or can not be thumbnailed, not generating thumbnails")
return
}
busy, err := thumbnailer.GenerateThumbnails(
context.Background(), finalPath, thumbnailSizes, r.MediaMetadata,
activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger,

View File

@ -51,7 +51,7 @@ func Test_uploadRequest_doUpload(t *testing.T) {
_ = os.Mkdir(testdataPath, os.ModePerm)
defer fileutils.RemoveDir(types.Path(testdataPath), nil)
db, err := storage.Open(&config.DatabaseOptions{
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
ConnectionString: "file::memory:?cache=shared",
MaxOpenConnections: 100,
MaxIdleConnections: 2,

View File

@ -22,9 +22,17 @@ import (
)
type Database interface {
MediaRepository
Thumbnails
}
type MediaRepository interface {
StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error
GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
}
type Thumbnails interface {
StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error
GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error)
GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error)

View File

@ -20,6 +20,8 @@ import (
"database/sql"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -69,24 +71,25 @@ type mediaStatements struct {
selectMediaByHashStmt *sql.Stmt
}
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(mediaSchema)
func NewPostgresMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
s := &mediaStatements{}
_, err := db.Exec(mediaSchema)
if err != nil {
return
return nil, err
}
return statementList{
return s, sqlutil.StatementList{
{&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL},
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
}.prepare(db)
}.Prepare(db)
}
func (s *mediaStatements) insertMedia(
ctx context.Context, mediaMetadata *types.MediaMetadata,
func (s *mediaStatements) InsertMedia(
ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
) error {
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
_, err := s.insertMediaStmt.ExecContext(
mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
ctx,
mediaMetadata.MediaID,
mediaMetadata.Origin,
@ -100,14 +103,14 @@ func (s *mediaStatements) insertMedia(
return err
}
func (s *mediaStatements) selectMedia(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
func (s *mediaStatements) SelectMedia(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
MediaID: mediaID,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,
@ -120,14 +123,14 @@ func (s *mediaStatements) selectMedia(
return &mediaMetadata, err
}
func (s *mediaStatements) selectMediaByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
func (s *mediaStatements) SelectMediaByHash(
ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
Base64Hash: mediaHash,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,

View File

@ -0,0 +1,46 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
"github.com/matrix-org/dendrite/setup/config"
)
// NewDatabase opens a postgres database.
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
mediaRepo, err := NewPostgresMediaRepositoryTable(db)
if err != nil {
return nil, err
}
thumbnails, err := NewPostgresThumbnailsTable(db)
if err != nil {
return nil, err
}
return &shared.Database{
MediaRepository: mediaRepo,
Thumbnails: thumbnails,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
}, nil
}

View File

@ -1,38 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// FIXME: This should be made internal!
package postgres
import (
"database/sql"
)
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type statementList []struct {
statement **sql.Stmt
sql string
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
return
}
}
return
}

View File

@ -1,36 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"database/sql"
)
type statements struct {
media mediaStatements
thumbnail thumbnailStatements
}
func (s *statements) prepare(db *sql.DB) (err error) {
if err = s.media.prepare(db); err != nil {
return
}
if err = s.thumbnail.prepare(db); err != nil {
return
}
return
}

View File

@ -21,6 +21,8 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -63,7 +65,7 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
// Note: this selects all thumbnails for a media_origin and media_id
const selectThumbnailsSQL = `
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
`
type thumbnailStatements struct {
@ -72,24 +74,25 @@ type thumbnailStatements struct {
selectThumbnailsStmt *sql.Stmt
}
func (s *thumbnailStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(thumbnailSchema)
func NewPostgresThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
s := &thumbnailStatements{}
_, err := db.Exec(thumbnailSchema)
if err != nil {
return
return nil, err
}
return statementList{
return s, sqlutil.StatementList{
{&s.insertThumbnailStmt, insertThumbnailSQL},
{&s.selectThumbnailStmt, selectThumbnailSQL},
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
}.prepare(db)
}.Prepare(db)
}
func (s *thumbnailStatements) insertThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
func (s *thumbnailStatements) InsertThumbnail(
ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata,
) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
_, err := s.insertThumbnailStmt.ExecContext(
thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
@ -103,8 +106,9 @@ func (s *thumbnailStatements) insertThumbnail(
return err
}
func (s *thumbnailStatements) selectThumbnail(
func (s *thumbnailStatements) SelectThumbnail(
ctx context.Context,
txn *sql.Tx,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
@ -121,7 +125,7 @@ func (s *thumbnailStatements) selectThumbnail(
ResizeMethod: resizeMethod,
},
}
err := s.selectThumbnailStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
@ -136,10 +140,10 @@ func (s *thumbnailStatements) selectThumbnail(
return &thumbnailMetadata, err
}
func (s *thumbnailStatements) selectThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
func (s *thumbnailStatements) SelectThumbnails(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
rows, err := s.selectThumbnailsStmt.QueryContext(
rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
ctx, mediaID, mediaOrigin,
)
if err != nil {

View File

@ -1,5 +1,4 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -13,54 +12,38 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
package shared
import (
"context"
"database/sql"
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database is used to store metadata about a repository of media files.
type Database struct {
statements statements
db *sql.DB
}
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
var d Database
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db); err != nil {
return nil, err
}
return &d, nil
DB *sql.DB
Writer sqlutil.Writer
MediaRepository tables.MediaRepository
Thumbnails tables.Thumbnails
}
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreMediaMetadata(
ctx context.Context, mediaMetadata *types.MediaMetadata,
) error {
return d.statements.media.insertMedia(ctx, mediaMetadata)
func (d Database) StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.MediaRepository.InsertMedia(ctx, txn, mediaMetadata)
})
}
// GetMediaMetadata returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadata(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
func (d Database) GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
mediaMetadata, err := d.MediaRepository.SelectMedia(ctx, nil, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
@ -70,10 +53,8 @@ func (d *Database) GetMediaMetadata(
// GetMediaMetadataByHash returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadataByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
func (d Database) GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
mediaMetadata, err := d.MediaRepository.SelectMediaByHash(ctx, nil, mediaHash, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
@ -82,40 +63,36 @@ func (d *Database) GetMediaMetadataByHash(
// StoreThumbnail inserts the metadata about the thumbnail into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error {
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
func (d Database) StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Thumbnails.InsertThumbnail(ctx, txn, thumbnailMetadata)
})
}
// GetThumbnail returns metadata about a specific thumbnail.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this thumbnail.
func (d *Database) GetThumbnail(
ctx context.Context,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
resizeMethod string,
) (*types.ThumbnailMetadata, error) {
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
)
if err != nil && err == sql.ErrNoRows {
return nil, nil
func (d Database) GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) {
metadata, err := d.Thumbnails.SelectThumbnail(ctx, nil, mediaID, mediaOrigin, width, height, resizeMethod)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return thumbnailMetadata, err
return metadata, err
}
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there are no thumbnails associated with this media.
func (d *Database) GetThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
func (d Database) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) {
metadatas, err := d.Thumbnails.SelectThumbnails(ctx, nil, mediaID, mediaOrigin)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return thumbnails, err
return metadatas, err
}

View File

@ -21,6 +21,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -66,57 +67,53 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_i
type mediaStatements struct {
db *sql.DB
writer sqlutil.Writer
insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt
selectMediaByHashStmt *sql.Stmt
}
func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
s.writer = writer
_, err = db.Exec(mediaSchema)
func NewSQLiteMediaRepositoryTable(db *sql.DB) (tables.MediaRepository, error) {
s := &mediaStatements{
db: db,
}
_, err := db.Exec(mediaSchema)
if err != nil {
return
return nil, err
}
return statementList{
return s, sqlutil.StatementList{
{&s.insertMediaStmt, insertMediaSQL},
{&s.selectMediaStmt, selectMediaSQL},
{&s.selectMediaByHashStmt, selectMediaByHashSQL},
}.prepare(db)
}.Prepare(db)
}
func (s *mediaStatements) insertMedia(
ctx context.Context, mediaMetadata *types.MediaMetadata,
func (s *mediaStatements) InsertMedia(
ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata,
) error {
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertMediaStmt)
_, err := stmt.ExecContext(
ctx,
mediaMetadata.MediaID,
mediaMetadata.Origin,
mediaMetadata.ContentType,
mediaMetadata.FileSizeBytes,
mediaMetadata.CreationTimestamp,
mediaMetadata.UploadName,
mediaMetadata.Base64Hash,
mediaMetadata.UserID,
)
return err
})
mediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertMediaStmt).ExecContext(
ctx,
mediaMetadata.MediaID,
mediaMetadata.Origin,
mediaMetadata.ContentType,
mediaMetadata.FileSizeBytes,
mediaMetadata.CreationTimestamp,
mediaMetadata.UploadName,
mediaMetadata.Base64Hash,
mediaMetadata.UserID,
)
return err
}
func (s *mediaStatements) selectMedia(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
func (s *mediaStatements) SelectMedia(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
MediaID: mediaID,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaStmt).QueryRowContext(
ctx, mediaMetadata.MediaID, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,
@ -129,14 +126,14 @@ func (s *mediaStatements) selectMedia(
return &mediaMetadata, err
}
func (s *mediaStatements) selectMediaByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
func (s *mediaStatements) SelectMediaByHash(
ctx context.Context, txn *sql.Tx, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata := types.MediaMetadata{
Base64Hash: mediaHash,
Origin: mediaOrigin,
}
err := s.selectMediaStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectMediaByHashStmt).QueryRowContext(
ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin,
).Scan(
&mediaMetadata.ContentType,

View File

@ -16,23 +16,30 @@
package sqlite3
import (
"database/sql"
// Import the postgres database driver.
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/shared"
"github.com/matrix-org/dendrite/setup/config"
)
type statements struct {
media mediaStatements
thumbnail thumbnailStatements
}
func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
if err = s.media.prepare(db, writer); err != nil {
return
// NewDatabase opens a SQLIte database.
func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
}
if err = s.thumbnail.prepare(db, writer); err != nil {
return
mediaRepo, err := NewSQLiteMediaRepositoryTable(db)
if err != nil {
return nil, err
}
return
thumbnails, err := NewSQLiteThumbnailsTable(db)
if err != nil {
return nil, err
}
return &shared.Database{
MediaRepository: mediaRepo,
Thumbnails: thumbnails,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
}, nil
}

View File

@ -1,38 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// FIXME: This should be made internal!
package sqlite3
import (
"database/sql"
)
// a statementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement.
type statementList []struct {
statement **sql.Stmt
sql string
}
// prepare the SQL for each statement in the list and assign the result to the prepared statement.
func (s statementList) prepare(db *sql.DB) (err error) {
for _, statement := range s {
if *statement.statement, err = db.Prepare(statement.sql); err != nil {
return
}
}
return
}

View File

@ -1,123 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
// Import the postgres database driver.
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
// Database is used to store metadata about a repository of media files.
type Database struct {
statements statements
db *sql.DB
writer sqlutil.Writer
}
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
d := Database{
writer: sqlutil.NewExclusiveWriter(),
}
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db, d.writer); err != nil {
return nil, err
}
return &d, nil
}
// StoreMediaMetadata inserts the metadata about the uploaded media into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreMediaMetadata(
ctx context.Context, mediaMetadata *types.MediaMetadata,
) error {
return d.statements.media.insertMedia(ctx, mediaMetadata)
}
// GetMediaMetadata returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadata(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return mediaMetadata, err
}
// GetMediaMetadataByHash returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadataByHash(
ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error) {
mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return mediaMetadata, err
}
// StoreThumbnail inserts the metadata about the thumbnail into the database.
// Returns an error if the combination of MediaID and Origin are not unique in the table.
func (d *Database) StoreThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error {
return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata)
}
// GetThumbnail returns metadata about a specific thumbnail.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there is no metadata associated with this thumbnail.
func (d *Database) GetThumbnail(
ctx context.Context,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
resizeMethod string,
) (*types.ThumbnailMetadata, error) {
thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail(
ctx, mediaID, mediaOrigin, width, height, resizeMethod,
)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return thumbnailMetadata, err
}
// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here.
// Returns nil metadata if there are no thumbnails associated with this media.
func (d *Database) GetThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return thumbnails, err
}

View File

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/storage/tables"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -54,55 +55,48 @@ SELECT content_type, file_size_bytes, creation_ts FROM mediaapi_thumbnail WHERE
// Note: this selects all thumbnails for a media_origin and media_id
const selectThumbnailsSQL = `
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2
SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method FROM mediaapi_thumbnail WHERE media_id = $1 AND media_origin = $2 ORDER BY creation_ts ASC
`
type thumbnailStatements struct {
db *sql.DB
writer sqlutil.Writer
insertThumbnailStmt *sql.Stmt
selectThumbnailStmt *sql.Stmt
selectThumbnailsStmt *sql.Stmt
}
func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
_, err = db.Exec(thumbnailSchema)
func NewSQLiteThumbnailsTable(db *sql.DB) (tables.Thumbnails, error) {
s := &thumbnailStatements{}
_, err := db.Exec(thumbnailSchema)
if err != nil {
return
return nil, err
}
s.db = db
s.writer = writer
return statementList{
return s, sqlutil.StatementList{
{&s.insertThumbnailStmt, insertThumbnailSQL},
{&s.selectThumbnailStmt, selectThumbnailSQL},
{&s.selectThumbnailsStmt, selectThumbnailsSQL},
}.prepare(db)
}.Prepare(db)
}
func (s *thumbnailStatements) insertThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt)
_, err := stmt.ExecContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
thumbnailMetadata.MediaMetadata.ContentType,
thumbnailMetadata.MediaMetadata.FileSizeBytes,
thumbnailMetadata.MediaMetadata.CreationTimestamp,
thumbnailMetadata.ThumbnailSize.Width,
thumbnailMetadata.ThumbnailSize.Height,
thumbnailMetadata.ThumbnailSize.ResizeMethod,
)
return err
})
func (s *thumbnailStatements) InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = gomatrixserverlib.AsTimestamp(time.Now())
_, err := sqlutil.TxStmtContext(ctx, txn, s.insertThumbnailStmt).ExecContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
thumbnailMetadata.MediaMetadata.ContentType,
thumbnailMetadata.MediaMetadata.FileSizeBytes,
thumbnailMetadata.MediaMetadata.CreationTimestamp,
thumbnailMetadata.ThumbnailSize.Width,
thumbnailMetadata.ThumbnailSize.Height,
thumbnailMetadata.ThumbnailSize.ResizeMethod,
)
return err
}
func (s *thumbnailStatements) selectThumbnail(
func (s *thumbnailStatements) SelectThumbnail(
ctx context.Context,
txn *sql.Tx,
mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
width, height int,
@ -119,7 +113,7 @@ func (s *thumbnailStatements) selectThumbnail(
ResizeMethod: resizeMethod,
},
}
err := s.selectThumbnailStmt.QueryRowContext(
err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailStmt).QueryRowContext(
ctx,
thumbnailMetadata.MediaMetadata.MediaID,
thumbnailMetadata.MediaMetadata.Origin,
@ -134,10 +128,11 @@ func (s *thumbnailStatements) selectThumbnail(
return &thumbnailMetadata, err
}
func (s *thumbnailStatements) selectThumbnails(
ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
func (s *thumbnailStatements) SelectThumbnails(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error) {
rows, err := s.selectThumbnailsStmt.QueryContext(
rows, err := sqlutil.TxStmtContext(ctx, txn, s.selectThumbnailsStmt).QueryContext(
ctx, mediaID, mediaOrigin,
)
if err != nil {

View File

@ -25,13 +25,13 @@ import (
"github.com/matrix-org/dendrite/setup/config"
)
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
// NewMediaAPIDatasource opens a database connection.
func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.Open(dbProperties)
return sqlite3.NewDatabase(dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return postgres.Open(dbProperties)
return postgres.NewDatabase(dbProperties)
default:
return nil, fmt.Errorf("unexpected database type")
}

View File

@ -0,0 +1,135 @@
package storage_test
import (
"context"
"reflect"
"testing"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewMediaAPIDatasource(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
})
if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err)
}
return db, close
}
func TestMediaRepository(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
ctx := context.Background()
t.Run("can insert media & query media", func(t *testing.T) {
metadata := &types.MediaMetadata{
MediaID: "testing",
Origin: "localhost",
ContentType: "image/png",
FileSizeBytes: 10,
UploadName: "upload test",
Base64Hash: "dGVzdGluZw==",
UserID: "@alice:localhost",
}
if err := db.StoreMediaMetadata(ctx, metadata); err != nil {
t.Fatalf("unable to store media metadata: %v", err)
}
// query by media id
gotMetadata, err := db.GetMediaMetadata(ctx, metadata.MediaID, metadata.Origin)
if err != nil {
t.Fatalf("unable to query media metadata: %v", err)
}
if !reflect.DeepEqual(metadata, gotMetadata) {
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
}
// query by media hash
gotMetadata, err = db.GetMediaMetadataByHash(ctx, metadata.Base64Hash, metadata.Origin)
if err != nil {
t.Fatalf("unable to query media metadata by hash: %v", err)
}
if !reflect.DeepEqual(metadata, gotMetadata) {
t.Fatalf("expected metadata %+v, got %v", metadata, gotMetadata)
}
})
})
}
func TestThumbnailsStorage(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
ctx := context.Background()
t.Run("can insert thumbnails & query media", func(t *testing.T) {
thumbnails := []*types.ThumbnailMetadata{
{
MediaMetadata: &types.MediaMetadata{
MediaID: "testing",
Origin: "localhost",
ContentType: "image/png",
FileSizeBytes: 6,
},
ThumbnailSize: types.ThumbnailSize{
Width: 5,
Height: 5,
ResizeMethod: types.Crop,
},
},
{
MediaMetadata: &types.MediaMetadata{
MediaID: "testing",
Origin: "localhost",
ContentType: "image/png",
FileSizeBytes: 7,
},
ThumbnailSize: types.ThumbnailSize{
Width: 1,
Height: 1,
ResizeMethod: types.Scale,
},
},
}
for i := range thumbnails {
if err := db.StoreThumbnail(ctx, thumbnails[i]); err != nil {
t.Fatalf("unable to store thumbnail metadata: %v", err)
}
}
// query by single thumbnail
gotMetadata, err := db.GetThumbnail(ctx,
thumbnails[0].MediaMetadata.MediaID,
thumbnails[0].MediaMetadata.Origin,
thumbnails[0].ThumbnailSize.Width, thumbnails[0].ThumbnailSize.Height,
thumbnails[0].ThumbnailSize.ResizeMethod,
)
if err != nil {
t.Fatalf("unable to query thumbnail metadata: %v", err)
}
if !reflect.DeepEqual(thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata) {
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
}
if !reflect.DeepEqual(thumbnails[0].ThumbnailSize, gotMetadata.ThumbnailSize) {
t.Fatalf("expected metadata %+v, got %+v", thumbnails[0].MediaMetadata, gotMetadata.MediaMetadata)
}
// query by all thumbnails
gotMediadatas, err := db.GetThumbnails(ctx, thumbnails[0].MediaMetadata.MediaID, thumbnails[0].MediaMetadata.Origin)
if err != nil {
t.Fatalf("unable to query media metadata by hash: %v", err)
}
if len(gotMediadatas) != len(thumbnails) {
t.Fatalf("expected %d stored thumbnail metadata, got %d", len(thumbnails), len(gotMediadatas))
}
for i := range gotMediadatas {
if !reflect.DeepEqual(thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata) {
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].MediaMetadata, gotMediadatas[i].MediaMetadata)
}
if !reflect.DeepEqual(thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize) {
t.Fatalf("expected metadata %+v, got %v", thumbnails[i].ThumbnailSize, gotMediadatas[i].ThumbnailSize)
}
}
})
})
}

View File

@ -22,10 +22,10 @@ import (
)
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (Database, error) {
func NewMediaAPIDatasource(dbProperties *config.DatabaseOptions) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.Open(dbProperties)
return sqlite3.NewDatabase(dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:

View File

@ -0,0 +1,46 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
type Thumbnails interface {
InsertThumbnail(ctx context.Context, txn *sql.Tx, thumbnailMetadata *types.ThumbnailMetadata) error
SelectThumbnail(
ctx context.Context, txn *sql.Tx,
mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName,
width, height int,
resizeMethod string,
) (*types.ThumbnailMetadata, error)
SelectThumbnails(
ctx context.Context, txn *sql.Tx, mediaID types.MediaID,
mediaOrigin gomatrixserverlib.ServerName,
) ([]*types.ThumbnailMetadata, error)
}
type MediaRepository interface {
InsertMedia(ctx context.Context, txn *sql.Tx, mediaMetadata *types.MediaMetadata) error
SelectMedia(ctx context.Context, txn *sql.Tx, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error)
SelectMediaByHash(
ctx context.Context, txn *sql.Tx,
mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName,
) (*types.MediaMetadata, error)
}

View File

@ -45,16 +45,13 @@ type RequestMethod string
// MatrixUserID is a Matrix user ID string in the form @user:domain e.g. @alice:matrix.org
type MatrixUserID string
// UnixMs is the milliseconds since the Unix epoch
type UnixMs int64
// MediaMetadata is metadata associated with a media file
type MediaMetadata struct {
MediaID MediaID
Origin gomatrixserverlib.ServerName
ContentType ContentType
FileSizeBytes FileSizeBytes
CreationTimestamp UnixMs
CreationTimestamp gomatrixserverlib.Timestamp
UploadName Filename
Base64Hash Base64Hash
UserID MatrixUserID