mediaapi/storage: Don't leak sql.ErrNoRows out of storage package

This commit is contained in:
Robert Swain 2017-05-31 14:29:28 +02:00
parent a936ad5063
commit a45f008c12
2 changed files with 18 additions and 46 deletions

View File

@ -50,7 +50,11 @@ func (d *Database) StoreMediaMetadata(mediaMetadata *types.MediaMetadata) error
// GetMediaMetadata returns metadata about media stored on this server. // GetMediaMetadata returns metadata about media stored on this server.
// The media could have been uploaded to this server or fetched from another server and cached here. // The media could have been uploaded to this server or fetched from another server and cached here.
// Returns sql.ErrNoRows if there is no metadata associated with this media. // Returns nil metadata if there is no metadata associated with this media.
func (d *Database) GetMediaMetadata(mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) { func (d *Database) GetMediaMetadata(mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) {
return d.statements.selectMedia(mediaID, mediaOrigin) mediaMetadata, err := d.statements.selectMedia(mediaID, mediaOrigin)
if err != nil && err == sql.ErrNoRows {
return nil, nil
}
return mediaMetadata, err
} }

View File

@ -15,7 +15,6 @@
package writers package writers
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -39,13 +38,6 @@ const mediaIDCharacters = "A-Za-z0-9_=-"
// Note: unfortunately regex.MustCompile() cannot be assigned to a const // Note: unfortunately regex.MustCompile() cannot be assigned to a const
var mediaIDRegex = regexp.MustCompile("[" + mediaIDCharacters + "]+") var mediaIDRegex = regexp.MustCompile("[" + mediaIDCharacters + "]+")
// Error types used by downloadRequest.getMediaMetadata
// FIXME: make into types
var (
errDBQuery = fmt.Errorf("error querying database for media")
errDBNotFound = fmt.Errorf("media not found")
)
// downloadRequest metadata included in or derivable from an download request // downloadRequest metadata included in or derivable from an download request
// https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid // https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid
type downloadRequest struct { type downloadRequest struct {
@ -121,15 +113,17 @@ func (r *downloadRequest) Validate() *util.JSONResponse {
func (r *downloadRequest) doDownload(w http.ResponseWriter, cfg *config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) *util.JSONResponse { func (r *downloadRequest) doDownload(w http.ResponseWriter, cfg *config.MediaAPI, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests) *util.JSONResponse {
// check if we have a record of the media in our database // check if we have a record of the media in our database
mediaMetadata, err := r.getMediaMetadata(db) mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
if err == nil { if err != nil {
// If we have a record, we can respond from the local file r.Logger.WithError(err).Error("Error querying the database.")
r.MediaMetadata = mediaMetadata return &util.JSONResponse{
return r.respondFromLocalFile(w, cfg.AbsBasePath) Code: 500,
} else if err == errDBNotFound { JSON: jsonerror.InternalServerError(),
}
}
if mediaMetadata == nil {
if r.MediaMetadata.Origin == cfg.ServerName { if r.MediaMetadata.Origin == cfg.ServerName {
// If we do not have a record and the origin is local, the file is not found // If we do not have a record and the origin is local, the file is not found
r.Logger.WithError(err).Warn("Failed to look up file in database")
return &util.JSONResponse{ return &util.JSONResponse{
Code: 404, Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
@ -141,35 +135,9 @@ func (r *downloadRequest) doDownload(w http.ResponseWriter, cfg *config.MediaAPI
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)), JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
} }
} }
// Another error from the database // If we have a record, we can respond from the local file
r.Logger.WithError(err).WithFields(log.Fields{ r.MediaMetadata = mediaMetadata
"MediaID": r.MediaMetadata.MediaID, return r.respondFromLocalFile(w, cfg.AbsBasePath)
"Origin": r.MediaMetadata.Origin,
}).Error("Error querying the database.")
return &util.JSONResponse{
Code: 500,
JSON: jsonerror.InternalServerError(),
}
}
// getMediaMetadata queries the database for media metadata
func (r *downloadRequest) getMediaMetadata(db *storage.Database) (*types.MediaMetadata, error) {
mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
if err != nil {
if err == sql.ErrNoRows {
r.Logger.WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
}).Info("Media not found in database.")
return nil, errDBNotFound
}
r.Logger.WithError(err).WithFields(log.Fields{
"Origin": r.MediaMetadata.Origin,
"MediaID": r.MediaMetadata.MediaID,
}).Error("Error querying database for media.")
return nil, errDBQuery
}
return mediaMetadata, nil
} }
// respondFromLocalFile reads a file from local storage and writes it to the http.ResponseWriter // respondFromLocalFile reads a file from local storage and writes it to the http.ResponseWriter