202 lines
5.7 KiB
Go
202 lines
5.7 KiB
Go
package discord
|
|
|
|
import (
|
|
"code.mrmelon54.com/melon/tools/utils"
|
|
"context"
|
|
"embed"
|
|
_ "embed"
|
|
"fmt"
|
|
"github.com/bwmarrin/discordgo"
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/mux"
|
|
"golang.org/x/oauth2"
|
|
"html/template"
|
|
"net/http"
|
|
"os"
|
|
)
|
|
|
|
var (
|
|
//go:embed pages/index.go.html
|
|
indexTemplate string
|
|
//go:embed pages/assets/icon
|
|
iconFiles embed.FS
|
|
)
|
|
|
|
type Module struct {
|
|
sessionWrapper func(cb func(http.ResponseWriter, *http.Request, *utils.State)) func(rw http.ResponseWriter, req *http.Request)
|
|
oauthClient *oauth2.Config
|
|
}
|
|
|
|
type discordKeyType int
|
|
|
|
const (
|
|
KeyOauthClient = discordKeyType(iota)
|
|
KeyUser
|
|
KeyState
|
|
KeyAccessToken
|
|
KeyRefreshToken
|
|
)
|
|
|
|
func New() *Module {
|
|
return &Module{}
|
|
}
|
|
|
|
func (m *Module) GetName() string { return "Discord" }
|
|
func (m *Module) GetEndpoint() string { return "/discord" }
|
|
|
|
func (m *Module) SetupModule(router *mux.Router, f func(cb func(http.ResponseWriter, *http.Request, *utils.State)) func(rw http.ResponseWriter, req *http.Request)) {
|
|
m.sessionWrapper = f
|
|
m.oauthClient = &oauth2.Config{
|
|
ClientID: os.Getenv("DISCORD_CLIENT_ID"),
|
|
ClientSecret: os.Getenv("DISCORD_CLIENT_SECRET"),
|
|
Scopes: []string{"identify", "guilds", "connections", "email"},
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: "https://discord.com/oauth2/authorize",
|
|
TokenURL: "https://discord.com/api/oauth2/token",
|
|
},
|
|
RedirectURL: os.Getenv("DISCORD_REDIRECT_URL"),
|
|
}
|
|
router.HandleFunc("/", m.getClient(m.homepage))
|
|
router.HandleFunc("/login", m.sessionWrapper(m.loginPage))
|
|
router.PathPrefix("/assets/icon/{name}.svg").HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
vars := mux.Vars(req)
|
|
b, err := iconFiles.ReadFile("pages/assets/icon/" + vars["name"] + ".svg")
|
|
if err != nil {
|
|
rw.WriteHeader(http.StatusNotFound)
|
|
} else {
|
|
rw.Header().Set("Content-Type", "image/svg+xml")
|
|
rw.WriteHeader(http.StatusOK)
|
|
_, _ = rw.Write(b)
|
|
}
|
|
})
|
|
}
|
|
|
|
func (m *Module) getClient(cb func(http.ResponseWriter, *http.Request, *utils.State, *discordgo.Session)) func(rw http.ResponseWriter, req *http.Request) {
|
|
return m.sessionWrapper(func(rw http.ResponseWriter, req *http.Request, state *utils.State) {
|
|
if v, ok := utils.GetStateValue[*discordgo.Session](state, KeyOauthClient); ok {
|
|
cb(rw, req, state, v)
|
|
return
|
|
}
|
|
http.Redirect(rw, req, "/discord/login", http.StatusTemporaryRedirect)
|
|
})
|
|
}
|
|
|
|
func (m *Module) homepage(rw http.ResponseWriter, _ *http.Request, state *utils.State, discordClient *discordgo.Session) {
|
|
myUser, err := discordClient.User("@me")
|
|
if err != nil {
|
|
state.Del(KeyOauthClient)
|
|
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
myGuilds, err := discordClient.UserGuilds(200, "", "")
|
|
if err != nil {
|
|
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
myConns, err := discordClient.UserConnections()
|
|
if err != nil {
|
|
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
tmp, err := template.New("homepage").Funcs(template.FuncMap{
|
|
"checkFlag": func(a discordgo.UserFlags, b int) bool { return (int(a) & b) != 0 },
|
|
"connectedLink": connectedLinkFunc,
|
|
}).Parse(indexTemplate)
|
|
if err != nil {
|
|
fmt.Println("Template parse error:", err)
|
|
return
|
|
}
|
|
|
|
guildIcons := make([]string, len(myGuilds))
|
|
for i, j := range myGuilds {
|
|
var a discordgo.Guild
|
|
a.ID = j.ID
|
|
a.Icon = j.Icon
|
|
guildIcons[i] = a.IconURL()
|
|
}
|
|
|
|
err = tmp.Execute(rw, struct {
|
|
User *discordgo.User
|
|
UserAccent string
|
|
Avatar string
|
|
Banner string
|
|
Guilds []*discordgo.UserGuild
|
|
GuildIcons []string
|
|
Connections []*discordgo.UserConnection
|
|
}{
|
|
User: myUser,
|
|
UserAccent: fmt.Sprintf("#%06x", myUser.AccentColor),
|
|
Avatar: myUser.AvatarURL("256"),
|
|
Banner: myUser.BannerURL("256"),
|
|
Guilds: myGuilds,
|
|
GuildIcons: guildIcons,
|
|
Connections: myConns,
|
|
})
|
|
if err != nil {
|
|
fmt.Println("Template execute error:", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (m *Module) loginPage(rw http.ResponseWriter, req *http.Request, state *utils.State) {
|
|
if myUser, ok := utils.GetStateValue[*string](state, KeyUser); ok {
|
|
if myUser != nil {
|
|
http.Redirect(rw, req, "/discord", http.StatusTemporaryRedirect)
|
|
return
|
|
}
|
|
}
|
|
|
|
if flowState, ok := utils.GetStateValue[uuid.UUID](state, KeyState); ok {
|
|
q := req.URL.Query()
|
|
if q.Has("code") && q.Has("state") {
|
|
if q.Get("state") == flowState.String() {
|
|
exchange, err := m.oauthClient.Exchange(context.Background(), q.Get("code"))
|
|
if err != nil {
|
|
fmt.Println("Exchange token error:", err)
|
|
return
|
|
}
|
|
c, err := discordgo.New("Bearer " + exchange.AccessToken)
|
|
if err != nil {
|
|
fmt.Println("Create client error:", err)
|
|
return
|
|
}
|
|
state.Put(KeyOauthClient, c)
|
|
state.Put(KeyAccessToken, exchange.AccessToken)
|
|
state.Put(KeyRefreshToken, exchange.RefreshToken)
|
|
http.Redirect(rw, req, "/discord", http.StatusTemporaryRedirect)
|
|
return
|
|
}
|
|
http.Error(rw, "OAuth flow state doesn't match\n", http.StatusBadRequest)
|
|
return
|
|
}
|
|
}
|
|
|
|
flowState := uuid.New()
|
|
state.Put(KeyState, flowState)
|
|
|
|
http.Redirect(rw, req, m.oauthClient.AuthCodeURL(flowState.String(), oauth2.AccessTypeOffline), http.StatusTemporaryRedirect)
|
|
}
|
|
|
|
func connectedLinkFunc(a *discordgo.UserConnection) string {
|
|
switch a.Type {
|
|
case "domain":
|
|
return "https://" + a.Name
|
|
case "github":
|
|
return "https://github.com/" + a.Name
|
|
case "reddit":
|
|
return "https://www.reddit.com/u/" + a.Name
|
|
case "spotify":
|
|
return "https://open.spotify.com/user/" + a.ID
|
|
case "steam":
|
|
return "https://steamcommunity.com/profiles/" + a.ID
|
|
case "twitch":
|
|
return "https://www.twitch.tv/" + a.Name
|
|
case "twitter":
|
|
return "https://twitter.com/" + a.Name
|
|
case "youtube":
|
|
return "https://www.youtube.com/channel/" + a.ID
|
|
}
|
|
return ""
|
|
}
|