fix: HEAD requests fetch and stream the full body #89
@@ -24,6 +24,39 @@ func TestProxyBlocklist(t *testing.T) {
|
||||
assertStatus(t, apiURL("/api/v1/remote/blocklist-test/malware.exe"), http.StatusForbidden)
|
||||
}
|
||||
|
||||
func TestProxyHeadBlocklist(t *testing.T) {
|
||||
createRemote(t, `{
|
||||
"name": "head-block-test",
|
||||
"package_type": "generic",
|
||||
"base_url": "https://example.com",
|
||||
"blocklist": ["\\.exe$"],
|
||||
"stale_on_error": true
|
||||
}`)
|
||||
defer deleteRemote(t, "head-block-test")
|
||||
|
||||
req, _ := http.NewRequest(http.MethodHead, apiURL("/v2/head-block-test/malware.exe"), nil)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("HEAD: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("HEAD blocklisted path: got %d, want 403", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHeadUnknownRemote(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodHead, apiURL("/v2/nonexistent/some/path"), nil)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("HEAD: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Fatalf("HEAD unknown remote: got %d, want 404", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyPatterns(t *testing.T) {
|
||||
createRemote(t, `{
|
||||
"name": "patterns-test",
|
||||
|
||||
@@ -42,7 +42,7 @@ func (h *ProxyHandler) DockerV2Routes() chi.Router {
|
||||
r.Get("/", h.handleDockerPing)
|
||||
r.Head("/", h.handleDockerPing)
|
||||
r.Get("/{remoteName}/*", h.handleProxy)
|
||||
r.Head("/{remoteName}/*", h.handleProxy)
|
||||
r.Head("/{remoteName}/*", h.handleProxyHead)
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -89,6 +89,42 @@ func (h *ProxyHandler) handleProxy(w http.ResponseWriter, r *http.Request) {
|
||||
io.Copy(w, result.Reader)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) handleProxyHead(w http.ResponseWriter, r *http.Request) {
|
||||
remoteName := chi.URLParam(r, "remoteName")
|
||||
path := chi.URLParam(r, "*")
|
||||
|
||||
remote, err := h.db.GetRemote(r.Context(), remoteName)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("remote %q not found", remoteName), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := provider.Get(remote.PackageType)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("no provider for %q", remote.PackageType), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.engine.Head(r.Context(), *remote, path, prov)
|
||||
if err != nil {
|
||||
var proxyErr *proxy.ProxyError
|
||||
if errors.As(err, &proxyErr) {
|
||||
http.Error(w, proxyErr.Message, proxyErr.Status)
|
||||
return
|
||||
}
|
||||
slog.Error("proxy head failed", "remote", remoteName, "path", path, "error", err)
|
||||
http.Error(w, "bad gateway", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", result.ContentType)
|
||||
w.Header().Set("X-Artifact-Source", result.Source)
|
||||
if result.Size > 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", result.Size))
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (h *ProxyHandler) handleVirtual(w http.ResponseWriter, r *http.Request) {
|
||||
virtualName := chi.URLParam(r, "virtualName")
|
||||
path := chi.URLParam(r, "*")
|
||||
|
||||
@@ -131,6 +131,87 @@ func (e *Engine) Fetch(ctx context.Context, remote models.Remote, path string, p
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// HeadResult carries artifact metadata for a HEAD request. There is no body.
|
||||
type HeadResult struct {
|
||||
ContentType string
|
||||
Size int64
|
||||
Source string // "cache" or "remote"
|
||||
}
|
||||
|
||||
// Head resolves artifact metadata without fetching or streaming the body.
|
||||
// Cached artifacts/indexes are answered from the store metadata; on a miss it
|
||||
// issues an upstream HEAD. It never downloads or caches the body.
|
||||
func (e *Engine) Head(ctx context.Context, remote models.Remote, path string, prov provider.Provider) (*HeadResult, error) {
|
||||
class := NewClassifier(prov).Classify(remote, path)
|
||||
if class == ClassDenied {
|
||||
return nil, &ProxyError{Status: http.StatusForbidden, Message: "access denied"}
|
||||
}
|
||||
|
||||
if artifact, err := e.db.GetArtifact(ctx, remote.Name, path); err == nil && artifact != nil {
|
||||
return &HeadResult{ContentType: artifact.ContentType, Size: artifact.SizeBytes, Source: "cache"}, nil
|
||||
}
|
||||
if info, err := e.store.Stat(ctx, storage.IndexKey(remote.Name, path)); err == nil {
|
||||
return &HeadResult{ContentType: info.ContentType, Size: info.Size, Source: "cache"}, nil
|
||||
}
|
||||
|
||||
return e.headUpstream(ctx, remote, path, prov)
|
||||
}
|
||||
|
||||
func (e *Engine) headUpstream(ctx context.Context, remote models.Remote, path string, prov provider.Provider) (*HeadResult, error) {
|
||||
url := prov.UpstreamURL(remote, path)
|
||||
|
||||
authHeaders, err := prov.AuthHeaders(ctx, remote)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth headers: %w", err)
|
||||
}
|
||||
|
||||
doHead := func(extra http.Header) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
for k, vv := range authHeaders {
|
||||
for _, v := range vv {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
for k, vv := range extra {
|
||||
for _, v := range vv {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
return http.DefaultClient.Do(req)
|
||||
}
|
||||
|
||||
resp, err := doHead(nil)
|
||||
if err != nil {
|
||||
return nil, &UpstreamError{Err: err}
|
||||
}
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
resp.Body.Close()
|
||||
token, terr := fetchBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote)
|
||||
if terr == nil && token != "" {
|
||||
resp, err = doHead(http.Header{"Authorization": []string{"Bearer " + token}})
|
||||
if err != nil {
|
||||
return nil, &UpstreamError{Err: err}
|
||||
}
|
||||
} else {
|
||||
return nil, &ProxyError{Status: http.StatusUnauthorized, Message: "upstream returned 401"}
|
||||
}
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, &ProxyError{Status: resp.StatusCode, Message: fmt.Sprintf("upstream returned %d", resp.StatusCode)}
|
||||
}
|
||||
|
||||
contentType := prov.ContentType(path)
|
||||
if ct := resp.Header.Get("Content-Type"); ct != "" {
|
||||
contentType = ct
|
||||
}
|
||||
return &HeadResult{ContentType: contentType, Size: resp.ContentLength, Source: "remote"}, nil
|
||||
}
|
||||
|
||||
func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, path string, prov provider.Provider, class Classification, ttl time.Duration, clientHeaders http.Header) (*FetchResult, error) {
|
||||
url := prov.UpstreamURL(remote, path)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user