diff --git a/e2e/proxy_test.go b/e2e/proxy_test.go index cc647a4..c3275a6 100644 --- a/e2e/proxy_test.go +++ b/e2e/proxy_test.go @@ -24,6 +24,39 @@ func TestProxyBlocklist(t *testing.T) { assertStatus(t, apiURL("/api/v1/remote/blocklist-test/malware.exe"), http.StatusForbidden) } +func TestProxyHeadBlocklist(t *testing.T) { + createRemote(t, `{ + "name": "head-block-test", + "package_type": "generic", + "base_url": "https://example.com", + "blocklist": ["\\.exe$"], + "stale_on_error": true + }`) + defer deleteRemote(t, "head-block-test") + + req, _ := http.NewRequest(http.MethodHead, apiURL("/v2/head-block-test/malware.exe"), nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("HEAD: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("HEAD blocklisted path: got %d, want 403", resp.StatusCode) + } +} + +func TestProxyHeadUnknownRemote(t *testing.T) { + req, _ := http.NewRequest(http.MethodHead, apiURL("/v2/nonexistent/some/path"), nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("HEAD: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("HEAD unknown remote: got %d, want 404", resp.StatusCode) + } +} + func TestProxyPatterns(t *testing.T) { createRemote(t, `{ "name": "patterns-test", diff --git a/internal/api/v1/proxy.go b/internal/api/v1/proxy.go index e91920d..bea9c4d 100644 --- a/internal/api/v1/proxy.go +++ b/internal/api/v1/proxy.go @@ -42,7 +42,7 @@ func (h *ProxyHandler) DockerV2Routes() chi.Router { r.Get("/", h.handleDockerPing) r.Head("/", h.handleDockerPing) r.Get("/{remoteName}/*", h.handleProxy) - r.Head("/{remoteName}/*", h.handleProxy) + r.Head("/{remoteName}/*", h.handleProxyHead) return r } @@ -89,6 +89,42 @@ func (h *ProxyHandler) handleProxy(w http.ResponseWriter, r *http.Request) { io.Copy(w, result.Reader) } +func (h *ProxyHandler) handleProxyHead(w http.ResponseWriter, r *http.Request) { + remoteName := chi.URLParam(r, "remoteName") + path := chi.URLParam(r, "*") + + remote, err := h.db.GetRemote(r.Context(), remoteName) + if err != nil { + http.Error(w, fmt.Sprintf("remote %q not found", remoteName), http.StatusNotFound) + return + } + + prov, err := provider.Get(remote.PackageType) + if err != nil { + http.Error(w, fmt.Sprintf("no provider for %q", remote.PackageType), http.StatusInternalServerError) + return + } + + result, err := h.engine.Head(r.Context(), *remote, path, prov) + if err != nil { + var proxyErr *proxy.ProxyError + if errors.As(err, &proxyErr) { + http.Error(w, proxyErr.Message, proxyErr.Status) + return + } + slog.Error("proxy head failed", "remote", remoteName, "path", path, "error", err) + http.Error(w, "bad gateway", http.StatusBadGateway) + return + } + + w.Header().Set("Content-Type", result.ContentType) + w.Header().Set("X-Artifact-Source", result.Source) + if result.Size > 0 { + w.Header().Set("Content-Length", fmt.Sprintf("%d", result.Size)) + } + w.WriteHeader(http.StatusOK) +} + func (h *ProxyHandler) handleVirtual(w http.ResponseWriter, r *http.Request) { virtualName := chi.URLParam(r, "virtualName") path := chi.URLParam(r, "*") diff --git a/internal/proxy/engine.go b/internal/proxy/engine.go index 59e9bb5..afaf82d 100644 --- a/internal/proxy/engine.go +++ b/internal/proxy/engine.go @@ -130,6 +130,87 @@ func (e *Engine) Fetch(ctx context.Context, remote models.Remote, path string, p return result, nil } +// HeadResult carries artifact metadata for a HEAD request. There is no body. +type HeadResult struct { + ContentType string + Size int64 + Source string // "cache" or "remote" +} + +// Head resolves artifact metadata without fetching or streaming the body. +// Cached artifacts/indexes are answered from the store metadata; on a miss it +// issues an upstream HEAD. It never downloads or caches the body. +func (e *Engine) Head(ctx context.Context, remote models.Remote, path string, prov provider.Provider) (*HeadResult, error) { + class := NewClassifier(prov).Classify(remote, path) + if class == ClassDenied { + return nil, &ProxyError{Status: http.StatusForbidden, Message: "access denied"} + } + + if artifact, err := e.db.GetArtifact(ctx, remote.Name, path); err == nil && artifact != nil { + return &HeadResult{ContentType: artifact.ContentType, Size: artifact.SizeBytes, Source: "cache"}, nil + } + if info, err := e.store.Stat(ctx, storage.IndexKey(remote.Name, path)); err == nil { + return &HeadResult{ContentType: info.ContentType, Size: info.Size, Source: "cache"}, nil + } + + return e.headUpstream(ctx, remote, path, prov) +} + +func (e *Engine) headUpstream(ctx context.Context, remote models.Remote, path string, prov provider.Provider) (*HeadResult, error) { + url := prov.UpstreamURL(remote, path) + + authHeaders, err := prov.AuthHeaders(ctx, remote) + if err != nil { + return nil, fmt.Errorf("auth headers: %w", err) + } + + doHead := func(extra http.Header) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + for k, vv := range authHeaders { + for _, v := range vv { + req.Header.Add(k, v) + } + } + for k, vv := range extra { + for _, v := range vv { + req.Header.Set(k, v) + } + } + return http.DefaultClient.Do(req) + } + + resp, err := doHead(nil) + if err != nil { + return nil, &UpstreamError{Err: err} + } + if resp.StatusCode == http.StatusUnauthorized { + resp.Body.Close() + token, terr := fetchBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote) + if terr == nil && token != "" { + resp, err = doHead(http.Header{"Authorization": []string{"Bearer " + token}}) + if err != nil { + return nil, &UpstreamError{Err: err} + } + } else { + return nil, &ProxyError{Status: http.StatusUnauthorized, Message: "upstream returned 401"} + } + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, &ProxyError{Status: resp.StatusCode, Message: fmt.Sprintf("upstream returned %d", resp.StatusCode)} + } + + contentType := prov.ContentType(path) + if ct := resp.Header.Get("Content-Type"); ct != "" { + contentType = ct + } + return &HeadResult{ContentType: contentType, Size: resp.ContentLength, Source: "remote"}, nil +} + func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, path string, prov provider.Provider, class Classification, ttl time.Duration, clientHeaders http.Header) (*FetchResult, error) { url := prov.UpstreamURL(remote, path)