diff --git a/router/manager.go b/router/manager.go index af11ed6..29e7549 100644 --- a/router/manager.go +++ b/router/manager.go @@ -146,18 +146,11 @@ func (m *Manager) GetAllRoutes(hosts []string) ([]target.RouteWithActive, error) return []target.RouteWithActive{}, nil } - var searchString strings.Builder - searchString.WriteString("WHERE ") - for i := range hosts { - if i != 0 { - searchString.WriteString(" OR ") - } - searchString.WriteString("source LIKE ?") - } + searchString, hostArgs := generateRouteAndRedirectSearch(hosts) s := make([]target.RouteWithActive, 0) - query, err := m.db.Query(`SELECT source, destination, flags, active FROM routes `+searchString.String(), hosts) + query, err := m.db.Query(`SELECT source, destination, flags, active FROM routes `+searchString, hostArgs) if err != nil { return nil, err } @@ -167,7 +160,14 @@ func (m *Manager) GetAllRoutes(hosts []string) ([]target.RouteWithActive, error) if query.Scan(&a.Src, &a.Dst, &a.Flags, &a.Active) != nil { return nil, err } - s = append(s, a) + + for _, i := range hosts { + // if this is never true then the domain was mistakenly grabbed from the database + if a.OnDomain(i) { + s = append(s, a) + break + } + } } return s, nil @@ -188,18 +188,11 @@ func (m *Manager) GetAllRedirects(hosts []string) ([]target.RedirectWithActive, return []target.RedirectWithActive{}, nil } - var searchString strings.Builder - searchString.WriteString("WHERE ") - for i := range hosts { - if i != 0 { - searchString.WriteString(" OR ") - } - searchString.WriteString("source LIKE ?") - } + searchString, hostArgs := generateRouteAndRedirectSearch(hosts) s := make([]target.RedirectWithActive, 0) - query, err := m.db.Query(`SELECT source, destination, flags, code, active FROM redirects `+searchString.String(), hosts) + query, err := m.db.Query(`SELECT source, destination, flags, code, active FROM redirects `+searchString, hostArgs) if err != nil { return nil, err } @@ -209,7 +202,14 @@ func (m *Manager) GetAllRedirects(hosts []string) ([]target.RedirectWithActive, if query.Scan(&a.Src, &a.Dst, &a.Flags, &a.Code, &a.Active) != nil { return nil, err } - s = append(s, a) + + for _, i := range hosts { + // if this is never true then the domain was mistakenly grabbed from the database + if a.OnDomain(i) { + s = append(s, a) + break + } + } } return s, nil @@ -224,3 +224,24 @@ func (m *Manager) DeleteRedirect(source string) error { _, err := m.db.Exec(`UPDATE redirects SET active = 0 WHERE source = ?`, source) return err } + +func generateRouteAndRedirectSearch(hosts []string) (string, []string) { + var searchString strings.Builder + searchString.WriteString("WHERE ") + + hostArgs := make([]string, len(hosts)*2) + for i := range hosts { + if i != 0 { + searchString.WriteString(" OR ") + } + // these like checks are not perfect but do reduce load on the database + searchString.WriteString("source LIKE '%' + ? + '/%'") + searchString.WriteString(" OR source LIKE '%' + ?") + + // loads the hostname into even and odd args + hostArgs[i*2] = hosts[i] + hostArgs[i*2+1] = hosts[i] + } + + return searchString.String(), hostArgs +} diff --git a/target/redirect.go b/target/redirect.go index 12340d1..e3f7677 100644 --- a/target/redirect.go +++ b/target/redirect.go @@ -23,6 +23,17 @@ type RedirectWithActive struct { Active bool `json:"active"` } +func (r Redirect) OnDomain(domain string) bool { + // if there is no / then the first part is still the domain + domainPart, _, _ := strings.Cut(r.Src, "/") + if domainPart == domain { + return true + } + + // domainPart could start with a subdomain + return strings.HasSuffix(domainPart, "."+domain) +} + func (r Redirect) HasFlag(flag Flags) bool { return r.Flags&flag != 0 } diff --git a/target/route.go b/target/route.go index f37ecc2..a79c9d4 100644 --- a/target/route.go +++ b/target/route.go @@ -49,6 +49,17 @@ type RouteWithActive struct { Active bool `json:"active"` } +func (r Route) OnDomain(domain string) bool { + // if there is no / then the first part is still the domain + domainPart, _, _ := strings.Cut(r.Src, "/") + if domainPart == domain { + return true + } + + // domainPart could start with a subdomain + return strings.HasSuffix(domainPart, "."+domain) +} + func (r Route) HasFlag(flag Flags) bool { return r.Flags&flag != 0 }