diff --git a/internal/api/v1/proxy.go b/internal/api/v1/proxy.go index 9092bfa..e91920d 100644 --- a/internal/api/v1/proxy.go +++ b/internal/api/v1/proxy.go @@ -67,7 +67,7 @@ func (h *ProxyHandler) handleProxy(w http.ResponseWriter, r *http.Request) { return } - result, err := h.engine.Fetch(r.Context(), *remote, path, prov) + result, err := h.engine.Fetch(r.Context(), *remote, path, prov, r.Header) if err != nil { var proxyErr *proxy.ProxyError if errors.As(err, &proxyErr) { diff --git a/internal/proxy/engine.go b/internal/proxy/engine.go index 3934d34..ba63e78 100644 --- a/internal/proxy/engine.go +++ b/internal/proxy/engine.go @@ -44,7 +44,7 @@ type FetchResult struct { Source string // "cache" or "remote" } -func (e *Engine) Fetch(ctx context.Context, remote models.Remote, path string, prov provider.Provider) (*FetchResult, error) { +func (e *Engine) Fetch(ctx context.Context, remote models.Remote, path string, prov provider.Provider, clientHeaders ...http.Header) (*FetchResult, error) { classifier := NewClassifier(prov) class := classifier.Classify(remote, path) @@ -105,8 +105,13 @@ func (e *Engine) Fetch(ctx context.Context, remote models.Remote, path string, p } } + var fwdHeaders http.Header + if len(clientHeaders) > 0 && clientHeaders[0] != nil { + fwdHeaders = clientHeaders[0] + } + start := time.Now() - result, err := e.fetchFromUpstream(ctx, remote, path, prov, class, ttl) + result, err := e.fetchFromUpstream(ctx, remote, path, prov, class, ttl, fwdHeaders) upstreamMS := int(time.Since(start).Milliseconds()) if err != nil { if remote.StaleOnError && isNetworkError(err) { @@ -126,7 +131,7 @@ func (e *Engine) Fetch(ctx context.Context, remote models.Remote, path string, p return result, nil } -func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, path string, prov provider.Provider, class Classification, ttl time.Duration) (*FetchResult, error) { +func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, path string, prov provider.Provider, class Classification, ttl time.Duration, clientHeaders http.Header) (*FetchResult, error) { url := prov.UpstreamURL(remote, path) authHeaders, err := prov.AuthHeaders(ctx, remote) @@ -143,6 +148,11 @@ func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, pa req.Header.Add(k, v) } } + if clientHeaders != nil { + if accept := clientHeaders.Get("Accept"); accept != "" { + req.Header.Set("Accept", accept) + } + } resp, err := http.DefaultClient.Do(req) if err != nil { @@ -155,6 +165,11 @@ func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, pa if err == nil && token != "" { req2, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) req2.Header.Set("Authorization", "Bearer "+token) + if clientHeaders != nil { + if accept := clientHeaders.Get("Accept"); accept != "" { + req2.Header.Set("Accept", accept) + } + } resp, err = http.DefaultClient.Do(req2) if err != nil { return nil, &UpstreamError{Err: err} @@ -184,7 +199,7 @@ func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, pa } contentType := prov.ContentType(path) - if ct := resp.Header.Get("Content-Type"); ct != "" && contentType == "application/octet-stream" { + if ct := resp.Header.Get("Content-Type"); ct != "" { contentType = ct }