diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index 9f35e90c..b74d1732 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -16,7 +16,8 @@ package routing import ( "context" - "encoding/base64" + "crypto/rand" + "encoding/hex" "fmt" "io" "net/http" @@ -93,6 +94,33 @@ func parseAndValidateRequest(req *http.Request, cfg *config.MediaAPI) (*uploadRe return r, nil } +func (r *uploadRequest) generateMediaID(ctx context.Context, db storage.Database) (types.MediaID, error) { + for { + // First try generating a meda ID. We'll do this by + // generating some random bytes and then hex-encoding. + mediaIDBytes := make([]byte, 32) + _, err := rand.Read(mediaIDBytes) + if err != nil { + return "", fmt.Errorf("rand.Read: %w", err) + } + mediaID := types.MediaID(hex.EncodeToString(mediaIDBytes)) + // Then we will check if this media ID already exists in + // our database. If it does then we had best generate a + // new one. + existingMetadata, err := db.GetMediaMetadata(ctx, mediaID, r.MediaMetadata.Origin) + if err != nil { + return "", fmt.Errorf("db.GetMediaMetadata: %w", err) + } + if existingMetadata != nil { + // The media ID was already used - repeat the process + // and generate a new one instead. + continue + } + // The media ID was not already used - let's return that. + return mediaID, nil + } +} + func (r *uploadRequest) doUpload( ctx context.Context, reqReader io.Reader, @@ -122,14 +150,53 @@ func (r *uploadRequest) doUpload( } } - r.MediaMetadata.FileSizeBytes = bytesWritten - r.MediaMetadata.Base64Hash = hash - r.MediaMetadata.MediaID = types.MediaID(base64.RawURLEncoding.EncodeToString( - []byte(string(r.MediaMetadata.UploadName) + string(r.MediaMetadata.Base64Hash)), - )) + // Look up the media by the file hash. If we already have the file but under a + // different media ID then we won't upload the file again - instead we'll just + // add a new metadata entry that refers to the same file. + existingMetadata, err := db.GetMediaMetadataByHash( + ctx, hash, r.MediaMetadata.Origin, + ) + if err != nil { + r.Logger.WithError(err).Error("Error querying the database by hash.") + resErr := jsonerror.InternalServerError() + return &resErr + } + if existingMetadata != nil { + // The file already exists. Make a new media ID up for it. + mediaID, merr := r.generateMediaID(ctx, db) + if merr != nil { + r.Logger.WithError(merr).Error("Failed to generate media ID for existing file") + resErr := jsonerror.InternalServerError() + return &resErr + } - r.Logger = r.Logger.WithField("MediaID", r.MediaMetadata.MediaID) + // Then amend the upload metadata. + r.MediaMetadata = &types.MediaMetadata{ + MediaID: mediaID, + Origin: r.MediaMetadata.Origin, + ContentType: r.MediaMetadata.ContentType, + FileSizeBytes: r.MediaMetadata.FileSizeBytes, + CreationTimestamp: r.MediaMetadata.CreationTimestamp, + UploadName: r.MediaMetadata.UploadName, + Base64Hash: hash, + UserID: r.MediaMetadata.UserID, + } + // Clean up the uploaded temporary file. + fileutils.RemoveDir(tmpDir, r.Logger) + } else { + // The file doesn't exist. Update the request metadata. + r.MediaMetadata.FileSizeBytes = bytesWritten + r.MediaMetadata.Base64Hash = hash + r.MediaMetadata.MediaID, err = r.generateMediaID(ctx, db) + if err != nil { + r.Logger.WithError(err).Error("Failed to generate media ID for new upload") + resErr := jsonerror.InternalServerError() + return &resErr + } + } + + r.Logger = r.Logger.WithField("media_id", r.MediaMetadata.MediaID) r.Logger.WithFields(log.Fields{ "Base64Hash": r.MediaMetadata.Base64Hash, "UploadName": r.MediaMetadata.UploadName, @@ -137,27 +204,6 @@ func (r *uploadRequest) doUpload( "ContentType": r.MediaMetadata.ContentType, }).Info("File uploaded") - // check if we already have a record of the media in our database and if so, we can remove the temporary directory - mediaMetadata, err := db.GetMediaMetadata( - ctx, r.MediaMetadata.MediaID, r.MediaMetadata.Origin, - ) - if err != nil { - r.Logger.WithError(err).Error("Error querying the database.") - resErr := jsonerror.InternalServerError() - return &resErr - } - - if mediaMetadata != nil { - r.MediaMetadata = mediaMetadata - fileutils.RemoveDir(tmpDir, r.Logger) - return &util.JSONResponse{ - Code: http.StatusOK, - JSON: uploadResponse{ - ContentURI: fmt.Sprintf("mxc://%s/%s", cfg.Matrix.ServerName, r.MediaMetadata.MediaID), - }, - } - } - return r.storeFileAndMetadata( ctx, tmpDir, cfg.AbsBasePath, db, cfg.ThumbnailSizes, activeThumbnailGeneration, cfg.MaxThumbnailGenerators, diff --git a/mediaapi/storage/interface.go b/mediaapi/storage/interface.go index 672e8ef5..84319971 100644 --- a/mediaapi/storage/interface.go +++ b/mediaapi/storage/interface.go @@ -24,6 +24,7 @@ import ( type Database interface { StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) + GetMediaMetadataByHash(ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) diff --git a/mediaapi/storage/postgres/media_repository_table.go b/mediaapi/storage/postgres/media_repository_table.go index e975530a..1d3264ca 100644 --- a/mediaapi/storage/postgres/media_repository_table.go +++ b/mediaapi/storage/postgres/media_repository_table.go @@ -59,9 +59,14 @@ const selectMediaSQL = ` SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user_id FROM mediaapi_media_repository WHERE media_id = $1 AND media_origin = $2 ` +const selectMediaByHashSQL = ` +SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_id FROM mediaapi_media_repository WHERE base64hash = $1 AND media_origin = $2 +` + type mediaStatements struct { - insertMediaStmt *sql.Stmt - selectMediaStmt *sql.Stmt + insertMediaStmt *sql.Stmt + selectMediaStmt *sql.Stmt + selectMediaByHashStmt *sql.Stmt } func (s *mediaStatements) prepare(db *sql.DB) (err error) { @@ -73,6 +78,7 @@ func (s *mediaStatements) prepare(db *sql.DB) (err error) { return statementList{ {&s.insertMediaStmt, insertMediaSQL}, {&s.selectMediaStmt, selectMediaSQL}, + {&s.selectMediaByHashStmt, selectMediaByHashSQL}, }.prepare(db) } @@ -113,3 +119,23 @@ func (s *mediaStatements) selectMedia( ) return &mediaMetadata, err } + +func (s *mediaStatements) selectMediaByHash( + ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, +) (*types.MediaMetadata, error) { + mediaMetadata := types.MediaMetadata{ + Base64Hash: mediaHash, + Origin: mediaOrigin, + } + err := s.selectMediaStmt.QueryRowContext( + ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin, + ).Scan( + &mediaMetadata.ContentType, + &mediaMetadata.FileSizeBytes, + &mediaMetadata.CreationTimestamp, + &mediaMetadata.UploadName, + &mediaMetadata.MediaID, + &mediaMetadata.UserID, + ) + return &mediaMetadata, err +} diff --git a/mediaapi/storage/postgres/storage.go b/mediaapi/storage/postgres/storage.go index 756913d3..f89501de 100644 --- a/mediaapi/storage/postgres/storage.go +++ b/mediaapi/storage/postgres/storage.go @@ -67,6 +67,19 @@ func (d *Database) GetMediaMetadata( return mediaMetadata, err } +// GetMediaMetadataByHash returns metadata about media stored on this server. +// The media could have been uploaded to this server or fetched from another server and cached here. +// Returns nil metadata if there is no metadata associated with this media. +func (d *Database) GetMediaMetadataByHash( + ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, +) (*types.MediaMetadata, error) { + mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin) + if err != nil && err == sql.ErrNoRows { + return nil, nil + } + return mediaMetadata, err +} + // StoreThumbnail inserts the metadata about the thumbnail into the database. // Returns an error if the combination of MediaID and Origin are not unique in the table. func (d *Database) StoreThumbnail( diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go index dcc1b41e..bcef609d 100644 --- a/mediaapi/storage/sqlite3/media_repository_table.go +++ b/mediaapi/storage/sqlite3/media_repository_table.go @@ -60,11 +60,16 @@ const selectMediaSQL = ` SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user_id FROM mediaapi_media_repository WHERE media_id = $1 AND media_origin = $2 ` +const selectMediaByHashSQL = ` +SELECT content_type, file_size_bytes, creation_ts, upload_name, media_id, user_id FROM mediaapi_media_repository WHERE base64hash = $1 AND media_origin = $2 +` + type mediaStatements struct { - db *sql.DB - writer sqlutil.Writer - insertMediaStmt *sql.Stmt - selectMediaStmt *sql.Stmt + db *sql.DB + writer sqlutil.Writer + insertMediaStmt *sql.Stmt + selectMediaStmt *sql.Stmt + selectMediaByHashStmt *sql.Stmt } func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { @@ -79,6 +84,7 @@ func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) return statementList{ {&s.insertMediaStmt, insertMediaSQL}, {&s.selectMediaStmt, selectMediaSQL}, + {&s.selectMediaByHashStmt, selectMediaByHashSQL}, }.prepare(db) } @@ -122,3 +128,23 @@ func (s *mediaStatements) selectMedia( ) return &mediaMetadata, err } + +func (s *mediaStatements) selectMediaByHash( + ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, +) (*types.MediaMetadata, error) { + mediaMetadata := types.MediaMetadata{ + Base64Hash: mediaHash, + Origin: mediaOrigin, + } + err := s.selectMediaStmt.QueryRowContext( + ctx, mediaMetadata.Base64Hash, mediaMetadata.Origin, + ).Scan( + &mediaMetadata.ContentType, + &mediaMetadata.FileSizeBytes, + &mediaMetadata.CreationTimestamp, + &mediaMetadata.UploadName, + &mediaMetadata.MediaID, + &mediaMetadata.UserID, + ) + return &mediaMetadata, err +} diff --git a/mediaapi/storage/sqlite3/storage.go b/mediaapi/storage/sqlite3/storage.go index d5c3031e..9e510fa3 100644 --- a/mediaapi/storage/sqlite3/storage.go +++ b/mediaapi/storage/sqlite3/storage.go @@ -70,6 +70,19 @@ func (d *Database) GetMediaMetadata( return mediaMetadata, err } +// GetMediaMetadataByHash returns metadata about media stored on this server. +// The media could have been uploaded to this server or fetched from another server and cached here. +// Returns nil metadata if there is no metadata associated with this media. +func (d *Database) GetMediaMetadataByHash( + ctx context.Context, mediaHash types.Base64Hash, mediaOrigin gomatrixserverlib.ServerName, +) (*types.MediaMetadata, error) { + mediaMetadata, err := d.statements.media.selectMediaByHash(ctx, mediaHash, mediaOrigin) + if err != nil && err == sql.ErrNoRows { + return nil, nil + } + return mediaMetadata, err +} + // StoreThumbnail inserts the metadata about the thumbnail into the database. // Returns an error if the combination of MediaID and Origin are not unique in the table. func (d *Database) StoreThumbnail(