WIP: Refactor entire hosts reading and writing code #1

Draft
melon wants to merge 2 commits from melon/HostPersister:master into master
10 changed files with 368 additions and 200 deletions

View File

@ -10,4 +10,4 @@ Maintainer:
[Captain ALM](https://code.mrmelon54.com/alfred) [Captain ALM](https://code.mrmelon54.com/alfred)
License: License:
BSD 3-Clause BSD 3-Clause

11
go.mod
View File

@ -2,4 +2,13 @@ module golang.captainalm.com/HostPersister
go 1.18 go 1.18
require github.com/joho/godotenv v1.5.1 require (
github.com/joho/godotenv v1.5.1
github.com/stretchr/testify v1.8.4
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

10
go.sum
View File

@ -1,2 +1,12 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,76 +0,0 @@
package hosts
import "strings"
func NewHostsEntry(lineIn string) Entry {
trLineIn := strings.ReplaceAll(strings.Trim(lineIn, "\r\n"), " ", " ")
lineSplt := strings.Split(trLineIn, " ")
if strings.HasPrefix(strings.TrimPrefix(trLineIn, " "), "#") {
return Entry{
IPAddress: "",
Domains: nil,
comment: trLineIn,
}
} else if len(lineSplt) > 1 {
var theDomains []string
for i := 1; i < len(lineSplt); i++ {
if lineSplt[i] == "" {
continue
}
if strings.HasPrefix(lineSplt[i], "#") {
break
}
theDomains = append(theDomains, lineSplt[i])
}
theComment := ""
theCommentStart := strings.Index(trLineIn, "#")
if theCommentStart > -1 {
theComment = trLineIn[theCommentStart:]
}
return Entry{
IPAddress: lineSplt[0],
Domains: theDomains,
comment: theComment,
}
} else {
return Entry{
IPAddress: "",
Domains: nil,
comment: "",
}
}
}
type Entry struct {
IPAddress string
Domains []string
comment string
}
func (e Entry) IsFilled() bool {
return e.IPAddress != "" && len(e.Domains) > 0
}
func (e Entry) HasDomain(domain string) bool {
if !e.IsFilled() {
return false
}
for _, c := range e.Domains {
if strings.EqualFold(c, domain) {
return true
}
}
return false
}
func (e Entry) ToLine() string {
if e.IsFilled() {
toReturn := []string{e.IPAddress}
toReturn = append(toReturn, e.Domains...)
if e.comment != "" {
toReturn = append(toReturn, e.comment)
}
return strings.Join(toReturn, " ")
}
return e.comment
}

View File

@ -1,122 +0,0 @@
package hosts
import (
"io"
"os"
"strings"
)
const readBufferSize = 8192
func NewHostsFile(filePath string) (*File, error) {
theHostFile := &File{
filePath: filePath,
}
err := theHostFile.ReadHostsFile()
if err == nil {
return theHostFile, nil
}
return nil, err
}
type File struct {
filePath string
Entries []Entry
lineEnding string
}
func (f *File) ReadHostsFile() error {
f.Entries = nil
theFile, err := os.Open(f.filePath)
if err != nil {
return err
}
defer theFile.Close()
var lenIn int
f.lineEnding = ""
theCBuffer := ""
theBuffer := make([]byte, readBufferSize)
for err == nil {
lenIn, err = theFile.Read(theBuffer)
if lenIn > 0 {
theCBuffer += string(theBuffer[:lenIn])
if f.lineEnding == "" {
if strings.Contains(theCBuffer, "\r\n") {
f.lineEnding = "\r\n"
} else if strings.Contains(theCBuffer, "\r") {
f.lineEnding = "\r"
} else if strings.Contains(theCBuffer, "\n") {
f.lineEnding = "\n"
}
}
if f.lineEnding == "\r\n" {
strings.ReplaceAll(theCBuffer, "\r\n", "\n")
} else if f.lineEnding == "\r" {
strings.ReplaceAll(theCBuffer, "\r", "\n")
}
splt := strings.Split(theCBuffer, "\n")
for i := 0; i < len(splt)-1; i++ {
f.Entries = append(f.Entries, NewHostsEntry(splt[i]))
}
theCBuffer = splt[len(splt)-1]
}
}
if err != io.EOF {
return err
}
if theCBuffer != "" {
f.Entries = append(f.Entries, NewHostsEntry(theCBuffer))
}
return nil
}
func (f File) WriteHostsFile() error {
theFile, err := os.OpenFile(f.filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
defer theFile.Close()
for _, entry := range f.Entries {
_, err = theFile.WriteString(entry.ToLine() + f.lineEnding)
if err != nil {
return err
}
}
return nil
}
func (f File) HasDomain(domain string) bool {
for _, entry := range f.Entries {
if entry.HasDomain(domain) {
return true
}
}
return false
}
func (f File) indexDomainSingleton(domain string) int {
for i, entry := range f.Entries {
if len(entry.Domains) == 1 && entry.HasDomain(domain) {
return i
}
}
return -1
}
func (f File) HasDomainSingleton(domain string) bool {
return f.indexDomainSingleton(domain) > -1
}
func (f *File) OverwriteDomainSingleton(domain string, ipAddress string) {
idx := f.indexDomainSingleton(domain)
if idx == -1 {
f.Entries = append(f.Entries, Entry{
IPAddress: ipAddress,
Domains: []string{domain},
})
} else {
theEntry := f.Entries[idx]
theEntry.IPAddress = ipAddress
f.Entries = append(append(f.Entries[:idx], theEntry), f.Entries[idx+1:]...)
}
}

39
hosts/comment-reader.go Normal file
View File

@ -0,0 +1,39 @@
package hosts
import (
"bytes"
"io"
)
type CommentReader struct {
r io.Reader
over []byte
mark byte
}
alfred marked this conversation as resolved
Review

Whats this used for (hit)?
It's only checked in one place.

Whats this used for (hit)? It's only checked in one place.
var _ io.Reader = &CommentReader{}
func NewCommentReader(r io.Reader, mark byte) *CommentReader {
return &CommentReader{r, nil, mark}
}
func (c *CommentReader) Read(p []byte) (n int, err error) {
if c.over != nil {
return 0, io.EOF
}
n, err = c.r.Read(p)
if err != nil {
return
}
n2 := bytes.IndexByte(p[:n], c.mark)
if n2 != -1 {
c.over = p[n2:n]
n = n2
err = io.EOF
}
return
}
func (c *CommentReader) Comment() io.Reader {
return io.MultiReader(bytes.NewReader(c.over), c.r)
}

View File

@ -0,0 +1,49 @@
package hosts
import (
"github.com/stretchr/testify/assert"
"io"
"strings"
"testing"
)
func TestNewCommentReader(t *testing.T) {
t.Run("normal", func(t *testing.T) {
r := NewCommentReader(strings.NewReader("Hello world! # this is a comment"), '#')
all, err := io.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, "Hello world! ", string(all))
all, err = io.ReadAll(r.Comment())
assert.NoError(t, err)
assert.Equal(t, "# this is a comment", string(all))
})
t.Run("overflow start", func(t *testing.T) {
r := NewCommentReader(strings.NewReader(strings.Repeat("Hello world!", 2048)+" # this is a comment"), '#')
all, err := io.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, strings.Repeat("Hello world!", 2048)+" ", string(all))
all, err = io.ReadAll(r.Comment())
assert.NoError(t, err)
assert.Equal(t, "# this is a comment", string(all))
})
t.Run("overflow comment", func(t *testing.T) {
r := NewCommentReader(strings.NewReader(strings.Repeat("Hello world!", 2048)+" # "+strings.Repeat("this is a comment", 2048)), '#')
all, err := io.ReadAll(r)
assert.NoError(t, err)
assert.Equal(t, strings.Repeat("Hello world!", 2048)+" ", string(all))
all, err = io.ReadAll(r.Comment())
assert.NoError(t, err)
assert.Equal(t, "# "+strings.Repeat("this is a comment", 2048), string(all))
})
}
func FuzzCommentReader(f *testing.F) {
f.Fuzz(func(t *testing.T, a string) {
r := NewCommentReader(strings.NewReader(a), '#')
b1, err := io.ReadAll(r)
assert.NoError(t, err)
b2, err := io.ReadAll(r.Comment())
assert.NoError(t, err)
assert.Equal(t, string(b1)+string(b2), a)
})
}

90
hosts/entry.go Normal file
View File

@ -0,0 +1,90 @@
package hosts
import (
"bufio"
"errors"
"io"
"strings"
)
var ErrInvalidEntry = errors.New("invalid entry")
type Entry struct {
IPAddress string
Domains []string
Comment string
}
func (e Entry) IsFilled() bool {
return e.IPAddress != "" && len(e.Domains) > 0
}
func (e Entry) HasDomain(domain string) bool {
if !e.IsFilled() {
return false
}
for _, c := range e.Domains {
if strings.EqualFold(c, domain) {
return true
}
}
return false
}
func (e Entry) ToLine() string {
var b strings.Builder
filled := e.IsFilled()
if filled {
b.WriteString(e.IPAddress)
for _, i := range e.Domains {
b.WriteByte(' ')
b.WriteString(i)
}
}
if e.Comment != "" {
b.Grow(len(e.Comment) + 2)
if filled {
b.WriteByte(' ')
}
b.WriteByte('#')
b.WriteString(e.Comment)
}
return b.String()
}
func ParseEntryString(line string) (Entry, error) {
return ParseEntry(strings.NewReader(line))
}
func ParseEntry(line io.Reader) (entry Entry, err error) {
cr := NewCommentReader(line, '#')
sc := bufio.NewScanner(cr)
Review

Support a larger buffer size if needed (Extra command line arg and/or env variable).

https://stackoverflow.com/questions/39859222/golang-how-to-overcome-scan-buffer-limit-from-bufio

Support a larger buffer size if needed (Extra command line arg and/or env variable). https://stackoverflow.com/questions/39859222/golang-how-to-overcome-scan-buffer-limit-from-bufio
sc.Split(bufio.ScanWords)
isFirst := true
for sc.Scan() {
t := sc.Text()
if isFirst {
entry.IPAddress = t
isFirst = false
} else {
entry.Domains = append(entry.Domains, t)
}
}
err = sc.Err()
if err != nil {
return
}
// invalid if the ip address is set but no domains are added
if entry.IPAddress != "" && len(entry.Domains) < 1 {
err = ErrInvalidEntry
return
}
var cAll []byte
cAll, err = io.ReadAll(cr.Comment())
if len(cAll) >= 1 {
entry.Comment = string(cAll[1:])
}
return
}

72
hosts/entry_test.go Normal file
View File

@ -0,0 +1,72 @@
package hosts
import (
"github.com/stretchr/testify/assert"
"strings"
"testing"
"unicode"
)
func TestParseEntry(t *testing.T) {
entry, err := ParseEntry(strings.NewReader("127.0.0.1 myself.local another.local # this is a test comment"))
assert.NoError(t, err)
assert.Equal(t, Entry{
IPAddress: "127.0.0.1",
Domains: []string{"myself.local", "another.local"},
Comment: " this is a test comment",
}, entry)
entry, err = ParseEntry(strings.NewReader("127.0.0.1 test.local #another comment"))
assert.NoError(t, err)
assert.Equal(t, Entry{
IPAddress: "127.0.0.1",
Domains: []string{"test.local"},
Comment: "another comment",
}, entry)
entry, err = ParseEntry(strings.NewReader("127.0.0.1 "))
assert.EqualError(t, err, "invalid entry")
}
func TestEntry_IsFilled(t *testing.T) {
assert.True(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local"}}.IsFilled())
assert.False(t, Entry{}.IsFilled())
assert.False(t, Entry{IPAddress: "127.0.0.1"}.IsFilled())
}
func TestEntry_HasDomain(t *testing.T) {
assert.True(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}}.HasDomain("myself.local"))
assert.True(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}}.HasDomain("MYSELF.local"))
assert.False(t, Entry{IPAddress: "127.0.0.1", Domains: []string{"notme.local", "another.local"}}.HasDomain("myself.local"))
}
func TestEntry_ToLine(t *testing.T) {
assert.Equal(t, "127.0.0.1 myself.local another.local # this is a test comment", Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}, Comment: " this is a test comment"}.ToLine())
assert.Equal(t, "# this is a test comment", Entry{IPAddress: "", Domains: nil, Comment: " this is a test comment"}.ToLine())
assert.Equal(t, "127.0.0.1 myself.local another.local", Entry{IPAddress: "127.0.0.1", Domains: []string{"myself.local", "another.local"}}.ToLine())
}
func FuzzParseEntry(f *testing.F) {
f.Add("127.0.0.1", "myself.local", "another.local", "this is a test comment")
f.Fuzz(func(t *testing.T, a, b, c, d string) {
for _, i := range []string{a, b, c} {
i = strings.TrimSpace(i)
if i == "" || strings.ContainsFunc(i, func(r rune) bool {
return r == '#' || unicode.IsSpace(r)
}) {
t.Skip()
}
}
entry := Entry{IPAddress: a, Domains: []string{b, c}, Comment: d}
e2, err := ParseEntryString(entry.ToLine())
if err != nil {
t.Error(err)
}
assert.Equal(t, Entry{
IPAddress: strings.TrimSpace(a),
Domains: []string{strings.TrimSpace(b), strings.TrimSpace(c)},
Comment: d,
}, e2)
})
}

97
hosts/file.go Normal file
View File

@ -0,0 +1,97 @@
package hosts
import (
"bufio"
"os"
)
func NewHostsFile(filePath string) (*File, error) {
theHostFile := &File{filePath: filePath}
err := theHostFile.ReadHostsFile()
if err == nil {
return theHostFile, nil
}
return nil, err
}
type File struct {
filePath string
Entries []Entry
LF string
}
func (f *File) ReadHostsFile() error {
f.Entries = nil
theFile, err := os.Open(f.filePath)
if err != nil {
return err
}
defer theFile.Close()
sc := bufio.NewScanner(theFile)
Review

Support a larger buffer size if needed (Extra command line arg and/or env variable).

https://stackoverflow.com/questions/39859222/golang-how-to-overcome-scan-buffer-limit-from-bufio

Support a larger buffer size if needed (Extra command line arg and/or env variable). https://stackoverflow.com/questions/39859222/golang-how-to-overcome-scan-buffer-limit-from-bufio
sc.Split(bufio.ScanLines)
for sc.Scan() {
t := sc.Text()
entry, err := ParseEntryString(t)
if err != nil {
return err
}
f.Entries = append(f.Entries, entry)
}
return sc.Err()
}
func (f *File) WriteHostsFile() error {
// default LF to \n
if f.LF == "" {
f.LF = "\n"
}
theFile, err := os.OpenFile(f.filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
defer theFile.Close()
for _, entry := range f.Entries {
_, err = theFile.WriteString(entry.ToLine() + f.LF)
if err != nil {
return err
}
}
return nil
}
func (f *File) HasDomain(domain string) bool {
for _, entry := range f.Entries {
if entry.HasDomain(domain) {
return true
}
}
return false
}
func (f *File) indexDomainSingleton(domain string) int {
for i, entry := range f.Entries {
if len(entry.Domains) == 1 && entry.HasDomain(domain) {
return i
}
}
return -1
}
func (f *File) HasDomainSingleton(domain string) bool {
return f.indexDomainSingleton(domain) > -1
}
func (f *File) OverwriteDomainSingleton(domain string, ipAddress string) {
idx := f.indexDomainSingleton(domain)
if idx == -1 {
f.Entries = append(f.Entries, Entry{
IPAddress: ipAddress,
Domains: []string{domain},
})
} else {
theEntry := f.Entries[idx]
theEntry.IPAddress = ipAddress
f.Entries = append(append(f.Entries[:idx], theEntry), f.Entries[idx+1:]...)
}
}