Refactor entire hosts reading and writing code #1

Open
melon wants to merge 2 commits from melon/HostPersister:master into master
10 changed files with 369 additions and 200 deletions
Showing only changes of commit 6238999c08 - Show all commits

View File

@ -10,4 +10,4 @@ Maintainer:
[Captain ALM](https://code.mrmelon54.com/alfred)
License:
BSD 3-Clause
BSD 3-Clause

11
go.mod
View File

@ -2,4 +2,13 @@ module golang.captainalm.com/HostPersister
go 1.18
require github.com/joho/godotenv v1.5.1
require (
github.com/joho/godotenv v1.5.1
github.com/stretchr/testify v1.8.4
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

10
go.sum
View File

@ -1,2 +1,12 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,76 +0,0 @@
package hosts
import "strings"
func NewHostsEntry(lineIn string) Entry {
trLineIn := strings.ReplaceAll(strings.Trim(lineIn, "\r\n"), " ", " ")
lineSplt := strings.Split(trLineIn, " ")
if strings.HasPrefix(strings.TrimPrefix(trLineIn, " "), "#") {
return Entry{
IPAddress: "",
Domains: nil,
comment: trLineIn,
}
} else if len(lineSplt) > 1 {
var theDomains []string
for i := 1; i < len(lineSplt); i++ {
if lineSplt[i] == "" {
continue
}
if strings.HasPrefix(lineSplt[i], "#") {
break
}
theDomains = append(theDomains, lineSplt[i])
}
theComment := ""
theCommentStart := strings.Index(trLineIn, "#")
if theCommentStart > -1 {
theComment = trLineIn[theCommentStart:]
}
return Entry{
IPAddress: lineSplt[0],
Domains: theDomains,
comment: theComment,
}
} else {
return Entry{
IPAddress: "",
Domains: nil,
comment: "",
}
}
}
type Entry struct {
IPAddress string
Domains []string
comment string
}
func (e Entry) IsFilled() bool {
return e.IPAddress != "" && len(e.Domains) > 0
}
func (e Entry) HasDomain(domain string) bool {
if !e.IsFilled() {
return false
}
for _, c := range e.Domains {
if strings.EqualFold(c, domain) {
return true
}
}
return false
}
func (e Entry) ToLine() string {
if e.IsFilled() {
toReturn := []string{e.IPAddress}
toReturn = append(toReturn, e.Domains...)
if e.comment != "" {
toReturn = append(toReturn, e.comment)
}
return strings.Join(toReturn, " ")
}
return e.comment
}

View File

@ -1,122 +0,0 @@
package hosts
import (
"io"
"os"
"strings"
)
const readBufferSize = 8192
func NewHostsFile(filePath string) (*File, error) {
theHostFile := &File{
filePath: filePath,
}
err := theHostFile.ReadHostsFile()
if err == nil {
return theHostFile, nil
}
return nil, err
}
type File struct {
filePath string
Entries []Entry
lineEnding string
}
func (f *File) ReadHostsFile() error {
f.Entries = nil
theFile, err := os.Open(f.filePath)
if err != nil {
return err
}
defer theFile.Close()
var lenIn int
f.lineEnding = ""
theCBuffer := ""
theBuffer := make([]byte, readBufferSize)
for err == nil {
lenIn, err = theFile.Read(theBuffer)
if lenIn > 0 {
theCBuffer += string(theBuffer[:lenIn])
if f.lineEnding == "" {
if strings.Contains(theCBuffer, "\r\n") {
f.lineEnding = "\r\n"
} else if strings.Contains(theCBuffer, "\r") {
f.lineEnding = "\r"
} else if strings.Contains(theCBuffer, "\n") {
f.lineEnding = "\n"
}
}
if f.lineEnding == "\r\n" {
strings.ReplaceAll(theCBuffer, "\r\n", "\n")
} else if f.lineEnding == "\r" {
strings.ReplaceAll(theCBuffer, "\r", "\n")
}
splt := strings.Split(theCBuffer, "\n")
for i := 0; i < len(splt)-1; i++ {
f.Entries = append(f.Entries, NewHostsEntry(splt[i]))
}
theCBuffer = splt[len(splt)-1]
}
}
if err != io.EOF {
return err
}
if theCBuffer != "" {
f.Entries = append(f.Entries, NewHostsEntry(theCBuffer))
}
return nil
}
func (f File) WriteHostsFile() error {
theFile, err := os.OpenFile(f.filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
defer theFile.Close()
for _, entry := range f.Entries {
_, err = theFile.WriteString(entry.ToLine() + f.lineEnding)
if err != nil {
return err
}
}
return nil
}
func (f File) HasDomain(domain string) bool {
for _, entry := range f.Entries {
if entry.HasDomain(domain) {
return true
}
}
return false
}
func (f File) indexDomainSingleton(domain string) int {
for i, entry := range f.Entries {
if len(entry.Domains) == 1 && entry.HasDomain(domain) {
return i
}
}
return -1
}
func (f File) HasDomainSingleton(domain string) bool {
return f.indexDomainSingleton(domain) > -1
}
func (f *File) OverwriteDomainSingleton(domain string, ipAddress string) {
idx := f.indexDomainSingleton(domain)
if idx == -1 {
f.Entries = append(f.Entries, Entry{
IPAddress: ipAddress,
Domains: []string{domain},
})
} else {
theEntry := f.Entries[idx]
theEntry.IPAddress = ipAddress
f.Entries = append(append(f.Entries[:idx], theEntry), f.Entries[idx+1:]...)
}
}

40
hosts/comment-reader.go Normal file
View File

@ -0,0 +1,40 @@
package hosts
import (
"bytes"
"io"
)
type CommentReader struct {
r io.Reader
over []byte
mark byte
hit bool
alfred marked this conversation as resolved
Review

Whats this used for (hit)?
It's only checked in one place.

Whats this used for (hit)? It's only checked in one place.
}
var _ io.Reader = &CommentReader{}
func NewCommentReader(r io.Reader, mark byte) *CommentReader {
return &CommentReader{r, nil, mark, false}
}
func (c *CommentReader) Read(p []byte) (n int, err error) {
if c.hit {
return 0, io.EOF
}
n, err = c.r.Read(p)
if err != nil {
return
}
n2 := bytes.IndexByte(p[:n], c.mark)
if n2 != -1 {
c.over = p[n2:n]
n = n2
err = io.EOF
}
return
}
func (c *CommentReader) Comment() io.Reader {
return io.MultiReader(bytes.NewReader(c.over), c.r)
}

View File

@ -0,0 +1,49 @@
package hosts
import (
"github.com/stretchr/testify/assert"
"io"
"strings"
"testing"
)
func TestNewCommentReader(t *testing.T) {
t.Run("normal", func(t *testing.T) {
r := NewCommentReader(strings.NewReader("Hello world! # this is a comment"), '#')
all, err := io.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, "Hello world! ", string(all))
all, err = io.ReadAll(r.Comment())
assert.NoError(t, err)
assert.Equal(t, "# this is a comment", string(all))
})
t.Run("overflow start", func(t *testing.T) {
r := NewCommentReader(strings.NewReader(strings.Repeat("Hello world!", 2048)+" # this is a comment"), '#')
all, err := io.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, strings.Repeat("Hello world!", 2048)+" ", string(all))
all, err = io.ReadAll(r.Comment())
assert.NoError(t, err)
assert.Equal(t, "# this is a comment", string(all))
})
t.Run("overflow comment", func(t *testing.T) {
r := NewCommentReader(strings.NewReader(strings.Repeat("Hello world!", 2048)+" # "+strings.Repeat("this is a comment", 2048)), '#')
all, err := io.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, strings.Repeat("Hello world!", 2048)+" ", string(all))
all, err = io.ReadAll(r.Comment())
assert.NoError(t, err)
assert.Equal(t, "# "+strings.Repeat("this is a comment", 2048), string(all))
})
}
func FuzzCommentReader(f *testing.F) {
f.Fuzz(func(t *testing.T, a string) {
r := NewCommentReader(strings.NewReader(a), '#')
b1, err := io.ReadAll(r)
assert.NoError(t, err)
b2, err := io.ReadAll(r.Comment())
assert.NoError(t, err)
assert.Equal(t, string(b1)+string(b2), a)
})
}

90
hosts/entry.go Normal file
View File

@ -0,0 +1,90 @@
package hosts
import (
"bufio"
"errors"
"io"
"strings"
)
var ErrInvalidEntry = errors.New("invalid entry")
type Entry struct {
IPAddress string
Domains []string
Comment string
}
func (e Entry) IsFilled() bool {
return e.IPAddress != "" && len(e.Domains) > 0
}
func (e Entry) HasDomain(domain string) bool {
if !e.IsFilled() {
return false
}
for _, c := range e.Domains {
if strings.EqualFold(c, domain) {
return true
}
}
return false
}
func (e Entry) ToLine() string {
var b strings.Builder
filled := e.IsFilled()
if filled {
b.WriteString(e.IPAddress)
for _, i := range e.Domains {
b.WriteByte(' ')
b.WriteString(i)
}
}
if e.Comment != "" {
b.Grow(len(e.Comment) + 2)
if filled {
b.WriteByte(' ')
}
b.WriteByte('#')
b.WriteString(e.Comment)
}
return b.String()
}
func ParseEntryString(line string) (Entry, error) {
return ParseEntry(strings.NewReader(line))
}
func ParseEntry(line io.Reader) (entry Entry, err error) {
cr := NewCommentReader(line, '#')
sc := bufio.NewScanner(cr)
Review

Support a larger buffer size if needed (Extra command line arg and/or env variable).

https://stackoverflow.com/questions/39859222/golang-how-to-overcome-scan-buffer-limit-from-bufio

Support a larger buffer size if needed (Extra command line arg and/or env variable). https://stackoverflow.com/questions/39859222/golang-how-to-overcome-scan-buffer-limit-from-bufio
sc.Split(bufio.ScanWords)
isFirst := true
for sc.Scan() {
t := sc.Text()
if isFirst {
entry.IPAddress = t
isFirst = false
} else {
entry.Domains = append(entry.Domains, t)
}
}
err = sc.Err()
if err != nil {
return
}
// invalid if the ip address is set but no domains are added
if entry.IPAddress != "" && len(entry.Domains) < 1 {
err = ErrInvalidEntry
return
}
var cAll []byte
cAll, err = io.ReadAll(cr.Comment())
if len(cAll) >= 1 {
entry.Comment = string(cAll[1:])
}
return
}

72
hosts/entry_test.go Normal file
View File

@ -0,0 +1,72 @@
package hosts
import (
"github.com/stretchr/testify/assert"
"strings"
"testing"
"unicode"
)
func TestParseEntry(t *testing.T) {
entry, err := ParseEntry(strings.NewReader("127.0.0.1 myself.local another.local # this is a test comment"))
assert.NoError(t, err)
assert.Equal(t, Entry{
IPAddress: "127.0.0.1",
Domains: []string{"myself.local", "another.local"},
Comment: " this is a test comment",
}, entry)
entry, err = ParseEntry(strings.NewReader("127.0.0.1 test.local #another comment"))
assert.NoError(t, err)
assert.Equal(t, Entry{
IPAddress: "127.0.0.1",
Domains: []string{"test.local"},
Comment: "another comment",
}, entry)
entry, err = ParseEntry(strings.NewReader("127.0.0.1 "))
assert.EqualError(t, err, "invalid entry")
}
func TestEntry_IsFilled(t *testing.T) {
assert.True(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local"}}.IsFilled())
assert.False(t, Entry{}.IsFilled())
assert.False(t, Entry{IPAddress: "127.0.0.1"}.IsFilled())
}
func TestEntry_HasDomain(t *testing.T) {
assert.True(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}}.HasDomain("myself.local"))
assert.True(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}}.HasDomain("MYSELF.local"))
assert.False(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"notme.local", "another.local"}}.HasDomain("myself.local"))
}
func TestEntry_ToLine(t *testing.T) {
assert.Equal(t, "127.0.0.1 myself.local another.local # this is a test comment", Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}, Comment: " this is a test comment"}.ToLine())
assert.Equal(t, "# this is a test comment", Entry{IPAddress: "", Domains: nil, Comment: " this is a test comment"}.ToLine())
assert.Equal(t, "127.0.0.1 myself.local another.local", Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}}.ToLine())
}
func FuzzParseEntry(f *testing.F) {
f.Add("127.0.0.1", "myself.local", "another.local", "this is a test comment")
f.Fuzz(func(t *testing.T, a, b, c, d string) {
for _, i := range []string{a, b, c} {
i = strings.TrimSpace(i)
if i == "" || strings.ContainsFunc(i, func(r rune) bool {
return r == '#' || unicode.IsSpace(r)
}) {
t.Skip()
}
}
entry := Entry{IPAddress: a, Domains: []string{b, c}, Comment: d}
e2, err := ParseEntryString(entry.ToLine())
if err != nil {
t.Error(err)
}
assert.Equal(t, Entry{
IPAddress: strings.TrimSpace(a),
Domains: []string{strings.TrimSpace(b), strings.TrimSpace(c)},
Comment: d,
}, e2)
})
}

97
hosts/file.go Normal file
View File

@ -0,0 +1,97 @@
package hosts
import (
"bufio"
"os"
)
func NewHostsFile(filePath string) (*File, error) {
theHostFile := &File{filePath: filePath}
err := theHostFile.ReadHostsFile()
if err == nil {
return theHostFile, nil
}
return nil, err
}
type File struct {
filePath string
Entries []Entry
LF string
}
func (f *File) ReadHostsFile() error {
f.Entries = nil
theFile, err := os.Open(f.filePath)
if err != nil {
return err
}
defer theFile.Close()
sc := bufio.NewScanner(theFile)
Review

Support a larger buffer size if needed (Extra command line arg and/or env variable).

https://stackoverflow.com/questions/39859222/golang-how-to-overcome-scan-buffer-limit-from-bufio

Support a larger buffer size if needed (Extra command line arg and/or env variable). https://stackoverflow.com/questions/39859222/golang-how-to-overcome-scan-buffer-limit-from-bufio
sc.Split(bufio.ScanLines)
for sc.Scan() {
t := sc.Text()
entry, err := ParseEntryString(t)
if err != nil {
return err
}
f.Entries = append(f.Entries, entry)
}
return sc.Err()
}
func (f *File) WriteHostsFile() error {
// default LF to \n
if f.LF == "" {
f.LF = "\n"
}
theFile, err := os.OpenFile(f.filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
defer theFile.Close()
for _, entry := range f.Entries {
_, err = theFile.WriteString(entry.ToLine() + f.LF)
if err != nil {
return err
}
}
return nil
}
func (f *File) HasDomain(domain string) bool {
for _, entry := range f.Entries {
if entry.HasDomain(domain) {
return true
}
}
return false
}
func (f *File) indexDomainSingleton(domain string) int {
for i, entry := range f.Entries {
if len(entry.Domains) == 1 && entry.HasDomain(domain) {
return i
}
}
return -1
}
func (f *File) HasDomainSingleton(domain string) bool {
return f.indexDomainSingleton(domain) > -1
}
func (f *File) OverwriteDomainSingleton(domain string, ipAddress string) {
idx := f.indexDomainSingleton(domain)
if idx == -1 {
f.Entries = append(f.Entries, Entry{
IPAddress: ipAddress,
Domains: []string{domain},
})
} else {
theEntry := f.Entries[idx]
theEntry.IPAddress = ipAddress
f.Entries = append(append(f.Entries[:idx], theEntry), f.Entries[idx+1:]...)
}
}