package auth import ( "context" "fmt" "net/http" "strings" "github.com/rs/zerolog/log" "git.sr.ht/~emersion/go-oauth2" ) type OAuth2Provider struct { metadata *oauth2.ServerMetadata clientID string clientSecret string } // Initializes a new OAuth 2.0 auth provider with the given connection string. func NewOAuth2(endpoint, clientID, clientSecret string) (AuthProvider, error) { metadata, err := oauth2.DiscoverServerMetadata(context.Background(), endpoint) if err != nil { return nil, fmt.Errorf("failed to fetch OAuth 2.0 server metadata: %v", err) } return &OAuth2Provider{ metadata: metadata, clientID: clientID, clientSecret: clientSecret, }, nil } func (prov *OAuth2Provider) Middleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { prov.doAuth(next, w, r) }) } } func (prov *OAuth2Provider) doAuth(next http.Handler, w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") authScheme, creds, _ := strings.Cut(auth, " ") var username, accessToken string switch strings.ToLower(authScheme) { case "bearer": accessToken = creds case "basic": username, accessToken, _ = r.BasicAuth() default: w.Header().Add("WWW-Authenticate", `Bearer, Basic realm="Please provide an OAuth access token", charset="UTF-8"`) http.Error(w, "HTTP auth is required", http.StatusUnauthorized) return } client := oauth2.Client{ Server: prov.metadata, ClientID: prov.clientID, ClientSecret: prov.clientSecret, } resp, err := client.Introspect(r.Context(), accessToken) if err != nil || !resp.Active { log.Debug().Err(err).Msg("auth error") http.Error(w, "Invalid access token", http.StatusUnauthorized) return } if username != "" && username != resp.Username { http.Error(w, "Invalid username", http.StatusUnauthorized) return } if resp.Username == "" { http.Error(w, "OAuth 2.0 server did not send username", http.StatusInternalServerError) return } authCtx := AuthContext{ AuthMethod: "oauth2", UserName: resp.Username, } ctx := NewContext(r.Context(), &authCtx) r = r.WithContext(ctx) next.ServeHTTP(w, r) }