Add context for clients

This commit is contained in:
Simon Ser
2023-12-13 14:37:38 +01:00
parent 0e58dbb003
commit 379a418130
5 changed files with 66 additions and 59 deletions

View File

@@ -113,6 +113,7 @@ func TestAddressBookDiscovery(t *testing.T) {
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
h := Handler{&testBackend{}, tc.prefix}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -135,21 +136,21 @@ func TestAddressBookDiscovery(t *testing.T) {
if err != nil {
t.Fatalf("error creating client: %s", err)
}
cup, err := client.FindCurrentUserPrincipal()
cup, err := client.FindCurrentUserPrincipal(ctx)
if err != nil {
t.Fatalf("error finding user principal url: %s", err)
}
if cup != tc.currentUserPrincipal {
t.Fatalf("Found current user principal URL '%s', expected '%s'", cup, tc.currentUserPrincipal)
}
hsp, err := client.FindAddressBookHomeSet(cup)
hsp, err := client.FindAddressBookHomeSet(ctx, cup)
if err != nil {
t.Fatalf("error finding home set path: %s", err)
}
if hsp != tc.homeSetPath {
t.Fatalf("Found home set path '%s', expected '%s'", hsp, tc.homeSetPath)
}
abs, err := client.FindAddressBooks(hsp)
abs, err := client.FindAddressBooks(ctx, hsp)
if err != nil {
t.Fatalf("error finding address books: %s", err)
}

View File

@@ -2,6 +2,7 @@ package carddav
import (
"bytes"
"context"
"fmt"
"mime"
"net"
@@ -18,9 +19,11 @@ import (
// Discover performs a DNS-based CardDAV service discovery as described in
// RFC 6352 section 11. It returns the URL to the CardDAV server.
func Discover(domain string) (string, error) {
func Discover(ctx context.Context, domain string) (string, error) {
var resolver net.Resolver
// Only lookup carddavs (not carddav), plaintext connections are insecure
_, addrs, err := net.LookupSRV("carddavs", "tcp", domain)
_, addrs, err := resolver.LookupSRV(ctx, "carddavs", "tcp", domain)
if dnsErr, ok := err.(*net.DNSError); ok {
if dnsErr.IsTemporary {
return "", err
@@ -69,8 +72,8 @@ func NewClient(c webdav.HTTPClient, endpoint string) (*Client, error) {
return &Client{wc, ic}, nil
}
func (c *Client) HasSupport() error {
classes, _, err := c.ic.Options("")
func (c *Client) HasSupport(ctx context.Context) error {
classes, _, err := c.ic.Options(ctx, "")
if err != nil {
return err
}
@@ -81,9 +84,9 @@ func (c *Client) HasSupport() error {
return nil
}
func (c *Client) FindAddressBookHomeSet(principal string) (string, error) {
func (c *Client) FindAddressBookHomeSet(ctx context.Context, principal string) (string, error) {
propfind := internal.NewPropNamePropFind(addressBookHomeSetName)
resp, err := c.ic.PropFindFlat(principal, propfind)
resp, err := c.ic.PropFindFlat(ctx, principal, propfind)
if err != nil {
return "", err
}
@@ -104,7 +107,7 @@ func decodeSupportedAddressData(supported *supportedAddressData) []AddressDataTy
return l
}
func (c *Client) FindAddressBooks(addressBookHomeSet string) ([]AddressBook, error) {
func (c *Client) FindAddressBooks(ctx context.Context, addressBookHomeSet string) ([]AddressBook, error) {
propfind := internal.NewPropNamePropFind(
internal.ResourceTypeName,
internal.DisplayNameName,
@@ -112,7 +115,7 @@ func (c *Client) FindAddressBooks(addressBookHomeSet string) ([]AddressBook, err
maxResourceSizeName,
supportedAddressDataName,
)
ms, err := c.ic.PropFind(addressBookHomeSet, internal.DepthOne, propfind)
ms, err := c.ic.PropFind(ctx, addressBookHomeSet, internal.DepthOne, propfind)
if err != nil {
return nil, err
}
@@ -271,7 +274,7 @@ func decodeAddressList(ms *internal.MultiStatus) ([]AddressObject, error) {
return addrs, nil
}
func (c *Client) QueryAddressBook(addressBook string, query *AddressBookQuery) ([]AddressObject, error) {
func (c *Client) QueryAddressBook(ctx context.Context, addressBook string, query *AddressBookQuery) ([]AddressObject, error) {
propReq, err := encodeAddressPropReq(&query.DataRequest)
if err != nil {
return nil, err
@@ -297,7 +300,7 @@ func (c *Client) QueryAddressBook(addressBook string, query *AddressBookQuery) (
req.Header.Add("Depth", "1")
ms, err := c.ic.DoMultiStatus(req)
ms, err := c.ic.DoMultiStatus(req.WithContext(ctx))
if err != nil {
return nil, err
}
@@ -305,7 +308,7 @@ func (c *Client) QueryAddressBook(addressBook string, query *AddressBookQuery) (
return decodeAddressList(ms)
}
func (c *Client) MultiGetAddressBook(path string, multiGet *AddressBookMultiGet) ([]AddressObject, error) {
func (c *Client) MultiGetAddressBook(ctx context.Context, path string, multiGet *AddressBookMultiGet) ([]AddressObject, error) {
propReq, err := encodeAddressPropReq(&multiGet.DataRequest)
if err != nil {
return nil, err
@@ -330,7 +333,7 @@ func (c *Client) MultiGetAddressBook(path string, multiGet *AddressBookMultiGet)
req.Header.Add("Depth", "1")
ms, err := c.ic.DoMultiStatus(req)
ms, err := c.ic.DoMultiStatus(req.WithContext(ctx))
if err != nil {
return nil, err
}
@@ -371,14 +374,14 @@ func populateAddressObject(ao *AddressObject, h http.Header) error {
return nil
}
func (c *Client) GetAddressObject(path string) (*AddressObject, error) {
func (c *Client) GetAddressObject(ctx context.Context, path string) (*AddressObject, error) {
req, err := c.ic.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", vcard.MIMEType)
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@@ -407,7 +410,7 @@ func (c *Client) GetAddressObject(path string) (*AddressObject, error) {
return ao, nil
}
func (c *Client) PutAddressObject(path string, card vcard.Card) (*AddressObject, error) {
func (c *Client) PutAddressObject(ctx context.Context, path string, card vcard.Card) (*AddressObject, error) {
// TODO: add support for If-None-Match and If-Match
// TODO: some servers want a Content-Length header, so we can't stream the
@@ -432,7 +435,7 @@ func (c *Client) PutAddressObject(path string, card vcard.Card) (*AddressObject,
}
req.Header.Set("Content-Type", vcard.MIMEType)
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@@ -447,7 +450,7 @@ func (c *Client) PutAddressObject(path string, card vcard.Card) (*AddressObject,
// SyncCollection performs a collection synchronization operation on the
// specified resource, as defined in RFC 6578.
func (c *Client) SyncCollection(path string, query *SyncQuery) (*SyncResponse, error) {
func (c *Client) SyncCollection(ctx context.Context, path string, query *SyncQuery) (*SyncResponse, error) {
var limit *internal.Limit
if query.Limit > 0 {
limit = &internal.Limit{NResults: uint(query.Limit)}
@@ -458,7 +461,7 @@ func (c *Client) SyncCollection(path string, query *SyncQuery) (*SyncResponse, e
return nil, err
}
ms, err := c.ic.SyncCollection(path, query.SyncToken, internal.DepthOne, limit, propReq)
ms, err := c.ic.SyncCollection(ctx, path, query.SyncToken, internal.DepthOne, limit, propReq)
if err != nil {
return nil, err
}