Add context for clients

This commit is contained in:
Simon Ser 2023-12-13 14:37:38 +01:00
parent 0e58dbb003
commit 379a418130
5 changed files with 66 additions and 59 deletions

View File

@ -2,6 +2,7 @@ package caldav
import (
"bytes"
"context"
"fmt"
"mime"
"net/http"
@ -34,9 +35,9 @@ func NewClient(c webdav.HTTPClient, endpoint string) (*Client, error) {
return &Client{wc, ic}, nil
}
func (c *Client) FindCalendarHomeSet(principal string) (string, error) {
func (c *Client) FindCalendarHomeSet(ctx context.Context, principal string) (string, error) {
propfind := internal.NewPropNamePropFind(calendarHomeSetName)
resp, err := c.ic.PropFindFlat(principal, propfind)
resp, err := c.ic.PropFindFlat(ctx, principal, propfind)
if err != nil {
return "", err
}
@ -49,7 +50,7 @@ func (c *Client) FindCalendarHomeSet(principal string) (string, error) {
return prop.Href.Path, nil
}
func (c *Client) FindCalendars(calendarHomeSet string) ([]Calendar, error) {
func (c *Client) FindCalendars(ctx context.Context, calendarHomeSet string) ([]Calendar, error) {
propfind := internal.NewPropNamePropFind(
internal.ResourceTypeName,
internal.DisplayNameName,
@ -57,7 +58,7 @@ func (c *Client) FindCalendars(calendarHomeSet string) ([]Calendar, error) {
maxResourceSizeName,
supportedCalendarComponentSetName,
)
ms, err := c.ic.PropFind(calendarHomeSet, internal.DepthOne, propfind)
ms, err := c.ic.PropFind(ctx, calendarHomeSet, internal.DepthOne, propfind)
if err != nil {
return nil, err
}
@ -214,7 +215,7 @@ func decodeCalendarObjectList(ms *internal.MultiStatus) ([]CalendarObject, error
return addrs, nil
}
func (c *Client) QueryCalendar(calendar string, query *CalendarQuery) ([]CalendarObject, error) {
func (c *Client) QueryCalendar(ctx context.Context, calendar string, query *CalendarQuery) ([]CalendarObject, error) {
propReq, err := encodeCalendarReq(&query.CompRequest)
if err != nil {
return nil, err
@ -228,7 +229,7 @@ func (c *Client) QueryCalendar(calendar string, query *CalendarQuery) ([]Calenda
}
req.Header.Add("Depth", "1")
ms, err := c.ic.DoMultiStatus(req)
ms, err := c.ic.DoMultiStatus(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -236,7 +237,7 @@ func (c *Client) QueryCalendar(calendar string, query *CalendarQuery) ([]Calenda
return decodeCalendarObjectList(ms)
}
func (c *Client) MultiGetCalendar(path string, multiGet *CalendarMultiGet) ([]CalendarObject, error) {
func (c *Client) MultiGetCalendar(ctx context.Context, path string, multiGet *CalendarMultiGet) ([]CalendarObject, error) {
propReq, err := encodeCalendarReq(&multiGet.CompRequest)
if err != nil {
return nil, err
@ -260,7 +261,7 @@ func (c *Client) MultiGetCalendar(path string, multiGet *CalendarMultiGet) ([]Ca
}
req.Header.Add("Depth", "1")
ms, err := c.ic.DoMultiStatus(req)
ms, err := c.ic.DoMultiStatus(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -301,14 +302,14 @@ func populateCalendarObject(co *CalendarObject, h http.Header) error {
return nil
}
func (c *Client) GetCalendarObject(path string) (*CalendarObject, error) {
func (c *Client) GetCalendarObject(ctx context.Context, path string) (*CalendarObject, error) {
req, err := c.ic.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", ical.MIMEType)
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -337,7 +338,7 @@ func (c *Client) GetCalendarObject(path string) (*CalendarObject, error) {
return co, nil
}
func (c *Client) PutCalendarObject(path string, cal *ical.Calendar) (*CalendarObject, error) {
func (c *Client) PutCalendarObject(ctx context.Context, path string, cal *ical.Calendar) (*CalendarObject, error) {
// TODO: add support for If-None-Match and If-Match
// TODO: some servers want a Content-Length header, so we can't stream the
@ -355,7 +356,7 @@ func (c *Client) PutCalendarObject(path string, cal *ical.Calendar) (*CalendarOb
}
req.Header.Set("Content-Type", ical.MIMEType)
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}

View File

@ -113,6 +113,7 @@ func TestAddressBookDiscovery(t *testing.T) {
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
h := Handler{&testBackend{}, tc.prefix}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -135,21 +136,21 @@ func TestAddressBookDiscovery(t *testing.T) {
if err != nil {
t.Fatalf("error creating client: %s", err)
}
cup, err := client.FindCurrentUserPrincipal()
cup, err := client.FindCurrentUserPrincipal(ctx)
if err != nil {
t.Fatalf("error finding user principal url: %s", err)
}
if cup != tc.currentUserPrincipal {
t.Fatalf("Found current user principal URL '%s', expected '%s'", cup, tc.currentUserPrincipal)
}
hsp, err := client.FindAddressBookHomeSet(cup)
hsp, err := client.FindAddressBookHomeSet(ctx, cup)
if err != nil {
t.Fatalf("error finding home set path: %s", err)
}
if hsp != tc.homeSetPath {
t.Fatalf("Found home set path '%s', expected '%s'", hsp, tc.homeSetPath)
}
abs, err := client.FindAddressBooks(hsp)
abs, err := client.FindAddressBooks(ctx, hsp)
if err != nil {
t.Fatalf("error finding address books: %s", err)
}

View File

@ -2,6 +2,7 @@ package carddav
import (
"bytes"
"context"
"fmt"
"mime"
"net"
@ -18,9 +19,11 @@ import (
// Discover performs a DNS-based CardDAV service discovery as described in
// RFC 6352 section 11. It returns the URL to the CardDAV server.
func Discover(domain string) (string, error) {
func Discover(ctx context.Context, domain string) (string, error) {
var resolver net.Resolver
// Only lookup carddavs (not carddav), plaintext connections are insecure
_, addrs, err := net.LookupSRV("carddavs", "tcp", domain)
_, addrs, err := resolver.LookupSRV(ctx, "carddavs", "tcp", domain)
if dnsErr, ok := err.(*net.DNSError); ok {
if dnsErr.IsTemporary {
return "", err
@ -69,8 +72,8 @@ func NewClient(c webdav.HTTPClient, endpoint string) (*Client, error) {
return &Client{wc, ic}, nil
}
func (c *Client) HasSupport() error {
classes, _, err := c.ic.Options("")
func (c *Client) HasSupport(ctx context.Context) error {
classes, _, err := c.ic.Options(ctx, "")
if err != nil {
return err
}
@ -81,9 +84,9 @@ func (c *Client) HasSupport() error {
return nil
}
func (c *Client) FindAddressBookHomeSet(principal string) (string, error) {
func (c *Client) FindAddressBookHomeSet(ctx context.Context, principal string) (string, error) {
propfind := internal.NewPropNamePropFind(addressBookHomeSetName)
resp, err := c.ic.PropFindFlat(principal, propfind)
resp, err := c.ic.PropFindFlat(ctx, principal, propfind)
if err != nil {
return "", err
}
@ -104,7 +107,7 @@ func decodeSupportedAddressData(supported *supportedAddressData) []AddressDataTy
return l
}
func (c *Client) FindAddressBooks(addressBookHomeSet string) ([]AddressBook, error) {
func (c *Client) FindAddressBooks(ctx context.Context, addressBookHomeSet string) ([]AddressBook, error) {
propfind := internal.NewPropNamePropFind(
internal.ResourceTypeName,
internal.DisplayNameName,
@ -112,7 +115,7 @@ func (c *Client) FindAddressBooks(addressBookHomeSet string) ([]AddressBook, err
maxResourceSizeName,
supportedAddressDataName,
)
ms, err := c.ic.PropFind(addressBookHomeSet, internal.DepthOne, propfind)
ms, err := c.ic.PropFind(ctx, addressBookHomeSet, internal.DepthOne, propfind)
if err != nil {
return nil, err
}
@ -271,7 +274,7 @@ func decodeAddressList(ms *internal.MultiStatus) ([]AddressObject, error) {
return addrs, nil
}
func (c *Client) QueryAddressBook(addressBook string, query *AddressBookQuery) ([]AddressObject, error) {
func (c *Client) QueryAddressBook(ctx context.Context, addressBook string, query *AddressBookQuery) ([]AddressObject, error) {
propReq, err := encodeAddressPropReq(&query.DataRequest)
if err != nil {
return nil, err
@ -297,7 +300,7 @@ func (c *Client) QueryAddressBook(addressBook string, query *AddressBookQuery) (
req.Header.Add("Depth", "1")
ms, err := c.ic.DoMultiStatus(req)
ms, err := c.ic.DoMultiStatus(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -305,7 +308,7 @@ func (c *Client) QueryAddressBook(addressBook string, query *AddressBookQuery) (
return decodeAddressList(ms)
}
func (c *Client) MultiGetAddressBook(path string, multiGet *AddressBookMultiGet) ([]AddressObject, error) {
func (c *Client) MultiGetAddressBook(ctx context.Context, path string, multiGet *AddressBookMultiGet) ([]AddressObject, error) {
propReq, err := encodeAddressPropReq(&multiGet.DataRequest)
if err != nil {
return nil, err
@ -330,7 +333,7 @@ func (c *Client) MultiGetAddressBook(path string, multiGet *AddressBookMultiGet)
req.Header.Add("Depth", "1")
ms, err := c.ic.DoMultiStatus(req)
ms, err := c.ic.DoMultiStatus(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -371,14 +374,14 @@ func populateAddressObject(ao *AddressObject, h http.Header) error {
return nil
}
func (c *Client) GetAddressObject(path string) (*AddressObject, error) {
func (c *Client) GetAddressObject(ctx context.Context, path string) (*AddressObject, error) {
req, err := c.ic.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", vcard.MIMEType)
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -407,7 +410,7 @@ func (c *Client) GetAddressObject(path string) (*AddressObject, error) {
return ao, nil
}
func (c *Client) PutAddressObject(path string, card vcard.Card) (*AddressObject, error) {
func (c *Client) PutAddressObject(ctx context.Context, path string, card vcard.Card) (*AddressObject, error) {
// TODO: add support for If-None-Match and If-Match
// TODO: some servers want a Content-Length header, so we can't stream the
@ -432,7 +435,7 @@ func (c *Client) PutAddressObject(path string, card vcard.Card) (*AddressObject,
}
req.Header.Set("Content-Type", vcard.MIMEType)
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -447,7 +450,7 @@ func (c *Client) PutAddressObject(path string, card vcard.Card) (*AddressObject,
// SyncCollection performs a collection synchronization operation on the
// specified resource, as defined in RFC 6578.
func (c *Client) SyncCollection(path string, query *SyncQuery) (*SyncResponse, error) {
func (c *Client) SyncCollection(ctx context.Context, path string, query *SyncQuery) (*SyncResponse, error) {
var limit *internal.Limit
if query.Limit > 0 {
limit = &internal.Limit{NResults: uint(query.Limit)}
@ -458,7 +461,7 @@ func (c *Client) SyncCollection(path string, query *SyncQuery) (*SyncResponse, e
return nil, err
}
ms, err := c.ic.SyncCollection(path, query.SyncToken, internal.DepthOne, limit, propReq)
ms, err := c.ic.SyncCollection(ctx, path, query.SyncToken, internal.DepthOne, limit, propReq)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,7 @@
package webdav
import (
"context"
"fmt"
"io"
"net/http"
@ -47,12 +48,12 @@ func NewClient(c HTTPClient, endpoint string) (*Client, error) {
return &Client{ic}, nil
}
func (c *Client) FindCurrentUserPrincipal() (string, error) {
func (c *Client) FindCurrentUserPrincipal(ctx context.Context) (string, error) {
propfind := internal.NewPropNamePropFind(internal.CurrentUserPrincipalName)
// TODO: consider retrying on the root URI "/" if this fails, as suggested
// by the RFC?
resp, err := c.ic.PropFindFlat("", propfind)
resp, err := c.ic.PropFindFlat(ctx, "", propfind)
if err != nil {
return "", err
}
@ -121,21 +122,21 @@ func fileInfoFromResponse(resp *internal.Response) (*FileInfo, error) {
return fi, nil
}
func (c *Client) Stat(name string) (*FileInfo, error) {
resp, err := c.ic.PropFindFlat(name, fileInfoPropFind)
func (c *Client) Stat(ctx context.Context, name string) (*FileInfo, error) {
resp, err := c.ic.PropFindFlat(ctx, name, fileInfoPropFind)
if err != nil {
return nil, err
}
return fileInfoFromResponse(resp)
}
func (c *Client) Open(name string) (io.ReadCloser, error) {
func (c *Client) Open(ctx context.Context, name string) (io.ReadCloser, error) {
req, err := c.ic.NewRequest(http.MethodGet, name, nil)
if err != nil {
return nil, err
}
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@ -143,13 +144,13 @@ func (c *Client) Open(name string) (io.ReadCloser, error) {
return resp.Body, nil
}
func (c *Client) Readdir(name string, recursive bool) ([]FileInfo, error) {
func (c *Client) Readdir(ctx context.Context, name string, recursive bool) ([]FileInfo, error) {
depth := internal.DepthOne
if recursive {
depth = internal.DepthInfinity
}
ms, err := c.ic.PropFind(name, depth, fileInfoPropFind)
ms, err := c.ic.PropFind(ctx, name, depth, fileInfoPropFind)
if err != nil {
return nil, err
}
@ -182,7 +183,7 @@ func (fw *fileWriter) Close() error {
return <-fw.done
}
func (c *Client) Create(name string) (io.WriteCloser, error) {
func (c *Client) Create(ctx context.Context, name string) (io.WriteCloser, error) {
pr, pw := io.Pipe()
req, err := c.ic.NewRequest(http.MethodPut, name, pr)
@ -193,7 +194,7 @@ func (c *Client) Create(name string) (io.WriteCloser, error) {
done := make(chan error, 1)
go func() {
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
done <- err
return
@ -205,13 +206,13 @@ func (c *Client) Create(name string) (io.WriteCloser, error) {
return &fileWriter{pw, done}, nil
}
func (c *Client) RemoveAll(name string) error {
func (c *Client) RemoveAll(ctx context.Context, name string) error {
req, err := c.ic.NewRequest(http.MethodDelete, name, nil)
if err != nil {
return err
}
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return err
}
@ -219,13 +220,13 @@ func (c *Client) RemoveAll(name string) error {
return nil
}
func (c *Client) Mkdir(name string) error {
func (c *Client) Mkdir(ctx context.Context, name string) error {
req, err := c.ic.NewRequest("MKCOL", name, nil)
if err != nil {
return err
}
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return err
}
@ -233,7 +234,7 @@ func (c *Client) Mkdir(name string) error {
return nil
}
func (c *Client) CopyAll(name, dest string, overwrite bool) error {
func (c *Client) CopyAll(ctx context.Context, name, dest string, overwrite bool) error {
req, err := c.ic.NewRequest("COPY", name, nil)
if err != nil {
return err
@ -242,7 +243,7 @@ func (c *Client) CopyAll(name, dest string, overwrite bool) error {
req.Header.Set("Destination", c.ic.ResolveHref(dest).String())
req.Header.Set("Overwrite", internal.FormatOverwrite(overwrite))
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return err
}
@ -250,7 +251,7 @@ func (c *Client) CopyAll(name, dest string, overwrite bool) error {
return nil
}
func (c *Client) MoveAll(name, dest string, overwrite bool) error {
func (c *Client) MoveAll(ctx context.Context, name, dest string, overwrite bool) error {
req, err := c.ic.NewRequest("MOVE", name, nil)
if err != nil {
return err
@ -259,7 +260,7 @@ func (c *Client) MoveAll(name, dest string, overwrite bool) error {
req.Header.Set("Destination", c.ic.ResolveHref(dest).String())
req.Header.Set("Overwrite", internal.FormatOverwrite(overwrite))
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return err
}

View File

@ -2,6 +2,7 @@ package internal
import (
"bytes"
"context"
"encoding/xml"
"fmt"
"io"
@ -131,7 +132,7 @@ func (c *Client) DoMultiStatus(req *http.Request) (*MultiStatus, error) {
return &ms, nil
}
func (c *Client) PropFind(path string, depth Depth, propfind *PropFind) (*MultiStatus, error) {
func (c *Client) PropFind(ctx context.Context, path string, depth Depth, propfind *PropFind) (*MultiStatus, error) {
req, err := c.NewXMLRequest("PROPFIND", path, propfind)
if err != nil {
return nil, err
@ -139,12 +140,12 @@ func (c *Client) PropFind(path string, depth Depth, propfind *PropFind) (*MultiS
req.Header.Add("Depth", depth.String())
return c.DoMultiStatus(req)
return c.DoMultiStatus(req.WithContext(ctx))
}
// PropfindFlat performs a PROPFIND request with a zero depth.
func (c *Client) PropFindFlat(path string, propfind *PropFind) (*Response, error) {
ms, err := c.PropFind(path, DepthZero, propfind)
func (c *Client) PropFindFlat(ctx context.Context, path string, propfind *PropFind) (*Response, error) {
ms, err := c.PropFind(ctx, path, DepthZero, propfind)
if err != nil {
return nil, err
}
@ -174,13 +175,13 @@ func parseCommaSeparatedSet(values []string, upper bool) map[string]bool {
return m
}
func (c *Client) Options(path string) (classes map[string]bool, methods map[string]bool, err error) {
func (c *Client) Options(ctx context.Context, path string) (classes map[string]bool, methods map[string]bool, err error) {
req, err := c.NewRequest(http.MethodOptions, path, nil)
if err != nil {
return nil, nil, err
}
resp, err := c.Do(req)
resp, err := c.Do(req.WithContext(ctx))
if err != nil {
return nil, nil, err
}
@ -196,7 +197,7 @@ func (c *Client) Options(path string) (classes map[string]bool, methods map[stri
}
// SyncCollection perform a `sync-collection` REPORT operation on a resource
func (c *Client) SyncCollection(path, syncToken string, level Depth, limit *Limit, prop *Prop) (*MultiStatus, error) {
func (c *Client) SyncCollection(ctx context.Context, path, syncToken string, level Depth, limit *Limit, prop *Prop) (*MultiStatus, error) {
q := SyncCollectionQuery{
SyncToken: syncToken,
SyncLevel: level.String(),
@ -209,7 +210,7 @@ func (c *Client) SyncCollection(path, syncToken string, level Depth, limit *Limi
return nil, err
}
ms, err := c.DoMultiStatus(req)
ms, err := c.DoMultiStatus(req.WithContext(ctx))
if err != nil {
return nil, err
}