mirror of
https://github.com/1f349/lotus.git
synced 2025-01-12 10:16:38 +00:00
113 lines
2.3 KiB
Go
113 lines
2.3 KiB
Go
|
package map_provider
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
configParser "github.com/1f349/lotus/postfix-config/config-parser"
|
||
|
"github.com/go-sql-driver/mysql"
|
||
|
_ "github.com/go-sql-driver/mysql"
|
||
|
"io"
|
||
|
"os"
|
||
|
"regexp"
|
||
|
"strings"
|
||
|
)
|
||
|
|
||
|
var checkUatD = regexp.MustCompile("^[^@]+@[^@]+$")
|
||
|
|
||
|
type MySql struct {
|
||
|
r io.Reader
|
||
|
db *sql.DB
|
||
|
query *PreparedQuery
|
||
|
}
|
||
|
|
||
|
var _ MapProvider = &MySql{}
|
||
|
|
||
|
func NewMySqlMapProvider(filename string) (*MySql, error) {
|
||
|
open, err := os.Open(filename)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return &MySql{r: open}, nil
|
||
|
}
|
||
|
|
||
|
func (m *MySql) Load() error {
|
||
|
p := configParser.NewConfigParser(m.r)
|
||
|
c := mysql.NewConfig()
|
||
|
var q string
|
||
|
for p.Scan() {
|
||
|
k, v := p.Pair()
|
||
|
switch k {
|
||
|
case "user":
|
||
|
c.User = v
|
||
|
case "password":
|
||
|
c.Passwd = v
|
||
|
case "hosts":
|
||
|
c.Net = "tcp"
|
||
|
c.Addr = v
|
||
|
case "dbname":
|
||
|
c.DBName = v
|
||
|
case "query":
|
||
|
q = v
|
||
|
}
|
||
|
}
|
||
|
if err := p.Err(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
q2, err := NewPreparedQuery(q)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
m.query = q2
|
||
|
|
||
|
// try opening connection
|
||
|
db, err := sql.Open("mysql", c.FormatDSN())
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
m.db = db
|
||
|
|
||
|
return db.Ping()
|
||
|
}
|
||
|
|
||
|
func (m *MySql) Find(name string) (string, bool) {
|
||
|
format, err := m.query.Format(genQueryArgs(name))
|
||
|
return format, err == nil
|
||
|
}
|
||
|
|
||
|
// genQueryArgs converts an input key into the % encoded parameters
|
||
|
//
|
||
|
// %s - full input key
|
||
|
// %u - user part of user@domain or full input key
|
||
|
// %d - domain part of user@domain or missing parameter
|
||
|
// %[1-9] - replaced with the most significant component of the input key's domain
|
||
|
// for `user@mail.example.com` %1 = com, %2 = example, %3 = mail
|
||
|
// otherwise they are missing parameters
|
||
|
func genQueryArgs(name string) map[byte]string {
|
||
|
args := make(map[byte]string)
|
||
|
args['s'] = name
|
||
|
args['u'] = name
|
||
|
if checkUatD.MatchString(name) {
|
||
|
n := strings.IndexByte(name, '@')
|
||
|
args['u'] = name[:n]
|
||
|
args['d'] = name[n+1:]
|
||
|
|
||
|
genDomainArgs(args, name[n+1:])
|
||
|
}
|
||
|
return args
|
||
|
}
|
||
|
|
||
|
// genDomainArgs replaces with the most significant component of the input key's
|
||
|
// domain for `user@mail.example.com` %1 = com, %2 = example, %3 = mail,
|
||
|
// otherwise they are missing parameters
|
||
|
func genDomainArgs(args map[byte]string, s string) {
|
||
|
i, l := byte(1), len(s)
|
||
|
for {
|
||
|
n := strings.LastIndexByte(s, '.')
|
||
|
if n == -1 {
|
||
|
break
|
||
|
}
|
||
|
args[(i + '0')] = s[n+1 : l]
|
||
|
l = n
|
||
|
}
|
||
|
}
|