fix: HEAD requests fetch and stream the full body (#89)
Fixes #70 ## Why Docker `HEAD` routes mapped to `handleProxy`, which ran a full `Fetch` + `io.Copy` — downloading the entire blob (and fetching upstream on a miss) only for net/http to discard the body. HEAD existence checks (manifests, blobs) are common. ## Changes - Add `Engine.Head`: answers cached artifacts/indexes from store metadata (no blob download); on a miss issues an upstream `HEAD` (with bearer-token handling) and never caches a body. - Route `HEAD /v2/{remote}/*` to a dedicated `handleProxyHead` that writes headers only. - Add e2e tests for HEAD on a blocklisted path (403) and an unknown remote (404). ## Note `headUpstream` uses `http.DefaultClient` to build cleanly on master; it will pick up the shared timeout-configured client from #67 once that merges. ## Validation - `make e2e` passes (includes new HEAD tests). Reviewed-on: #89 Co-authored-by: Ben Vincent <ben@unkin.net> Co-committed-by: Ben Vincent <ben@unkin.net>
This commit was merged in pull request #89.
This commit is contained in:
@@ -24,6 +24,39 @@ func TestProxyBlocklist(t *testing.T) {
|
|||||||
assertStatus(t, apiURL("/api/v1/remote/blocklist-test/malware.exe"), http.StatusForbidden)
|
assertStatus(t, apiURL("/api/v1/remote/blocklist-test/malware.exe"), http.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyHeadBlocklist(t *testing.T) {
|
||||||
|
createRemote(t, `{
|
||||||
|
"name": "head-block-test",
|
||||||
|
"package_type": "generic",
|
||||||
|
"base_url": "https://example.com",
|
||||||
|
"blocklist": ["\\.exe$"],
|
||||||
|
"stale_on_error": true
|
||||||
|
}`)
|
||||||
|
defer deleteRemote(t, "head-block-test")
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodHead, apiURL("/v2/head-block-test/malware.exe"), nil)
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HEAD: %v", err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusForbidden {
|
||||||
|
t.Fatalf("HEAD blocklisted path: got %d, want 403", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyHeadUnknownRemote(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodHead, apiURL("/v2/nonexistent/some/path"), nil)
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HEAD: %v", err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusNotFound {
|
||||||
|
t.Fatalf("HEAD unknown remote: got %d, want 404", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestProxyPatterns(t *testing.T) {
|
func TestProxyPatterns(t *testing.T) {
|
||||||
createRemote(t, `{
|
createRemote(t, `{
|
||||||
"name": "patterns-test",
|
"name": "patterns-test",
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func (h *ProxyHandler) DockerV2Routes() chi.Router {
|
|||||||
r.Get("/", h.handleDockerPing)
|
r.Get("/", h.handleDockerPing)
|
||||||
r.Head("/", h.handleDockerPing)
|
r.Head("/", h.handleDockerPing)
|
||||||
r.Get("/{remoteName}/*", h.handleProxy)
|
r.Get("/{remoteName}/*", h.handleProxy)
|
||||||
r.Head("/{remoteName}/*", h.handleProxy)
|
r.Head("/{remoteName}/*", h.handleProxyHead)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,6 +89,42 @@ func (h *ProxyHandler) handleProxy(w http.ResponseWriter, r *http.Request) {
|
|||||||
io.Copy(w, result.Reader)
|
io.Copy(w, result.Reader)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *ProxyHandler) handleProxyHead(w http.ResponseWriter, r *http.Request) {
|
||||||
|
remoteName := chi.URLParam(r, "remoteName")
|
||||||
|
path := chi.URLParam(r, "*")
|
||||||
|
|
||||||
|
remote, err := h.db.GetRemote(r.Context(), remoteName)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("remote %q not found", remoteName), http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prov, err := provider.Get(remote.PackageType)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("no provider for %q", remote.PackageType), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.engine.Head(r.Context(), *remote, path, prov)
|
||||||
|
if err != nil {
|
||||||
|
var proxyErr *proxy.ProxyError
|
||||||
|
if errors.As(err, &proxyErr) {
|
||||||
|
http.Error(w, proxyErr.Message, proxyErr.Status)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Error("proxy head failed", "remote", remoteName, "path", path, "error", err)
|
||||||
|
http.Error(w, "bad gateway", http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", result.ContentType)
|
||||||
|
w.Header().Set("X-Artifact-Source", result.Source)
|
||||||
|
if result.Size > 0 {
|
||||||
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", result.Size))
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *ProxyHandler) handleVirtual(w http.ResponseWriter, r *http.Request) {
|
func (h *ProxyHandler) handleVirtual(w http.ResponseWriter, r *http.Request) {
|
||||||
virtualName := chi.URLParam(r, "virtualName")
|
virtualName := chi.URLParam(r, "virtualName")
|
||||||
path := chi.URLParam(r, "*")
|
path := chi.URLParam(r, "*")
|
||||||
|
|||||||
@@ -130,6 +130,87 @@ func (e *Engine) Fetch(ctx context.Context, remote models.Remote, path string, p
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HeadResult carries artifact metadata for a HEAD request. There is no body.
|
||||||
|
type HeadResult struct {
|
||||||
|
ContentType string
|
||||||
|
Size int64
|
||||||
|
Source string // "cache" or "remote"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Head resolves artifact metadata without fetching or streaming the body.
|
||||||
|
// Cached artifacts/indexes are answered from the store metadata; on a miss it
|
||||||
|
// issues an upstream HEAD. It never downloads or caches the body.
|
||||||
|
func (e *Engine) Head(ctx context.Context, remote models.Remote, path string, prov provider.Provider) (*HeadResult, error) {
|
||||||
|
class := NewClassifier(prov).Classify(remote, path)
|
||||||
|
if class == ClassDenied {
|
||||||
|
return nil, &ProxyError{Status: http.StatusForbidden, Message: "access denied"}
|
||||||
|
}
|
||||||
|
|
||||||
|
if artifact, err := e.db.GetArtifact(ctx, remote.Name, path); err == nil && artifact != nil {
|
||||||
|
return &HeadResult{ContentType: artifact.ContentType, Size: artifact.SizeBytes, Source: "cache"}, nil
|
||||||
|
}
|
||||||
|
if info, err := e.store.Stat(ctx, storage.IndexKey(remote.Name, path)); err == nil {
|
||||||
|
return &HeadResult{ContentType: info.ContentType, Size: info.Size, Source: "cache"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.headUpstream(ctx, remote, path, prov)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) headUpstream(ctx context.Context, remote models.Remote, path string, prov provider.Provider) (*HeadResult, error) {
|
||||||
|
url := prov.UpstreamURL(remote, path)
|
||||||
|
|
||||||
|
authHeaders, err := prov.AuthHeaders(ctx, remote)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("auth headers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
doHead := func(extra http.Header) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
for k, vv := range authHeaders {
|
||||||
|
for _, v := range vv {
|
||||||
|
req.Header.Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k, vv := range extra {
|
||||||
|
for _, v := range vv {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return http.DefaultClient.Do(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := doHead(nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &UpstreamError{Err: err}
|
||||||
|
}
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized {
|
||||||
|
resp.Body.Close()
|
||||||
|
token, terr := fetchBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote)
|
||||||
|
if terr == nil && token != "" {
|
||||||
|
resp, err = doHead(http.Header{"Authorization": []string{"Bearer " + token}})
|
||||||
|
if err != nil {
|
||||||
|
return nil, &UpstreamError{Err: err}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, &ProxyError{Status: http.StatusUnauthorized, Message: "upstream returned 401"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, &ProxyError{Status: resp.StatusCode, Message: fmt.Sprintf("upstream returned %d", resp.StatusCode)}
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := prov.ContentType(path)
|
||||||
|
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
||||||
|
contentType = ct
|
||||||
|
}
|
||||||
|
return &HeadResult{ContentType: contentType, Size: resp.ContentLength, Source: "remote"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, path string, prov provider.Provider, class Classification, ttl time.Duration, clientHeaders http.Header) (*FetchResult, error) {
|
func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, path string, prov provider.Provider, class Classification, ttl time.Duration, clientHeaders http.Header) (*FetchResult, error) {
|
||||||
url := prov.UpstreamURL(remote, path)
|
url := prov.UpstreamURL(remote, path)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user