diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6abb452..08c41e5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,5 +11,5 @@ jobs: with: go-version: ${{ matrix.go-version }} - uses: actions/checkout@v3 - - run: go build ./cmd/primrose/ + - run: go build ./cmd/lotus/ - run: go test ./... diff --git a/.idea/primrose.iml b/.idea/lotus.iml similarity index 100% rename from .idea/primrose.iml rename to .idea/lotus.iml diff --git a/.idea/modules.xml b/.idea/modules.xml index e5ec2e0..4b3d167 100644 --- a/.idea/modules.xml +++ b/.idea/modules.xml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/.idea/sqldialects.xml b/.idea/sqldialects.xml new file mode 100644 index 0000000..b7d9ab9 --- /dev/null +++ b/.idea/sqldialects.xml @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/api/api.go b/api/api.go index 4b18891..bdbadba 100644 --- a/api/api.go +++ b/api/api.go @@ -2,22 +2,23 @@ package api import ( "encoding/json" - "github.com/1f349/primrose/imap" - "github.com/1f349/primrose/smtp" + "github.com/1f349/lotus/imap" + "github.com/1f349/lotus/smtp" "github.com/julienschmidt/httprouter" "net/http" "time" ) -type Conf struct { - Listen string `yaml:"listen"` -} - -func SetupApiServer(conf Conf, send *smtp.Smtp, recv *imap.Imap) *http.Server { +func SetupApiServer(listen string, auth func(callback AuthCallback) httprouter.Handle, send *smtp.Smtp, recv *imap.Imap) *http.Server { r := httprouter.New() - // smtp - r.POST("/message", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + // === ACCOUNT === + r.GET("/account", auth(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { + // TODO(melon): find users aliases and other account data + })) + + // === SMTP === + r.POST("/message", auth(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { // check body exists if req.Body == nil { rw.WriteHeader(http.StatusBadRequest) @@ -32,6 +33,11 @@ func SetupApiServer(conf Conf, send *smtp.Smtp, recv *imap.Imap) *http.Server { return } + // TODO(melon): add alias support + if j.From == b.Subject { + + } + mail, err := j.PrepareMail() if err != nil { rw.WriteHeader(http.StatusBadRequest) @@ -44,10 +50,63 @@ func SetupApiServer(conf Conf, send *smtp.Smtp, recv *imap.Imap) *http.Server { } rw.WriteHeader(http.StatusAccepted) - }) + })) + + // === IMAP === + type statusJson struct { + Folder string `json:"folder"` + } + r.GET("/status", auth(imapClient(recv, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, cli *imap.Client, t statusJson) { + status, err := cli.Status(t.Folder) + if err != nil { + rw.WriteHeader(http.StatusForbidden) + return + } + _ = json.NewEncoder(rw).Encode(status) + }))) + r.GET("/list-messages", auth(imapClient(recv, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, cli *imap.Client, t statusJson) { + messages, err := cli.Fetch(t.Folder, 1, 100, 100) + if err != nil { + rw.WriteHeader(http.StatusForbidden) + return + } + _ = json.NewEncoder(rw).Encode(messages) + }))) + r.GET("/search-messages", auth(imapClient(recv, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, cli *imap.Client, t statusJson) { + status, err := cli.Status(t.Folder) + if err != nil { + rw.WriteHeader(http.StatusForbidden) + return + } + _ = json.NewEncoder(rw).Encode(status) + }))) + r.POST("/create-message", auth(imapClient(recv, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, cli *imap.Client, t statusJson) { + status, err := cli.Status(t.Folder) + if err != nil { + rw.WriteHeader(http.StatusForbidden) + return + } + _ = json.NewEncoder(rw).Encode(status) + }))) + r.POST("/update-messages-flags", auth(imapClient(recv, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, cli *imap.Client, t statusJson) { + status, err := cli.Status(t.Folder) + if err != nil { + rw.WriteHeader(http.StatusForbidden) + return + } + _ = json.NewEncoder(rw).Encode(status) + }))) + r.POST("/copy-messages", auth(imapClient(recv, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, cli *imap.Client, t statusJson) { + status, err := cli.Status(t.Folder) + if err != nil { + rw.WriteHeader(http.StatusForbidden) + return + } + _ = json.NewEncoder(rw).Encode(status) + }))) return &http.Server{ - Addr: conf.Listen, + Addr: listen, Handler: r, ReadTimeout: time.Minute, ReadHeaderTimeout: time.Minute, @@ -56,3 +115,33 @@ func SetupApiServer(conf Conf, send *smtp.Smtp, recv *imap.Imap) *http.Server { MaxHeaderBytes: 2500, } } + +// apiError outputs a generic JSON error message +func apiError(rw http.ResponseWriter, code int, m string) { + rw.WriteHeader(code) + _ = json.NewEncoder(rw).Encode(map[string]string{ + "error": m, + }) +} + +type IcCallback[T any] func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, cli *imap.Client, t T) + +func imapClient[T any](recv *imap.Imap, cb IcCallback[T]) AuthCallback { + return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { + if req.Body == nil { + rw.WriteHeader(http.StatusBadRequest) + return + } + var t T + if json.NewDecoder(req.Body).Decode(&t) != nil { + rw.WriteHeader(http.StatusBadRequest) + return + } + cli, err := recv.MakeClient(b.Subject) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + return + } + cb(rw, req, params, cli, t) + } +} diff --git a/api/auth.go b/api/auth.go new file mode 100644 index 0000000..ae0ea29 --- /dev/null +++ b/api/auth.go @@ -0,0 +1,61 @@ +package api + +import ( + "crypto/subtle" + "github.com/1f349/violet/utils" + "github.com/MrMelon54/mjwt" + "github.com/MrMelon54/mjwt/auth" + "github.com/julienschmidt/httprouter" + "net/http" +) + +type AuthClaims mjwt.BaseTypeClaims[auth.AccessTokenClaims] + +type AuthCallback func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) + +type authChecker struct { + verify mjwt.Verifier + aud string + cb AuthCallback +} + +func (a *authChecker) Handle(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + // Get bearer token + bearer := utils.GetBearer(req) + if bearer == "" { + apiError(rw, http.StatusForbidden, "Missing bearer token") + return + } + + // Read claims from mjwt + _, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](a.verify, bearer) + if err != nil { + apiError(rw, http.StatusForbidden, "Invalid token") + return + } + + var validAud bool + for _, i := range b.Audience { + if subtle.ConstantTimeCompare([]byte(i), []byte(a.aud)) == 1 { + validAud = true + } + } + if !validAud { + apiError(rw, http.StatusForbidden, "Invalid audience claim") + return + } + + a.cb(rw, req, params, AuthClaims(b)) +} + +// CheckAuth validates the bearer token against a mjwt.Verifier and returns an +// error message or continues to the next handler +func CheckAuth(verify mjwt.Verifier, aud string) func(cb AuthCallback) httprouter.Handle { + return func(cb AuthCallback) httprouter.Handle { + return (&authChecker{ + verify: verify, + aud: aud, + cb: cb, + }).Handle + } +} diff --git a/cmd/lotus/conf.go b/cmd/lotus/conf.go new file mode 100644 index 0000000..dd2daee --- /dev/null +++ b/cmd/lotus/conf.go @@ -0,0 +1,13 @@ +package main + +import ( + "github.com/1f349/lotus/imap" + "github.com/1f349/lotus/smtp" +) + +type Conf struct { + Listen string `yaml:"listen"` + Audience string `yaml:"audience"` + Smtp *smtp.Smtp `yaml:"smtp"` + Imap *imap.Imap `yaml:"imap"` +} diff --git a/cmd/lotus/main.go b/cmd/lotus/main.go new file mode 100644 index 0000000..d89f017 --- /dev/null +++ b/cmd/lotus/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "flag" + "github.com/1f349/lotus/api" + "github.com/1f349/violet/utils" + exitReload "github.com/MrMelon54/exit-reload" + "github.com/MrMelon54/mjwt" + "gopkg.in/yaml.v3" + "log" + "os" + "path/filepath" +) + +var configPath string + +func main() { + flag.StringVar(&configPath, "conf", "", "/path/to/config.yml : path to the config file") + flag.Parse() + + if configPath == "" { + log.Println("[Lotus] Error: config flag is missing") + return + } + + openConf, err := os.Open(configPath) + if err != nil { + if os.IsNotExist(err) { + log.Println("[Lotus] Error: missing config file") + } else { + log.Println("[Lotus] Error: open config file: ", err) + } + return + } + + var conf Conf + err = yaml.NewDecoder(openConf).Decode(&conf) + if err != nil { + log.Println("[Lotus] Error: invalid config file: ", err) + return + } + + wd := filepath.Dir(configPath) + + verify, err := mjwt.NewMJwtVerifierFromFile(filepath.Join(wd, "signer.public.pem")) + if err != nil { + log.Fatalf("[Lotus] Failed to load MJWT verifier public key from file '%s': %s", filepath.Join(wd, "signer.public.pem"), err) + } + + userAuth := api.CheckAuth(verify, conf.Audience) + srv := api.SetupApiServer(conf.Listen, userAuth, conf.Smtp, conf.Imap) + log.Printf("[Lotus] Starting API server on: '%s'\n", srv.Addr) + go utils.RunBackgroundHttp("Lotus", srv) + + exitReload.ExitReload("Lotus", func() {}, func() { + // stop server + srv.Close() + }) +} diff --git a/cmd/primrose/conf.go b/cmd/primrose/conf.go deleted file mode 100644 index b829463..0000000 --- a/cmd/primrose/conf.go +++ /dev/null @@ -1,13 +0,0 @@ -package main - -import ( - "github.com/1f349/primrose/api" - "github.com/1f349/primrose/imap" - "github.com/1f349/primrose/smtp" -) - -type Conf struct { - Smtp smtp.Smtp `yaml:"smtp"` - Imap imap.Imap `yaml:"imap"` - Api api.Conf `yaml:"api"` -} diff --git a/cmd/primrose/main.go b/cmd/primrose/main.go deleted file mode 100644 index 914319c..0000000 --- a/cmd/primrose/main.go +++ /dev/null @@ -1,5 +0,0 @@ -package main - -func main() { - // TODO(Melon): write start up code -} diff --git a/go.mod b/go.mod index 3e845db..43e6ae7 100644 --- a/go.mod +++ b/go.mod @@ -1,20 +1,28 @@ -module github.com/1f349/primrose +module github.com/1f349/lotus go 1.21.0 require ( + github.com/1f349/violet v0.0.7 + github.com/MrMelon54/exit-reload v0.0.1 + github.com/MrMelon54/mjwt v0.1.1 github.com/emersion/go-imap v1.2.1 github.com/emersion/go-message v0.16.0 github.com/emersion/go-sasl v0.0.0-20220912192320-0145f2c60ead github.com/emersion/go-smtp v0.17.0 + github.com/go-sql-driver/mysql v1.7.1 + github.com/hydrogen18/memlistener v1.0.0 github.com/julienschmidt/httprouter v1.3.0 github.com/stretchr/testify v1.8.4 + gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/emersion/go-textwrapper v0.0.0-20200911093747-65d896831594 // indirect + github.com/golang-jwt/jwt/v4 v4.5.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/text v0.12.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3ffd958..c3e80ca 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,10 @@ +github.com/1f349/violet v0.0.7 h1:FxCAIVjzUzkgGfhGMX7FcvGj+kaJky45PnLfqKNgA8M= +github.com/1f349/violet v0.0.7/go.mod h1:YfKZX9p55Uot8iSDnbqQbAgU717H0rFNo8ieu2wbxI4= +github.com/MrMelon54/exit-reload v0.0.1 h1:sxHa59tNEQMcikwuX2+93lw6Vi1+R7oCRF8a0C3alXc= +github.com/MrMelon54/exit-reload v0.0.1/go.mod h1:PLiSfmUzwdpTTQP3BBfUPhkqPwaIZjx0DuXBnM76Bug= +github.com/MrMelon54/mjwt v0.1.1 h1:m+aTpxbhQCrOPKHN170DQMFR5r938LkviU38unob5Jw= +github.com/MrMelon54/mjwt v0.1.1/go.mod h1:oYrDBWK09Hju98xb+bRQ0wy+RuAzacxYvKYOZchR2Tk= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/emersion/go-imap v1.2.1 h1:+s9ZjMEjOB8NzZMVTM3cCenz2JrQIGGo5j1df19WjTA= @@ -12,8 +19,20 @@ github.com/emersion/go-smtp v0.17.0 h1:tq90evlrcyqRfE6DSXaWVH54oX6OuZOQECEmhWBME github.com/emersion/go-smtp v0.17.0/go.mod h1:qm27SGYgoIPRot6ubfQ/GpiPy/g3PaZAVRxiO/sDUgQ= github.com/emersion/go-textwrapper v0.0.0-20200911093747-65d896831594 h1:IbFBtwoTQyw0fIM5xv1HF+Y+3ZijDR839WMulgxCcUY= github.com/emersion/go-textwrapper v0.0.0-20200911093747-65d896831594/go.mod h1:aqO8z8wPrjkscevZJFVE1wXJrLpC5LtJG7fqLOsPb2U= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/hydrogen18/memlistener v1.0.0 h1:JR7eDj8HD6eXrc5fWLbSUnfcQFL06PYvCc0DKQnWfaU= +github.com/hydrogen18/memlistener v1.0.0/go.mod h1:qEIFzExnS6016fRpRfxrExeVn2gbClQA99gQhnIcdhE= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= @@ -23,7 +42,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.12.0 h1:k+n5B8goJNdU7hSvEtMUz3d1Q6D/XW4COJSJR6fN0mc= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/imap/client.go b/imap/client.go index 776933a..d91d5d3 100644 --- a/imap/client.go +++ b/imap/client.go @@ -3,6 +3,7 @@ package imap import ( "github.com/emersion/go-imap" "github.com/emersion/go-imap/client" + "time" ) var imapStatusFlags = []imap.StatusItem{ @@ -17,9 +18,20 @@ type Client struct { ic *client.Client } -func (c *Client) Status(folder string) (*imap.MailboxStatus, error) { - mbox, err := c.ic.Status(folder, imapStatusFlags) - return mbox, err +func (c *Client) Append(name string, flags []string, date time.Time, msg imap.Literal) error { + return c.ic.Append(name, flags, date, msg) +} + +func (c *Client) Copy(seqset *imap.SeqSet, dest string) error { + return c.ic.Copy(seqset, dest) +} + +func (c *Client) Create(name string) error { + return c.ic.Create(name) +} + +func (c *Client) Delete(name string) error { + return c.ic.Delete(name) } func (c *Client) Fetch(folder string, start, end, limit uint32) ([]*imap.Message, error) { @@ -45,12 +57,50 @@ func (c *Client) Fetch(folder string, start, end, limit uint32) ([]*imap.Message done <- c.ic.Fetch(seqSet, []imap.FetchItem{imap.FetchEnvelope}, messages) }() - outMsg := make([]*imap.Message, 0, limit) + out := make([]*imap.Message, 0, limit) for msg := range messages { - outMsg = append(outMsg, msg) + out = append(out, msg) } if err := <-done; err != nil { return nil, err } - return outMsg, nil + return out, nil +} + +func (c *Client) List(ref, name string) ([]*imap.MailboxInfo, error) { + infos := make(chan *imap.MailboxInfo, 1) + done := make(chan error, 1) + go func() { + done <- c.ic.List(ref, name, infos) + }() + + out := make([]*imap.MailboxInfo, 0) + for info := range infos { + out = append(out, info) + } + if err := <-done; err != nil { + return nil, err + } + return out, nil +} + +func (c *Client) Move(seqset *imap.SeqSet, dest string) error { + return c.ic.Move(seqset, dest) +} + +func (c *Client) Noop() error { + return c.ic.Noop() +} + +func (c *Client) Rename(existingName, newName string) error { + return c.ic.Rename(existingName, newName) +} + +func (c *Client) Search(criteria *imap.SearchCriteria) ([]uint32, error) { + return c.ic.Search(criteria) +} + +func (c *Client) Status(name string) (*imap.MailboxStatus, error) { + mbox, err := c.ic.Status(name, imapStatusFlags) + return mbox, err } diff --git a/imap/fake/backend.go b/imap/fake/backend.go new file mode 100644 index 0000000..c954320 --- /dev/null +++ b/imap/fake/backend.go @@ -0,0 +1,20 @@ +package fake + +import ( + "fmt" + "github.com/emersion/go-imap" + "github.com/emersion/go-imap/backend" +) + +type Backend struct { + Debug chan []byte + Username string + Password string +} + +func (i *Backend) Login(connInfo *imap.ConnInfo, username, password string) (backend.User, error) { + if username != i.Username || password != i.Password { + return nil, fmt.Errorf("invalid user") + } + return &User{i.Debug, username}, nil +} diff --git a/imap/fake/mailbox.go b/imap/fake/mailbox.go new file mode 100644 index 0000000..69c8b7b --- /dev/null +++ b/imap/fake/mailbox.go @@ -0,0 +1,63 @@ +package fake + +import ( + "fmt" + "github.com/emersion/go-imap" + "time" +) + +type Mailbox struct { + Debug chan []byte + ImapName string +} + +func (m *Mailbox) Name() string { + return m.ImapName +} + +func (m *Mailbox) Info() (*imap.MailboxInfo, error) { + return &imap.MailboxInfo{ + Attributes: []string{imap.UnmarkedAttr, imap.HasNoChildrenAttr}, + Delimiter: "/", + Name: m.ImapName, + }, nil +} + +func (m *Mailbox) Status(items []imap.StatusItem) (*imap.MailboxStatus, error) { + return &imap.MailboxStatus{ + Name: m.ImapName, + Messages: 1, + }, nil +} + +func (m *Mailbox) SetSubscribed(subscribed bool) error { + return fmt.Errorf("failed to subscribe") +} + +func (m *Mailbox) Check() error { + return nil +} + +func (m *Mailbox) ListMessages(uid bool, seqset *imap.SeqSet, items []imap.FetchItem, ch chan<- *imap.Message) error { + return fmt.Errorf("failed to list messages") +} + +func (m *Mailbox) SearchMessages(uid bool, criteria *imap.SearchCriteria) ([]uint32, error) { + return nil, fmt.Errorf("failed to search messages") +} + +func (m *Mailbox) CreateMessage(flags []string, date time.Time, body imap.Literal) error { + return fmt.Errorf("failed to create message") +} + +func (m *Mailbox) UpdateMessagesFlags(uid bool, seqset *imap.SeqSet, operation imap.FlagsOp, flags []string) error { + return fmt.Errorf("failed to update message flags") +} + +func (m *Mailbox) CopyMessages(uid bool, seqset *imap.SeqSet, dest string) error { + return fmt.Errorf("failed to copy messages") +} + +func (m *Mailbox) Expunge() error { + return fmt.Errorf("failed to expunge") +} diff --git a/imap/fake/user.go b/imap/fake/user.go new file mode 100644 index 0000000..e3c38ff --- /dev/null +++ b/imap/fake/user.go @@ -0,0 +1,39 @@ +package fake + +import ( + "fmt" + "github.com/emersion/go-imap/backend" +) + +type User struct { + Debug chan []byte + ImapUser string +} + +func (i *User) Username() string { + return i.ImapUser +} + +func (i *User) ListMailboxes(subscribed bool) ([]backend.Mailbox, error) { + return []backend.Mailbox{}, nil +} + +func (i *User) GetMailbox(name string) (backend.Mailbox, error) { + return &Mailbox{i.Debug, name}, nil +} + +func (i *User) CreateMailbox(name string) error { + return fmt.Errorf("failed to create mailbox") +} + +func (i *User) DeleteMailbox(name string) error { + return fmt.Errorf("failed to delete mailbox") +} + +func (i *User) RenameMailbox(existingName, newName string) error { + return fmt.Errorf("failed to rename mailbox") +} + +func (i *User) Logout() error { + return nil +} diff --git a/imap/imap.go b/imap/imap.go index 2cb6549..ecf8689 100644 --- a/imap/imap.go +++ b/imap/imap.go @@ -13,9 +13,11 @@ type Imap struct { Separator string `yaml:"separator"` } +var defaultDialer = client.Dial + func (i *Imap) MakeClient(user string) (*Client, error) { // dial imap server - imapClient, err := client.Dial(i.Server) + imapClient, err := defaultDialer(i.Server) if err != nil { return nil, err } diff --git a/imap/imap_test.go b/imap/imap_test.go new file mode 100644 index 0000000..7c0011d --- /dev/null +++ b/imap/imap_test.go @@ -0,0 +1,36 @@ +package imap + +import ( + "github.com/1f349/lotus/imap/fake" + "github.com/emersion/go-imap/client" + "github.com/emersion/go-imap/server" + "github.com/hydrogen18/memlistener" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestImap_MakeClient(t *testing.T) { + listener := memlistener.NewMemoryListener() + serverData := make(chan []byte, 4) + srv := server.New(&fake.Backend{Debug: serverData, Username: "a@localhost*master@localhost", Password: "1234"}) + srv.AllowInsecureAuth = true + go func() { + _ = srv.Serve(listener) + }() + + defaultDialer = func(addr string) (*client.Client, error) { + dial, err := listener.Dial("", "") + if err != nil { + return nil, err + } + return client.New(dial) + } + + i := &Imap{Server: "localhost", Username: "master@localhost", Password: "1234", Separator: "*"} + cli, err := i.MakeClient("a@localhost") + assert.NoError(t, err) + status, err := cli.Status("INBOX") + assert.NoError(t, err) + assert.Equal(t, "INBOX", status.Name) + assert.Equal(t, uint32(1), status.Messages) +} diff --git a/postfix-config/comma-list-scanner/comma-list-scanner.go b/postfix-config/comma-list-scanner/comma-list-scanner.go new file mode 100644 index 0000000..461eb99 --- /dev/null +++ b/postfix-config/comma-list-scanner/comma-list-scanner.go @@ -0,0 +1,44 @@ +package comma_list_scanner + +import ( + "bufio" + "bytes" + "io" +) + +type CommaListScanner struct { + r *bufio.Scanner + text string + err error +} + +func NewCommaListScanner(r io.Reader) *CommaListScanner { + s := bufio.NewScanner(r) + s.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, ','); i >= 0 { + return i + 1, bytes.TrimSpace(data[0:i]), nil + } + // If we're at EOF, we have a final non-terminated line. Return it. + if atEOF { + return len(data), bytes.TrimSpace(data), nil + } + // Request more data. + return 0, nil, nil + }) + return &CommaListScanner{r: s} +} + +func (c *CommaListScanner) Scan() bool { + if c.r.Scan() { + c.text = c.r.Text() + } + c.err = c.r.Err() + return false +} + +func (c *CommaListScanner) Text() string { + return c.text +} diff --git a/postfix-config/comma-list-scanner/comma-list-scanner_test.go b/postfix-config/comma-list-scanner/comma-list-scanner_test.go new file mode 100644 index 0000000..7e8fa3b --- /dev/null +++ b/postfix-config/comma-list-scanner/comma-list-scanner_test.go @@ -0,0 +1,27 @@ +package comma_list_scanner + +import ( + "github.com/stretchr/testify/assert" + "strings" + "testing" +) + +var testCommaList = []struct { + text string + out []string +}{ + {"hello, wow this is cool, amazing", []string{"hello", "wow this is cool", "amazing"}}, + {"hello, wow this is cool, amazing", []string{"hello", "wow this is cool", "amazing"}}, +} + +func TestNewCommaListScanner(t *testing.T) { + for _, i := range testCommaList { + t.Run(i.text, func(t *testing.T) { + s := NewCommaListScanner(strings.NewReader(i.text)) + n := 0 + for s.Scan() { + assert.Equal(t, i.out[n], s.Text()) + } + }) + } +} diff --git a/postfix-config/config-parser/config-parser.go b/postfix-config/config-parser/config-parser.go new file mode 100644 index 0000000..f349ea2 --- /dev/null +++ b/postfix-config/config-parser/config-parser.go @@ -0,0 +1,49 @@ +package config_parser + +import ( + "bufio" + "errors" + "io" + "strings" +) + +var ErrInvalidConfigLine = errors.New("invalid config line") + +type ConfigParser struct { + s *bufio.Scanner + pair [2]string + err error +} + +func NewConfigParser(r io.Reader) *ConfigParser { + return &ConfigParser{s: bufio.NewScanner(r)} +} + +func (c *ConfigParser) Scan() bool { +scanAgain: + if !c.s.Scan() { + return false + } + text := strings.TrimSpace(c.s.Text()) + if text == "" || strings.HasPrefix(text, "#") { + goto scanAgain + } + n := strings.IndexByte(text, '=') + if n < 2 || n+2 >= len(text) || text[n-1] != ' ' || text[n+1] != ' ' { + c.err = ErrInvalidConfigLine + return false + } + c.pair = [2]string{text[:n-1], text[n+2:]} + return true +} + +func (c *ConfigParser) Pair() (string, string) { + return c.pair[0], c.pair[1] +} + +func (c *ConfigParser) Err() error { + if c.err != nil { + return c.err + } + return c.s.Err() +} diff --git a/postfix-config/config-parser/config-parser_test.go b/postfix-config/config-parser/config-parser_test.go new file mode 100644 index 0000000..b1762e1 --- /dev/null +++ b/postfix-config/config-parser/config-parser_test.go @@ -0,0 +1,39 @@ +package config_parser + +import ( + "github.com/stretchr/testify/assert" + "strings" + "testing" +) + +var configParserData = []struct { + Input string + Values [][2]string +}{ + { + "a = a", + [][2]string{{"a", "a"}}, + }, + { + " a = a ", + [][2]string{{"a", "a"}}, + }, + { + " # this is a comment\n a = a, b\nb = c, d", + [][2]string{{"a", "a, b"}, {"b", "c, d"}}, + }, +} + +func TestConfigParser(t *testing.T) { + for _, i := range configParserData { + t.Run(i.Input, func(t *testing.T) { + a := NewConfigParser(strings.NewReader(i.Input)) + n := 0 + for a.Scan() { + assert.False(t, n >= len(i.Values)) + assert.Equal(t, i.Values[n], a.pair) + n++ + } + }) + } +} diff --git a/postfix-config/config.go b/postfix-config/config.go new file mode 100644 index 0000000..e36c104 --- /dev/null +++ b/postfix-config/config.go @@ -0,0 +1,38 @@ +package postfix_config + +import mapProvider "github.com/1f349/lotus/postfix-config/map-provider" + +type Config struct { + // same + VirtualMailboxDomains mapProvider.MapProvider + VirtualAliasMaps mapProvider.MapProvider + VirtualMailboxMaps mapProvider.MapProvider + AliasMaps mapProvider.MapProvider + LocalRecipientMaps mapProvider.MapProvider + SmtpdSenderLoginMaps string // TODO(melon): union map? +} + +func (c *Config) SetKey(k string, m mapProvider.MapProvider) { + switch k { + case "virtual_mailbox_domains": + c.VirtualMailboxDomains = m + case "virtual_alias_maps": + c.VirtualAliasMaps = m + case "virtual_mailbox_maps": + c.VirtualMailboxMaps = m + case "alias_maps": + c.AliasMaps = m + case "local_recipient_maps": + c.LocalRecipientMaps = m + case "smtpd_sender_login_maps": + c.SmtpdSenderLoginMaps = "" + } +} + +func (c *Config) NeedsMapProvider(k string) bool { + switch k { + case "virtual_mailbox_domains", "virtual_alias_maps", "virtual_mailbox_maps", "alias_maps", "local_recipient_maps", "smtpd_sender_login_maps": + return true + } + return false +} diff --git a/postfix-config/decoder.go b/postfix-config/decoder.go new file mode 100644 index 0000000..4dfd143 --- /dev/null +++ b/postfix-config/decoder.go @@ -0,0 +1,49 @@ +package postfix_config + +import ( + "bufio" + "fmt" + configParser "github.com/1f349/lotus/postfix-config/config-parser" + mapProvider "github.com/1f349/lotus/postfix-config/map-provider" + "io" + "strings" +) + +type Decoder struct { + r *configParser.ConfigParser + v *Config + t map[string]string +} + +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{r: configParser.NewConfigParser(r)} +} + +func (d *Decoder) Load() error { + d.v = &Config{} + for d.r.Scan() { + k, v := d.r.Pair() + if d.v.NeedsMapProvider(k) { + m := mapProvider.SequenceMapProvider{} + + s := bufio.NewScanner(strings.NewReader(v)) + s.Split(bufio.ScanWords) + for s.Scan() { + a := s.Text() + println("a", a) + if strings.HasPrefix(a, "$") { + // is variable + } + n := strings.IndexByte(a, ':') + if n == -1 { + return fmt.Errorf("missing prefix") + } + } + if err := s.Err(); err != nil { + return err + } + d.v.SetKey(k, m) + } + } + return d.r.Err() +} diff --git a/postfix-config/decoder_test.go b/postfix-config/decoder_test.go new file mode 100644 index 0000000..e646731 --- /dev/null +++ b/postfix-config/decoder_test.go @@ -0,0 +1,18 @@ +package postfix_config + +import ( + "bytes" + _ "embed" + configParser "github.com/1f349/lotus/postfix-config/config-parser" + "github.com/stretchr/testify/assert" + "testing" +) + +//go:embed example.cf +var exampleConfig []byte + +func TestDecoder_Load(t *testing.T) { + b := bytes.NewReader(exampleConfig) + d := &Decoder{r: configParser.NewConfigParser(b)} + assert.NoError(t, d.Load()) +} diff --git a/postfix-config/example.cf b/postfix-config/example.cf new file mode 100644 index 0000000..c16fd44 --- /dev/null +++ b/postfix-config/example.cf @@ -0,0 +1,10 @@ +# this only contains the relevant config properties + +recipient_delimiter = + + +virtual_mailbox_domains = mysql:/etc/postfix/sql/mysql_virtual_domains_maps.cf +virtual_alias_maps = mysql:/etc/postfix/sql/mysql_virtual_alias_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_wildcard_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_domain_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_user_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_userdomain_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_domain_catchall_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_user_catchall_maps.cf +virtual_mailbox_maps = mysql:/etc/postfix/sql/mysql_virtual_mailbox_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_domain_mailbox_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_user_mailbox_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_userdomain_mailbox_maps.cf +alias_maps = hash:/etc/aliases $virtual_alias_maps +local_recipient_maps = $virtual_mailbox_maps $alias_maps +smtpd_sender_login_maps = unionmap:{ hash:/etc/aliases, mysql:/etc/postfix/sql/mysql_sender_alias_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_domain_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_user_maps.cf, mysql:/etc/postfix/sql/mysql_virtual_alias_userdomain_maps.cf } diff --git a/postfix-config/map-provider/hash.go b/postfix-config/map-provider/hash.go new file mode 100644 index 0000000..025e82a --- /dev/null +++ b/postfix-config/map-provider/hash.go @@ -0,0 +1,48 @@ +package map_provider + +import ( + "bufio" + "io" + "os" + "strings" +) + +type Hash struct { + r io.Reader + v map[string]string +} + +var _ MapProvider = &Hash{} + +func NewHashMapProvider(filename string) (*Hash, error) { + open, err := os.Open(filename) + if err != nil { + return nil, err + } + return &Hash{open, make(map[string]string)}, nil +} + +func (h *Hash) Load() error { + scanner := bufio.NewScanner(h.r) + scanner.Split(bufio.ScanLines) + for scanner.Scan() { + text := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(text, "#") { + continue + } + + n := strings.IndexByte(text, ':') + key := strings.TrimSpace(text[:n]) + values := strings.Split(text[n+1:], ",") + for _, i := range values { + k := strings.TrimSpace(i) + h.v[k] = key + } + } + return scanner.Err() +} + +func (h *Hash) Find(name string) (string, bool) { + v, ok := h.v[name] + return v, ok +} diff --git a/postfix-config/map-provider/hash_example.txt b/postfix-config/map-provider/hash_example.txt new file mode 100644 index 0000000..b8c3f94 --- /dev/null +++ b/postfix-config/map-provider/hash_example.txt @@ -0,0 +1,3 @@ +# See man 5 aliases for format +postmaster: root +test: this, is, an, example diff --git a/postfix-config/map-provider/hash_test.go b/postfix-config/map-provider/hash_test.go new file mode 100644 index 0000000..846b1b1 --- /dev/null +++ b/postfix-config/map-provider/hash_test.go @@ -0,0 +1,23 @@ +package map_provider + +import ( + "bytes" + _ "embed" + "github.com/stretchr/testify/assert" + "testing" +) + +//go:embed hash_example.txt +var hashExample []byte + +func TestHash_Load(t *testing.T) { + h := &Hash{r: bytes.NewReader(hashExample), v: make(map[string]string)} + assert.NoError(t, h.Load()) + assert.Equal(t, map[string]string{ + "root": "postmaster", + "this": "test", + "is": "test", + "an": "test", + "example": "test", + }, h.v) +} diff --git a/postfix-config/map-provider/map-provider.go b/postfix-config/map-provider/map-provider.go new file mode 100644 index 0000000..be220db --- /dev/null +++ b/postfix-config/map-provider/map-provider.go @@ -0,0 +1,23 @@ +package map_provider + +// MapProvider is an interface to allow looking up mapped values from variables, +// hash files or mysql queries. +type MapProvider interface { + Find(name string) (string, bool) +} + +// SequenceMapProvider calls Find against each provider in a slice and outputs +// the true mapped value of the input. first mapped value found. If the input was +// not found then "", false is returned. +type SequenceMapProvider []MapProvider + +func (s SequenceMapProvider) Find(name string) (string, bool) { + for _, i := range s { + if find, ok := i.Find(name); ok { + return find, true + } + } + return "", false +} + +var _ MapProvider = SequenceMapProvider{} diff --git a/postfix-config/map-provider/mysql-prepared-query.go b/postfix-config/map-provider/mysql-prepared-query.go new file mode 100644 index 0000000..757fd48 --- /dev/null +++ b/postfix-config/map-provider/mysql-prepared-query.go @@ -0,0 +1,76 @@ +package map_provider + +import ( + "errors" + "sort" + "strings" + "unicode" +) + +var ( + ErrMissingArgument = errors.New("missing argument") + ErrInvalidRawQuery = errors.New("invalid raw query") +) + +type PreparedQuery struct { + raw string + params map[int]byte +} + +func NewPreparedQuery(raw string) (*PreparedQuery, error) { + var s strings.Builder + origin := 0 + params := make(map[int]byte) + for { + n := strings.IndexByte(raw[origin:], '%') + if n == -1 { + break + } + n += origin + if n+1 == len(raw) { + return nil, ErrInvalidRawQuery + } + s.WriteString(raw[origin:n]) + if raw[n+1] == '%' { + s.WriteByte('%') + origin = n + 1 + continue + } + params[s.Len()] = toLower(raw[n+1]) + origin = n + 2 + } + s.WriteString(raw[origin:]) + return &PreparedQuery{ + raw: s.String(), + params: params, + }, nil +} + +func (p *PreparedQuery) Format(args map[byte]string) (string, error) { + var s strings.Builder + keys := make([]int, 0, len(p.params)) + for k := range p.params { + keys = append(keys, k) + } + sort.Ints(keys) + origin := 0 + for _, k := range keys { + r, ok := args[p.params[k]] + if !ok { + return "", ErrMissingArgument + } + + // write up to and including the next parameter + s.WriteString(p.raw[origin:k]) + s.WriteString(strings.ReplaceAll(r, "'", "")) + origin = k + } + + // write the rest of the query + s.WriteString(p.raw[origin:]) + return s.String(), nil +} + +func toLower(a byte) byte { + return byte(unicode.ToLower(rune(a))) +} diff --git a/postfix-config/map-provider/mysql-prepared-query_test.go b/postfix-config/map-provider/mysql-prepared-query_test.go new file mode 100644 index 0000000..e94fc5f --- /dev/null +++ b/postfix-config/map-provider/mysql-prepared-query_test.go @@ -0,0 +1,40 @@ +package map_provider + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +const ( + testQuery = "SELECT aliasMap.goto FROM aliasMap,aliasdomainMap WHERE aliasdomainMap.domain='%d' AND aliasMap.address = CONCAT('%u', '@', aliasdomainMap.goto) AND aliasMap.active > 0 AND aliasdomainMap.active > 0" + testQueryRaw = "SELECT aliasMap.goto FROM aliasMap,aliasdomainMap WHERE aliasdomainMap.domain='' AND aliasMap.address = CONCAT('', '@', aliasdomainMap.goto) AND aliasMap.active > 0 AND aliasdomainMap.active > 0" + testQueryFormat = "SELECT aliasMap.goto FROM aliasMap,aliasdomainMap WHERE aliasdomainMap.domain='example.com' AND aliasMap.address = CONCAT('test', '@', aliasdomainMap.goto) AND aliasMap.active > 0 AND aliasdomainMap.active > 0" +) + +func TestNewPreparedQuery(t *testing.T) { + query, err := NewPreparedQuery(testQuery) + assert.NoError(t, err) + assert.Equal(t, PreparedQuery{ + raw: testQueryRaw, + params: map[int]byte{ + 79: 'd', + 112: 'u', + }, + }, *query) +} + +func TestPreparedQuery_Format(t *testing.T) { + query := &PreparedQuery{ + raw: testQueryRaw, + params: map[int]byte{ + 79: 'd', + 112: 'u', + }, + } + format, err := query.Format(map[byte]string{ + 'd': "example.com", + 'u': "test", + }) + assert.NoError(t, err) + assert.Equal(t, testQueryFormat, format) +} diff --git a/postfix-config/map-provider/mysql.go b/postfix-config/map-provider/mysql.go new file mode 100644 index 0000000..937fcef --- /dev/null +++ b/postfix-config/map-provider/mysql.go @@ -0,0 +1,112 @@ +package map_provider + +import ( + "database/sql" + configParser "github.com/1f349/lotus/postfix-config/config-parser" + "github.com/go-sql-driver/mysql" + _ "github.com/go-sql-driver/mysql" + "io" + "os" + "regexp" + "strings" +) + +var checkUatD = regexp.MustCompile("^[^@]+@[^@]+$") + +type MySql struct { + r io.Reader + db *sql.DB + query *PreparedQuery +} + +var _ MapProvider = &MySql{} + +func NewMySqlMapProvider(filename string) (*MySql, error) { + open, err := os.Open(filename) + if err != nil { + return nil, err + } + return &MySql{r: open}, nil +} + +func (m *MySql) Load() error { + p := configParser.NewConfigParser(m.r) + c := mysql.NewConfig() + var q string + for p.Scan() { + k, v := p.Pair() + switch k { + case "user": + c.User = v + case "password": + c.Passwd = v + case "hosts": + c.Net = "tcp" + c.Addr = v + case "dbname": + c.DBName = v + case "query": + q = v + } + } + if err := p.Err(); err != nil { + return err + } + + q2, err := NewPreparedQuery(q) + if err != nil { + return err + } + m.query = q2 + + // try opening connection + db, err := sql.Open("mysql", c.FormatDSN()) + if err != nil { + return err + } + m.db = db + + return db.Ping() +} + +func (m *MySql) Find(name string) (string, bool) { + format, err := m.query.Format(genQueryArgs(name)) + return format, err == nil +} + +// genQueryArgs converts an input key into the % encoded parameters +// +// %s - full input key +// %u - user part of user@domain or full input key +// %d - domain part of user@domain or missing parameter +// %[1-9] - replaced with the most significant component of the input key's domain +// for `user@mail.example.com` %1 = com, %2 = example, %3 = mail +// otherwise they are missing parameters +func genQueryArgs(name string) map[byte]string { + args := make(map[byte]string) + args['s'] = name + args['u'] = name + if checkUatD.MatchString(name) { + n := strings.IndexByte(name, '@') + args['u'] = name[:n] + args['d'] = name[n+1:] + + genDomainArgs(args, name[n+1:]) + } + return args +} + +// genDomainArgs replaces with the most significant component of the input key's +// domain for `user@mail.example.com` %1 = com, %2 = example, %3 = mail, +// otherwise they are missing parameters +func genDomainArgs(args map[byte]string, s string) { + i, l := byte(1), len(s) + for { + n := strings.LastIndexByte(s, '.') + if n == -1 { + break + } + args[(i + '0')] = s[n+1 : l] + l = n + } +} diff --git a/postfix-config/map-provider/mysql_example.cf b/postfix-config/map-provider/mysql_example.cf new file mode 100644 index 0000000..e22ea38 --- /dev/null +++ b/postfix-config/map-provider/mysql_example.cf @@ -0,0 +1,5 @@ +user = example +password = 1234 +hosts = 127.0.0.1 +dbname = mail +query = SELECT aliasMap.goto FROM aliasMap,aliasdomainMap WHERE aliasdomainMap.domain='%d' AND aliasMap.address = CONCAT('%u', '@', aliasdomainMap.goto) AND aliasMap.active > 0 AND aliasdomainMap.active > 0 diff --git a/postfix-config/map-provider/variable.go b/postfix-config/map-provider/variable.go new file mode 100644 index 0000000..a0dc17f --- /dev/null +++ b/postfix-config/map-provider/variable.go @@ -0,0 +1,12 @@ +package map_provider + +type Variable struct { + Name string + Value MapProvider +} + +func (v *Variable) Find(name string) (string, bool) { + return v.Value.Find(name) +} + +var _ MapProvider = &Variable{} diff --git a/smtp/fake/fake-smtp.go b/smtp/fake/fake-smtp.go new file mode 100644 index 0000000..49f6d95 --- /dev/null +++ b/smtp/fake/fake-smtp.go @@ -0,0 +1,45 @@ +package fake + +import ( + "github.com/emersion/go-smtp" + "io" + "log" +) + +type SmtpBackend struct { + Debug chan []byte +} + +func (f *SmtpBackend) NewSession(c *smtp.Conn) (smtp.Session, error) { + return &SmtpSession{f.Debug}, nil +} + +type SmtpSession struct { + Debug chan []byte +} + +func (f *SmtpSession) Reset() {} + +func (f *SmtpSession) Logout() error { return nil } + +func (f *SmtpSession) AuthPlain(username, password string) error { return nil } + +func (f *SmtpSession) Mail(from string, opts *smtp.MailOptions) error { + log.Println("MAIL " + from) + f.Debug <- []byte("MAIL " + from + "\n") + return nil +} + +func (f *SmtpSession) Rcpt(to string) error { + f.Debug <- []byte("RCPT " + to + "\n") + return nil +} + +func (f *SmtpSession) Data(r io.Reader) error { + all, err := io.ReadAll(r) + if err != nil { + return err + } + f.Debug <- all + return nil +} diff --git a/smtp/smtp.go b/smtp/smtp.go index 4201314..20e3bc7 100644 --- a/smtp/smtp.go +++ b/smtp/smtp.go @@ -16,9 +16,11 @@ type Mail struct { body []byte } +var defaultDialer = smtp.Dial + func (s *Smtp) Send(mail *Mail) error { // dial smtp server - smtpClient, err := smtp.Dial(s.Server) + smtpClient, err := defaultDialer(s.Server) if err != nil { return err } diff --git a/smtp/smtp_test.go b/smtp/smtp_test.go index 96adbff..3e1fdb5 100644 --- a/smtp/smtp_test.go +++ b/smtp/smtp_test.go @@ -1,11 +1,64 @@ package smtp import ( + "bytes" + "github.com/1f349/lotus/smtp/fake" + "github.com/emersion/go-message" "github.com/emersion/go-message/mail" + "github.com/emersion/go-smtp" + "github.com/hydrogen18/memlistener" "github.com/stretchr/testify/assert" + "log" + "strings" "testing" + "time" ) +var sendTestMessage []byte + +func init() { + var h mail.Header + h.SetDate(time.Date(2000, time.January, 1, 0, 0, 0, 0, time.Local)) + h.SetSubject("Happy Millennium") + h.SetAddressList("From", []*mail.Address{{Name: "Test", Address: "test@localhost"}}) + h.SetAddressList("To", []*mail.Address{{Name: "A", Address: "a@localhost"}}) + h.Set("Content-Type", "text/plain; charset=utf-8") + entity, err := message.New(h.Header, strings.NewReader("Thanks")) + if err != nil { + log.Fatal(err) + } + out := new(bytes.Buffer) + if entity.WriteTo(out) != nil { + log.Fatal(err) + } + sendTestMessage = out.Bytes() +} + +func TestSmtp_Send(t *testing.T) { + listener := memlistener.NewMemoryListener() + serverData := make(chan []byte, 4) + server := smtp.NewServer(&fake.SmtpBackend{Debug: serverData}) + go func() { + _ = server.Serve(listener) + }() + + defaultDialer = func(addr string) (*smtp.Client, error) { + dial, err := listener.Dial("", "") + if err != nil { + return nil, err + } + return smtp.NewClient(dial, "localhost") + } + + s := &Smtp{Server: "localhost:25"} + err := s.Send(&Mail{from: "test@localhost", deliver: []string{"a@localhost", "b@localhost"}, body: sendTestMessage}) + assert.NoError(t, err) + assert.Equal(t, []byte("MAIL test@localhost\n"), <-serverData) + assert.Equal(t, []byte("RCPT a@localhost\n"), <-serverData) + assert.Equal(t, []byte("RCPT b@localhost\n"), <-serverData) + assert.Equal(t, append(sendTestMessage, '\r', '\n'), <-serverData) +} + func TestCreateSenderSlice(t *testing.T) { a := []*mail.Address{{Address: "a@example.com"}, {Address: "b@example.com"}} b := []*mail.Address{{Address: "a@example.com"}, {Address: "c@example.com"}}