Add context for clients

This commit is contained in:
Simon Ser
2023-12-13 14:37:38 +01:00
parent 0e58dbb003
commit 379a418130
5 changed files with 66 additions and 59 deletions

View File

@@ -1,6 +1,7 @@
package webdav
import (
"context"
"fmt"
"io"
"net/http"
@@ -47,12 +48,12 @@ func NewClient(c HTTPClient, endpoint string) (*Client, error) {
return &Client{ic}, nil
}
func (c *Client) FindCurrentUserPrincipal() (string, error) {
func (c *Client) FindCurrentUserPrincipal(ctx context.Context) (string, error) {
propfind := internal.NewPropNamePropFind(internal.CurrentUserPrincipalName)
// TODO: consider retrying on the root URI "/" if this fails, as suggested
// by the RFC?
resp, err := c.ic.PropFindFlat("", propfind)
resp, err := c.ic.PropFindFlat(ctx, "", propfind)
if err != nil {
return "", err
}
@@ -121,21 +122,21 @@ func fileInfoFromResponse(resp *internal.Response) (*FileInfo, error) {
return fi, nil
}
func (c *Client) Stat(name string) (*FileInfo, error) {
resp, err := c.ic.PropFindFlat(name, fileInfoPropFind)
func (c *Client) Stat(ctx context.Context, name string) (*FileInfo, error) {
resp, err := c.ic.PropFindFlat(ctx, name, fileInfoPropFind)
if err != nil {
return nil, err
}
return fileInfoFromResponse(resp)
}
func (c *Client) Open(name string) (io.ReadCloser, error) {
func (c *Client) Open(ctx context.Context, name string) (io.ReadCloser, error) {
req, err := c.ic.NewRequest(http.MethodGet, name, nil)
if err != nil {
return nil, err
}
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
@@ -143,13 +144,13 @@ func (c *Client) Open(name string) (io.ReadCloser, error) {
return resp.Body, nil
}
func (c *Client) Readdir(name string, recursive bool) ([]FileInfo, error) {
func (c *Client) Readdir(ctx context.Context, name string, recursive bool) ([]FileInfo, error) {
depth := internal.DepthOne
if recursive {
depth = internal.DepthInfinity
}
ms, err := c.ic.PropFind(name, depth, fileInfoPropFind)
ms, err := c.ic.PropFind(ctx, name, depth, fileInfoPropFind)
if err != nil {
return nil, err
}
@@ -182,7 +183,7 @@ func (fw *fileWriter) Close() error {
return <-fw.done
}
func (c *Client) Create(name string) (io.WriteCloser, error) {
func (c *Client) Create(ctx context.Context, name string) (io.WriteCloser, error) {
pr, pw := io.Pipe()
req, err := c.ic.NewRequest(http.MethodPut, name, pr)
@@ -193,7 +194,7 @@ func (c *Client) Create(name string) (io.WriteCloser, error) {
done := make(chan error, 1)
go func() {
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
done <- err
return
@@ -205,13 +206,13 @@ func (c *Client) Create(name string) (io.WriteCloser, error) {
return &fileWriter{pw, done}, nil
}
func (c *Client) RemoveAll(name string) error {
func (c *Client) RemoveAll(ctx context.Context, name string) error {
req, err := c.ic.NewRequest(http.MethodDelete, name, nil)
if err != nil {
return err
}
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return err
}
@@ -219,13 +220,13 @@ func (c *Client) RemoveAll(name string) error {
return nil
}
func (c *Client) Mkdir(name string) error {
func (c *Client) Mkdir(ctx context.Context, name string) error {
req, err := c.ic.NewRequest("MKCOL", name, nil)
if err != nil {
return err
}
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return err
}
@@ -233,7 +234,7 @@ func (c *Client) Mkdir(name string) error {
return nil
}
func (c *Client) CopyAll(name, dest string, overwrite bool) error {
func (c *Client) CopyAll(ctx context.Context, name, dest string, overwrite bool) error {
req, err := c.ic.NewRequest("COPY", name, nil)
if err != nil {
return err
@@ -242,7 +243,7 @@ func (c *Client) CopyAll(name, dest string, overwrite bool) error {
req.Header.Set("Destination", c.ic.ResolveHref(dest).String())
req.Header.Set("Overwrite", internal.FormatOverwrite(overwrite))
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return err
}
@@ -250,7 +251,7 @@ func (c *Client) CopyAll(name, dest string, overwrite bool) error {
return nil
}
func (c *Client) MoveAll(name, dest string, overwrite bool) error {
func (c *Client) MoveAll(ctx context.Context, name, dest string, overwrite bool) error {
req, err := c.ic.NewRequest("MOVE", name, nil)
if err != nil {
return err
@@ -259,7 +260,7 @@ func (c *Client) MoveAll(name, dest string, overwrite bool) error {
req.Header.Set("Destination", c.ic.ResolveHref(dest).String())
req.Header.Set("Overwrite", internal.FormatOverwrite(overwrite))
resp, err := c.ic.Do(req)
resp, err := c.ic.Do(req.WithContext(ctx))
if err != nil {
return err
}