From 055a297f6e826a7dac9fabb55880cf3967af5fa4 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Tue, 14 Jan 2020 17:52:14 +0100 Subject: [PATCH] internal: add RawXMLValue to defer XML encoding/decoding --- internal/xml.go | 109 +++++++++++++++++++++++++++++++++++++++++++ internal/xml_test.go | 74 +++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+) create mode 100644 internal/xml.go create mode 100644 internal/xml_test.go diff --git a/internal/xml.go b/internal/xml.go new file mode 100644 index 0000000..44311ff --- /dev/null +++ b/internal/xml.go @@ -0,0 +1,109 @@ +package internal + +import ( + "encoding/xml" + "io" +) + +// RawXMLValue is a raw XML value. It implements xml.Unmarshaler and +// xml.Marshaler and can be used to delay XML decoding or precompute an XML +// encoding. +type RawXMLValue struct { + tok xml.Token // guaranteed not to be xml.EndElement + children []RawXMLValue +} + +// UnmarshalXML implements xml.Unmarshaler. +func (val *RawXMLValue) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + val.tok = start + val.children = nil + + for { + tok, err := d.Token() + if err != nil { + return err + } + switch tok := tok.(type) { + case xml.StartElement: + child := RawXMLValue{} + if err := child.UnmarshalXML(d, tok); err != nil { + return err + } + val.children = append(val.children, child) + case xml.EndElement: + return nil + default: + val.children = append(val.children, RawXMLValue{tok: xml.CopyToken(tok)}) + } + } +} + +// MarshalXML implements xml.Marshaler. +func (val *RawXMLValue) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + switch tok := val.tok.(type) { + case xml.StartElement: + if err := e.EncodeToken(tok); err != nil { + return err + } + for _, child := range val.children { + // TODO: find a sensible value for the start argument? + if err := child.MarshalXML(e, xml.StartElement{}); err != nil { + return err + } + } + return e.EncodeToken(tok.End()) + case xml.EndElement: + panic("unexpected end element") + default: + return e.EncodeToken(tok) + } +} + +var _ xml.Marshaler = (*RawXMLValue)(nil) +var _ xml.Unmarshaler = (*RawXMLValue)(nil) + +// TokenReader returns a stream of tokens for the XML value. +func (val *RawXMLValue) TokenReader() xml.TokenReader { + return &rawXMLValueReader{val: val} +} + +type rawXMLValueReader struct { + val *RawXMLValue + start, end bool + child int + childReader xml.TokenReader +} + +func (tr *rawXMLValueReader) Token() (xml.Token, error) { + if tr.end { + return nil, io.EOF + } + + start, ok := tr.val.tok.(xml.StartElement) + if !ok { + tr.end = true + return tr.val.tok, nil + } + + if !tr.start { + tr.start = true + return start, nil + } + + for tr.child < len(tr.val.children) { + if tr.childReader == nil { + tr.childReader = tr.val.children[tr.child].TokenReader() + } + + tok, err := tr.childReader.Token() + if err == io.EOF { + tr.childReader = nil + tr.child++ + } else { + return tok, err + } + } + + tr.end = true + return start.End(), nil +} diff --git a/internal/xml_test.go b/internal/xml_test.go new file mode 100644 index 0000000..084157b --- /dev/null +++ b/internal/xml_test.go @@ -0,0 +1,74 @@ +package internal + +import ( + "bytes" + "encoding/xml" + "io" + "testing" +) + +const rawXML = ` + + + Everyday Italian + Giada De Laurentiis + 2005 + + + + Harry Potter + J K. Rowling + 2005 + +` + +func TestRawXMLValue(t *testing.T) { + // TODO: test XML namespaces too + + var rawValue RawXMLValue + if err := xml.Unmarshal([]byte(rawXML), &rawValue); err != nil { + t.Fatalf("xml.Unmarshal() = %v", err) + } + + b, err := xml.Marshal(&rawValue) + if err != nil { + t.Fatalf("xml.Marshal() = %v", err) + } + + s := xml.Header + string(b) + if s != rawXML { + t.Errorf("input doesn't match output:\n%v\nvs.\n%v", rawXML, s) + } +} + +func TestRawXMLValue_TokenReader(t *testing.T) { + var rawValue RawXMLValue + if err := xml.Unmarshal([]byte(rawXML), &rawValue); err != nil { + t.Fatalf("xml.Unmarshal() = %v", err) + } + + tr := rawValue.TokenReader() + + var buf bytes.Buffer + enc := xml.NewEncoder(&buf) + for { + tok, err := tr.Token() + if err == io.EOF { + break + } else if err != nil { + t.Fatalf("TokenReader.Token() = %v", err) + } + + if err := enc.EncodeToken(tok); err != nil { + t.Fatalf("Encoder.EncodeToken() = %v", err) + } + } + if err := enc.Flush(); err != nil { + t.Fatalf("Encoder.Flush() = %v", err) + } + + s := xml.Header + buf.String() + if s != rawXML { + t.Errorf("input doesn't match output:\n%v\nvs.\n%v", rawXML, s) + } +}