Refactor entire hosts reading and writing code
This commit is contained in:
parent
f6d0ea83e4
commit
6238999c08
11
go.mod
11
go.mod
@ -2,4 +2,13 @@ module golang.captainalm.com/HostPersister
|
||||
|
||||
go 1.18
|
||||
|
||||
require github.com/joho/godotenv v1.5.1
|
||||
require (
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/stretchr/testify v1.8.4
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
10
go.sum
10
go.sum
@ -1,2 +1,12 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
@ -1,76 +0,0 @@
|
||||
package hosts
|
||||
|
||||
import "strings"
|
||||
|
||||
func NewHostsEntry(lineIn string) Entry {
|
||||
trLineIn := strings.ReplaceAll(strings.Trim(lineIn, "\r\n"), " ", " ")
|
||||
lineSplt := strings.Split(trLineIn, " ")
|
||||
if strings.HasPrefix(strings.TrimPrefix(trLineIn, " "), "#") {
|
||||
return Entry{
|
||||
IPAddress: "",
|
||||
Domains: nil,
|
||||
comment: trLineIn,
|
||||
}
|
||||
} else if len(lineSplt) > 1 {
|
||||
var theDomains []string
|
||||
for i := 1; i < len(lineSplt); i++ {
|
||||
if lineSplt[i] == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(lineSplt[i], "#") {
|
||||
break
|
||||
}
|
||||
theDomains = append(theDomains, lineSplt[i])
|
||||
}
|
||||
theComment := ""
|
||||
theCommentStart := strings.Index(trLineIn, "#")
|
||||
if theCommentStart > -1 {
|
||||
theComment = trLineIn[theCommentStart:]
|
||||
}
|
||||
return Entry{
|
||||
IPAddress: lineSplt[0],
|
||||
Domains: theDomains,
|
||||
comment: theComment,
|
||||
}
|
||||
} else {
|
||||
return Entry{
|
||||
IPAddress: "",
|
||||
Domains: nil,
|
||||
comment: "",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
IPAddress string
|
||||
Domains []string
|
||||
comment string
|
||||
}
|
||||
|
||||
func (e Entry) IsFilled() bool {
|
||||
return e.IPAddress != "" && len(e.Domains) > 0
|
||||
}
|
||||
|
||||
func (e Entry) HasDomain(domain string) bool {
|
||||
if !e.IsFilled() {
|
||||
return false
|
||||
}
|
||||
for _, c := range e.Domains {
|
||||
if strings.EqualFold(c, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e Entry) ToLine() string {
|
||||
if e.IsFilled() {
|
||||
toReturn := []string{e.IPAddress}
|
||||
toReturn = append(toReturn, e.Domains...)
|
||||
if e.comment != "" {
|
||||
toReturn = append(toReturn, e.comment)
|
||||
}
|
||||
return strings.Join(toReturn, " ")
|
||||
}
|
||||
return e.comment
|
||||
}
|
122
hosts/File.go
122
hosts/File.go
@ -1,122 +0,0 @@
|
||||
package hosts
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const readBufferSize = 8192
|
||||
|
||||
func NewHostsFile(filePath string) (*File, error) {
|
||||
theHostFile := &File{
|
||||
filePath: filePath,
|
||||
}
|
||||
err := theHostFile.ReadHostsFile()
|
||||
if err == nil {
|
||||
return theHostFile, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type File struct {
|
||||
filePath string
|
||||
Entries []Entry
|
||||
lineEnding string
|
||||
}
|
||||
|
||||
func (f *File) ReadHostsFile() error {
|
||||
f.Entries = nil
|
||||
theFile, err := os.Open(f.filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer theFile.Close()
|
||||
var lenIn int
|
||||
f.lineEnding = ""
|
||||
theCBuffer := ""
|
||||
theBuffer := make([]byte, readBufferSize)
|
||||
for err == nil {
|
||||
lenIn, err = theFile.Read(theBuffer)
|
||||
if lenIn > 0 {
|
||||
theCBuffer += string(theBuffer[:lenIn])
|
||||
if f.lineEnding == "" {
|
||||
if strings.Contains(theCBuffer, "\r\n") {
|
||||
f.lineEnding = "\r\n"
|
||||
} else if strings.Contains(theCBuffer, "\r") {
|
||||
f.lineEnding = "\r"
|
||||
} else if strings.Contains(theCBuffer, "\n") {
|
||||
f.lineEnding = "\n"
|
||||
}
|
||||
}
|
||||
if f.lineEnding == "\r\n" {
|
||||
strings.ReplaceAll(theCBuffer, "\r\n", "\n")
|
||||
} else if f.lineEnding == "\r" {
|
||||
strings.ReplaceAll(theCBuffer, "\r", "\n")
|
||||
}
|
||||
splt := strings.Split(theCBuffer, "\n")
|
||||
for i := 0; i < len(splt)-1; i++ {
|
||||
f.Entries = append(f.Entries, NewHostsEntry(splt[i]))
|
||||
}
|
||||
theCBuffer = splt[len(splt)-1]
|
||||
}
|
||||
}
|
||||
if err != io.EOF {
|
||||
return err
|
||||
}
|
||||
if theCBuffer != "" {
|
||||
f.Entries = append(f.Entries, NewHostsEntry(theCBuffer))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f File) WriteHostsFile() error {
|
||||
theFile, err := os.OpenFile(f.filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer theFile.Close()
|
||||
for _, entry := range f.Entries {
|
||||
_, err = theFile.WriteString(entry.ToLine() + f.lineEnding)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f File) HasDomain(domain string) bool {
|
||||
for _, entry := range f.Entries {
|
||||
if entry.HasDomain(domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (f File) indexDomainSingleton(domain string) int {
|
||||
for i, entry := range f.Entries {
|
||||
if len(entry.Domains) == 1 && entry.HasDomain(domain) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (f File) HasDomainSingleton(domain string) bool {
|
||||
return f.indexDomainSingleton(domain) > -1
|
||||
}
|
||||
|
||||
func (f *File) OverwriteDomainSingleton(domain string, ipAddress string) {
|
||||
idx := f.indexDomainSingleton(domain)
|
||||
if idx == -1 {
|
||||
f.Entries = append(f.Entries, Entry{
|
||||
IPAddress: ipAddress,
|
||||
Domains: []string{domain},
|
||||
})
|
||||
} else {
|
||||
theEntry := f.Entries[idx]
|
||||
theEntry.IPAddress = ipAddress
|
||||
f.Entries = append(append(f.Entries[:idx], theEntry), f.Entries[idx+1:]...)
|
||||
}
|
||||
}
|
40
hosts/comment-reader.go
Normal file
40
hosts/comment-reader.go
Normal file
@ -0,0 +1,40 @@
|
||||
package hosts
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
)
|
||||
|
||||
type CommentReader struct {
|
||||
r io.Reader
|
||||
over []byte
|
||||
mark byte
|
||||
hit bool
|
||||
}
|
||||
|
||||
var _ io.Reader = &CommentReader{}
|
||||
|
||||
func NewCommentReader(r io.Reader, mark byte) *CommentReader {
|
||||
return &CommentReader{r, nil, mark, false}
|
||||
}
|
||||
|
||||
func (c *CommentReader) Read(p []byte) (n int, err error) {
|
||||
if c.hit {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, err = c.r.Read(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n2 := bytes.IndexByte(p[:n], c.mark)
|
||||
if n2 != -1 {
|
||||
c.over = p[n2:n]
|
||||
n = n2
|
||||
err = io.EOF
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *CommentReader) Comment() io.Reader {
|
||||
return io.MultiReader(bytes.NewReader(c.over), c.r)
|
||||
}
|
49
hosts/comment-reader_test.go
Normal file
49
hosts/comment-reader_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package hosts
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewCommentReader(t *testing.T) {
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
r := NewCommentReader(strings.NewReader("Hello world! # this is a comment"), '#')
|
||||
all, err := io.ReadAll(r)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Hello world! ", string(all))
|
||||
all, err = io.ReadAll(r.Comment())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "# this is a comment", string(all))
|
||||
})
|
||||
t.Run("overflow start", func(t *testing.T) {
|
||||
r := NewCommentReader(strings.NewReader(strings.Repeat("Hello world!", 2048)+" # this is a comment"), '#')
|
||||
all, err := io.ReadAll(r)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, strings.Repeat("Hello world!", 2048)+" ", string(all))
|
||||
all, err = io.ReadAll(r.Comment())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "# this is a comment", string(all))
|
||||
})
|
||||
t.Run("overflow comment", func(t *testing.T) {
|
||||
r := NewCommentReader(strings.NewReader(strings.Repeat("Hello world!", 2048)+" # "+strings.Repeat("this is a comment", 2048)), '#')
|
||||
all, err := io.ReadAll(r)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, strings.Repeat("Hello world!", 2048)+" ", string(all))
|
||||
all, err = io.ReadAll(r.Comment())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "# "+strings.Repeat("this is a comment", 2048), string(all))
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzCommentReader(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, a string) {
|
||||
r := NewCommentReader(strings.NewReader(a), '#')
|
||||
b1, err := io.ReadAll(r)
|
||||
assert.NoError(t, err)
|
||||
b2, err := io.ReadAll(r.Comment())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, string(b1)+string(b2), a)
|
||||
})
|
||||
}
|
90
hosts/entry.go
Normal file
90
hosts/entry.go
Normal file
@ -0,0 +1,90 @@
|
||||
package hosts
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ErrInvalidEntry = errors.New("invalid entry")
|
||||
|
||||
type Entry struct {
|
||||
IPAddress string
|
||||
Domains []string
|
||||
Comment string
|
||||
}
|
||||
|
||||
func (e Entry) IsFilled() bool {
|
||||
return e.IPAddress != "" && len(e.Domains) > 0
|
||||
}
|
||||
|
||||
func (e Entry) HasDomain(domain string) bool {
|
||||
if !e.IsFilled() {
|
||||
return false
|
||||
}
|
||||
for _, c := range e.Domains {
|
||||
if strings.EqualFold(c, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e Entry) ToLine() string {
|
||||
var b strings.Builder
|
||||
filled := e.IsFilled()
|
||||
if filled {
|
||||
b.WriteString(e.IPAddress)
|
||||
for _, i := range e.Domains {
|
||||
b.WriteByte(' ')
|
||||
b.WriteString(i)
|
||||
}
|
||||
}
|
||||
if e.Comment != "" {
|
||||
b.Grow(len(e.Comment) + 2)
|
||||
if filled {
|
||||
b.WriteByte(' ')
|
||||
}
|
||||
b.WriteByte('#')
|
||||
b.WriteString(e.Comment)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func ParseEntryString(line string) (Entry, error) {
|
||||
return ParseEntry(strings.NewReader(line))
|
||||
}
|
||||
|
||||
func ParseEntry(line io.Reader) (entry Entry, err error) {
|
||||
cr := NewCommentReader(line, '#')
|
||||
sc := bufio.NewScanner(cr)
|
||||
sc.Split(bufio.ScanWords)
|
||||
|
||||
isFirst := true
|
||||
for sc.Scan() {
|
||||
t := sc.Text()
|
||||
if isFirst {
|
||||
entry.IPAddress = t
|
||||
isFirst = false
|
||||
} else {
|
||||
entry.Domains = append(entry.Domains, t)
|
||||
}
|
||||
}
|
||||
err = sc.Err()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// invalid if the ip address is set but no domains are added
|
||||
if entry.IPAddress != "" && len(entry.Domains) < 1 {
|
||||
err = ErrInvalidEntry
|
||||
return
|
||||
}
|
||||
var cAll []byte
|
||||
cAll, err = io.ReadAll(cr.Comment())
|
||||
if len(cAll) >= 1 {
|
||||
entry.Comment = string(cAll[1:])
|
||||
}
|
||||
return
|
||||
}
|
72
hosts/entry_test.go
Normal file
72
hosts/entry_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
package hosts
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func TestParseEntry(t *testing.T) {
|
||||
entry, err := ParseEntry(strings.NewReader("127.0.0.1 myself.local another.local # this is a test comment"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, Entry{
|
||||
IPAddress: "127.0.0.1",
|
||||
Domains: []string{"myself.local", "another.local"},
|
||||
Comment: " this is a test comment",
|
||||
}, entry)
|
||||
|
||||
entry, err = ParseEntry(strings.NewReader("127.0.0.1 test.local #another comment"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, Entry{
|
||||
IPAddress: "127.0.0.1",
|
||||
Domains: []string{"test.local"},
|
||||
Comment: "another comment",
|
||||
}, entry)
|
||||
|
||||
entry, err = ParseEntry(strings.NewReader("127.0.0.1 "))
|
||||
assert.EqualError(t, err, "invalid entry")
|
||||
}
|
||||
|
||||
func TestEntry_IsFilled(t *testing.T) {
|
||||
assert.True(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local"}}.IsFilled())
|
||||
assert.False(t, Entry{}.IsFilled())
|
||||
assert.False(t, Entry{IPAddress: "127.0.0.1"}.IsFilled())
|
||||
}
|
||||
|
||||
func TestEntry_HasDomain(t *testing.T) {
|
||||
assert.True(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}}.HasDomain("myself.local"))
|
||||
assert.True(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}}.HasDomain("MYSELF.local"))
|
||||
assert.False(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"notme.local", "another.local"}}.HasDomain("myself.local"))
|
||||
}
|
||||
|
||||
func TestEntry_ToLine(t *testing.T) {
|
||||
assert.Equal(t, "127.0.0.1 myself.local another.local # this is a test comment", Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}, Comment: " this is a test comment"}.ToLine())
|
||||
assert.Equal(t, "# this is a test comment", Entry{IPAddress: "", Domains: nil, Comment: " this is a test comment"}.ToLine())
|
||||
assert.Equal(t, "127.0.0.1 myself.local another.local", Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}}.ToLine())
|
||||
}
|
||||
|
||||
func FuzzParseEntry(f *testing.F) {
|
||||
f.Add("127.0.0.1", "myself.local", "another.local", "this is a test comment")
|
||||
f.Fuzz(func(t *testing.T, a, b, c, d string) {
|
||||
for _, i := range []string{a, b, c} {
|
||||
i = strings.TrimSpace(i)
|
||||
if i == "" || strings.ContainsFunc(i, func(r rune) bool {
|
||||
return r == '#' || unicode.IsSpace(r)
|
||||
}) {
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
|
||||
entry := Entry{IPAddress: a, Domains: []string{b, c}, Comment: d}
|
||||
e2, err := ParseEntryString(entry.ToLine())
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.Equal(t, Entry{
|
||||
IPAddress: strings.TrimSpace(a),
|
||||
Domains: []string{strings.TrimSpace(b), strings.TrimSpace(c)},
|
||||
Comment: d,
|
||||
}, e2)
|
||||
})
|
||||
}
|
97
hosts/file.go
Normal file
97
hosts/file.go
Normal file
@ -0,0 +1,97 @@
|
||||
package hosts
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"os"
|
||||
)
|
||||
|
||||
func NewHostsFile(filePath string) (*File, error) {
|
||||
theHostFile := &File{filePath: filePath}
|
||||
err := theHostFile.ReadHostsFile()
|
||||
if err == nil {
|
||||
return theHostFile, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type File struct {
|
||||
filePath string
|
||||
Entries []Entry
|
||||
LF string
|
||||
}
|
||||
|
||||
func (f *File) ReadHostsFile() error {
|
||||
f.Entries = nil
|
||||
theFile, err := os.Open(f.filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer theFile.Close()
|
||||
|
||||
sc := bufio.NewScanner(theFile)
|
||||
sc.Split(bufio.ScanLines)
|
||||
for sc.Scan() {
|
||||
t := sc.Text()
|
||||
entry, err := ParseEntryString(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Entries = append(f.Entries, entry)
|
||||
}
|
||||
return sc.Err()
|
||||
}
|
||||
|
||||
func (f *File) WriteHostsFile() error {
|
||||
// default LF to \n
|
||||
if f.LF == "" {
|
||||
f.LF = "\n"
|
||||
}
|
||||
theFile, err := os.OpenFile(f.filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer theFile.Close()
|
||||
for _, entry := range f.Entries {
|
||||
_, err = theFile.WriteString(entry.ToLine() + f.LF)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *File) HasDomain(domain string) bool {
|
||||
for _, entry := range f.Entries {
|
||||
if entry.HasDomain(domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *File) indexDomainSingleton(domain string) int {
|
||||
for i, entry := range f.Entries {
|
||||
if len(entry.Domains) == 1 && entry.HasDomain(domain) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (f *File) HasDomainSingleton(domain string) bool {
|
||||
return f.indexDomainSingleton(domain) > -1
|
||||
}
|
||||
|
||||
func (f *File) OverwriteDomainSingleton(domain string, ipAddress string) {
|
||||
idx := f.indexDomainSingleton(domain)
|
||||
if idx == -1 {
|
||||
f.Entries = append(f.Entries, Entry{
|
||||
IPAddress: ipAddress,
|
||||
Domains: []string{domain},
|
||||
})
|
||||
} else {
|
||||
theEntry := f.Entries[idx]
|
||||
theEntry.IPAddress = ipAddress
|
||||
f.Entries = append(append(f.Entries[:idx], theEntry), f.Entries[idx+1:]...)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user