From 6238999c08ce004810185f9adbe8d2cbcafc0699 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Fri, 12 Jan 2024 01:16:12 +0000 Subject: [PATCH] Refactor entire hosts reading and writing code --- README.md | 2 +- go.mod | 11 +++- go.sum | 10 +++ hosts/Entry.go | 76 ---------------------- hosts/File.go | 122 ----------------------------------- hosts/comment-reader.go | 40 ++++++++++++ hosts/comment-reader_test.go | 49 ++++++++++++++ hosts/entry.go | 90 ++++++++++++++++++++++++++ hosts/entry_test.go | 72 +++++++++++++++++++++ hosts/file.go | 97 ++++++++++++++++++++++++++++ 10 files changed, 369 insertions(+), 200 deletions(-) delete mode 100644 hosts/Entry.go delete mode 100644 hosts/File.go create mode 100644 hosts/comment-reader.go create mode 100644 hosts/comment-reader_test.go create mode 100644 hosts/entry.go create mode 100644 hosts/entry_test.go create mode 100644 hosts/file.go diff --git a/README.md b/README.md index 1f00e57..fc96771 100644 --- a/README.md +++ b/README.md @@ -10,4 +10,4 @@ Maintainer: [Captain ALM](https://code.mrmelon54.com/alfred) License: -BSD 3-Clause \ No newline at end of file +BSD 3-Clause diff --git a/go.mod b/go.mod index 94b2a54..fb0d5fe 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,13 @@ module golang.captainalm.com/HostPersister go 1.18 -require github.com/joho/godotenv v1.5.1 +require ( + github.com/joho/godotenv v1.5.1 + github.com/stretchr/testify v1.8.4 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index d61b19e..26059cd 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/hosts/Entry.go b/hosts/Entry.go deleted file mode 100644 index e97583c..0000000 --- a/hosts/Entry.go +++ /dev/null @@ -1,76 +0,0 @@ -package hosts - -import "strings" - -func NewHostsEntry(lineIn string) Entry { - trLineIn := strings.ReplaceAll(strings.Trim(lineIn, "\r\n"), " ", " ") - lineSplt := strings.Split(trLineIn, " ") - if strings.HasPrefix(strings.TrimPrefix(trLineIn, " "), "#") { - return Entry{ - IPAddress: "", - Domains: nil, - comment: trLineIn, - } - } else if len(lineSplt) > 1 { - var theDomains []string - for i := 1; i < len(lineSplt); i++ { - if lineSplt[i] == "" { - continue - } - if strings.HasPrefix(lineSplt[i], "#") { - break - } - theDomains = append(theDomains, lineSplt[i]) - } - theComment := "" - theCommentStart := strings.Index(trLineIn, "#") - if theCommentStart > -1 { - theComment = trLineIn[theCommentStart:] - } - return Entry{ - IPAddress: lineSplt[0], - Domains: theDomains, - comment: theComment, - } - } else { - return Entry{ - IPAddress: "", - Domains: nil, - comment: "", - } - } -} - -type Entry struct { - IPAddress string - Domains []string - comment string -} - -func (e Entry) IsFilled() bool { - return e.IPAddress != "" && len(e.Domains) > 0 -} - -func (e Entry) HasDomain(domain string) bool { - if !e.IsFilled() { - return false - } - for _, c := range e.Domains { - if strings.EqualFold(c, domain) { - return true - } - } - return false -} - -func (e Entry) ToLine() string { - if e.IsFilled() { - toReturn := []string{e.IPAddress} - toReturn = append(toReturn, e.Domains...) - if e.comment != "" { - toReturn = append(toReturn, e.comment) - } - return strings.Join(toReturn, " ") - } - return e.comment -} diff --git a/hosts/File.go b/hosts/File.go deleted file mode 100644 index 141d7a3..0000000 --- a/hosts/File.go +++ /dev/null @@ -1,122 +0,0 @@ -package hosts - -import ( - "io" - "os" - "strings" -) - -const readBufferSize = 8192 - -func NewHostsFile(filePath string) (*File, error) { - theHostFile := &File{ - filePath: filePath, - } - err := theHostFile.ReadHostsFile() - if err == nil { - return theHostFile, nil - } - return nil, err -} - -type File struct { - filePath string - Entries []Entry - lineEnding string -} - -func (f *File) ReadHostsFile() error { - f.Entries = nil - theFile, err := os.Open(f.filePath) - if err != nil { - return err - } - defer theFile.Close() - var lenIn int - f.lineEnding = "" - theCBuffer := "" - theBuffer := make([]byte, readBufferSize) - for err == nil { - lenIn, err = theFile.Read(theBuffer) - if lenIn > 0 { - theCBuffer += string(theBuffer[:lenIn]) - if f.lineEnding == "" { - if strings.Contains(theCBuffer, "\r\n") { - f.lineEnding = "\r\n" - } else if strings.Contains(theCBuffer, "\r") { - f.lineEnding = "\r" - } else if strings.Contains(theCBuffer, "\n") { - f.lineEnding = "\n" - } - } - if f.lineEnding == "\r\n" { - strings.ReplaceAll(theCBuffer, "\r\n", "\n") - } else if f.lineEnding == "\r" { - strings.ReplaceAll(theCBuffer, "\r", "\n") - } - splt := strings.Split(theCBuffer, "\n") - for i := 0; i < len(splt)-1; i++ { - f.Entries = append(f.Entries, NewHostsEntry(splt[i])) - } - theCBuffer = splt[len(splt)-1] - } - } - if err != io.EOF { - return err - } - if theCBuffer != "" { - f.Entries = append(f.Entries, NewHostsEntry(theCBuffer)) - } - return nil -} - -func (f File) WriteHostsFile() error { - theFile, err := os.OpenFile(f.filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) - if err != nil { - return err - } - defer theFile.Close() - for _, entry := range f.Entries { - _, err = theFile.WriteString(entry.ToLine() + f.lineEnding) - if err != nil { - return err - } - } - return nil -} - -func (f File) HasDomain(domain string) bool { - for _, entry := range f.Entries { - if entry.HasDomain(domain) { - return true - } - } - return false -} - -func (f File) indexDomainSingleton(domain string) int { - for i, entry := range f.Entries { - if len(entry.Domains) == 1 && entry.HasDomain(domain) { - return i - } - } - return -1 -} - -func (f File) HasDomainSingleton(domain string) bool { - return f.indexDomainSingleton(domain) > -1 -} - -func (f *File) OverwriteDomainSingleton(domain string, ipAddress string) { - idx := f.indexDomainSingleton(domain) - if idx == -1 { - f.Entries = append(f.Entries, Entry{ - IPAddress: ipAddress, - Domains: []string{domain}, - }) - } else { - theEntry := f.Entries[idx] - theEntry.IPAddress = ipAddress - f.Entries = append(append(f.Entries[:idx], theEntry), f.Entries[idx+1:]...) - } -} diff --git a/hosts/comment-reader.go b/hosts/comment-reader.go new file mode 100644 index 0000000..14673f4 --- /dev/null +++ b/hosts/comment-reader.go @@ -0,0 +1,40 @@ +package hosts + +import ( + "bytes" + "io" +) + +type CommentReader struct { + r io.Reader + over []byte + mark byte + hit bool +} + +var _ io.Reader = &CommentReader{} + +func NewCommentReader(r io.Reader, mark byte) *CommentReader { + return &CommentReader{r, nil, mark, false} +} + +func (c *CommentReader) Read(p []byte) (n int, err error) { + if c.hit { + return 0, io.EOF + } + n, err = c.r.Read(p) + if err != nil { + return + } + n2 := bytes.IndexByte(p[:n], c.mark) + if n2 != -1 { + c.over = p[n2:n] + n = n2 + err = io.EOF + } + return +} + +func (c *CommentReader) Comment() io.Reader { + return io.MultiReader(bytes.NewReader(c.over), c.r) +} diff --git a/hosts/comment-reader_test.go b/hosts/comment-reader_test.go new file mode 100644 index 0000000..61b784b --- /dev/null +++ b/hosts/comment-reader_test.go @@ -0,0 +1,49 @@ +package hosts + +import ( + "github.com/stretchr/testify/assert" + "io" + "strings" + "testing" +) + +func TestNewCommentReader(t *testing.T) { + t.Run("normal", func(t *testing.T) { + r := NewCommentReader(strings.NewReader("Hello world! # this is a comment"), '#') + all, err := io.ReadAll(r) + assert.NoError(t, err) + assert.Equal(t, "Hello world! ", string(all)) + all, err = io.ReadAll(r.Comment()) + assert.NoError(t, err) + assert.Equal(t, "# this is a comment", string(all)) + }) + t.Run("overflow start", func(t *testing.T) { + r := NewCommentReader(strings.NewReader(strings.Repeat("Hello world!", 2048)+" # this is a comment"), '#') + all, err := io.ReadAll(r) + assert.NoError(t, err) + assert.Equal(t, strings.Repeat("Hello world!", 2048)+" ", string(all)) + all, err = io.ReadAll(r.Comment()) + assert.NoError(t, err) + assert.Equal(t, "# this is a comment", string(all)) + }) + t.Run("overflow comment", func(t *testing.T) { + r := NewCommentReader(strings.NewReader(strings.Repeat("Hello world!", 2048)+" # "+strings.Repeat("this is a comment", 2048)), '#') + all, err := io.ReadAll(r) + assert.NoError(t, err) + assert.Equal(t, strings.Repeat("Hello world!", 2048)+" ", string(all)) + all, err = io.ReadAll(r.Comment()) + assert.NoError(t, err) + assert.Equal(t, "# "+strings.Repeat("this is a comment", 2048), string(all)) + }) +} + +func FuzzCommentReader(f *testing.F) { + f.Fuzz(func(t *testing.T, a string) { + r := NewCommentReader(strings.NewReader(a), '#') + b1, err := io.ReadAll(r) + assert.NoError(t, err) + b2, err := io.ReadAll(r.Comment()) + assert.NoError(t, err) + assert.Equal(t, string(b1)+string(b2), a) + }) +} diff --git a/hosts/entry.go b/hosts/entry.go new file mode 100644 index 0000000..1693499 --- /dev/null +++ b/hosts/entry.go @@ -0,0 +1,90 @@ +package hosts + +import ( + "bufio" + "errors" + "io" + "strings" +) + +var ErrInvalidEntry = errors.New("invalid entry") + +type Entry struct { + IPAddress string + Domains []string + Comment string +} + +func (e Entry) IsFilled() bool { + return e.IPAddress != "" && len(e.Domains) > 0 +} + +func (e Entry) HasDomain(domain string) bool { + if !e.IsFilled() { + return false + } + for _, c := range e.Domains { + if strings.EqualFold(c, domain) { + return true + } + } + return false +} + +func (e Entry) ToLine() string { + var b strings.Builder + filled := e.IsFilled() + if filled { + b.WriteString(e.IPAddress) + for _, i := range e.Domains { + b.WriteByte(' ') + b.WriteString(i) + } + } + if e.Comment != "" { + b.Grow(len(e.Comment) + 2) + if filled { + b.WriteByte(' ') + } + b.WriteByte('#') + b.WriteString(e.Comment) + } + return b.String() +} + +func ParseEntryString(line string) (Entry, error) { + return ParseEntry(strings.NewReader(line)) +} + +func ParseEntry(line io.Reader) (entry Entry, err error) { + cr := NewCommentReader(line, '#') + sc := bufio.NewScanner(cr) + sc.Split(bufio.ScanWords) + + isFirst := true + for sc.Scan() { + t := sc.Text() + if isFirst { + entry.IPAddress = t + isFirst = false + } else { + entry.Domains = append(entry.Domains, t) + } + } + err = sc.Err() + if err != nil { + return + } + + // invalid if the ip address is set but no domains are added + if entry.IPAddress != "" && len(entry.Domains) < 1 { + err = ErrInvalidEntry + return + } + var cAll []byte + cAll, err = io.ReadAll(cr.Comment()) + if len(cAll) >= 1 { + entry.Comment = string(cAll[1:]) + } + return +} diff --git a/hosts/entry_test.go b/hosts/entry_test.go new file mode 100644 index 0000000..fb642c8 --- /dev/null +++ b/hosts/entry_test.go @@ -0,0 +1,72 @@ +package hosts + +import ( + "github.com/stretchr/testify/assert" + "strings" + "testing" + "unicode" +) + +func TestParseEntry(t *testing.T) { + entry, err := ParseEntry(strings.NewReader("127.0.0.1 myself.local another.local # this is a test comment")) + assert.NoError(t, err) + assert.Equal(t, Entry{ + IPAddress: "127.0.0.1", + Domains: []string{"myself.local", "another.local"}, + Comment: " this is a test comment", + }, entry) + + entry, err = ParseEntry(strings.NewReader("127.0.0.1 test.local #another comment")) + assert.NoError(t, err) + assert.Equal(t, Entry{ + IPAddress: "127.0.0.1", + Domains: []string{"test.local"}, + Comment: "another comment", + }, entry) + + entry, err = ParseEntry(strings.NewReader("127.0.0.1 ")) + assert.EqualError(t, err, "invalid entry") +} + +func TestEntry_IsFilled(t *testing.T) { + assert.True(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local"}}.IsFilled()) + assert.False(t, Entry{}.IsFilled()) + assert.False(t, Entry{IPAddress: "127.0.0.1"}.IsFilled()) +} + +func TestEntry_HasDomain(t *testing.T) { + assert.True(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}}.HasDomain("myself.local")) + assert.True(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}}.HasDomain("MYSELF.local")) + assert.False(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"notme.local", "another.local"}}.HasDomain("myself.local")) +} + +func TestEntry_ToLine(t *testing.T) { + assert.Equal(t, "127.0.0.1 myself.local another.local # this is a test comment", Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}, Comment: " this is a test comment"}.ToLine()) + assert.Equal(t, "# this is a test comment", Entry{IPAddress: "", Domains: nil, Comment: " this is a test comment"}.ToLine()) + assert.Equal(t, "127.0.0.1 myself.local another.local", Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}}.ToLine()) +} + +func FuzzParseEntry(f *testing.F) { + f.Add("127.0.0.1", "myself.local", "another.local", "this is a test comment") + f.Fuzz(func(t *testing.T, a, b, c, d string) { + for _, i := range []string{a, b, c} { + i = strings.TrimSpace(i) + if i == "" || strings.ContainsFunc(i, func(r rune) bool { + return r == '#' || unicode.IsSpace(r) + }) { + t.Skip() + } + } + + entry := Entry{IPAddress: a, Domains: []string{b, c}, Comment: d} + e2, err := ParseEntryString(entry.ToLine()) + if err != nil { + t.Error(err) + } + assert.Equal(t, Entry{ + IPAddress: strings.TrimSpace(a), + Domains: []string{strings.TrimSpace(b), strings.TrimSpace(c)}, + Comment: d, + }, e2) + }) +} diff --git a/hosts/file.go b/hosts/file.go new file mode 100644 index 0000000..1974f2d --- /dev/null +++ b/hosts/file.go @@ -0,0 +1,97 @@ +package hosts + +import ( + "bufio" + "os" +) + +func NewHostsFile(filePath string) (*File, error) { + theHostFile := &File{filePath: filePath} + err := theHostFile.ReadHostsFile() + if err == nil { + return theHostFile, nil + } + return nil, err +} + +type File struct { + filePath string + Entries []Entry + LF string +} + +func (f *File) ReadHostsFile() error { + f.Entries = nil + theFile, err := os.Open(f.filePath) + if err != nil { + return err + } + defer theFile.Close() + + sc := bufio.NewScanner(theFile) + sc.Split(bufio.ScanLines) + for sc.Scan() { + t := sc.Text() + entry, err := ParseEntryString(t) + if err != nil { + return err + } + f.Entries = append(f.Entries, entry) + } + return sc.Err() +} + +func (f *File) WriteHostsFile() error { + // default LF to \n + if f.LF == "" { + f.LF = "\n" + } + theFile, err := os.OpenFile(f.filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) + if err != nil { + return err + } + defer theFile.Close() + for _, entry := range f.Entries { + _, err = theFile.WriteString(entry.ToLine() + f.LF) + if err != nil { + return err + } + } + return nil +} + +func (f *File) HasDomain(domain string) bool { + for _, entry := range f.Entries { + if entry.HasDomain(domain) { + return true + } + } + return false +} + +func (f *File) indexDomainSingleton(domain string) int { + for i, entry := range f.Entries { + if len(entry.Domains) == 1 && entry.HasDomain(domain) { + return i + } + } + return -1 +} + +func (f *File) HasDomainSingleton(domain string) bool { + return f.indexDomainSingleton(domain) > -1 +} + +func (f *File) OverwriteDomainSingleton(domain string, ipAddress string) { + idx := f.indexDomainSingleton(domain) + if idx == -1 { + f.Entries = append(f.Entries, Entry{ + IPAddress: ipAddress, + Domains: []string{domain}, + }) + } else { + theEntry := f.Entries[idx] + theEntry.IPAddress = ipAddress + f.Entries = append(append(f.Entries[:idx], theEntry), f.Entries[idx+1:]...) + } +}