diff --git a/caldav/client.go b/caldav/client.go index 24a28c0..d0e2cbb 100644 --- a/caldav/client.go +++ b/caldav/client.go @@ -2,6 +2,7 @@ package caldav import ( "bytes" + "context" "fmt" "mime" "net/http" @@ -34,9 +35,9 @@ func NewClient(c webdav.HTTPClient, endpoint string) (*Client, error) { return &Client{wc, ic}, nil } -func (c *Client) FindCalendarHomeSet(principal string) (string, error) { +func (c *Client) FindCalendarHomeSet(ctx context.Context, principal string) (string, error) { propfind := internal.NewPropNamePropFind(calendarHomeSetName) - resp, err := c.ic.PropFindFlat(principal, propfind) + resp, err := c.ic.PropFindFlat(ctx, principal, propfind) if err != nil { return "", err } @@ -49,7 +50,7 @@ func (c *Client) FindCalendarHomeSet(principal string) (string, error) { return prop.Href.Path, nil } -func (c *Client) FindCalendars(calendarHomeSet string) ([]Calendar, error) { +func (c *Client) FindCalendars(ctx context.Context, calendarHomeSet string) ([]Calendar, error) { propfind := internal.NewPropNamePropFind( internal.ResourceTypeName, internal.DisplayNameName, @@ -57,7 +58,7 @@ func (c *Client) FindCalendars(calendarHomeSet string) ([]Calendar, error) { maxResourceSizeName, supportedCalendarComponentSetName, ) - ms, err := c.ic.PropFind(calendarHomeSet, internal.DepthOne, propfind) + ms, err := c.ic.PropFind(ctx, calendarHomeSet, internal.DepthOne, propfind) if err != nil { return nil, err } @@ -214,7 +215,7 @@ func decodeCalendarObjectList(ms *internal.MultiStatus) ([]CalendarObject, error return addrs, nil } -func (c *Client) QueryCalendar(calendar string, query *CalendarQuery) ([]CalendarObject, error) { +func (c *Client) QueryCalendar(ctx context.Context, calendar string, query *CalendarQuery) ([]CalendarObject, error) { propReq, err := encodeCalendarReq(&query.CompRequest) if err != nil { return nil, err @@ -228,7 +229,7 @@ func (c *Client) QueryCalendar(calendar string, query *CalendarQuery) ([]Calenda } req.Header.Add("Depth", "1") - ms, err := c.ic.DoMultiStatus(req) + ms, err := c.ic.DoMultiStatus(req.WithContext(ctx)) if err != nil { return nil, err } @@ -236,7 +237,7 @@ func (c *Client) QueryCalendar(calendar string, query *CalendarQuery) ([]Calenda return decodeCalendarObjectList(ms) } -func (c *Client) MultiGetCalendar(path string, multiGet *CalendarMultiGet) ([]CalendarObject, error) { +func (c *Client) MultiGetCalendar(ctx context.Context, path string, multiGet *CalendarMultiGet) ([]CalendarObject, error) { propReq, err := encodeCalendarReq(&multiGet.CompRequest) if err != nil { return nil, err @@ -260,7 +261,7 @@ func (c *Client) MultiGetCalendar(path string, multiGet *CalendarMultiGet) ([]Ca } req.Header.Add("Depth", "1") - ms, err := c.ic.DoMultiStatus(req) + ms, err := c.ic.DoMultiStatus(req.WithContext(ctx)) if err != nil { return nil, err } @@ -301,14 +302,14 @@ func populateCalendarObject(co *CalendarObject, h http.Header) error { return nil } -func (c *Client) GetCalendarObject(path string) (*CalendarObject, error) { +func (c *Client) GetCalendarObject(ctx context.Context, path string) (*CalendarObject, error) { req, err := c.ic.NewRequest(http.MethodGet, path, nil) if err != nil { return nil, err } req.Header.Set("Accept", ical.MIMEType) - resp, err := c.ic.Do(req) + resp, err := c.ic.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -337,7 +338,7 @@ func (c *Client) GetCalendarObject(path string) (*CalendarObject, error) { return co, nil } -func (c *Client) PutCalendarObject(path string, cal *ical.Calendar) (*CalendarObject, error) { +func (c *Client) PutCalendarObject(ctx context.Context, path string, cal *ical.Calendar) (*CalendarObject, error) { // TODO: add support for If-None-Match and If-Match // TODO: some servers want a Content-Length header, so we can't stream the @@ -355,7 +356,7 @@ func (c *Client) PutCalendarObject(path string, cal *ical.Calendar) (*CalendarOb } req.Header.Set("Content-Type", ical.MIMEType) - resp, err := c.ic.Do(req) + resp, err := c.ic.Do(req.WithContext(ctx)) if err != nil { return nil, err } diff --git a/carddav/carddav_test.go b/carddav/carddav_test.go index 3a4b87f..392dcff 100644 --- a/carddav/carddav_test.go +++ b/carddav/carddav_test.go @@ -113,6 +113,7 @@ func TestAddressBookDiscovery(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() h := Handler{&testBackend{}, tc.prefix} ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -135,21 +136,21 @@ func TestAddressBookDiscovery(t *testing.T) { if err != nil { t.Fatalf("error creating client: %s", err) } - cup, err := client.FindCurrentUserPrincipal() + cup, err := client.FindCurrentUserPrincipal(ctx) if err != nil { t.Fatalf("error finding user principal url: %s", err) } if cup != tc.currentUserPrincipal { t.Fatalf("Found current user principal URL '%s', expected '%s'", cup, tc.currentUserPrincipal) } - hsp, err := client.FindAddressBookHomeSet(cup) + hsp, err := client.FindAddressBookHomeSet(ctx, cup) if err != nil { t.Fatalf("error finding home set path: %s", err) } if hsp != tc.homeSetPath { t.Fatalf("Found home set path '%s', expected '%s'", hsp, tc.homeSetPath) } - abs, err := client.FindAddressBooks(hsp) + abs, err := client.FindAddressBooks(ctx, hsp) if err != nil { t.Fatalf("error finding address books: %s", err) } diff --git a/carddav/client.go b/carddav/client.go index 0f35180..859022a 100644 --- a/carddav/client.go +++ b/carddav/client.go @@ -2,6 +2,7 @@ package carddav import ( "bytes" + "context" "fmt" "mime" "net" @@ -18,9 +19,11 @@ import ( // Discover performs a DNS-based CardDAV service discovery as described in // RFC 6352 section 11. It returns the URL to the CardDAV server. -func Discover(domain string) (string, error) { +func Discover(ctx context.Context, domain string) (string, error) { + var resolver net.Resolver + // Only lookup carddavs (not carddav), plaintext connections are insecure - _, addrs, err := net.LookupSRV("carddavs", "tcp", domain) + _, addrs, err := resolver.LookupSRV(ctx, "carddavs", "tcp", domain) if dnsErr, ok := err.(*net.DNSError); ok { if dnsErr.IsTemporary { return "", err @@ -69,8 +72,8 @@ func NewClient(c webdav.HTTPClient, endpoint string) (*Client, error) { return &Client{wc, ic}, nil } -func (c *Client) HasSupport() error { - classes, _, err := c.ic.Options("") +func (c *Client) HasSupport(ctx context.Context) error { + classes, _, err := c.ic.Options(ctx, "") if err != nil { return err } @@ -81,9 +84,9 @@ func (c *Client) HasSupport() error { return nil } -func (c *Client) FindAddressBookHomeSet(principal string) (string, error) { +func (c *Client) FindAddressBookHomeSet(ctx context.Context, principal string) (string, error) { propfind := internal.NewPropNamePropFind(addressBookHomeSetName) - resp, err := c.ic.PropFindFlat(principal, propfind) + resp, err := c.ic.PropFindFlat(ctx, principal, propfind) if err != nil { return "", err } @@ -104,7 +107,7 @@ func decodeSupportedAddressData(supported *supportedAddressData) []AddressDataTy return l } -func (c *Client) FindAddressBooks(addressBookHomeSet string) ([]AddressBook, error) { +func (c *Client) FindAddressBooks(ctx context.Context, addressBookHomeSet string) ([]AddressBook, error) { propfind := internal.NewPropNamePropFind( internal.ResourceTypeName, internal.DisplayNameName, @@ -112,7 +115,7 @@ func (c *Client) FindAddressBooks(addressBookHomeSet string) ([]AddressBook, err maxResourceSizeName, supportedAddressDataName, ) - ms, err := c.ic.PropFind(addressBookHomeSet, internal.DepthOne, propfind) + ms, err := c.ic.PropFind(ctx, addressBookHomeSet, internal.DepthOne, propfind) if err != nil { return nil, err } @@ -271,7 +274,7 @@ func decodeAddressList(ms *internal.MultiStatus) ([]AddressObject, error) { return addrs, nil } -func (c *Client) QueryAddressBook(addressBook string, query *AddressBookQuery) ([]AddressObject, error) { +func (c *Client) QueryAddressBook(ctx context.Context, addressBook string, query *AddressBookQuery) ([]AddressObject, error) { propReq, err := encodeAddressPropReq(&query.DataRequest) if err != nil { return nil, err @@ -297,7 +300,7 @@ func (c *Client) QueryAddressBook(addressBook string, query *AddressBookQuery) ( req.Header.Add("Depth", "1") - ms, err := c.ic.DoMultiStatus(req) + ms, err := c.ic.DoMultiStatus(req.WithContext(ctx)) if err != nil { return nil, err } @@ -305,7 +308,7 @@ func (c *Client) QueryAddressBook(addressBook string, query *AddressBookQuery) ( return decodeAddressList(ms) } -func (c *Client) MultiGetAddressBook(path string, multiGet *AddressBookMultiGet) ([]AddressObject, error) { +func (c *Client) MultiGetAddressBook(ctx context.Context, path string, multiGet *AddressBookMultiGet) ([]AddressObject, error) { propReq, err := encodeAddressPropReq(&multiGet.DataRequest) if err != nil { return nil, err @@ -330,7 +333,7 @@ func (c *Client) MultiGetAddressBook(path string, multiGet *AddressBookMultiGet) req.Header.Add("Depth", "1") - ms, err := c.ic.DoMultiStatus(req) + ms, err := c.ic.DoMultiStatus(req.WithContext(ctx)) if err != nil { return nil, err } @@ -371,14 +374,14 @@ func populateAddressObject(ao *AddressObject, h http.Header) error { return nil } -func (c *Client) GetAddressObject(path string) (*AddressObject, error) { +func (c *Client) GetAddressObject(ctx context.Context, path string) (*AddressObject, error) { req, err := c.ic.NewRequest(http.MethodGet, path, nil) if err != nil { return nil, err } req.Header.Set("Accept", vcard.MIMEType) - resp, err := c.ic.Do(req) + resp, err := c.ic.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -407,7 +410,7 @@ func (c *Client) GetAddressObject(path string) (*AddressObject, error) { return ao, nil } -func (c *Client) PutAddressObject(path string, card vcard.Card) (*AddressObject, error) { +func (c *Client) PutAddressObject(ctx context.Context, path string, card vcard.Card) (*AddressObject, error) { // TODO: add support for If-None-Match and If-Match // TODO: some servers want a Content-Length header, so we can't stream the @@ -432,7 +435,7 @@ func (c *Client) PutAddressObject(path string, card vcard.Card) (*AddressObject, } req.Header.Set("Content-Type", vcard.MIMEType) - resp, err := c.ic.Do(req) + resp, err := c.ic.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -447,7 +450,7 @@ func (c *Client) PutAddressObject(path string, card vcard.Card) (*AddressObject, // SyncCollection performs a collection synchronization operation on the // specified resource, as defined in RFC 6578. -func (c *Client) SyncCollection(path string, query *SyncQuery) (*SyncResponse, error) { +func (c *Client) SyncCollection(ctx context.Context, path string, query *SyncQuery) (*SyncResponse, error) { var limit *internal.Limit if query.Limit > 0 { limit = &internal.Limit{NResults: uint(query.Limit)} @@ -458,7 +461,7 @@ func (c *Client) SyncCollection(path string, query *SyncQuery) (*SyncResponse, e return nil, err } - ms, err := c.ic.SyncCollection(path, query.SyncToken, internal.DepthOne, limit, propReq) + ms, err := c.ic.SyncCollection(ctx, path, query.SyncToken, internal.DepthOne, limit, propReq) if err != nil { return nil, err } diff --git a/client.go b/client.go index 14b0a3a..e9cf259 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package webdav import ( + "context" "fmt" "io" "net/http" @@ -47,12 +48,12 @@ func NewClient(c HTTPClient, endpoint string) (*Client, error) { return &Client{ic}, nil } -func (c *Client) FindCurrentUserPrincipal() (string, error) { +func (c *Client) FindCurrentUserPrincipal(ctx context.Context) (string, error) { propfind := internal.NewPropNamePropFind(internal.CurrentUserPrincipalName) // TODO: consider retrying on the root URI "/" if this fails, as suggested // by the RFC? - resp, err := c.ic.PropFindFlat("", propfind) + resp, err := c.ic.PropFindFlat(ctx, "", propfind) if err != nil { return "", err } @@ -121,21 +122,21 @@ func fileInfoFromResponse(resp *internal.Response) (*FileInfo, error) { return fi, nil } -func (c *Client) Stat(name string) (*FileInfo, error) { - resp, err := c.ic.PropFindFlat(name, fileInfoPropFind) +func (c *Client) Stat(ctx context.Context, name string) (*FileInfo, error) { + resp, err := c.ic.PropFindFlat(ctx, name, fileInfoPropFind) if err != nil { return nil, err } return fileInfoFromResponse(resp) } -func (c *Client) Open(name string) (io.ReadCloser, error) { +func (c *Client) Open(ctx context.Context, name string) (io.ReadCloser, error) { req, err := c.ic.NewRequest(http.MethodGet, name, nil) if err != nil { return nil, err } - resp, err := c.ic.Do(req) + resp, err := c.ic.Do(req.WithContext(ctx)) if err != nil { return nil, err } @@ -143,13 +144,13 @@ func (c *Client) Open(name string) (io.ReadCloser, error) { return resp.Body, nil } -func (c *Client) Readdir(name string, recursive bool) ([]FileInfo, error) { +func (c *Client) Readdir(ctx context.Context, name string, recursive bool) ([]FileInfo, error) { depth := internal.DepthOne if recursive { depth = internal.DepthInfinity } - ms, err := c.ic.PropFind(name, depth, fileInfoPropFind) + ms, err := c.ic.PropFind(ctx, name, depth, fileInfoPropFind) if err != nil { return nil, err } @@ -182,7 +183,7 @@ func (fw *fileWriter) Close() error { return <-fw.done } -func (c *Client) Create(name string) (io.WriteCloser, error) { +func (c *Client) Create(ctx context.Context, name string) (io.WriteCloser, error) { pr, pw := io.Pipe() req, err := c.ic.NewRequest(http.MethodPut, name, pr) @@ -193,7 +194,7 @@ func (c *Client) Create(name string) (io.WriteCloser, error) { done := make(chan error, 1) go func() { - resp, err := c.ic.Do(req) + resp, err := c.ic.Do(req.WithContext(ctx)) if err != nil { done <- err return @@ -205,13 +206,13 @@ func (c *Client) Create(name string) (io.WriteCloser, error) { return &fileWriter{pw, done}, nil } -func (c *Client) RemoveAll(name string) error { +func (c *Client) RemoveAll(ctx context.Context, name string) error { req, err := c.ic.NewRequest(http.MethodDelete, name, nil) if err != nil { return err } - resp, err := c.ic.Do(req) + resp, err := c.ic.Do(req.WithContext(ctx)) if err != nil { return err } @@ -219,13 +220,13 @@ func (c *Client) RemoveAll(name string) error { return nil } -func (c *Client) Mkdir(name string) error { +func (c *Client) Mkdir(ctx context.Context, name string) error { req, err := c.ic.NewRequest("MKCOL", name, nil) if err != nil { return err } - resp, err := c.ic.Do(req) + resp, err := c.ic.Do(req.WithContext(ctx)) if err != nil { return err } @@ -233,7 +234,7 @@ func (c *Client) Mkdir(name string) error { return nil } -func (c *Client) CopyAll(name, dest string, overwrite bool) error { +func (c *Client) CopyAll(ctx context.Context, name, dest string, overwrite bool) error { req, err := c.ic.NewRequest("COPY", name, nil) if err != nil { return err @@ -242,7 +243,7 @@ func (c *Client) CopyAll(name, dest string, overwrite bool) error { req.Header.Set("Destination", c.ic.ResolveHref(dest).String()) req.Header.Set("Overwrite", internal.FormatOverwrite(overwrite)) - resp, err := c.ic.Do(req) + resp, err := c.ic.Do(req.WithContext(ctx)) if err != nil { return err } @@ -250,7 +251,7 @@ func (c *Client) CopyAll(name, dest string, overwrite bool) error { return nil } -func (c *Client) MoveAll(name, dest string, overwrite bool) error { +func (c *Client) MoveAll(ctx context.Context, name, dest string, overwrite bool) error { req, err := c.ic.NewRequest("MOVE", name, nil) if err != nil { return err @@ -259,7 +260,7 @@ func (c *Client) MoveAll(name, dest string, overwrite bool) error { req.Header.Set("Destination", c.ic.ResolveHref(dest).String()) req.Header.Set("Overwrite", internal.FormatOverwrite(overwrite)) - resp, err := c.ic.Do(req) + resp, err := c.ic.Do(req.WithContext(ctx)) if err != nil { return err } diff --git a/internal/client.go b/internal/client.go index 76c98bf..39449ef 100644 --- a/internal/client.go +++ b/internal/client.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "encoding/xml" "fmt" "io" @@ -131,7 +132,7 @@ func (c *Client) DoMultiStatus(req *http.Request) (*MultiStatus, error) { return &ms, nil } -func (c *Client) PropFind(path string, depth Depth, propfind *PropFind) (*MultiStatus, error) { +func (c *Client) PropFind(ctx context.Context, path string, depth Depth, propfind *PropFind) (*MultiStatus, error) { req, err := c.NewXMLRequest("PROPFIND", path, propfind) if err != nil { return nil, err @@ -139,12 +140,12 @@ func (c *Client) PropFind(path string, depth Depth, propfind *PropFind) (*MultiS req.Header.Add("Depth", depth.String()) - return c.DoMultiStatus(req) + return c.DoMultiStatus(req.WithContext(ctx)) } // PropfindFlat performs a PROPFIND request with a zero depth. -func (c *Client) PropFindFlat(path string, propfind *PropFind) (*Response, error) { - ms, err := c.PropFind(path, DepthZero, propfind) +func (c *Client) PropFindFlat(ctx context.Context, path string, propfind *PropFind) (*Response, error) { + ms, err := c.PropFind(ctx, path, DepthZero, propfind) if err != nil { return nil, err } @@ -174,13 +175,13 @@ func parseCommaSeparatedSet(values []string, upper bool) map[string]bool { return m } -func (c *Client) Options(path string) (classes map[string]bool, methods map[string]bool, err error) { +func (c *Client) Options(ctx context.Context, path string) (classes map[string]bool, methods map[string]bool, err error) { req, err := c.NewRequest(http.MethodOptions, path, nil) if err != nil { return nil, nil, err } - resp, err := c.Do(req) + resp, err := c.Do(req.WithContext(ctx)) if err != nil { return nil, nil, err } @@ -196,7 +197,7 @@ func (c *Client) Options(path string) (classes map[string]bool, methods map[stri } // SyncCollection perform a `sync-collection` REPORT operation on a resource -func (c *Client) SyncCollection(path, syncToken string, level Depth, limit *Limit, prop *Prop) (*MultiStatus, error) { +func (c *Client) SyncCollection(ctx context.Context, path, syncToken string, level Depth, limit *Limit, prop *Prop) (*MultiStatus, error) { q := SyncCollectionQuery{ SyncToken: syncToken, SyncLevel: level.String(), @@ -209,7 +210,7 @@ func (c *Client) SyncCollection(path, syncToken string, level Depth, limit *Limi return nil, err } - ms, err := c.DoMultiStatus(req) + ms, err := c.DoMultiStatus(req.WithContext(ctx)) if err != nil { return nil, err }