Make media repo use error rather than jsonErrorResponse (#235)

* Make media repo use error rather than jsonErrorResponse

* Update comments

* gb vendor github.com/pkg/errors

* Fixup error formats
This commit is contained in:
Erik Johnston 2017-09-19 11:40:21 +01:00 committed by GitHub
parent 856bc5b52e
commit 08b9940dde
14 changed files with 2041 additions and 141 deletions

11
.gitignore vendored
View File

@ -12,11 +12,12 @@ kafka.tgz
*.so *.so
# Folders # Folders
kafka /kafka
bin /bin
pkg /pkg
_obj /_obj
_test /_test
/vendor/bin
# Architecture specific extensions/prefixes # Architecture specific extensions/prefixes
*.[568vq] *.[568vq]

View File

@ -19,7 +19,6 @@ import (
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
) )
// FileSizeBytes is a file size in bytes // FileSizeBytes is a file size in bytes
@ -67,8 +66,8 @@ type RemoteRequestResult struct {
Cond *sync.Cond Cond *sync.Cond
// MediaMetadata of the requested file to avoid querying the database for every waiting routine // MediaMetadata of the requested file to avoid querying the database for every waiting routine
MediaMetadata *MediaMetadata MediaMetadata *MediaMetadata
// An error in util.JSONResponse form. nil in case of no error. // An error, nil in case of no error.
ErrorResponse *util.JSONResponse Error error
} }
// ActiveRemoteRequests is a lockable map of media URIs requested from remote homeservers // ActiveRemoteRequests is a lockable map of media URIs requested from remote homeservers

View File

@ -36,6 +36,7 @@ import (
"github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/pkg/errors"
) )
const mediaIDCharacters = "A-Za-z0-9_=-" const mediaIDCharacters = "A-Za-z0-9_=-"
@ -59,8 +60,18 @@ type downloadRequest struct {
// If they are present in the cache, they are served directly. // If they are present in the cache, they are served directly.
// If they are not present in the cache, they are obtained from the remote server and // If they are not present in the cache, they are obtained from the remote server and
// simultaneously served back to the client and written into the cache. // simultaneously served back to the client and written into the cache.
func Download(w http.ResponseWriter, req *http.Request, origin gomatrixserverlib.ServerName, mediaID types.MediaID, cfg *config.Dendrite, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, isThumbnailRequest bool) { func Download(
r := &downloadRequest{ w http.ResponseWriter,
req *http.Request,
origin gomatrixserverlib.ServerName,
mediaID types.MediaID,
cfg *config.Dendrite,
db *storage.Database,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
isThumbnailRequest bool,
) {
dReq := &downloadRequest{
MediaMetadata: &types.MediaMetadata{ MediaMetadata: &types.MediaMetadata{
MediaID: mediaID, MediaID: mediaID,
Origin: origin, Origin: origin,
@ -72,7 +83,7 @@ func Download(w http.ResponseWriter, req *http.Request, origin gomatrixserverlib
}), }),
} }
if r.IsThumbnailRequest { if dReq.IsThumbnailRequest {
width, err := strconv.Atoi(req.FormValue("width")) width, err := strconv.Atoi(req.FormValue("width"))
if err != nil { if err != nil {
width = -1 width = -1
@ -81,36 +92,47 @@ func Download(w http.ResponseWriter, req *http.Request, origin gomatrixserverlib
if err != nil { if err != nil {
height = -1 height = -1
} }
r.ThumbnailSize = types.ThumbnailSize{ dReq.ThumbnailSize = types.ThumbnailSize{
Width: width, Width: width,
Height: height, Height: height,
ResizeMethod: strings.ToLower(req.FormValue("method")), ResizeMethod: strings.ToLower(req.FormValue("method")),
} }
r.Logger.WithFields(log.Fields{ dReq.Logger.WithFields(log.Fields{
"RequestedWidth": r.ThumbnailSize.Width, "RequestedWidth": dReq.ThumbnailSize.Width,
"RequestedHeight": r.ThumbnailSize.Height, "RequestedHeight": dReq.ThumbnailSize.Height,
"RequestedResizeMethod": r.ThumbnailSize.ResizeMethod, "RequestedResizeMethod": dReq.ThumbnailSize.ResizeMethod,
}) })
} }
// request validation // request validation
if req.Method != "GET" { if req.Method != "GET" {
r.jsonErrorResponse(w, util.JSONResponse{ dReq.jsonErrorResponse(w, util.JSONResponse{
Code: 405, Code: 405,
JSON: jsonerror.Unknown("request method must be GET"), JSON: jsonerror.Unknown("request method must be GET"),
}) })
return return
} }
if resErr := r.Validate(); resErr != nil { if resErr := dReq.Validate(); resErr != nil {
r.jsonErrorResponse(w, *resErr) dReq.jsonErrorResponse(w, *resErr)
return return
} }
if resErr := r.doDownload(w, cfg, db, activeRemoteRequests, activeThumbnailGeneration); resErr != nil { metadata, err := dReq.doDownload(w, cfg, db, activeRemoteRequests, activeThumbnailGeneration)
r.jsonErrorResponse(w, *resErr) if err != nil {
// TODO: Handle the fact we might have started writing the response
dReq.jsonErrorResponse(w, util.ErrorResponse(err))
return return
} }
if metadata == nil {
dReq.jsonErrorResponse(w, util.JSONResponse{
Code: 404,
JSON: jsonerror.NotFound("File not found"),
})
return
}
} }
func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse) { func (r *downloadRequest) jsonErrorResponse(w http.ResponseWriter, res util.JSONResponse) {
@ -167,26 +189,27 @@ func (r *downloadRequest) Validate() *util.JSONResponse {
return nil return nil
} }
func (r *downloadRequest) doDownload(w http.ResponseWriter, cfg *config.Dendrite, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration) *util.JSONResponse { func (r *downloadRequest) doDownload(
w http.ResponseWriter,
cfg *config.Dendrite,
db *storage.Database,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
) (*types.MediaMetadata, error) {
// check if we have a record of the media in our database // check if we have a record of the media in our database
mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin) mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
if err != nil { if err != nil {
r.Logger.WithError(err).Error("Error querying the database.") return nil, errors.Wrap(err, "error querying the database")
resErr := jsonerror.InternalServerError()
return &resErr
} }
if mediaMetadata == nil { if mediaMetadata == nil {
if r.MediaMetadata.Origin == cfg.Matrix.ServerName { if r.MediaMetadata.Origin == cfg.Matrix.ServerName {
// If we do not have a record and the origin is local, the file is not found // If we do not have a record and the origin is local, the file is not found
return &util.JSONResponse{ return nil, nil
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
}
} }
// If we do not have a record and the origin is remote, we need to fetch it and respond with that file // If we do not have a record and the origin is remote, we need to fetch it and respond with that file
resErr := r.getRemoteFile(cfg, db, activeRemoteRequests, activeThumbnailGeneration) resErr := r.getRemoteFile(cfg, db, activeRemoteRequests, activeThumbnailGeneration)
if resErr != nil { if resErr != nil {
return resErr return nil, resErr
} }
} else { } else {
// If we have a record, we can respond from the local file // If we have a record, we can respond from the local file
@ -196,26 +219,28 @@ func (r *downloadRequest) doDownload(w http.ResponseWriter, cfg *config.Dendrite
} }
// respondFromLocalFile reads a file from local storage and writes it to the http.ResponseWriter // respondFromLocalFile reads a file from local storage and writes it to the http.ResponseWriter
// Returns a util.JSONResponse error in case of error // If no file was found then returns nil, nil
func (r *downloadRequest) respondFromLocalFile(w http.ResponseWriter, absBasePath config.Path, activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int, db *storage.Database, dynamicThumbnails bool, thumbnailSizes []config.ThumbnailSize) *util.JSONResponse { func (r *downloadRequest) respondFromLocalFile(
w http.ResponseWriter,
absBasePath config.Path,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
maxThumbnailGenerators int,
db *storage.Database,
dynamicThumbnails bool,
thumbnailSizes []config.ThumbnailSize,
) (*types.MediaMetadata, error) {
filePath, err := fileutils.GetPathFromBase64Hash(r.MediaMetadata.Base64Hash, absBasePath) filePath, err := fileutils.GetPathFromBase64Hash(r.MediaMetadata.Base64Hash, absBasePath)
if err != nil { if err != nil {
r.Logger.WithError(err).Error("Failed to get file path from metadata") return nil, errors.Wrap(err, "failed to get file path from metadata")
resErr := jsonerror.InternalServerError()
return &resErr
} }
file, err := os.Open(filePath) file, err := os.Open(filePath)
defer file.Close() defer file.Close()
if err != nil { if err != nil {
r.Logger.WithError(err).Error("Failed to open file") return nil, errors.Wrap(err, "failed to open file")
resErr := jsonerror.InternalServerError()
return &resErr
} }
stat, err := file.Stat() stat, err := file.Stat()
if err != nil { if err != nil {
r.Logger.WithError(err).Error("Failed to stat file") return nil, errors.Wrap(err, "failed to stat file")
resErr := jsonerror.InternalServerError()
return &resErr
} }
if r.MediaMetadata.FileSizeBytes > 0 && int64(r.MediaMetadata.FileSizeBytes) != stat.Size() { if r.MediaMetadata.FileSizeBytes > 0 && int64(r.MediaMetadata.FileSizeBytes) != stat.Size() {
@ -223,8 +248,7 @@ func (r *downloadRequest) respondFromLocalFile(w http.ResponseWriter, absBasePat
"fileSizeDatabase": r.MediaMetadata.FileSizeBytes, "fileSizeDatabase": r.MediaMetadata.FileSizeBytes,
"fileSizeDisk": stat.Size(), "fileSizeDisk": stat.Size(),
}).Warn("File size in database and on-disk differ.") }).Warn("File size in database and on-disk differ.")
resErr := jsonerror.InternalServerError() return nil, errors.New("file size in database and on-disk differ")
return &resErr
} }
var responseFile *os.File var responseFile *os.File
@ -235,7 +259,7 @@ func (r *downloadRequest) respondFromLocalFile(w http.ResponseWriter, absBasePat
defer thumbFile.Close() defer thumbFile.Close()
} }
if resErr != nil { if resErr != nil {
return resErr return nil, resErr
} }
if thumbFile == nil { if thumbFile == nil {
r.Logger.WithFields(log.Fields{ r.Logger.WithFields(log.Fields{
@ -271,37 +295,38 @@ func (r *downloadRequest) respondFromLocalFile(w http.ResponseWriter, absBasePat
" object-src 'self';" " object-src 'self';"
w.Header().Set("Content-Security-Policy", contentSecurityPolicy) w.Header().Set("Content-Security-Policy", contentSecurityPolicy)
if bytesResponded, err := io.Copy(w, responseFile); err != nil { if _, err := io.Copy(w, responseFile); err != nil {
r.Logger.WithError(err).Warn("Failed to copy from cache") return nil, errors.Wrap(err, "failed to copy from cache")
if bytesResponded == 0 {
resErr := jsonerror.InternalServerError()
return &resErr
}
// If we have written any data then we have already responded with 200 OK and all we can do is close the connection
return nil
} }
return nil return responseMetadata, nil
} }
// Note: Thumbnail generation may be ongoing asynchronously. // Note: Thumbnail generation may be ongoing asynchronously.
func (r *downloadRequest) getThumbnailFile(filePath types.Path, activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int, db *storage.Database, dynamicThumbnails bool, thumbnailSizes []config.ThumbnailSize) (*os.File, *types.ThumbnailMetadata, *util.JSONResponse) { // If no thumbnail was found then returns nil, nil, nil
func (r *downloadRequest) getThumbnailFile(
filePath types.Path,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
maxThumbnailGenerators int,
db *storage.Database,
dynamicThumbnails bool,
thumbnailSizes []config.ThumbnailSize,
) (*os.File, *types.ThumbnailMetadata, error) {
var thumbnail *types.ThumbnailMetadata var thumbnail *types.ThumbnailMetadata
var resErr *util.JSONResponse var err error
if dynamicThumbnails { if dynamicThumbnails {
thumbnail, resErr = r.generateThumbnail(filePath, r.ThumbnailSize, activeThumbnailGeneration, maxThumbnailGenerators, db) thumbnail, err = r.generateThumbnail(filePath, r.ThumbnailSize, activeThumbnailGeneration, maxThumbnailGenerators, db)
if resErr != nil { if err != nil {
return nil, nil, resErr return nil, nil, err
} }
} }
// If dynamicThumbnails is true but there are too many thumbnails being actively generated, we can fall back // If dynamicThumbnails is true but there are too many thumbnails being actively generated, we can fall back
// to trying to use a pre-generated thumbnail // to trying to use a pre-generated thumbnail
if thumbnail == nil { if thumbnail == nil {
thumbnails, err := db.GetThumbnails(r.MediaMetadata.MediaID, r.MediaMetadata.Origin) var thumbnails []*types.ThumbnailMetadata
thumbnails, err = db.GetThumbnails(r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
if err != nil { if err != nil {
r.Logger.WithError(err).Error("Error looking up thumbnails") return nil, nil, errors.Wrap(err, "error looking up thumbnails")
resErr := jsonerror.InternalServerError()
return nil, nil, &resErr
} }
// If we get a thumbnailSize, a pre-generated thumbnail would be best but it is not yet generated. // If we get a thumbnailSize, a pre-generated thumbnail would be best but it is not yet generated.
@ -316,9 +341,9 @@ func (r *downloadRequest) getThumbnailFile(filePath types.Path, activeThumbnailG
"Height": thumbnailSize.Height, "Height": thumbnailSize.Height,
"ResizeMethod": thumbnailSize.ResizeMethod, "ResizeMethod": thumbnailSize.ResizeMethod,
}).Info("Pre-generating thumbnail for immediate response.") }).Info("Pre-generating thumbnail for immediate response.")
thumbnail, resErr = r.generateThumbnail(filePath, *thumbnailSize, activeThumbnailGeneration, maxThumbnailGenerators, db) thumbnail, err = r.generateThumbnail(filePath, *thumbnailSize, activeThumbnailGeneration, maxThumbnailGenerators, db)
if resErr != nil { if err != nil {
return nil, nil, resErr return nil, nil, err
} }
} }
} }
@ -335,35 +360,36 @@ func (r *downloadRequest) getThumbnailFile(filePath types.Path, activeThumbnailG
thumbPath := string(thumbnailer.GetThumbnailPath(types.Path(filePath), thumbnail.ThumbnailSize)) thumbPath := string(thumbnailer.GetThumbnailPath(types.Path(filePath), thumbnail.ThumbnailSize))
thumbFile, err := os.Open(string(thumbPath)) thumbFile, err := os.Open(string(thumbPath))
if err != nil { if err != nil {
r.Logger.WithError(err).Warn("Failed to open file") thumbFile.Close()
resErr := jsonerror.InternalServerError() return nil, nil, errors.Wrap(err, "failed to open file")
return thumbFile, nil, &resErr
} }
thumbStat, err := thumbFile.Stat() thumbStat, err := thumbFile.Stat()
if err != nil { if err != nil {
r.Logger.WithError(err).Warn("Failed to stat file") thumbFile.Close()
resErr := jsonerror.InternalServerError() return nil, nil, errors.Wrap(err, "failed to stat file")
return thumbFile, nil, &resErr
} }
if types.FileSizeBytes(thumbStat.Size()) != thumbnail.MediaMetadata.FileSizeBytes { if types.FileSizeBytes(thumbStat.Size()) != thumbnail.MediaMetadata.FileSizeBytes {
r.Logger.WithError(err).Warn("Thumbnail file sizes on disk and in database differ") thumbFile.Close()
resErr := jsonerror.InternalServerError() return nil, nil, errors.New("thumbnail file sizes on disk and in database differ")
return thumbFile, nil, &resErr
} }
return thumbFile, thumbnail, nil return thumbFile, thumbnail, nil
} }
func (r *downloadRequest) generateThumbnail(filePath types.Path, thumbnailSize types.ThumbnailSize, activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int, db *storage.Database) (*types.ThumbnailMetadata, *util.JSONResponse) { func (r *downloadRequest) generateThumbnail(
logger := r.Logger.WithFields(log.Fields{ filePath types.Path,
thumbnailSize types.ThumbnailSize,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
maxThumbnailGenerators int,
db *storage.Database,
) (*types.ThumbnailMetadata, error) {
r.Logger.WithFields(log.Fields{
"Width": thumbnailSize.Width, "Width": thumbnailSize.Width,
"Height": thumbnailSize.Height, "Height": thumbnailSize.Height,
"ResizeMethod": thumbnailSize.ResizeMethod, "ResizeMethod": thumbnailSize.ResizeMethod,
}) })
busy, err := thumbnailer.GenerateThumbnail(filePath, thumbnailSize, r.MediaMetadata, activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger) busy, err := thumbnailer.GenerateThumbnail(filePath, thumbnailSize, r.MediaMetadata, activeThumbnailGeneration, maxThumbnailGenerators, db, r.Logger)
if err != nil { if err != nil {
logger.WithError(err).Error("Error creating thumbnail") return nil, errors.Wrap(err, "error creating thumbnail")
resErr := jsonerror.InternalServerError()
return nil, &resErr
} }
if busy { if busy {
return nil, nil return nil, nil
@ -371,9 +397,7 @@ func (r *downloadRequest) generateThumbnail(filePath types.Path, thumbnailSize t
var thumbnail *types.ThumbnailMetadata var thumbnail *types.ThumbnailMetadata
thumbnail, err = db.GetThumbnail(r.MediaMetadata.MediaID, r.MediaMetadata.Origin, thumbnailSize.Width, thumbnailSize.Height, thumbnailSize.ResizeMethod) thumbnail, err = db.GetThumbnail(r.MediaMetadata.MediaID, r.MediaMetadata.Origin, thumbnailSize.Width, thumbnailSize.Height, thumbnailSize.ResizeMethod)
if err != nil { if err != nil {
logger.WithError(err).Error("Error looking up thumbnails") return nil, errors.Wrap(err, "error looking up thumbnail")
resErr := jsonerror.InternalServerError()
return nil, &resErr
} }
return thumbnail, nil return thumbnail, nil
} }
@ -382,8 +406,12 @@ func (r *downloadRequest) generateThumbnail(filePath types.Path, thumbnailSize t
// A hash map of active remote requests to a struct containing a sync.Cond is used to only download remote files once, // A hash map of active remote requests to a struct containing a sync.Cond is used to only download remote files once,
// regardless of how many download requests are received. // regardless of how many download requests are received.
// Note: The named errorResponse return variable is used in a deferred broadcast of the metadata and error response to waiting goroutines. // Note: The named errorResponse return variable is used in a deferred broadcast of the metadata and error response to waiting goroutines.
// Returns a util.JSONResponse error in case of error func (r *downloadRequest) getRemoteFile(
func (r *downloadRequest) getRemoteFile(cfg *config.Dendrite, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration) (errorResponse *util.JSONResponse) { cfg *config.Dendrite,
db *storage.Database,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
) (errorResponse error) {
// Note: getMediaMetadataFromActiveRequest uses mutexes and conditions from activeRemoteRequests // Note: getMediaMetadataFromActiveRequest uses mutexes and conditions from activeRemoteRequests
mediaMetadata, resErr := r.getMediaMetadataFromActiveRequest(activeRemoteRequests) mediaMetadata, resErr := r.getMediaMetadataFromActiveRequest(activeRemoteRequests)
if resErr != nil { if resErr != nil {
@ -397,8 +425,7 @@ func (r *downloadRequest) getRemoteFile(cfg *config.Dendrite, db *storage.Databa
defer func() { defer func() {
// Note: errorResponse is the named return variable so we wrap this in a closure to re-evaluate the arguments at defer-time // Note: errorResponse is the named return variable so we wrap this in a closure to re-evaluate the arguments at defer-time
if err := recover(); err != nil { if err := recover(); err != nil {
resErr := jsonerror.InternalServerError() r.broadcastMediaMetadata(activeRemoteRequests, errors.New("paniced"))
r.broadcastMediaMetadata(activeRemoteRequests, &resErr)
panic(err) panic(err)
} }
r.broadcastMediaMetadata(activeRemoteRequests, errorResponse) r.broadcastMediaMetadata(activeRemoteRequests, errorResponse)
@ -407,26 +434,24 @@ func (r *downloadRequest) getRemoteFile(cfg *config.Dendrite, db *storage.Databa
// check if we have a record of the media in our database // check if we have a record of the media in our database
mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin) mediaMetadata, err := db.GetMediaMetadata(r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
if err != nil { if err != nil {
r.Logger.WithError(err).Error("Error querying the database.") return errors.Wrap(err, "error querying the database.")
resErr := jsonerror.InternalServerError()
return &resErr
} }
if mediaMetadata == nil { if mediaMetadata == nil {
// If we do not have a record, we need to fetch the remote file first and then respond from the local file // If we do not have a record, we need to fetch the remote file first and then respond from the local file
resErr := r.fetchRemoteFileAndStoreMetadata(cfg.Media.AbsBasePath, *cfg.Media.MaxFileSizeBytes, db, cfg.Media.ThumbnailSizes, activeThumbnailGeneration, cfg.Media.MaxThumbnailGenerators) err := r.fetchRemoteFileAndStoreMetadata(cfg.Media.AbsBasePath, *cfg.Media.MaxFileSizeBytes, db, cfg.Media.ThumbnailSizes, activeThumbnailGeneration, cfg.Media.MaxThumbnailGenerators)
if resErr != nil { if err != nil {
return resErr return errors.Wrap(err, "error querying the database.")
} }
} else { } else {
// If we have a record, we can respond from the local file // If we have a record, we can respond from the local file
r.MediaMetadata = mediaMetadata r.MediaMetadata = mediaMetadata
} }
} }
return return nil
} }
func (r *downloadRequest) getMediaMetadataFromActiveRequest(activeRemoteRequests *types.ActiveRemoteRequests) (*types.MediaMetadata, *util.JSONResponse) { func (r *downloadRequest) getMediaMetadataFromActiveRequest(activeRemoteRequests *types.ActiveRemoteRequests) (*types.MediaMetadata, error) {
// Check if there is an active remote request for the file // Check if there is an active remote request for the file
mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID)
@ -438,15 +463,12 @@ func (r *downloadRequest) getMediaMetadataFromActiveRequest(activeRemoteRequests
// NOTE: Wait unlocks and locks again internally. There is still a deferred Unlock() that will unlock this. // NOTE: Wait unlocks and locks again internally. There is still a deferred Unlock() that will unlock this.
activeRemoteRequestResult.Cond.Wait() activeRemoteRequestResult.Cond.Wait()
if activeRemoteRequestResult.ErrorResponse != nil { if activeRemoteRequestResult.Error != nil {
return nil, activeRemoteRequestResult.ErrorResponse return nil, activeRemoteRequestResult.Error
} }
if activeRemoteRequestResult.MediaMetadata == nil { if activeRemoteRequestResult.MediaMetadata == nil {
return nil, &util.JSONResponse{ return nil, nil
Code: 404,
JSON: jsonerror.NotFound("File not found."),
}
} }
return activeRemoteRequestResult.MediaMetadata, nil return activeRemoteRequestResult.MediaMetadata, nil
@ -462,24 +484,31 @@ func (r *downloadRequest) getMediaMetadataFromActiveRequest(activeRemoteRequests
// broadcastMediaMetadata broadcasts the media metadata and error response to waiting goroutines // broadcastMediaMetadata broadcasts the media metadata and error response to waiting goroutines
// Only the owner of the activeRemoteRequestResult for this origin and media ID should call this function. // Only the owner of the activeRemoteRequestResult for this origin and media ID should call this function.
func (r *downloadRequest) broadcastMediaMetadata(activeRemoteRequests *types.ActiveRemoteRequests, errorResponse *util.JSONResponse) { func (r *downloadRequest) broadcastMediaMetadata(activeRemoteRequests *types.ActiveRemoteRequests, err error) {
activeRemoteRequests.Lock() activeRemoteRequests.Lock()
defer activeRemoteRequests.Unlock() defer activeRemoteRequests.Unlock()
mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID) mxcURL := "mxc://" + string(r.MediaMetadata.Origin) + "/" + string(r.MediaMetadata.MediaID)
if activeRemoteRequestResult, ok := activeRemoteRequests.MXCToResult[mxcURL]; ok { if activeRemoteRequestResult, ok := activeRemoteRequests.MXCToResult[mxcURL]; ok {
r.Logger.Info("Signalling other goroutines waiting for this goroutine to fetch the file.") r.Logger.Info("Signalling other goroutines waiting for this goroutine to fetch the file.")
activeRemoteRequestResult.MediaMetadata = r.MediaMetadata activeRemoteRequestResult.MediaMetadata = r.MediaMetadata
activeRemoteRequestResult.ErrorResponse = errorResponse activeRemoteRequestResult.Error = err
activeRemoteRequestResult.Cond.Broadcast() activeRemoteRequestResult.Cond.Broadcast()
} }
delete(activeRemoteRequests.MXCToResult, mxcURL) delete(activeRemoteRequests.MXCToResult, mxcURL)
} }
// fetchRemoteFileAndStoreMetadata fetches the file from the remote server and stores its metadata in the database // fetchRemoteFileAndStoreMetadata fetches the file from the remote server and stores its metadata in the database
func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes, db *storage.Database, thumbnailSizes []config.ThumbnailSize, activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int) *util.JSONResponse { func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
finalPath, duplicate, resErr := r.fetchRemoteFile(absBasePath, maxFileSizeBytes) absBasePath config.Path,
if resErr != nil { maxFileSizeBytes config.FileSizeBytes,
return resErr db *storage.Database,
thumbnailSizes []config.ThumbnailSize,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
maxThumbnailGenerators int,
) error {
finalPath, duplicate, err := r.fetchRemoteFile(absBasePath, maxFileSizeBytes)
if err != nil {
return err
} }
r.Logger.WithFields(log.Fields{ r.Logger.WithFields(log.Fields{
@ -500,8 +529,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(absBasePath config.Pat
} }
// NOTE: It should really not be possible to fail the uniqueness test here so // NOTE: It should really not be possible to fail the uniqueness test here so
// there is no need to handle that separately // there is no need to handle that separately
resErr := jsonerror.InternalServerError() return errors.New("failed to store file metadata in DB")
return &resErr
} }
go func() { go func() {
@ -524,13 +552,17 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(absBasePath config.Pat
return nil return nil
} }
func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes) (types.Path, bool, *util.JSONResponse) { func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes) (types.Path, bool, error) {
r.Logger.Info("Fetching remote file") r.Logger.Info("Fetching remote file")
// create request for remote file // create request for remote file
resp, resErr := r.createRemoteRequest() resp, err := r.createRemoteRequest()
if resErr != nil { if err != nil {
return "", false, resErr return "", false, err
}
if resp == nil {
// Remote file not found
return "", false, nil
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -538,16 +570,11 @@ func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBy
contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) contentLength, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
if err != nil { if err != nil {
r.Logger.WithError(err).Warn("Failed to parse content length") r.Logger.WithError(err).Warn("Failed to parse content length")
return "", false, &util.JSONResponse{ return "", false, errors.Wrap(err, "invalid response from remote server")
Code: 502,
JSON: jsonerror.Unknown("Invalid response from remote server"),
}
} }
if contentLength > int64(maxFileSizeBytes) { if contentLength > int64(maxFileSizeBytes) {
return "", false, &util.JSONResponse{ // TODO: Bubble up this as a 413
Code: 413, return "", false, fmt.Errorf("remote file is too large (%v > %v bytes)", contentLength, maxFileSizeBytes)
JSON: jsonerror.Unknown(fmt.Sprintf("Remote file is too large (%v > %v bytes)", contentLength, maxFileSizeBytes)),
}
} }
r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength) r.MediaMetadata.FileSizeBytes = types.FileSizeBytes(contentLength)
r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type")) r.MediaMetadata.ContentType = types.ContentType(resp.Header.Get("Content-Type"))
@ -568,10 +595,7 @@ func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBy
"MaxFileSizeBytes": maxFileSizeBytes, "MaxFileSizeBytes": maxFileSizeBytes,
}).Warn("Error while downloading file from remote server") }).Warn("Error while downloading file from remote server")
fileutils.RemoveDir(tmpDir, r.Logger) fileutils.RemoveDir(tmpDir, r.Logger)
return "", false, &util.JSONResponse{ return "", false, errors.New("file could not be downloaded from remote server")
Code: 502,
JSON: jsonerror.Unknown("File could not be downloaded from remote server"),
}
} }
r.Logger.Info("Remote file transferred") r.Logger.Info("Remote file transferred")
@ -585,9 +609,7 @@ func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBy
// The database is the source of truth so we need to have moved the file first // The database is the source of truth so we need to have moved the file first
finalPath, duplicate, err := fileutils.MoveFileWithHashCheck(tmpDir, r.MediaMetadata, absBasePath, r.Logger) finalPath, duplicate, err := fileutils.MoveFileWithHashCheck(tmpDir, r.MediaMetadata, absBasePath, r.Logger)
if err != nil { if err != nil {
r.Logger.WithError(err).Error("Failed to move file.") return "", false, errors.Wrap(err, "failed to move file")
resErr := jsonerror.InternalServerError()
return "", false, &resErr
} }
if duplicate { if duplicate {
r.Logger.WithField("dst", finalPath).Info("File was stored previously - discarding duplicate") r.Logger.WithField("dst", finalPath).Info("File was stored previously - discarding duplicate")
@ -597,32 +619,22 @@ func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBy
return types.Path(finalPath), duplicate, nil return types.Path(finalPath), duplicate, nil
} }
func (r *downloadRequest) createRemoteRequest() (*http.Response, *util.JSONResponse) { func (r *downloadRequest) createRemoteRequest() (*http.Response, error) {
matrixClient := gomatrixserverlib.NewClient() matrixClient := gomatrixserverlib.NewClient()
resp, err := matrixClient.CreateMediaDownloadRequest(r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) resp, err := matrixClient.CreateMediaDownloadRequest(r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
if err != nil { if err != nil {
r.Logger.WithError(err).Error("Failed to create download request") return nil, fmt.Errorf("file with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
return nil, &util.JSONResponse{
Code: 502,
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
}
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
if resp.StatusCode == 404 { if resp.StatusCode == 404 {
return nil, &util.JSONResponse{ return nil, nil
Code: 404,
JSON: jsonerror.NotFound(fmt.Sprintf("File with media ID %q does not exist", r.MediaMetadata.MediaID)),
}
} }
r.Logger.WithFields(log.Fields{ r.Logger.WithFields(log.Fields{
"StatusCode": resp.StatusCode, "StatusCode": resp.StatusCode,
}).Warn("Received error response") }).Warn("Received error response")
return nil, &util.JSONResponse{ return nil, fmt.Errorf("file with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
Code: 502,
JSON: jsonerror.Unknown(fmt.Sprintf("File with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)),
}
} }
return resp, nil return resp, nil

6
vendor/manifest vendored
View File

@ -170,6 +170,12 @@
"branch": "master", "branch": "master",
"path": "/xxHash32" "path": "/xxHash32"
}, },
{
"importpath": "github.com/pkg/errors",
"repository": "https://github.com/pkg/errors",
"revision": "c605e284fe17294bda444b34710735b29d1a9d90",
"branch": "master"
},
{ {
"importpath": "github.com/pmezard/go-difflib/difflib", "importpath": "github.com/pmezard/go-difflib/difflib",
"repository": "https://github.com/pmezard/go-difflib", "repository": "https://github.com/pmezard/go-difflib",

View File

@ -0,0 +1,23 @@
Copyright (c) 2015, Dave Cheney <dave@cheney.net>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,52 @@
# errors [![Travis-CI](https://travis-ci.org/pkg/errors.svg)](https://travis-ci.org/pkg/errors) [![AppVeyor](https://ci.appveyor.com/api/projects/status/b98mptawhudj53ep/branch/master?svg=true)](https://ci.appveyor.com/project/davecheney/errors/branch/master) [![GoDoc](https://godoc.org/github.com/pkg/errors?status.svg)](http://godoc.org/github.com/pkg/errors) [![Report card](https://goreportcard.com/badge/github.com/pkg/errors)](https://goreportcard.com/report/github.com/pkg/errors)
Package errors provides simple error handling primitives.
`go get github.com/pkg/errors`
The traditional error handling idiom in Go is roughly akin to
```go
if err != nil {
return err
}
```
which applied recursively up the call stack results in error reports without context or debugging information. The errors package allows programmers to add context to the failure path in their code in a way that does not destroy the original value of the error.
## Adding context to an error
The errors.Wrap function returns a new error that adds context to the original error. For example
```go
_, err := ioutil.ReadAll(r)
if err != nil {
return errors.Wrap(err, "read failed")
}
```
## Retrieving the cause of an error
Using `errors.Wrap` constructs a stack of errors, adding context to the preceding error. Depending on the nature of the error it may be necessary to reverse the operation of errors.Wrap to retrieve the original error for inspection. Any error value which implements this interface can be inspected by `errors.Cause`.
```go
type causer interface {
Cause() error
}
```
`errors.Cause` will recursively retrieve the topmost error which does not implement `causer`, which is assumed to be the original cause. For example:
```go
switch err := errors.Cause(err).(type) {
case *MyError:
// handle specifically
default:
// unknown error
}
```
[Read the package documentation for more information](https://godoc.org/github.com/pkg/errors).
## Contributing
We welcome pull requests, bug fixes and issue reports. With that said, the bar for adding new symbols to this package is intentionally set high.
Before proposing a change, please discuss your change by raising an issue.
## Licence
BSD-2-Clause

View File

@ -0,0 +1,32 @@
version: build-{build}.{branch}
clone_folder: C:\gopath\src\github.com\pkg\errors
shallow_clone: true # for startup speed
environment:
GOPATH: C:\gopath
platform:
- x64
# http://www.appveyor.com/docs/installed-software
install:
# some helpful output for debugging builds
- go version
- go env
# pre-installed MinGW at C:\MinGW is 32bit only
# but MSYS2 at C:\msys64 has mingw64
- set PATH=C:\msys64\mingw64\bin;%PATH%
- gcc --version
- g++ --version
build_script:
- go install -v ./...
test_script:
- set PATH=C:\gopath\bin;%PATH%
- go test -v ./...
#artifacts:
# - path: '%GOPATH%\bin\*.exe'
deploy: off

View File

@ -0,0 +1,63 @@
// +build go1.7
package errors
import (
"fmt"
"testing"
stderrors "errors"
)
func noErrors(at, depth int) error {
if at >= depth {
return stderrors.New("no error")
}
return noErrors(at+1, depth)
}
func yesErrors(at, depth int) error {
if at >= depth {
return New("ye error")
}
return yesErrors(at+1, depth)
}
// GlobalE is an exported global to store the result of benchmark results,
// preventing the compiler from optimising the benchmark functions away.
var GlobalE error
func BenchmarkErrors(b *testing.B) {
type run struct {
stack int
std bool
}
runs := []run{
{10, false},
{10, true},
{100, false},
{100, true},
{1000, false},
{1000, true},
}
for _, r := range runs {
part := "pkg/errors"
if r.std {
part = "errors"
}
name := fmt.Sprintf("%s-stack-%d", part, r.stack)
b.Run(name, func(b *testing.B) {
var err error
f := yesErrors
if r.std {
f = noErrors
}
b.ReportAllocs()
for i := 0; i < b.N; i++ {
err = f(0, r.stack)
}
b.StopTimer()
GlobalE = err
})
}
}

View File

@ -0,0 +1,269 @@
// Package errors provides simple error handling primitives.
//
// The traditional error handling idiom in Go is roughly akin to
//
// if err != nil {
// return err
// }
//
// which applied recursively up the call stack results in error reports
// without context or debugging information. The errors package allows
// programmers to add context to the failure path in their code in a way
// that does not destroy the original value of the error.
//
// Adding context to an error
//
// The errors.Wrap function returns a new error that adds context to the
// original error by recording a stack trace at the point Wrap is called,
// and the supplied message. For example
//
// _, err := ioutil.ReadAll(r)
// if err != nil {
// return errors.Wrap(err, "read failed")
// }
//
// If additional control is required the errors.WithStack and errors.WithMessage
// functions destructure errors.Wrap into its component operations of annotating
// an error with a stack trace and an a message, respectively.
//
// Retrieving the cause of an error
//
// Using errors.Wrap constructs a stack of errors, adding context to the
// preceding error. Depending on the nature of the error it may be necessary
// to reverse the operation of errors.Wrap to retrieve the original error
// for inspection. Any error value which implements this interface
//
// type causer interface {
// Cause() error
// }
//
// can be inspected by errors.Cause. errors.Cause will recursively retrieve
// the topmost error which does not implement causer, which is assumed to be
// the original cause. For example:
//
// switch err := errors.Cause(err).(type) {
// case *MyError:
// // handle specifically
// default:
// // unknown error
// }
//
// causer interface is not exported by this package, but is considered a part
// of stable public API.
//
// Formatted printing of errors
//
// All error values returned from this package implement fmt.Formatter and can
// be formatted by the fmt package. The following verbs are supported
//
// %s print the error. If the error has a Cause it will be
// printed recursively
// %v see %s
// %+v extended format. Each Frame of the error's StackTrace will
// be printed in detail.
//
// Retrieving the stack trace of an error or wrapper
//
// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are
// invoked. This information can be retrieved with the following interface.
//
// type stackTracer interface {
// StackTrace() errors.StackTrace
// }
//
// Where errors.StackTrace is defined as
//
// type StackTrace []Frame
//
// The Frame type represents a call site in the stack trace. Frame supports
// the fmt.Formatter interface that can be used for printing information about
// the stack trace of this error. For example:
//
// if err, ok := err.(stackTracer); ok {
// for _, f := range err.StackTrace() {
// fmt.Printf("%+s:%d", f)
// }
// }
//
// stackTracer interface is not exported by this package, but is considered a part
// of stable public API.
//
// See the documentation for Frame.Format for more details.
package errors
import (
"fmt"
"io"
)
// New returns an error with the supplied message.
// New also records the stack trace at the point it was called.
func New(message string) error {
return &fundamental{
msg: message,
stack: callers(),
}
}
// Errorf formats according to a format specifier and returns the string
// as a value that satisfies error.
// Errorf also records the stack trace at the point it was called.
func Errorf(format string, args ...interface{}) error {
return &fundamental{
msg: fmt.Sprintf(format, args...),
stack: callers(),
}
}
// fundamental is an error that has a message and a stack, but no caller.
type fundamental struct {
msg string
*stack
}
func (f *fundamental) Error() string { return f.msg }
func (f *fundamental) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
io.WriteString(s, f.msg)
f.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, f.msg)
case 'q':
fmt.Fprintf(s, "%q", f.msg)
}
}
// WithStack annotates err with a stack trace at the point WithStack was called.
// If err is nil, WithStack returns nil.
func WithStack(err error) error {
if err == nil {
return nil
}
return &withStack{
err,
callers(),
}
}
type withStack struct {
error
*stack
}
func (w *withStack) Cause() error { return w.error }
func (w *withStack) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v", w.Cause())
w.stack.Format(s, verb)
return
}
fallthrough
case 's':
io.WriteString(s, w.Error())
case 'q':
fmt.Fprintf(s, "%q", w.Error())
}
}
// Wrap returns an error annotating err with a stack trace
// at the point Wrap is called, and the supplied message.
// If err is nil, Wrap returns nil.
func Wrap(err error, message string) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: message,
}
return &withStack{
err,
callers(),
}
}
// Wrapf returns an error annotating err with a stack trace
// at the point Wrapf is call, and the format specifier.
// If err is nil, Wrapf returns nil.
func Wrapf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
err = &withMessage{
cause: err,
msg: fmt.Sprintf(format, args...),
}
return &withStack{
err,
callers(),
}
}
// WithMessage annotates err with a new message.
// If err is nil, WithMessage returns nil.
func WithMessage(err error, message string) error {
if err == nil {
return nil
}
return &withMessage{
cause: err,
msg: message,
}
}
type withMessage struct {
cause error
msg string
}
func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() }
func (w *withMessage) Cause() error { return w.cause }
func (w *withMessage) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
if s.Flag('+') {
fmt.Fprintf(s, "%+v\n", w.Cause())
io.WriteString(s, w.msg)
return
}
fallthrough
case 's', 'q':
io.WriteString(s, w.Error())
}
}
// Cause returns the underlying cause of the error, if possible.
// An error value has a cause if it implements the following
// interface:
//
// type causer interface {
// Cause() error
// }
//
// If the error does not implement Cause, the original error will
// be returned. If the error is nil, nil will be returned without further
// investigation.
func Cause(err error) error {
type causer interface {
Cause() error
}
for err != nil {
cause, ok := err.(causer)
if !ok {
break
}
err = cause.Cause()
}
return err
}

View File

@ -0,0 +1,225 @@
package errors
import (
"errors"
"fmt"
"io"
"reflect"
"testing"
)
func TestNew(t *testing.T) {
tests := []struct {
err string
want error
}{
{"", fmt.Errorf("")},
{"foo", fmt.Errorf("foo")},
{"foo", New("foo")},
{"string with format specifiers: %v", errors.New("string with format specifiers: %v")},
}
for _, tt := range tests {
got := New(tt.err)
if got.Error() != tt.want.Error() {
t.Errorf("New.Error(): got: %q, want %q", got, tt.want)
}
}
}
func TestWrapNil(t *testing.T) {
got := Wrap(nil, "no error")
if got != nil {
t.Errorf("Wrap(nil, \"no error\"): got %#v, expected nil", got)
}
}
func TestWrap(t *testing.T) {
tests := []struct {
err error
message string
want string
}{
{io.EOF, "read error", "read error: EOF"},
{Wrap(io.EOF, "read error"), "client error", "client error: read error: EOF"},
}
for _, tt := range tests {
got := Wrap(tt.err, tt.message).Error()
if got != tt.want {
t.Errorf("Wrap(%v, %q): got: %v, want %v", tt.err, tt.message, got, tt.want)
}
}
}
type nilError struct{}
func (nilError) Error() string { return "nil error" }
func TestCause(t *testing.T) {
x := New("error")
tests := []struct {
err error
want error
}{{
// nil error is nil
err: nil,
want: nil,
}, {
// explicit nil error is nil
err: (error)(nil),
want: nil,
}, {
// typed nil is nil
err: (*nilError)(nil),
want: (*nilError)(nil),
}, {
// uncaused error is unaffected
err: io.EOF,
want: io.EOF,
}, {
// caused error returns cause
err: Wrap(io.EOF, "ignored"),
want: io.EOF,
}, {
err: x, // return from errors.New
want: x,
}, {
WithMessage(nil, "whoops"),
nil,
}, {
WithMessage(io.EOF, "whoops"),
io.EOF,
}, {
WithStack(nil),
nil,
}, {
WithStack(io.EOF),
io.EOF,
}}
for i, tt := range tests {
got := Cause(tt.err)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("test %d: got %#v, want %#v", i+1, got, tt.want)
}
}
}
func TestWrapfNil(t *testing.T) {
got := Wrapf(nil, "no error")
if got != nil {
t.Errorf("Wrapf(nil, \"no error\"): got %#v, expected nil", got)
}
}
func TestWrapf(t *testing.T) {
tests := []struct {
err error
message string
want string
}{
{io.EOF, "read error", "read error: EOF"},
{Wrapf(io.EOF, "read error without format specifiers"), "client error", "client error: read error without format specifiers: EOF"},
{Wrapf(io.EOF, "read error with %d format specifier", 1), "client error", "client error: read error with 1 format specifier: EOF"},
}
for _, tt := range tests {
got := Wrapf(tt.err, tt.message).Error()
if got != tt.want {
t.Errorf("Wrapf(%v, %q): got: %v, want %v", tt.err, tt.message, got, tt.want)
}
}
}
func TestErrorf(t *testing.T) {
tests := []struct {
err error
want string
}{
{Errorf("read error without format specifiers"), "read error without format specifiers"},
{Errorf("read error with %d format specifier", 1), "read error with 1 format specifier"},
}
for _, tt := range tests {
got := tt.err.Error()
if got != tt.want {
t.Errorf("Errorf(%v): got: %q, want %q", tt.err, got, tt.want)
}
}
}
func TestWithStackNil(t *testing.T) {
got := WithStack(nil)
if got != nil {
t.Errorf("WithStack(nil): got %#v, expected nil", got)
}
}
func TestWithStack(t *testing.T) {
tests := []struct {
err error
want string
}{
{io.EOF, "EOF"},
{WithStack(io.EOF), "EOF"},
}
for _, tt := range tests {
got := WithStack(tt.err).Error()
if got != tt.want {
t.Errorf("WithStack(%v): got: %v, want %v", tt.err, got, tt.want)
}
}
}
func TestWithMessageNil(t *testing.T) {
got := WithMessage(nil, "no error")
if got != nil {
t.Errorf("WithMessage(nil, \"no error\"): got %#v, expected nil", got)
}
}
func TestWithMessage(t *testing.T) {
tests := []struct {
err error
message string
want string
}{
{io.EOF, "read error", "read error: EOF"},
{WithMessage(io.EOF, "read error"), "client error", "client error: read error: EOF"},
}
for _, tt := range tests {
got := WithMessage(tt.err, tt.message).Error()
if got != tt.want {
t.Errorf("WithMessage(%v, %q): got: %q, want %q", tt.err, tt.message, got, tt.want)
}
}
}
// errors.New, etc values are not expected to be compared by value
// but the change in errors#27 made them incomparable. Assert that
// various kinds of errors have a functional equality operator, even
// if the result of that equality is always false.
func TestErrorEquality(t *testing.T) {
vals := []error{
nil,
io.EOF,
errors.New("EOF"),
New("EOF"),
Errorf("EOF"),
Wrap(io.EOF, "EOF"),
Wrapf(io.EOF, "EOF%d", 2),
WithMessage(nil, "whoops"),
WithMessage(io.EOF, "whoops"),
WithStack(io.EOF),
WithStack(nil),
}
for i := range vals {
for j := range vals {
_ = vals[i] == vals[j] // mustn't panic
}
}
}

View File

@ -0,0 +1,205 @@
package errors_test
import (
"fmt"
"github.com/pkg/errors"
)
func ExampleNew() {
err := errors.New("whoops")
fmt.Println(err)
// Output: whoops
}
func ExampleNew_printf() {
err := errors.New("whoops")
fmt.Printf("%+v", err)
// Example output:
// whoops
// github.com/pkg/errors_test.ExampleNew_printf
// /home/dfc/src/github.com/pkg/errors/example_test.go:17
// testing.runExample
// /home/dfc/go/src/testing/example.go:114
// testing.RunExamples
// /home/dfc/go/src/testing/example.go:38
// testing.(*M).Run
// /home/dfc/go/src/testing/testing.go:744
// main.main
// /github.com/pkg/errors/_test/_testmain.go:106
// runtime.main
// /home/dfc/go/src/runtime/proc.go:183
// runtime.goexit
// /home/dfc/go/src/runtime/asm_amd64.s:2059
}
func ExampleWithMessage() {
cause := errors.New("whoops")
err := errors.WithMessage(cause, "oh noes")
fmt.Println(err)
// Output: oh noes: whoops
}
func ExampleWithStack() {
cause := errors.New("whoops")
err := errors.WithStack(cause)
fmt.Println(err)
// Output: whoops
}
func ExampleWithStack_printf() {
cause := errors.New("whoops")
err := errors.WithStack(cause)
fmt.Printf("%+v", err)
// Example Output:
// whoops
// github.com/pkg/errors_test.ExampleWithStack_printf
// /home/fabstu/go/src/github.com/pkg/errors/example_test.go:55
// testing.runExample
// /usr/lib/go/src/testing/example.go:114
// testing.RunExamples
// /usr/lib/go/src/testing/example.go:38
// testing.(*M).Run
// /usr/lib/go/src/testing/testing.go:744
// main.main
// github.com/pkg/errors/_test/_testmain.go:106
// runtime.main
// /usr/lib/go/src/runtime/proc.go:183
// runtime.goexit
// /usr/lib/go/src/runtime/asm_amd64.s:2086
// github.com/pkg/errors_test.ExampleWithStack_printf
// /home/fabstu/go/src/github.com/pkg/errors/example_test.go:56
// testing.runExample
// /usr/lib/go/src/testing/example.go:114
// testing.RunExamples
// /usr/lib/go/src/testing/example.go:38
// testing.(*M).Run
// /usr/lib/go/src/testing/testing.go:744
// main.main
// github.com/pkg/errors/_test/_testmain.go:106
// runtime.main
// /usr/lib/go/src/runtime/proc.go:183
// runtime.goexit
// /usr/lib/go/src/runtime/asm_amd64.s:2086
}
func ExampleWrap() {
cause := errors.New("whoops")
err := errors.Wrap(cause, "oh noes")
fmt.Println(err)
// Output: oh noes: whoops
}
func fn() error {
e1 := errors.New("error")
e2 := errors.Wrap(e1, "inner")
e3 := errors.Wrap(e2, "middle")
return errors.Wrap(e3, "outer")
}
func ExampleCause() {
err := fn()
fmt.Println(err)
fmt.Println(errors.Cause(err))
// Output: outer: middle: inner: error
// error
}
func ExampleWrap_extended() {
err := fn()
fmt.Printf("%+v\n", err)
// Example output:
// error
// github.com/pkg/errors_test.fn
// /home/dfc/src/github.com/pkg/errors/example_test.go:47
// github.com/pkg/errors_test.ExampleCause_printf
// /home/dfc/src/github.com/pkg/errors/example_test.go:63
// testing.runExample
// /home/dfc/go/src/testing/example.go:114
// testing.RunExamples
// /home/dfc/go/src/testing/example.go:38
// testing.(*M).Run
// /home/dfc/go/src/testing/testing.go:744
// main.main
// /github.com/pkg/errors/_test/_testmain.go:104
// runtime.main
// /home/dfc/go/src/runtime/proc.go:183
// runtime.goexit
// /home/dfc/go/src/runtime/asm_amd64.s:2059
// github.com/pkg/errors_test.fn
// /home/dfc/src/github.com/pkg/errors/example_test.go:48: inner
// github.com/pkg/errors_test.fn
// /home/dfc/src/github.com/pkg/errors/example_test.go:49: middle
// github.com/pkg/errors_test.fn
// /home/dfc/src/github.com/pkg/errors/example_test.go:50: outer
}
func ExampleWrapf() {
cause := errors.New("whoops")
err := errors.Wrapf(cause, "oh noes #%d", 2)
fmt.Println(err)
// Output: oh noes #2: whoops
}
func ExampleErrorf_extended() {
err := errors.Errorf("whoops: %s", "foo")
fmt.Printf("%+v", err)
// Example output:
// whoops: foo
// github.com/pkg/errors_test.ExampleErrorf
// /home/dfc/src/github.com/pkg/errors/example_test.go:101
// testing.runExample
// /home/dfc/go/src/testing/example.go:114
// testing.RunExamples
// /home/dfc/go/src/testing/example.go:38
// testing.(*M).Run
// /home/dfc/go/src/testing/testing.go:744
// main.main
// /github.com/pkg/errors/_test/_testmain.go:102
// runtime.main
// /home/dfc/go/src/runtime/proc.go:183
// runtime.goexit
// /home/dfc/go/src/runtime/asm_amd64.s:2059
}
func Example_stackTrace() {
type stackTracer interface {
StackTrace() errors.StackTrace
}
err, ok := errors.Cause(fn()).(stackTracer)
if !ok {
panic("oops, err does not implement stackTracer")
}
st := err.StackTrace()
fmt.Printf("%+v", st[0:2]) // top two frames
// Example output:
// github.com/pkg/errors_test.fn
// /home/dfc/src/github.com/pkg/errors/example_test.go:47
// github.com/pkg/errors_test.Example_stackTrace
// /home/dfc/src/github.com/pkg/errors/example_test.go:127
}
func ExampleCause_printf() {
err := errors.Wrap(func() error {
return func() error {
return errors.Errorf("hello %s", fmt.Sprintf("world"))
}()
}(), "failed")
fmt.Printf("%v", err)
// Output: failed: hello world
}

View File

@ -0,0 +1,535 @@
package errors
import (
"errors"
"fmt"
"io"
"regexp"
"strings"
"testing"
)
func TestFormatNew(t *testing.T) {
tests := []struct {
error
format string
want string
}{{
New("error"),
"%s",
"error",
}, {
New("error"),
"%v",
"error",
}, {
New("error"),
"%+v",
"error\n" +
"github.com/pkg/errors.TestFormatNew\n" +
"\t.+/github.com/pkg/errors/format_test.go:26",
}, {
New("error"),
"%q",
`"error"`,
}}
for i, tt := range tests {
testFormatRegexp(t, i, tt.error, tt.format, tt.want)
}
}
func TestFormatErrorf(t *testing.T) {
tests := []struct {
error
format string
want string
}{{
Errorf("%s", "error"),
"%s",
"error",
}, {
Errorf("%s", "error"),
"%v",
"error",
}, {
Errorf("%s", "error"),
"%+v",
"error\n" +
"github.com/pkg/errors.TestFormatErrorf\n" +
"\t.+/github.com/pkg/errors/format_test.go:56",
}}
for i, tt := range tests {
testFormatRegexp(t, i, tt.error, tt.format, tt.want)
}
}
func TestFormatWrap(t *testing.T) {
tests := []struct {
error
format string
want string
}{{
Wrap(New("error"), "error2"),
"%s",
"error2: error",
}, {
Wrap(New("error"), "error2"),
"%v",
"error2: error",
}, {
Wrap(New("error"), "error2"),
"%+v",
"error\n" +
"github.com/pkg/errors.TestFormatWrap\n" +
"\t.+/github.com/pkg/errors/format_test.go:82",
}, {
Wrap(io.EOF, "error"),
"%s",
"error: EOF",
}, {
Wrap(io.EOF, "error"),
"%v",
"error: EOF",
}, {
Wrap(io.EOF, "error"),
"%+v",
"EOF\n" +
"error\n" +
"github.com/pkg/errors.TestFormatWrap\n" +
"\t.+/github.com/pkg/errors/format_test.go:96",
}, {
Wrap(Wrap(io.EOF, "error1"), "error2"),
"%+v",
"EOF\n" +
"error1\n" +
"github.com/pkg/errors.TestFormatWrap\n" +
"\t.+/github.com/pkg/errors/format_test.go:103\n",
}, {
Wrap(New("error with space"), "context"),
"%q",
`"context: error with space"`,
}}
for i, tt := range tests {
testFormatRegexp(t, i, tt.error, tt.format, tt.want)
}
}
func TestFormatWrapf(t *testing.T) {
tests := []struct {
error
format string
want string
}{{
Wrapf(io.EOF, "error%d", 2),
"%s",
"error2: EOF",
}, {
Wrapf(io.EOF, "error%d", 2),
"%v",
"error2: EOF",
}, {
Wrapf(io.EOF, "error%d", 2),
"%+v",
"EOF\n" +
"error2\n" +
"github.com/pkg/errors.TestFormatWrapf\n" +
"\t.+/github.com/pkg/errors/format_test.go:134",
}, {
Wrapf(New("error"), "error%d", 2),
"%s",
"error2: error",
}, {
Wrapf(New("error"), "error%d", 2),
"%v",
"error2: error",
}, {
Wrapf(New("error"), "error%d", 2),
"%+v",
"error\n" +
"github.com/pkg/errors.TestFormatWrapf\n" +
"\t.+/github.com/pkg/errors/format_test.go:149",
}}
for i, tt := range tests {
testFormatRegexp(t, i, tt.error, tt.format, tt.want)
}
}
func TestFormatWithStack(t *testing.T) {
tests := []struct {
error
format string
want []string
}{{
WithStack(io.EOF),
"%s",
[]string{"EOF"},
}, {
WithStack(io.EOF),
"%v",
[]string{"EOF"},
}, {
WithStack(io.EOF),
"%+v",
[]string{"EOF",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:175"},
}, {
WithStack(New("error")),
"%s",
[]string{"error"},
}, {
WithStack(New("error")),
"%v",
[]string{"error"},
}, {
WithStack(New("error")),
"%+v",
[]string{"error",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:189",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:189"},
}, {
WithStack(WithStack(io.EOF)),
"%+v",
[]string{"EOF",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:197",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:197"},
}, {
WithStack(WithStack(Wrapf(io.EOF, "message"))),
"%+v",
[]string{"EOF",
"message",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:205",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:205",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:205"},
}, {
WithStack(Errorf("error%d", 1)),
"%+v",
[]string{"error1",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:216",
"github.com/pkg/errors.TestFormatWithStack\n" +
"\t.+/github.com/pkg/errors/format_test.go:216"},
}}
for i, tt := range tests {
testFormatCompleteCompare(t, i, tt.error, tt.format, tt.want, true)
}
}
func TestFormatWithMessage(t *testing.T) {
tests := []struct {
error
format string
want []string
}{{
WithMessage(New("error"), "error2"),
"%s",
[]string{"error2: error"},
}, {
WithMessage(New("error"), "error2"),
"%v",
[]string{"error2: error"},
}, {
WithMessage(New("error"), "error2"),
"%+v",
[]string{
"error",
"github.com/pkg/errors.TestFormatWithMessage\n" +
"\t.+/github.com/pkg/errors/format_test.go:244",
"error2"},
}, {
WithMessage(io.EOF, "addition1"),
"%s",
[]string{"addition1: EOF"},
}, {
WithMessage(io.EOF, "addition1"),
"%v",
[]string{"addition1: EOF"},
}, {
WithMessage(io.EOF, "addition1"),
"%+v",
[]string{"EOF", "addition1"},
}, {
WithMessage(WithMessage(io.EOF, "addition1"), "addition2"),
"%v",
[]string{"addition2: addition1: EOF"},
}, {
WithMessage(WithMessage(io.EOF, "addition1"), "addition2"),
"%+v",
[]string{"EOF", "addition1", "addition2"},
}, {
Wrap(WithMessage(io.EOF, "error1"), "error2"),
"%+v",
[]string{"EOF", "error1", "error2",
"github.com/pkg/errors.TestFormatWithMessage\n" +
"\t.+/github.com/pkg/errors/format_test.go:272"},
}, {
WithMessage(Errorf("error%d", 1), "error2"),
"%+v",
[]string{"error1",
"github.com/pkg/errors.TestFormatWithMessage\n" +
"\t.+/github.com/pkg/errors/format_test.go:278",
"error2"},
}, {
WithMessage(WithStack(io.EOF), "error"),
"%+v",
[]string{
"EOF",
"github.com/pkg/errors.TestFormatWithMessage\n" +
"\t.+/github.com/pkg/errors/format_test.go:285",
"error"},
}, {
WithMessage(Wrap(WithStack(io.EOF), "inside-error"), "outside-error"),
"%+v",
[]string{
"EOF",
"github.com/pkg/errors.TestFormatWithMessage\n" +
"\t.+/github.com/pkg/errors/format_test.go:293",
"inside-error",
"github.com/pkg/errors.TestFormatWithMessage\n" +
"\t.+/github.com/pkg/errors/format_test.go:293",
"outside-error"},
}}
for i, tt := range tests {
testFormatCompleteCompare(t, i, tt.error, tt.format, tt.want, true)
}
}
func TestFormatGeneric(t *testing.T) {
starts := []struct {
err error
want []string
}{
{New("new-error"), []string{
"new-error",
"github.com/pkg/errors.TestFormatGeneric\n" +
"\t.+/github.com/pkg/errors/format_test.go:315"},
}, {Errorf("errorf-error"), []string{
"errorf-error",
"github.com/pkg/errors.TestFormatGeneric\n" +
"\t.+/github.com/pkg/errors/format_test.go:319"},
}, {errors.New("errors-new-error"), []string{
"errors-new-error"},
},
}
wrappers := []wrapper{
{
func(err error) error { return WithMessage(err, "with-message") },
[]string{"with-message"},
}, {
func(err error) error { return WithStack(err) },
[]string{
"github.com/pkg/errors.(func·002|TestFormatGeneric.func2)\n\t" +
".+/github.com/pkg/errors/format_test.go:333",
},
}, {
func(err error) error { return Wrap(err, "wrap-error") },
[]string{
"wrap-error",
"github.com/pkg/errors.(func·003|TestFormatGeneric.func3)\n\t" +
".+/github.com/pkg/errors/format_test.go:339",
},
}, {
func(err error) error { return Wrapf(err, "wrapf-error%d", 1) },
[]string{
"wrapf-error1",
"github.com/pkg/errors.(func·004|TestFormatGeneric.func4)\n\t" +
".+/github.com/pkg/errors/format_test.go:346",
},
},
}
for s := range starts {
err := starts[s].err
want := starts[s].want
testFormatCompleteCompare(t, s, err, "%+v", want, false)
testGenericRecursive(t, err, want, wrappers, 3)
}
}
func testFormatRegexp(t *testing.T, n int, arg interface{}, format, want string) {
got := fmt.Sprintf(format, arg)
gotLines := strings.SplitN(got, "\n", -1)
wantLines := strings.SplitN(want, "\n", -1)
if len(wantLines) > len(gotLines) {
t.Errorf("test %d: wantLines(%d) > gotLines(%d):\n got: %q\nwant: %q", n+1, len(wantLines), len(gotLines), got, want)
return
}
for i, w := range wantLines {
match, err := regexp.MatchString(w, gotLines[i])
if err != nil {
t.Fatal(err)
}
if !match {
t.Errorf("test %d: line %d: fmt.Sprintf(%q, err):\n got: %q\nwant: %q", n+1, i+1, format, got, want)
}
}
}
var stackLineR = regexp.MustCompile(`\.`)
// parseBlocks parses input into a slice, where:
// - incase entry contains a newline, its a stacktrace
// - incase entry contains no newline, its a solo line.
//
// Detecting stack boundaries only works incase the WithStack-calls are
// to be found on the same line, thats why it is optionally here.
//
// Example use:
//
// for _, e := range blocks {
// if strings.ContainsAny(e, "\n") {
// // Match as stack
// } else {
// // Match as line
// }
// }
//
func parseBlocks(input string, detectStackboundaries bool) ([]string, error) {
var blocks []string
stack := ""
wasStack := false
lines := map[string]bool{} // already found lines
for _, l := range strings.Split(input, "\n") {
isStackLine := stackLineR.MatchString(l)
switch {
case !isStackLine && wasStack:
blocks = append(blocks, stack, l)
stack = ""
lines = map[string]bool{}
case isStackLine:
if wasStack {
// Detecting two stacks after another, possible cause lines match in
// our tests due to WithStack(WithStack(io.EOF)) on same line.
if detectStackboundaries {
if lines[l] {
if len(stack) == 0 {
return nil, errors.New("len of block must not be zero here")
}
blocks = append(blocks, stack)
stack = l
lines = map[string]bool{l: true}
continue
}
}
stack = stack + "\n" + l
} else {
stack = l
}
lines[l] = true
case !isStackLine && !wasStack:
blocks = append(blocks, l)
default:
return nil, errors.New("must not happen")
}
wasStack = isStackLine
}
// Use up stack
if stack != "" {
blocks = append(blocks, stack)
}
return blocks, nil
}
func testFormatCompleteCompare(t *testing.T, n int, arg interface{}, format string, want []string, detectStackBoundaries bool) {
gotStr := fmt.Sprintf(format, arg)
got, err := parseBlocks(gotStr, detectStackBoundaries)
if err != nil {
t.Fatal(err)
}
if len(got) != len(want) {
t.Fatalf("test %d: fmt.Sprintf(%s, err) -> wrong number of blocks: got(%d) want(%d)\n got: %s\nwant: %s\ngotStr: %q",
n+1, format, len(got), len(want), prettyBlocks(got), prettyBlocks(want), gotStr)
}
for i := range got {
if strings.ContainsAny(want[i], "\n") {
// Match as stack
match, err := regexp.MatchString(want[i], got[i])
if err != nil {
t.Fatal(err)
}
if !match {
t.Fatalf("test %d: block %d: fmt.Sprintf(%q, err):\ngot:\n%q\nwant:\n%q\nall-got:\n%s\nall-want:\n%s\n",
n+1, i+1, format, got[i], want[i], prettyBlocks(got), prettyBlocks(want))
}
} else {
// Match as message
if got[i] != want[i] {
t.Fatalf("test %d: fmt.Sprintf(%s, err) at block %d got != want:\n got: %q\nwant: %q", n+1, format, i+1, got[i], want[i])
}
}
}
}
type wrapper struct {
wrap func(err error) error
want []string
}
func prettyBlocks(blocks []string, prefix ...string) string {
var out []string
for _, b := range blocks {
out = append(out, fmt.Sprintf("%v", b))
}
return " " + strings.Join(out, "\n ")
}
func testGenericRecursive(t *testing.T, beforeErr error, beforeWant []string, list []wrapper, maxDepth int) {
if len(beforeWant) == 0 {
panic("beforeWant must not be empty")
}
for _, w := range list {
if len(w.want) == 0 {
panic("want must not be empty")
}
err := w.wrap(beforeErr)
// Copy required cause append(beforeWant, ..) modified beforeWant subtly.
beforeCopy := make([]string, len(beforeWant))
copy(beforeCopy, beforeWant)
beforeWant := beforeCopy
last := len(beforeWant) - 1
var want []string
// Merge two stacks behind each other.
if strings.ContainsAny(beforeWant[last], "\n") && strings.ContainsAny(w.want[0], "\n") {
want = append(beforeWant[:last], append([]string{beforeWant[last] + "((?s).*)" + w.want[0]}, w.want[1:]...)...)
} else {
want = append(beforeWant, w.want...)
}
testFormatCompleteCompare(t, maxDepth, err, "%+v", want, false)
if maxDepth > 0 {
testGenericRecursive(t, err, want, list, maxDepth-1)
}
}
}

View File

@ -0,0 +1,186 @@
package errors
import (
"fmt"
"io"
"path"
"runtime"
"strings"
)
// Frame represents a program counter inside a stack frame.
type Frame uintptr
// pc returns the program counter for this frame;
// multiple frames may have the same PC value.
func (f Frame) pc() uintptr { return uintptr(f) - 1 }
// file returns the full path to the file that contains the
// function for this Frame's pc.
func (f Frame) file() string {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return "unknown"
}
file, _ := fn.FileLine(f.pc())
return file
}
// line returns the line number of source code of the
// function for this Frame's pc.
func (f Frame) line() int {
fn := runtime.FuncForPC(f.pc())
if fn == nil {
return 0
}
_, line := fn.FileLine(f.pc())
return line
}
// Format formats the frame according to the fmt.Formatter interface.
//
// %s source file
// %d source line
// %n function name
// %v equivalent to %s:%d
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+s path of source file relative to the compile time GOPATH
// %+v equivalent to %+s:%d
func (f Frame) Format(s fmt.State, verb rune) {
switch verb {
case 's':
switch {
case s.Flag('+'):
pc := f.pc()
fn := runtime.FuncForPC(pc)
if fn == nil {
io.WriteString(s, "unknown")
} else {
file, _ := fn.FileLine(pc)
fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file)
}
default:
io.WriteString(s, path.Base(f.file()))
}
case 'd':
fmt.Fprintf(s, "%d", f.line())
case 'n':
name := runtime.FuncForPC(f.pc()).Name()
io.WriteString(s, funcname(name))
case 'v':
f.Format(s, 's')
io.WriteString(s, ":")
f.Format(s, 'd')
}
}
// StackTrace is stack of Frames from innermost (newest) to outermost (oldest).
type StackTrace []Frame
// Format formats the stack of Frames according to the fmt.Formatter interface.
//
// %s lists source files for each Frame in the stack
// %v lists the source file and line number for each Frame in the stack
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+v Prints filename, function, and line number for each Frame in the stack.
func (st StackTrace) Format(s fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case s.Flag('+'):
for _, f := range st {
fmt.Fprintf(s, "\n%+v", f)
}
case s.Flag('#'):
fmt.Fprintf(s, "%#v", []Frame(st))
default:
fmt.Fprintf(s, "%v", []Frame(st))
}
case 's':
fmt.Fprintf(s, "%s", []Frame(st))
}
}
// stack represents a stack of program counters.
type stack []uintptr
func (s *stack) Format(st fmt.State, verb rune) {
switch verb {
case 'v':
switch {
case st.Flag('+'):
for _, pc := range *s {
f := Frame(pc)
fmt.Fprintf(st, "\n%+v", f)
}
}
}
}
func (s *stack) StackTrace() StackTrace {
f := make([]Frame, len(*s))
for i := 0; i < len(f); i++ {
f[i] = Frame((*s)[i])
}
return f
}
func callers() *stack {
const depth = 32
var pcs [depth]uintptr
n := runtime.Callers(3, pcs[:])
var st stack = pcs[0:n]
return &st
}
// funcname removes the path prefix component of a function's name reported by func.Name().
func funcname(name string) string {
i := strings.LastIndex(name, "/")
name = name[i+1:]
i = strings.Index(name, ".")
return name[i+1:]
}
func trimGOPATH(name, file string) string {
// Here we want to get the source file path relative to the compile time
// GOPATH. As of Go 1.6.x there is no direct way to know the compiled
// GOPATH at runtime, but we can infer the number of path segments in the
// GOPATH. We note that fn.Name() returns the function name qualified by
// the import path, which does not include the GOPATH. Thus we can trim
// segments from the beginning of the file path until the number of path
// separators remaining is one more than the number of path separators in
// the function name. For example, given:
//
// GOPATH /home/user
// file /home/user/src/pkg/sub/file.go
// fn.Name() pkg/sub.Type.Method
//
// We want to produce:
//
// pkg/sub/file.go
//
// From this we can easily see that fn.Name() has one less path separator
// than our desired output. We count separators from the end of the file
// path until it finds two more than in the function name and then move
// one character forward to preserve the initial path segment without a
// leading separator.
const sep = "/"
goal := strings.Count(name, sep) + 2
i := len(file)
for n := 0; n < goal; n++ {
i = strings.LastIndex(file[:i], sep)
if i == -1 {
// not enough separators found, set i so that the slice expression
// below leaves file unmodified
i = -len(sep)
break
}
}
// get back to 0 or trim the leading separator
file = file[i+len(sep):]
return file
}

View File

@ -0,0 +1,292 @@
package errors
import (
"fmt"
"runtime"
"testing"
)
var initpc, _, _, _ = runtime.Caller(0)
func TestFrameLine(t *testing.T) {
var tests = []struct {
Frame
want int
}{{
Frame(initpc),
9,
}, {
func() Frame {
var pc, _, _, _ = runtime.Caller(0)
return Frame(pc)
}(),
20,
}, {
func() Frame {
var pc, _, _, _ = runtime.Caller(1)
return Frame(pc)
}(),
28,
}, {
Frame(0), // invalid PC
0,
}}
for _, tt := range tests {
got := tt.Frame.line()
want := tt.want
if want != got {
t.Errorf("Frame(%v): want: %v, got: %v", uintptr(tt.Frame), want, got)
}
}
}
type X struct{}
func (x X) val() Frame {
var pc, _, _, _ = runtime.Caller(0)
return Frame(pc)
}
func (x *X) ptr() Frame {
var pc, _, _, _ = runtime.Caller(0)
return Frame(pc)
}
func TestFrameFormat(t *testing.T) {
var tests = []struct {
Frame
format string
want string
}{{
Frame(initpc),
"%s",
"stack_test.go",
}, {
Frame(initpc),
"%+s",
"github.com/pkg/errors.init\n" +
"\t.+/github.com/pkg/errors/stack_test.go",
}, {
Frame(0),
"%s",
"unknown",
}, {
Frame(0),
"%+s",
"unknown",
}, {
Frame(initpc),
"%d",
"9",
}, {
Frame(0),
"%d",
"0",
}, {
Frame(initpc),
"%n",
"init",
}, {
func() Frame {
var x X
return x.ptr()
}(),
"%n",
`\(\*X\).ptr`,
}, {
func() Frame {
var x X
return x.val()
}(),
"%n",
"X.val",
}, {
Frame(0),
"%n",
"",
}, {
Frame(initpc),
"%v",
"stack_test.go:9",
}, {
Frame(initpc),
"%+v",
"github.com/pkg/errors.init\n" +
"\t.+/github.com/pkg/errors/stack_test.go:9",
}, {
Frame(0),
"%v",
"unknown:0",
}}
for i, tt := range tests {
testFormatRegexp(t, i, tt.Frame, tt.format, tt.want)
}
}
func TestFuncname(t *testing.T) {
tests := []struct {
name, want string
}{
{"", ""},
{"runtime.main", "main"},
{"github.com/pkg/errors.funcname", "funcname"},
{"funcname", "funcname"},
{"io.copyBuffer", "copyBuffer"},
{"main.(*R).Write", "(*R).Write"},
}
for _, tt := range tests {
got := funcname(tt.name)
want := tt.want
if got != want {
t.Errorf("funcname(%q): want: %q, got %q", tt.name, want, got)
}
}
}
func TestTrimGOPATH(t *testing.T) {
var tests = []struct {
Frame
want string
}{{
Frame(initpc),
"github.com/pkg/errors/stack_test.go",
}}
for i, tt := range tests {
pc := tt.Frame.pc()
fn := runtime.FuncForPC(pc)
file, _ := fn.FileLine(pc)
got := trimGOPATH(fn.Name(), file)
testFormatRegexp(t, i, got, "%s", tt.want)
}
}
func TestStackTrace(t *testing.T) {
tests := []struct {
err error
want []string
}{{
New("ooh"), []string{
"github.com/pkg/errors.TestStackTrace\n" +
"\t.+/github.com/pkg/errors/stack_test.go:172",
},
}, {
Wrap(New("ooh"), "ahh"), []string{
"github.com/pkg/errors.TestStackTrace\n" +
"\t.+/github.com/pkg/errors/stack_test.go:177", // this is the stack of Wrap, not New
},
}, {
Cause(Wrap(New("ooh"), "ahh")), []string{
"github.com/pkg/errors.TestStackTrace\n" +
"\t.+/github.com/pkg/errors/stack_test.go:182", // this is the stack of New
},
}, {
func() error { return New("ooh") }(), []string{
`github.com/pkg/errors.(func·009|TestStackTrace.func1)` +
"\n\t.+/github.com/pkg/errors/stack_test.go:187", // this is the stack of New
"github.com/pkg/errors.TestStackTrace\n" +
"\t.+/github.com/pkg/errors/stack_test.go:187", // this is the stack of New's caller
},
}, {
Cause(func() error {
return func() error {
return Errorf("hello %s", fmt.Sprintf("world"))
}()
}()), []string{
`github.com/pkg/errors.(func·010|TestStackTrace.func2.1)` +
"\n\t.+/github.com/pkg/errors/stack_test.go:196", // this is the stack of Errorf
`github.com/pkg/errors.(func·011|TestStackTrace.func2)` +
"\n\t.+/github.com/pkg/errors/stack_test.go:197", // this is the stack of Errorf's caller
"github.com/pkg/errors.TestStackTrace\n" +
"\t.+/github.com/pkg/errors/stack_test.go:198", // this is the stack of Errorf's caller's caller
},
}}
for i, tt := range tests {
x, ok := tt.err.(interface {
StackTrace() StackTrace
})
if !ok {
t.Errorf("expected %#v to implement StackTrace() StackTrace", tt.err)
continue
}
st := x.StackTrace()
for j, want := range tt.want {
testFormatRegexp(t, i, st[j], "%+v", want)
}
}
}
func stackTrace() StackTrace {
const depth = 8
var pcs [depth]uintptr
n := runtime.Callers(1, pcs[:])
var st stack = pcs[0:n]
return st.StackTrace()
}
func TestStackTraceFormat(t *testing.T) {
tests := []struct {
StackTrace
format string
want string
}{{
nil,
"%s",
`\[\]`,
}, {
nil,
"%v",
`\[\]`,
}, {
nil,
"%+v",
"",
}, {
nil,
"%#v",
`\[\]errors.Frame\(nil\)`,
}, {
make(StackTrace, 0),
"%s",
`\[\]`,
}, {
make(StackTrace, 0),
"%v",
`\[\]`,
}, {
make(StackTrace, 0),
"%+v",
"",
}, {
make(StackTrace, 0),
"%#v",
`\[\]errors.Frame{}`,
}, {
stackTrace()[:2],
"%s",
`\[stack_test.go stack_test.go\]`,
}, {
stackTrace()[:2],
"%v",
`\[stack_test.go:225 stack_test.go:272\]`,
}, {
stackTrace()[:2],
"%+v",
"\n" +
"github.com/pkg/errors.stackTrace\n" +
"\t.+/github.com/pkg/errors/stack_test.go:225\n" +
"github.com/pkg/errors.TestStackTraceFormat\n" +
"\t.+/github.com/pkg/errors/stack_test.go:276",
}, {
stackTrace()[:2],
"%#v",
`\[\]errors.Frame{stack_test.go:225, stack_test.go:284}`,
}}
for i, tt := range tests {
testFormatRegexp(t, i, tt.StackTrace, tt.format, tt.want)
}
}