tokidoki/auth/oauth2.go
2024-02-21 17:07:08 +01:00

88 lines
2.2 KiB
Go

package auth
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/rs/zerolog/log"
"git.sr.ht/~emersion/go-oauth2"
)
type OAuth2Provider struct {
metadata *oauth2.ServerMetadata
clientID string
clientSecret string
}
// Initializes a new OAuth 2.0 auth provider with the given connection string.
func NewOAuth2(endpoint, clientID, clientSecret string) (AuthProvider, error) {
metadata, err := oauth2.DiscoverServerMetadata(context.Background(), endpoint)
if err != nil {
return nil, fmt.Errorf("failed to fetch OAuth 2.0 server metadata: %v", err)
}
return &OAuth2Provider{
metadata: metadata,
clientID: clientID,
clientSecret: clientSecret,
}, nil
}
func (prov *OAuth2Provider) Middleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
prov.doAuth(next, w, r)
})
}
}
func (prov *OAuth2Provider) doAuth(next http.Handler,
w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
authScheme, creds, _ := strings.Cut(auth, " ")
var username, accessToken string
switch strings.ToLower(authScheme) {
case "bearer":
accessToken = creds
case "basic":
username, accessToken, _ = r.BasicAuth()
default:
w.Header().Add("WWW-Authenticate", `Bearer, Basic realm="Please provide an OAuth access token", charset="UTF-8"`)
http.Error(w, "HTTP auth is required", http.StatusUnauthorized)
return
}
client := oauth2.Client{
Server: prov.metadata,
ClientID: prov.clientID,
ClientSecret: prov.clientSecret,
}
resp, err := client.Introspect(r.Context(), accessToken)
if err != nil || !resp.Active {
log.Debug().Err(err).Msg("auth error")
http.Error(w, "Invalid access token", http.StatusUnauthorized)
return
}
if username != "" && username != resp.Username {
http.Error(w, "Invalid username", http.StatusUnauthorized)
return
}
if resp.Username == "" {
http.Error(w, "OAuth 2.0 server did not send username", http.StatusInternalServerError)
return
}
authCtx := AuthContext{
AuthMethod: "oauth2",
UserName: resp.Username,
}
ctx := NewContext(r.Context(), &authCtx)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}