Replace postfix config parser with query command

This commit is contained in:
Melon 2023-08-23 16:17:01 +01:00
parent d6048ccc6e
commit 9b87a8a857
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
42 changed files with 455 additions and 781 deletions

View File

@ -1,7 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="SqlDialectMappings"> <component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/postfix-config/map-provider/mysql-prepared-query_test.go" dialect="GenericSQL" />
<file url="PROJECT" dialect="MySQL" /> <file url="PROJECT" dialect="MySQL" />
</component> </component>
</project> </project>

View File

@ -3,13 +3,12 @@ package api
import ( import (
"encoding/json" "encoding/json"
"github.com/1f349/lotus/imap" "github.com/1f349/lotus/imap"
"github.com/1f349/lotus/smtp"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"net/http" "net/http"
"time" "time"
) )
func SetupApiServer(listen string, auth func(callback AuthCallback) httprouter.Handle, send *smtp.Smtp, recv *imap.Imap) *http.Server { func SetupApiServer(listen string, auth func(callback AuthCallback) httprouter.Handle, send Smtp, recv Imap) *http.Server {
r := httprouter.New() r := httprouter.New()
// === ACCOUNT === // === ACCOUNT ===
@ -18,39 +17,7 @@ func SetupApiServer(listen string, auth func(callback AuthCallback) httprouter.H
})) }))
// === SMTP === // === SMTP ===
r.POST("/message", auth(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { r.POST("/message", auth(MessageSender(send)))
// check body exists
if req.Body == nil {
rw.WriteHeader(http.StatusBadRequest)
return
}
// parse json body
var j smtp.Json
err := json.NewDecoder(req.Body).Decode(&j)
if err != nil {
rw.WriteHeader(http.StatusBadRequest)
return
}
// TODO(melon): add alias support
if j.From == b.Subject {
}
mail, err := j.PrepareMail()
if err != nil {
rw.WriteHeader(http.StatusBadRequest)
return
}
if send.Send(mail) != nil {
rw.WriteHeader(http.StatusInternalServerError)
return
}
rw.WriteHeader(http.StatusAccepted)
}))
// === IMAP === // === IMAP ===
type statusJson struct { type statusJson struct {
@ -126,7 +93,7 @@ func apiError(rw http.ResponseWriter, code int, m string) {
type IcCallback[T any] func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, cli *imap.Client, t T) type IcCallback[T any] func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, cli *imap.Client, t T)
func imapClient[T any](recv *imap.Imap, cb IcCallback[T]) AuthCallback { func imapClient[T any](recv Imap, cb IcCallback[T]) AuthCallback {
return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) { return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
if req.Body == nil { if req.Body == nil {
rw.WriteHeader(http.StatusBadRequest) rw.WriteHeader(http.StatusBadRequest)

14
api/interfaces.go Normal file
View File

@ -0,0 +1,14 @@
package api
import (
"github.com/1f349/lotus/imap"
"github.com/1f349/lotus/smtp"
)
type Smtp interface {
Send(mail *smtp.Mail) error
}
type Imap interface {
MakeClient(user string) (*imap.Client, error)
}

71
api/send-message.go Normal file
View File

@ -0,0 +1,71 @@
package api
import (
"encoding/json"
"errors"
postfixLookup "github.com/1f349/lotus/postfix-lookup"
"github.com/1f349/lotus/smtp"
"github.com/julienschmidt/httprouter"
"net/http"
"time"
)
var defaultPostfixLookup = postfixLookup.NewPostfixLookup().Lookup
var timeNow = time.Now
// MessageSender is the internal handler for `POST /message` requests
// the access token is already validated at this point
func MessageSender(send Smtp) func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, b AuthClaims) {
// check body exists
if req.Body == nil {
apiError(rw, http.StatusBadRequest, "Missing request body")
return
}
// parse json body
var j smtp.Json
err := json.NewDecoder(req.Body).Decode(&j)
if err != nil {
apiError(rw, http.StatusBadRequest, "Invalid JSON body")
return
}
// prepare the mail for sending
mail, err := j.PrepareMail(timeNow())
if err != nil {
apiError(rw, http.StatusBadRequest, "Invalid mail message")
return
}
// this looks up the underlying account for the sender alias
println(mail.From)
lookup, err := defaultPostfixLookup(mail.From)
// the alias does not exist
if errors.Is(err, postfixLookup.ErrInvalidAlias) {
apiError(rw, http.StatusBadRequest, "Invalid sender alias")
return
}
// the alias lookup failed to run
if err != nil {
apiError(rw, http.StatusInternalServerError, "Sender alias lookup failed")
return
}
// the alias does not match the logged-in user
if lookup != b.Subject {
apiError(rw, http.StatusBadRequest, "User does not own sender alias")
return
}
// try sending the mail
if send.Send(mail) != nil {
apiError(rw, http.StatusInternalServerError, "Failed to send mail")
return
}
rw.WriteHeader(http.StatusAccepted)
}
}

240
api/send-message_test.go Normal file
View File

@ -0,0 +1,240 @@
package api
import (
"bytes"
"encoding/json"
"errors"
"fmt"
postfixLookup "github.com/1f349/lotus/postfix-lookup"
"github.com/1f349/lotus/smtp"
"github.com/MrMelon54/mjwt/auth"
"github.com/MrMelon54/mjwt/claims"
"github.com/golang-jwt/jwt/v4"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"io"
"net/http"
"net/http/httptest"
"slices"
"strings"
"testing"
"time"
)
func init() {
defaultPostfixLookup = func(key string) (string, error) {
switch key {
case "noreply@example.com", "admin@example.com":
return "admin@example.com", nil
case "user@example.com":
return "user@example.com", nil
}
return "", postfixLookup.ErrInvalidAlias
}
timeNow = func() time.Time {
return time.Date(2000, time.January, 1, 0, 0, 0, 0, time.UTC)
}
}
type fakeSmtp struct {
from string
deliver []string
body []byte
}
func (f *fakeSmtp) Send(mail *smtp.Mail) error {
if mail.From != f.from {
return fmt.Errorf("test fail: invalid from address")
}
if !slices.Equal(mail.Deliver, f.deliver) {
return fmt.Errorf("test fail: invalid deliver slice")
}
if !slices.Equal(mail.Body, f.body) {
return fmt.Errorf("test fail: invalid message body")
}
return nil
}
type fakeFailedSmtp struct{}
func (f *fakeFailedSmtp) Send(mail *smtp.Mail) error {
return errors.New("sending failed")
}
var messageSenderTestData = []struct {
req func() (*http.Request, error)
smtp Smtp
claims AuthClaims
status int
output string
}{
{
req: func() (*http.Request, error) {
return http.NewRequest(http.MethodPost, "https://api.example.com/v1/mail/message", nil)
},
smtp: &fakeSmtp{},
claims: makeFakeAuthClaims("admin@example.com"),
status: http.StatusBadRequest,
output: "Missing request body",
},
{
req: func() (*http.Request, error) {
return http.NewRequest(http.MethodPost, "https://api.example.com/v1/mail/message", strings.NewReader(`{`))
},
smtp: &fakeSmtp{},
claims: makeFakeAuthClaims("admin@example.com"),
status: http.StatusBadRequest,
output: "Invalid JSON body",
},
{
req: func() (*http.Request, error) {
j, err := json.Marshal(smtp.Json{
From: "noreply2@example.com",
ReplyTo: "admin@example.com",
To: "user@example.com",
Subject: "Test Subject",
BodyType: "plain",
Body: "Plain text",
})
if err != nil {
return nil, err
}
return http.NewRequest(http.MethodPost, "https://api.example.com/v1/mail/message", bytes.NewReader(j))
},
smtp: &fakeSmtp{},
claims: makeFakeAuthClaims("admin@example.com"),
status: http.StatusBadRequest,
output: "Invalid sender alias",
},
{
req: func() (*http.Request, error) {
j, err := json.Marshal(smtp.Json{
From: "user@example.com",
ReplyTo: "admin@example.com",
To: "user@example.com",
Subject: "Test Subject",
BodyType: "plain",
Body: "Plain text",
})
if err != nil {
return nil, err
}
return http.NewRequest(http.MethodPost, "https://api.example.com/v1/mail/message", bytes.NewReader(j))
},
smtp: &fakeSmtp{},
claims: makeFakeAuthClaims("admin@example.com"),
status: http.StatusBadRequest,
output: "User does not own sender alias",
},
{
req: func() (*http.Request, error) {
j, err := json.Marshal(smtp.Json{
From: "noreply@example.com, user2@example.com",
ReplyTo: "admin@example.com",
To: "user@example.com",
Subject: "Test Subject",
BodyType: "plain",
Body: "Plain text",
})
if err != nil {
return nil, err
}
return http.NewRequest(http.MethodPost, "https://api.example.com/v1/mail/message", bytes.NewReader(j))
},
smtp: &fakeSmtp{},
claims: makeFakeAuthClaims("admin@example.com"),
status: http.StatusBadRequest,
output: "Invalid mail message",
},
{
req: func() (*http.Request, error) {
j, err := json.Marshal(smtp.Json{
From: "noreply@example.com",
ReplyTo: "admin@example.com",
To: "user@example.com",
Subject: "Test Subject",
BodyType: "plain",
Body: "Plain text",
})
if err != nil {
return nil, err
}
return http.NewRequest(http.MethodPost, "https://api.example.com/v1/mail/message", bytes.NewReader(j))
},
smtp: &fakeFailedSmtp{},
claims: makeFakeAuthClaims("admin@example.com"),
status: http.StatusInternalServerError,
output: "Failed to send mail",
},
{
req: func() (*http.Request, error) {
j, err := json.Marshal(smtp.Json{
From: "noreply@example.com",
ReplyTo: "admin@example.com",
To: "user@example.com",
Cc: "user2@example.com",
Bcc: "user3@example.com",
Subject: "Test Subject",
BodyType: "plain",
Body: "Some plain text",
})
if err != nil {
return nil, err
}
return http.NewRequest(http.MethodPost, "https://api.example.com/v1/mail/message", bytes.NewReader(j))
},
smtp: &fakeSmtp{
from: "noreply@example.com",
deliver: []string{"user@example.com", "user2@example.com", "user3@example.com"},
body: []byte("Mime-Version: 1.0\r\n" +
"Content-Type: text/plain; charset=utf-8\r\n" +
"Cc: <user2@example.com>\r\n" +
"To: <user@example.com>\r\n" +
"Reply-To: <admin@example.com>\r\n" +
"From: <noreply@example.com>\r\n" +
"Subject: Test Subject\r\n" +
"Date: Sat, 01 Jan 2000 00:00:00 +0000\r\n" +
"\r\n" +
"Some plain text"),
},
claims: makeFakeAuthClaims("admin@example.com"),
status: http.StatusAccepted,
output: "",
},
}
func makeFakeAuthClaims(subject string) AuthClaims {
return struct {
jwt.RegisteredClaims
ClaimType string
Claims auth.AccessTokenClaims
}{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "Test",
Subject: subject,
Audience: jwt.ClaimStrings{"mail.example.com"},
},
ClaimType: "access-token",
Claims: auth.AccessTokenClaims{Perms: claims.NewPermStorage()},
}
}
func TestMessageSender(t *testing.T) {
for _, i := range messageSenderTestData {
rec := httptest.NewRecorder()
req, err := i.req()
assert.NoError(t, err)
MessageSender(i.smtp)(rec, req, httprouter.Params{}, i.claims)
res := rec.Result()
assert.Equal(t, i.status, res.StatusCode)
assert.NotNil(t, res.Body)
all, err := io.ReadAll(res.Body)
assert.NoError(t, err)
if i.output == "" {
assert.Equal(t, "", string(all))
} else {
assert.Equal(t, "{\"error\":\""+i.output+"\"}\n", string(all))
}
}
}

View File

@ -1,57 +0,0 @@
package comma_list_scanner
import (
"bufio"
"bytes"
"fmt"
"io"
)
type CommaListScanner struct {
r *bufio.Scanner
text string
err error
}
func NewCommaListScanner(r io.Reader) *CommaListScanner {
s := bufio.NewScanner(r)
s.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
println("data", fmt.Sprintf("%s", data))
println("index", bytes.IndexAny(data, " ,"))
if i := bytes.IndexAny(data, " ,"); i >= 0 {
// consume all spaces after the comma
j := i + 1
for j < len(data) && data[j] == ' ' {
j++
}
return j, bytes.TrimSpace(data[0:i]), nil
}
// If we're at EOF, we have a final non-terminated line. Return it.
if atEOF {
return len(data), bytes.TrimSpace(data), nil
}
// Request more data.
return 0, nil, nil
})
return &CommaListScanner{r: s}
}
func (c *CommaListScanner) Scan() bool {
if c.r.Scan() {
c.text = c.r.Text()
return true
}
c.err = c.r.Err()
return false
}
func (c *CommaListScanner) Text() string {
return c.text
}
func (c *CommaListScanner) Err() error {
return c.err
}

View File

@ -1,45 +0,0 @@
package comma_list_scanner
import (
"github.com/stretchr/testify/assert"
"strings"
"testing"
)
var testCommaList = []string{
"hello, wow-this-is-cool, amazing",
"hello, wow-this-is-cool",
"hello, wow-this-is-cool, ",
"hello, wow-this-is-cool,",
",hello, wow-this-is-cool",
",hello, wow-this-is-cool,",
"hello, wow-this-is-cool,,,",
}
func TestNewCommaListScanner(t *testing.T) {
for _, i := range testCommaList {
t.Run(i, func(t *testing.T) {
// use comma list scanner
s := NewCommaListScanner(strings.NewReader(i))
n := strings.Count(i, ",")
a := make([]string, 0, n+1)
for s.Scan() {
a = append(a, s.Text())
}
assert.NoError(t, s.Err())
// test against splitting and trimming strings
b := strings.Split(i, ",")
for i := 0; i < len(b); i++ {
c := strings.TrimSpace(b[i])
if c == "" {
b = append(b[0:i], b[i+1:]...)
i--
} else {
b[i] = c
}
}
assert.Equal(t, b, a)
})
}
}

View File

@ -1,49 +0,0 @@
package config_parser
import (
"bufio"
"errors"
"io"
"strings"
)
var ErrInvalidConfigLine = errors.New("invalid config line")
type ConfigParser struct {
s *bufio.Scanner
pair [2]string
err error
}
func NewConfigParser(r io.Reader) *ConfigParser {
return &ConfigParser{s: bufio.NewScanner(r)}
}
func (c *ConfigParser) Scan() bool {
scanAgain:
if !c.s.Scan() {
return false
}
text := strings.TrimSpace(c.s.Text())
if text == "" || strings.HasPrefix(text, "#") {
goto scanAgain
}
n := strings.IndexByte(text, '=')
if n < 2 || n+2 >= len(text) || text[n-1] != ' ' || text[n+1] != ' ' {
c.err = ErrInvalidConfigLine
return false
}
c.pair = [2]string{text[:n-1], text[n+2:]}
return true
}
func (c *ConfigParser) Pair() (string, string) {
return strings.TrimSpace(c.pair[0]), strings.TrimSpace(c.pair[1])
}
func (c *ConfigParser) Err() error {
if c.err != nil {
return c.err
}
return c.s.Err()
}

View File

@ -1,39 +0,0 @@
package config_parser
import (
"github.com/stretchr/testify/assert"
"strings"
"testing"
)
var configParserData = []struct {
Input string
Values [][2]string
}{
{
"a = a",
[][2]string{{"a", "a"}},
},
{
" a = a ",
[][2]string{{"a", "a"}},
},
{
" # this is a comment\n a = a, b\nb = c, d",
[][2]string{{"a", "a, b"}, {"b", "c, d"}},
},
}
func TestConfigParser(t *testing.T) {
for _, i := range configParserData {
t.Run(i.Input, func(t *testing.T) {
a := NewConfigParser(strings.NewReader(i.Input))
n := 0
for a.Scan() {
assert.False(t, n >= len(i.Values))
assert.Equal(t, i.Values[n], a.pair)
n++
}
})
}
}

View File

@ -1,43 +0,0 @@
package postfix_config
import mapProvider "github.com/1f349/lotus/postfix-config/map-provider"
type Config struct {
// same
VirtualMailboxDomains mapProvider.MapProvider
VirtualAliasMaps mapProvider.MapProvider
VirtualMailboxMaps mapProvider.MapProvider
AliasMaps mapProvider.MapProvider
LocalRecipientMaps mapProvider.MapProvider
SmtpdSenderLoginMaps mapProvider.MapProvider
}
var parseProviderData = map[string]string{
"virtual_mailbox_domains": "comma",
"virtual_alias_maps": "comma",
"virtual_mailbox_maps": "comma",
"alias_maps": "comma",
"local_recipient_maps": "comma",
"smtpd_sender_login_maps": "union",
}
func (c *Config) ParseProvider(k string) string {
return parseProviderData[k]
}
func (c *Config) SetKey(k string, m mapProvider.MapProvider) {
switch k {
case "virtual_mailbox_domains":
c.VirtualMailboxDomains = m
case "virtual_alias_maps":
c.VirtualAliasMaps = m
case "virtual_mailbox_maps":
c.VirtualMailboxMaps = m
case "alias_maps":
c.AliasMaps = m
case "local_recipient_maps":
c.LocalRecipientMaps = m
case "smtpd_sender_login_maps":
c.SmtpdSenderLoginMaps = m
}
}

View File

@ -1,108 +0,0 @@
package postfix_config
import (
"errors"
"fmt"
commaListScanner "github.com/1f349/lotus/postfix-config/comma-list-scanner"
configParser "github.com/1f349/lotus/postfix-config/config-parser"
mapProvider "github.com/1f349/lotus/postfix-config/map-provider"
"io"
"path/filepath"
"strings"
)
type Decoder struct {
r *configParser.ConfigParser
temp map[string]string
basePath string
}
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: configParser.NewConfigParser(r)}
}
func (d *Decoder) Load() error {
for d.r.Scan() {
k, v := d.r.Pair()
d.temp[k] = v
}
if err := d.r.Err(); err != nil {
return err
}
switch d.value.ParseProvider(k) {
case "comma":
m := mapProvider.SequenceMapProvider{}
s := commaListScanner.NewCommaListScanner(strings.NewReader(v))
for s.Scan() {
a := s.Text()
println("a", a)
if strings.HasPrefix(a, "$") {
m = append(m, &mapProvider.Variable{Name: a[1:]})
continue
}
v2, err := d.createValue(a)
if err != nil {
return err
}
m = append(m, v2)
}
if err := s.Err(); err != nil {
return err
}
d.value.SetKey(k, m)
case "union":
if !strings.HasPrefix(v, "unionmap:{") || !strings.HasSuffix(v, "}") {
return errors.New("key requires a union map")
}
v = v[len("unionmap:{") : len(v)-1]
m := mapProvider.SequenceMapProvider{}
s := commaListScanner.NewCommaListScanner(strings.NewReader(v))
for s.Scan() {
a := s.Text()
v2, err := d.createValue(a)
if err != nil {
return err
}
m = append(m, v2)
}
default:
return fmt.Errorf("key '%s' has no defined parse provider", k)
}
}
return d.r.Err()
}
func (d *Decoder) createValue(a string) (mapProvider.MapProvider, error) {
n := strings.IndexByte(a, ':')
if n == -1 {
return nil, fmt.Errorf("missing prefix")
}
namespace := a[:n]
value := a[n+1:]
switch namespace {
case "mysql":
if !filepath.IsAbs(value) {
value = filepath.Join(d.basePath, value)
}
provider, err := mapProvider.NewMySqlMapProvider(value)
if err != nil {
return nil, err
}
return provider, nil
case "hash":
if !filepath.IsAbs(value) {
value = filepath.Join(d.basePath, value)
}
provider, err := mapProvider.NewHashMapProvider(value)
if err != nil {
return nil, err
}
return provider, nil
}
return nil, errors.New("invalid provider namespace")
}

View File

@ -1,28 +0,0 @@
package postfix_config
import (
"bytes"
_ "embed"
configParser "github.com/1f349/lotus/postfix-config/config-parser"
"github.com/stretchr/testify/assert"
"os"
"path/filepath"
"testing"
)
//go:embed example.cf
var exampleConfig []byte
func TestDecoder_Load(t *testing.T) {
// get working directory
wd, err := os.Getwd()
assert.NoError(t, err)
// read example config
b := bytes.NewReader(exampleConfig)
d := &Decoder{
r: configParser.NewConfigParser(b),
basePath: filepath.Join(wd, "test-data"),
}
assert.NoError(t, d.Load())
}

View File

@ -1,10 +0,0 @@
# this only contains the relevant config properties
#recipient_delimiter = +
virtual_mailbox_domains = mysql:mysql_virtual_domains_maps.cf
virtual_alias_maps = mysql:mysql_virtual_alias_maps.cf, mysql:mysql_virtual_alias_wildcard_maps.cf, mysql:mysql_virtual_alias_domain_maps.cf, mysql:mysql_virtual_alias_user_maps.cf, mysql:mysql_virtual_alias_userdomain_maps.cf, mysql:mysql_virtual_alias_domain_catchall_maps.cf, mysql:mysql_virtual_alias_user_catchall_maps.cf
virtual_mailbox_maps = mysql:mysql_virtual_mailbox_maps.cf, mysql:mysql_virtual_alias_domain_mailbox_maps.cf, mysql:mysql_virtual_alias_user_mailbox_maps.cf, mysql:mysql_virtual_alias_userdomain_mailbox_maps.cf
alias_maps = hash:aliases.txt $virtual_alias_maps
local_recipient_maps = $virtual_mailbox_maps $alias_maps
smtpd_sender_login_maps = unionmap:{ hash:aliases.txt, mysql:mysql_sender_alias_maps.cf, mysql:mysql_virtual_alias_maps.cf, mysql:mysql_virtual_alias_domain_maps.cf, mysql:mysql_virtual_alias_user_maps.cf, mysql:mysql_virtual_alias_userdomain_maps.cf }

View File

@ -1,48 +0,0 @@
package map_provider
import (
"bufio"
"io"
"os"
"strings"
)
type Hash struct {
r io.Reader
v map[string]string
}
var _ MapProvider = &Hash{}
func NewHashMapProvider(filename string) (*Hash, error) {
open, err := os.Open(filename)
if err != nil {
return nil, err
}
return &Hash{open, make(map[string]string)}, nil
}
func (h *Hash) Load() error {
scanner := bufio.NewScanner(h.r)
scanner.Split(bufio.ScanLines)
for scanner.Scan() {
text := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(text, "#") {
continue
}
n := strings.IndexByte(text, ':')
key := strings.TrimSpace(text[:n])
values := strings.Split(text[n+1:], ",")
for _, i := range values {
k := strings.TrimSpace(i)
h.v[k] = key
}
}
return scanner.Err()
}
func (h *Hash) Find(name string) (string, bool) {
v, ok := h.v[name]
return v, ok
}

View File

@ -1,3 +0,0 @@
# See man 5 aliases for format
postmaster: root
test: this, is, an, example

View File

@ -1,23 +0,0 @@
package map_provider
import (
"bytes"
_ "embed"
"github.com/stretchr/testify/assert"
"testing"
)
//go:embed hash_example.txt
var hashExample []byte
func TestHash_Load(t *testing.T) {
h := &Hash{r: bytes.NewReader(hashExample), v: make(map[string]string)}
assert.NoError(t, h.Load())
assert.Equal(t, map[string]string{
"root": "postmaster",
"this": "test",
"is": "test",
"an": "test",
"example": "test",
}, h.v)
}

View File

@ -1,23 +0,0 @@
package map_provider
// MapProvider is an interface to allow looking up mapped values from variables,
// hash files or mysql queries.
type MapProvider interface {
Find(name string) (string, bool)
}
// SequenceMapProvider calls Find against each provider in a slice and outputs
// the true mapped value of the input. first mapped value found. If the input was
// not found then "", false is returned.
type SequenceMapProvider []MapProvider
func (s SequenceMapProvider) Find(name string) (string, bool) {
for _, i := range s {
if find, ok := i.Find(name); ok {
return find, true
}
}
return "", false
}
var _ MapProvider = SequenceMapProvider{}

View File

@ -1,76 +0,0 @@
package map_provider
import (
"errors"
"sort"
"strings"
"unicode"
)
var (
ErrMissingArgument = errors.New("missing argument")
ErrInvalidRawQuery = errors.New("invalid raw query")
)
type PreparedQuery struct {
raw string
params map[int]byte
}
func NewPreparedQuery(raw string) (*PreparedQuery, error) {
var s strings.Builder
origin := 0
params := make(map[int]byte)
for {
n := strings.IndexByte(raw[origin:], '%')
if n == -1 {
break
}
n += origin
if n+1 == len(raw) {
return nil, ErrInvalidRawQuery
}
s.WriteString(raw[origin:n])
if raw[n+1] == '%' {
s.WriteByte('%')
origin = n + 1
continue
}
params[s.Len()] = toLower(raw[n+1])
origin = n + 2
}
s.WriteString(raw[origin:])
return &PreparedQuery{
raw: s.String(),
params: params,
}, nil
}
func (p *PreparedQuery) Format(args map[byte]string) (string, error) {
var s strings.Builder
keys := make([]int, 0, len(p.params))
for k := range p.params {
keys = append(keys, k)
}
sort.Ints(keys)
origin := 0
for _, k := range keys {
r, ok := args[p.params[k]]
if !ok {
return "", ErrMissingArgument
}
// write up to and including the next parameter
s.WriteString(p.raw[origin:k])
s.WriteString(strings.ReplaceAll(r, "'", ""))
origin = k
}
// write the rest of the query
s.WriteString(p.raw[origin:])
return s.String(), nil
}
func toLower(a byte) byte {
return byte(unicode.ToLower(rune(a)))
}

View File

@ -1,40 +0,0 @@
package map_provider
import (
"github.com/stretchr/testify/assert"
"testing"
)
const (
testQuery = "SELECT aliasMap.goto FROM aliasMap,aliasdomainMap WHERE aliasdomainMap.domain='%d' AND aliasMap.address = CONCAT('%u', '@', aliasdomainMap.goto) AND aliasMap.active > 0 AND aliasdomainMap.active > 0"
testQueryRaw = "SELECT aliasMap.goto FROM aliasMap,aliasdomainMap WHERE aliasdomainMap.domain='' AND aliasMap.address = CONCAT('', '@', aliasdomainMap.goto) AND aliasMap.active > 0 AND aliasdomainMap.active > 0"
testQueryFormat = "SELECT aliasMap.goto FROM aliasMap,aliasdomainMap WHERE aliasdomainMap.domain='example.com' AND aliasMap.address = CONCAT('test', '@', aliasdomainMap.goto) AND aliasMap.active > 0 AND aliasdomainMap.active > 0"
)
func TestNewPreparedQuery(t *testing.T) {
query, err := NewPreparedQuery(testQuery)
assert.NoError(t, err)
assert.Equal(t, PreparedQuery{
raw: testQueryRaw,
params: map[int]byte{
79: 'd',
112: 'u',
},
}, *query)
}
func TestPreparedQuery_Format(t *testing.T) {
query := &PreparedQuery{
raw: testQueryRaw,
params: map[int]byte{
79: 'd',
112: 'u',
},
}
format, err := query.Format(map[byte]string{
'd': "example.com",
'u': "test",
})
assert.NoError(t, err)
assert.Equal(t, testQueryFormat, format)
}

View File

@ -1,112 +0,0 @@
package map_provider
import (
"database/sql"
configParser "github.com/1f349/lotus/postfix-config/config-parser"
"github.com/go-sql-driver/mysql"
_ "github.com/go-sql-driver/mysql"
"io"
"os"
"regexp"
"strings"
)
var checkUatD = regexp.MustCompile("^[^@]+@[^@]+$")
type MySql struct {
r io.Reader
db *sql.DB
query *PreparedQuery
}
var _ MapProvider = &MySql{}
func NewMySqlMapProvider(filename string) (*MySql, error) {
open, err := os.Open(filename)
if err != nil {
return nil, err
}
return &MySql{r: open}, nil
}
func (m *MySql) Load() error {
p := configParser.NewConfigParser(m.r)
c := mysql.NewConfig()
var q string
for p.Scan() {
k, v := p.Pair()
switch k {
case "user":
c.User = v
case "password":
c.Passwd = v
case "hosts":
c.Net = "tcp"
c.Addr = v
case "dbname":
c.DBName = v
case "query":
q = v
}
}
if err := p.Err(); err != nil {
return err
}
q2, err := NewPreparedQuery(q)
if err != nil {
return err
}
m.query = q2
// try opening connection
db, err := sql.Open("mysql", c.FormatDSN())
if err != nil {
return err
}
m.db = db
return db.Ping()
}
func (m *MySql) Find(name string) (string, bool) {
format, err := m.query.Format(genQueryArgs(name))
return format, err == nil
}
// genQueryArgs converts an input key into the % encoded parameters
//
// %s - full input key
// %u - user part of user@domain or full input key
// %d - domain part of user@domain or missing parameter
// %[1-9] - replaced with the most significant component of the input key's domain
// for `user@mail.example.com` %1 = com, %2 = example, %3 = mail
// otherwise they are missing parameters
func genQueryArgs(name string) map[byte]string {
args := make(map[byte]string)
args['s'] = name
args['u'] = name
if checkUatD.MatchString(name) {
n := strings.IndexByte(name, '@')
args['u'] = name[:n]
args['d'] = name[n+1:]
genDomainArgs(args, name[n+1:])
}
return args
}
// genDomainArgs replaces with the most significant component of the input key's
// domain for `user@mail.example.com` %1 = com, %2 = example, %3 = mail,
// otherwise they are missing parameters
func genDomainArgs(args map[byte]string, s string) {
i, l := byte(1), len(s)
for {
n := strings.LastIndexByte(s, '.')
if n == -1 {
break
}
args[(i + '0')] = s[n+1 : l]
l = n
}
}

View File

@ -1,5 +0,0 @@
user = example
password = 1234
hosts = 127.0.0.1
dbname = mail
query = SELECT aliasMap.goto FROM aliasMap,aliasdomainMap WHERE aliasdomainMap.domain='%d' AND aliasMap.address = CONCAT('%u', '@', aliasdomainMap.goto) AND aliasMap.active > 0 AND aliasdomainMap.active > 0

View File

@ -1,12 +0,0 @@
package map_provider
type Variable struct {
Name string
Value MapProvider
}
func (v *Variable) Find(name string) (string, bool) {
return v.Value.Find(name)
}
var _ MapProvider = &Variable{}

View File

@ -1 +0,0 @@
a: a b

5
postfix-lookup/lookup.sh Normal file
View File

@ -0,0 +1,5 @@
#!/bin/bash
virtual_alias_maps=$(postconf -h virtual_alias_maps | tr ',' '\n')
alias_to_lookup="$1"
result=$(echo "$virtual_alias_maps" | xargs -I {} postmap -q "$alias_to_lookup" {})
echo "result=$result"

View File

@ -0,0 +1,47 @@
package postfix_lookup
import (
"bufio"
"bytes"
_ "embed"
"errors"
"os/exec"
"strings"
)
var ErrInvalidAlias = errors.New("invalid alias")
//go:embed lookup.sh
var lookupScript string
type PostfixLookup struct {
execCmd func(key string) ([]byte, error)
}
func NewPostfixLookup() *PostfixLookup {
return &PostfixLookup{
execCmd: func(key string) ([]byte, error) {
return exec.Command("bash", "-c", lookupScript, "--", key).Output()
},
}
}
func (d *PostfixLookup) Lookup(key string) (string, error) {
output, err := d.execCmd(key)
if err != nil {
return "", err
}
s := bufio.NewScanner(bytes.NewReader(output))
for s.Scan() {
a := s.Text()
n := strings.IndexByte(a, '=')
if n != -1 && a[:n] == "result" {
return a[n+1:], nil
}
}
if err := s.Err(); err != nil {
return "", err
}
return "", ErrInvalidAlias
}

View File

@ -0,0 +1,45 @@
package postfix_lookup
import (
_ "embed"
"github.com/stretchr/testify/assert"
"strings"
"testing"
)
var postfixLookupData = []struct {
Input string
Output string
}{
{"hi@example.com", "admin@example.com"},
{"test@example.com", "admin@example.com"},
{"user@example.org", "admin@example.org"},
{"user@example.net", ""},
}
func TestDecoder_Load(t *testing.T) {
p := &PostfixLookup{execCmd: func(key string) ([]byte, error) {
n := strings.IndexByte(key, '@')
if n == -1 {
return []byte{}, nil
}
addr := key[n+1:]
switch addr {
case "example.com", "example.org":
return []byte("result=admin@" + addr + "\nadmin@" + addr + "\n"), nil
}
return []byte{}, nil
}}
for _, i := range postfixLookupData {
t.Run(i.Input, func(t *testing.T) {
lookup, err := p.Lookup(i.Input)
if i.Output == "" && err == nil {
t.Fatal("expected error for empty output test case")
}
if i.Output != "" && err != nil {
t.Fatal("expected no error for non-empty output test case")
}
assert.Equal(t, i.Output, lookup)
})
}
}

View File

@ -22,7 +22,7 @@ type Json struct {
Bcc string `json:"bcc"` Bcc string `json:"bcc"`
Subject string `json:"subject"` Subject string `json:"subject"`
BodyType string `json:"body_type"` BodyType string `json:"body_type"`
Body string `json:"body"` Body string `json:"Body"`
} }
func (s Json) parseAddresses() (addrFrom, addrReplyTo, addrTo, addrCc, addrBcc []*mail.Address, err error) { func (s Json) parseAddresses() (addrFrom, addrReplyTo, addrTo, addrCc, addrBcc []*mail.Address, err error) {
@ -31,23 +31,31 @@ func (s Json) parseAddresses() (addrFrom, addrReplyTo, addrTo, addrCc, addrBcc [
if err != nil { if err != nil {
return return
} }
if s.ReplyTo != "" {
addrReplyTo, err = mail.ParseAddressList(s.ReplyTo) addrReplyTo, err = mail.ParseAddressList(s.ReplyTo)
if err != nil { if err != nil {
return return
} }
}
if s.To != "" {
addrTo, err = mail.ParseAddressList(s.To) addrTo, err = mail.ParseAddressList(s.To)
if err != nil { if err != nil {
return return
} }
}
if s.Cc != "" {
addrCc, err = mail.ParseAddressList(s.Cc) addrCc, err = mail.ParseAddressList(s.Cc)
if err != nil { if err != nil {
return return
} }
}
if s.Bcc != "" {
addrBcc, err = mail.ParseAddressList(s.Bcc) addrBcc, err = mail.ParseAddressList(s.Bcc)
}
return return
} }
func (s Json) PrepareMail() (*Mail, error) { func (s Json) PrepareMail(now time.Time) (*Mail, error) {
// parse addresses from json data // parse addresses from json data
addrFrom, addrReplyTo, addrTo, addrCc, addrBcc, err := s.parseAddresses() addrFrom, addrReplyTo, addrTo, addrCc, addrBcc, err := s.parseAddresses()
if err != nil { if err != nil {
@ -64,7 +72,7 @@ func (s Json) PrepareMail() (*Mail, error) {
// set base headers // set base headers
var h mail.Header var h mail.Header
h.SetDate(time.Now()) h.SetDate(now)
h.SetSubject(s.Subject) h.SetSubject(s.Subject)
h.SetAddressList("From", addrFrom) h.SetAddressList("From", addrFrom)
h.SetAddressList("Reply-To", addrReplyTo) h.SetAddressList("Reply-To", addrReplyTo)
@ -87,8 +95,8 @@ func (s Json) PrepareMail() (*Mail, error) {
} }
m := &Mail{ m := &Mail{
from: from, From: from,
deliver: CreateSenderSlice(addrTo, addrCc, addrBcc), Deliver: CreateSenderSlice(addrTo, addrCc, addrBcc),
} }
out := new(bytes.Buffer) out := new(bytes.Buffer)
@ -96,6 +104,6 @@ func (s Json) PrepareMail() (*Mail, error) {
return nil, err return nil, err
} }
m.body = out.Bytes() m.Body = out.Bytes()
return m, nil return m, nil
} }

View File

@ -11,9 +11,9 @@ type Smtp struct {
} }
type Mail struct { type Mail struct {
from string From string
deliver []string Deliver []string
body []byte Body []byte
} }
var defaultDialer = smtp.Dial var defaultDialer = smtp.Dial
@ -26,10 +26,10 @@ func (s *Smtp) Send(mail *Mail) error {
} }
// use a reader to send bytes // use a reader to send bytes
r := bytes.NewReader(mail.body) r := bytes.NewReader(mail.Body)
// send mail // send mail
return smtpClient.SendMail(mail.from, mail.deliver, r) return smtpClient.SendMail(mail.From, mail.Deliver, r)
} }
func CreateSenderSlice(to, cc, bcc []*mail.Address) []string { func CreateSenderSlice(to, cc, bcc []*mail.Address) []string {

View File

@ -51,7 +51,7 @@ func TestSmtp_Send(t *testing.T) {
} }
s := &Smtp{Server: "localhost:25"} s := &Smtp{Server: "localhost:25"}
err := s.Send(&Mail{from: "test@localhost", deliver: []string{"a@localhost", "b@localhost"}, body: sendTestMessage}) err := s.Send(&Mail{From: "test@localhost", Deliver: []string{"a@localhost", "b@localhost"}, Body: sendTestMessage})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []byte("MAIL test@localhost\n"), <-serverData) assert.Equal(t, []byte("MAIL test@localhost\n"), <-serverData)
assert.Equal(t, []byte("RCPT a@localhost\n"), <-serverData) assert.Equal(t, []byte("RCPT a@localhost\n"), <-serverData)