Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 428c6d0e97 | |||
| 1b585af14e | |||
| e7c9387bcc | |||
| 7e07eaa758 |
@@ -0,0 +1,23 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"git.unkin.net/unkin/artifactapi/pkg/models"
|
||||
)
|
||||
|
||||
func TestBasicHeaders(t *testing.T) {
|
||||
h := BasicHeaders(models.Remote{Username: "alice", Password: "secret"})
|
||||
got := h.Get("Authorization")
|
||||
want := "Basic " + base64.StdEncoding.EncodeToString([]byte("alice:secret"))
|
||||
if got != want {
|
||||
t.Errorf("Authorization = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicHeadersNoUser(t *testing.T) {
|
||||
if h := BasicHeaders(models.Remote{}); h.Get("Authorization") != "" {
|
||||
t.Error("expected no Authorization header without a username")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadDefaults(t *testing.T) {
|
||||
// Unset the vars Load reads so the fallback defaults are exercised.
|
||||
for _, k := range []string{
|
||||
"LISTEN_ADDR", "DBHOST", "DBPORT", "DBUSER", "DBPASS", "DBNAME", "DBSSL",
|
||||
"REDIS_URL", "MINIO_ENDPOINT", "MINIO_ACCESS_KEY", "MINIO_SECRET_KEY",
|
||||
"MINIO_BUCKET", "MINIO_SECURE", "MINIO_REGION",
|
||||
} {
|
||||
old, ok := os.LookupEnv(k)
|
||||
os.Unsetenv(k)
|
||||
if ok {
|
||||
t.Cleanup(func() { os.Setenv(k, old) })
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("load: %v", err)
|
||||
}
|
||||
if cfg.ListenAddr != ":8000" || cfg.DBPort != 5432 || cfg.DBUser != "artifacts" {
|
||||
t.Errorf("unexpected defaults: %+v", cfg)
|
||||
}
|
||||
if cfg.RedisURL != "redis://localhost:6379" || cfg.S3Bucket != "artifacts" || cfg.S3Secure {
|
||||
t.Errorf("unexpected defaults: %+v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOverrides(t *testing.T) {
|
||||
t.Setenv("LISTEN_ADDR", ":9999")
|
||||
t.Setenv("DBHOST", "db.example.com")
|
||||
t.Setenv("DBPORT", "6000")
|
||||
t.Setenv("DBUSER", "u")
|
||||
t.Setenv("DBPASS", "pw")
|
||||
t.Setenv("DBNAME", "n")
|
||||
t.Setenv("DBSSL", "require")
|
||||
t.Setenv("MINIO_SECURE", "true")
|
||||
t.Setenv("MINIO_REGION", "us-east-1")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("load: %v", err)
|
||||
}
|
||||
if cfg.ListenAddr != ":9999" || cfg.DBHost != "db.example.com" || cfg.DBPort != 6000 {
|
||||
t.Errorf("overrides not applied: %+v", cfg)
|
||||
}
|
||||
if !cfg.S3Secure {
|
||||
t.Error("MINIO_SECURE=true not parsed")
|
||||
}
|
||||
want := "postgres://u:pw@db.example.com:6000/n?sslmode=require"
|
||||
if got := cfg.DatabaseDSN(); got != want {
|
||||
t.Errorf("DSN = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadInvalidPort(t *testing.T) {
|
||||
t.Setenv("DBPORT", "not-a-number")
|
||||
if _, err := Load(); err == nil {
|
||||
t.Error("expected error for invalid DBPORT")
|
||||
}
|
||||
}
|
||||
@@ -138,16 +138,22 @@ func (db *DB) InsertAccessLogBatch(ctx context.Context, entries []AccessLogEntry
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *DB) FindOrphanedBlobs(ctx context.Context) ([]models.Blob, error) {
|
||||
// FindOrphanedBlobs returns blobs no longer referenced by any artifact or
|
||||
// local file, restricted to those created before now()-minAge. The age cutoff
|
||||
// is a grace period that avoids a TOCTOU race with in-flight dedup uploads,
|
||||
// which insert the blob row before the referencing artifact/local_files row.
|
||||
func (db *DB) FindOrphanedBlobs(ctx context.Context, minAge time.Duration) ([]models.Blob, error) {
|
||||
cutoff := time.Now().Add(-minAge)
|
||||
rows, err := db.Pool.Query(ctx, `
|
||||
SELECT b.content_hash, b.s3_key, b.size_bytes, b.content_type, b.created_at
|
||||
FROM blobs b
|
||||
WHERE b.content_hash NOT IN (
|
||||
WHERE b.created_at < $1
|
||||
AND b.content_hash NOT IN (
|
||||
SELECT content_hash FROM artifacts
|
||||
UNION
|
||||
SELECT content_hash FROM local_files
|
||||
)
|
||||
`)
|
||||
`, cutoff)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+6
-1
@@ -9,6 +9,11 @@ import (
|
||||
"git.unkin.net/unkin/artifactapi/internal/storage"
|
||||
)
|
||||
|
||||
// blobGracePeriod is how old an orphaned blob must be before GC will delete
|
||||
// it. This avoids racing in-flight dedup uploads that insert the blob row
|
||||
// before the referencing artifact/local_files row exists.
|
||||
const blobGracePeriod = 1 * time.Hour
|
||||
|
||||
type Collector struct {
|
||||
db *database.DB
|
||||
store *storage.S3
|
||||
@@ -38,7 +43,7 @@ func (c *Collector) Run(ctx context.Context) {
|
||||
func (c *Collector) sweep(ctx context.Context) {
|
||||
start := time.Now()
|
||||
|
||||
orphaned, err := c.db.FindOrphanedBlobs(ctx)
|
||||
orphaned, err := c.db.FindOrphanedBlobs(ctx, blobGracePeriod)
|
||||
if err != nil {
|
||||
slog.Error("gc: find orphaned blobs", "error", err)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
package alpine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"git.unkin.net/unkin/artifactapi/internal/provider"
|
||||
"git.unkin.net/unkin/artifactapi/pkg/models"
|
||||
)
|
||||
|
||||
func TestType(t *testing.T) {
|
||||
if (&Provider{}).Type() != models.PackageAlpine {
|
||||
t.Fatal("wrong type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassify(t *testing.T) {
|
||||
p := &Provider{}
|
||||
if p.Classify("v3.19/main/x86_64/APKINDEX.tar.gz") != provider.Mutable {
|
||||
t.Error("APKINDEX should be mutable")
|
||||
}
|
||||
if p.Classify("v3.19/main/x86_64/curl-8.0-r0.apk") != provider.Immutable {
|
||||
t.Error("apk should be immutable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContentType(t *testing.T) {
|
||||
p := &Provider{}
|
||||
cases := map[string]string{
|
||||
"pkg.apk": "application/vnd.android.package-archive",
|
||||
"APKINDEX.tar.gz": "application/gzip",
|
||||
"something.random": "application/octet-stream",
|
||||
}
|
||||
for path, want := range cases {
|
||||
if got := p.ContentType(path); got != want {
|
||||
t.Errorf("ContentType(%q) = %q, want %q", path, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamURL(t *testing.T) {
|
||||
p := &Provider{}
|
||||
got := p.UpstreamURL(models.Remote{BaseURL: "https://dl-cdn.alpinelinux.org/alpine/"}, "/v3.19/main/x86_64/curl.apk")
|
||||
if got != "https://dl-cdn.alpinelinux.org/alpine/v3.19/main/x86_64/curl.apk" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteResponse(t *testing.T) {
|
||||
if out, err := (&Provider{}).RewriteResponse([]byte("x"), models.Remote{}, "http://proxy"); out != nil || err != nil {
|
||||
t.Error("alpine never rewrites")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHeaders(t *testing.T) {
|
||||
h, _ := (&Provider{}).AuthHeaders(context.Background(), models.Remote{Username: "u", Password: "p"})
|
||||
if h.Get("Authorization") == "" {
|
||||
t.Error("expected auth header")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package npm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"git.unkin.net/unkin/artifactapi/internal/provider"
|
||||
"git.unkin.net/unkin/artifactapi/pkg/models"
|
||||
)
|
||||
|
||||
func TestType(t *testing.T) {
|
||||
if (&Provider{}).Type() != models.PackageNPM {
|
||||
t.Fatal("wrong type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassify(t *testing.T) {
|
||||
p := &Provider{}
|
||||
if p.Classify("pkg/-/pkg-1.0.0.tgz") != provider.Immutable {
|
||||
t.Error("tgz should be immutable")
|
||||
}
|
||||
if p.Classify("pkg") != provider.Mutable {
|
||||
t.Error("metadata should be mutable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContentType(t *testing.T) {
|
||||
p := &Provider{}
|
||||
if p.ContentType("pkg/-/pkg-1.0.0.tgz") != "application/gzip" {
|
||||
t.Error("tgz content type")
|
||||
}
|
||||
if p.ContentType("pkg") != "application/json" {
|
||||
t.Error("metadata content type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamURL(t *testing.T) {
|
||||
p := &Provider{}
|
||||
got := p.UpstreamURL(models.Remote{BaseURL: "https://registry.npmjs.org/"}, "/pkg")
|
||||
if got != "https://registry.npmjs.org/pkg" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteResponse(t *testing.T) {
|
||||
p := &Provider{}
|
||||
remote := models.Remote{Name: "npmjs", BaseURL: "https://registry.npmjs.org"}
|
||||
|
||||
if out, _ := p.RewriteResponse([]byte(`{"a":1}`), remote, ""); out != nil {
|
||||
t.Error("empty proxyBaseURL should be a no-op")
|
||||
}
|
||||
if out, _ := p.RewriteResponse([]byte("not json"), remote, "http://proxy"); out != nil {
|
||||
t.Error("invalid json should be a no-op")
|
||||
}
|
||||
body := []byte(`{"tarball":"https://registry.npmjs.org/pkg/-/pkg-1.0.0.tgz"}`)
|
||||
out, err := p.RewriteResponse(body, remote, "http://proxy")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(out) != `{"tarball":"http://proxy/api/v1/remote/npmjs/pkg/-/pkg-1.0.0.tgz"}` {
|
||||
t.Errorf("rewrite: %s", out)
|
||||
}
|
||||
if out, _ := p.RewriteResponse([]byte(`{"x":"unrelated"}`), remote, "http://proxy"); out != nil {
|
||||
t.Error("no matching base URL should be a no-op")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHeaders(t *testing.T) {
|
||||
p := &Provider{}
|
||||
h, _ := p.AuthHeaders(context.Background(), models.Remote{Username: "u", Password: "pw"})
|
||||
if h.Get("Authorization") == "" {
|
||||
t.Error("expected auth header when credentials set")
|
||||
}
|
||||
h, _ = p.AuthHeaders(context.Background(), models.Remote{})
|
||||
if h.Get("Authorization") != "" {
|
||||
t.Error("expected no auth header without credentials")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package puppet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.unkin.net/unkin/artifactapi/internal/provider"
|
||||
"git.unkin.net/unkin/artifactapi/pkg/models"
|
||||
)
|
||||
|
||||
func TestType(t *testing.T) {
|
||||
if (&Provider{}).Type() != models.PackagePuppet {
|
||||
t.Fatal("wrong type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassify(t *testing.T) {
|
||||
p := &Provider{}
|
||||
if p.Classify("v3/modules/puppetlabs-stdlib") != provider.Mutable {
|
||||
t.Error("modules should be mutable")
|
||||
}
|
||||
if p.Classify("v3/releases?module=x") != provider.Mutable {
|
||||
t.Error("releases should be mutable")
|
||||
}
|
||||
if p.Classify("v3/files/puppetlabs-stdlib-1.0.0.tar.gz") != provider.Immutable {
|
||||
t.Error("files should be immutable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContentType(t *testing.T) {
|
||||
p := &Provider{}
|
||||
if p.ContentType("x/mod-1.0.0.tar.gz") != "application/gzip" {
|
||||
t.Error("tar.gz")
|
||||
}
|
||||
if p.ContentType("v3/modules/x") != "application/json" {
|
||||
t.Error("v3 json")
|
||||
}
|
||||
if p.ContentType("other") != "application/octet-stream" {
|
||||
t.Error("default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamURL(t *testing.T) {
|
||||
got := (&Provider{}).UpstreamURL(models.Remote{BaseURL: "https://forgeapi.puppet.com/"}, "/v3/modules/x")
|
||||
if got != "https://forgeapi.puppet.com/v3/modules/x" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteResponse(t *testing.T) {
|
||||
p := &Provider{}
|
||||
remote := models.Remote{Name: "forge", BaseURL: "https://forgeapi.puppet.com"}
|
||||
|
||||
if out, _ := p.RewriteResponse([]byte("x"), remote, ""); out != nil {
|
||||
t.Error("empty proxyBaseURL is a no-op")
|
||||
}
|
||||
|
||||
body := []byte(`{"file_uri":"/v3/files/mod.tar.gz","home":"https://forgeapi.puppet.com/x"}`)
|
||||
out, err := p.RewriteResponse(body, remote, "http://proxy")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := string(out)
|
||||
if !strings.Contains(s, "http://proxy/api/v1/remote/forge/v3/files/mod.tar.gz") {
|
||||
t.Errorf("v3/files not rewritten: %s", s)
|
||||
}
|
||||
if !strings.Contains(s, "http://proxy/api/v1/remote/forge/x") {
|
||||
t.Errorf("base URL not rewritten: %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHeaders(t *testing.T) {
|
||||
h, _ := (&Provider{}).AuthHeaders(context.Background(), models.Remote{})
|
||||
if h.Get("Authorization") != "" {
|
||||
t.Error("no credentials, no header")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package pypi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.unkin.net/unkin/artifactapi/internal/provider"
|
||||
"git.unkin.net/unkin/artifactapi/pkg/models"
|
||||
)
|
||||
|
||||
// fakeFileStore is an in-memory provider.FileStore for exercising local index
|
||||
// generation without a database.
|
||||
type fakeFileStore struct {
|
||||
packages []string
|
||||
files map[string][]provider.FileEntry
|
||||
}
|
||||
|
||||
func (f *fakeFileStore) ListPackages(_ context.Context, _ string) ([]string, error) {
|
||||
return f.packages, nil
|
||||
}
|
||||
|
||||
func (f *fakeFileStore) ListFilesByPrefix(_ context.Context, _, prefix string) ([]provider.FileEntry, error) {
|
||||
return f.files[prefix], nil
|
||||
}
|
||||
|
||||
func TestTypeClassifyContentType(t *testing.T) {
|
||||
p := &Provider{}
|
||||
if p.Type() != models.PackagePyPI {
|
||||
t.Fatal("type")
|
||||
}
|
||||
if p.Classify("simple/foo/") != provider.Mutable {
|
||||
t.Error("simple index should be mutable")
|
||||
}
|
||||
if p.Classify("packages/foo-1.0.whl") != provider.Immutable {
|
||||
t.Error("wheel should be immutable")
|
||||
}
|
||||
cases := map[string]string{
|
||||
"foo-1.0-py3-none-any.whl": "application/zip",
|
||||
"foo-1.0.zip": "application/zip",
|
||||
"foo-1.0.tar.gz": "application/gzip",
|
||||
"simple/foo/": "text/html",
|
||||
"weird": "application/octet-stream",
|
||||
}
|
||||
for path, want := range cases {
|
||||
if got := p.ContentType(path); got != want {
|
||||
t.Errorf("ContentType(%q)=%q want %q", path, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamURL(t *testing.T) {
|
||||
p := &Provider{}
|
||||
if got := p.UpstreamURL(models.Remote{BaseURL: "https://files.example.com"}, "packages/foo.whl"); got != "https://files.example.com/packages/foo.whl" {
|
||||
t.Errorf("got %q", got)
|
||||
}
|
||||
if got := p.UpstreamURL(models.Remote{BaseURL: "https://x"}, "simple/foo/"); got != "https://pypi.org/simple/foo/" {
|
||||
t.Errorf("simple should hit pypi.org, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpload(t *testing.T) {
|
||||
p := &Provider{}
|
||||
sp, ct, err := p.ValidateUpload("numpy-1.26.0-cp311-cp311-linux_x86_64.whl")
|
||||
if err != nil || sp != "numpy/numpy-1.26.0-cp311-cp311-linux_x86_64.whl" || ct != "application/zip" {
|
||||
t.Errorf("wheel: sp=%q ct=%q err=%v", sp, ct, err)
|
||||
}
|
||||
sp, ct, err = p.ValidateUpload("requests-2.31.0.tar.gz")
|
||||
if err != nil || sp != "requests/requests-2.31.0.tar.gz" || ct != "application/gzip" {
|
||||
t.Errorf("sdist: sp=%q ct=%q err=%v", sp, ct, err)
|
||||
}
|
||||
if _, _, err := p.ValidateUpload("not-a-package.txt"); err == nil {
|
||||
t.Error("expected error for bad extension")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackageNameParsing(t *testing.T) {
|
||||
if got := packageFromWheel("Foo_Bar-1.0-py3-none-any.whl"); got != "foo-bar" {
|
||||
t.Errorf("wheel name = %q", got)
|
||||
}
|
||||
if got := packageFromWheel("noseparator.whl"); got != "" {
|
||||
t.Errorf("expected empty for unparseable wheel, got %q", got)
|
||||
}
|
||||
if got := packageFromSdist("My.Pkg-2.0.tar.gz"); got != "my-pkg" {
|
||||
t.Errorf("sdist name = %q", got)
|
||||
}
|
||||
if got := packageFromSdist("noseparator.zip"); got != "" {
|
||||
t.Errorf("expected empty, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadResponse(t *testing.T) {
|
||||
resp := (&Provider{}).UploadResponse("foo/foo-1.0.whl", "sha256:abc", 123)
|
||||
if resp["filename"] != "foo-1.0.whl" || resp["package"] != "foo" || resp["content_hash"] != "sha256:abc" {
|
||||
t.Errorf("unexpected upload response: %v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteResponse(t *testing.T) {
|
||||
p := &Provider{}
|
||||
if out, _ := p.RewriteResponse([]byte("x"), models.Remote{Name: "pypi"}, ""); out != nil {
|
||||
t.Error("empty proxyBaseURL is a no-op")
|
||||
}
|
||||
body := []byte(`<a href="https://files.pythonhosted.org/packages/foo.whl">foo.whl</a>`)
|
||||
out, err := p.RewriteResponse(body, models.Remote{Name: "pypi"}, "http://proxy")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(string(out), "http://proxy/api/v1/remote/pypi/") {
|
||||
t.Errorf("not rewritten: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateLocalIndex(t *testing.T) {
|
||||
p := &Provider{}
|
||||
fs := &fakeFileStore{
|
||||
packages: []string{"foo", "bar"},
|
||||
files: map[string][]provider.FileEntry{
|
||||
"foo/": {{FilePath: "foo/foo-1.0-py3-none-any.whl", ContentHash: "sha256:aaa"}},
|
||||
},
|
||||
}
|
||||
list, err := p.GenerateLocalIndex(context.Background(), fs, "local", "simple/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(string(list), "foo") || !strings.Contains(string(list), "bar") {
|
||||
t.Errorf("package list missing entries: %s", list)
|
||||
}
|
||||
|
||||
files, err := p.GenerateLocalIndex(context.Background(), fs, "local", "simple/foo/")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(string(files), "foo-1.0-py3-none-any.whl") {
|
||||
t.Errorf("file list missing wheel: %s", files)
|
||||
}
|
||||
|
||||
if _, err := p.GenerateLocalIndex(context.Background(), fs, "local", "notsimple"); err == nil {
|
||||
t.Error("expected error for non-simple path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHeaders(t *testing.T) {
|
||||
h, _ := (&Provider{}).AuthHeaders(context.Background(), models.Remote{Username: "u", Password: "p"})
|
||||
if h.Get("Authorization") == "" {
|
||||
t.Error("expected auth header")
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -31,6 +33,7 @@ type Engine struct {
|
||||
cache *cache.Redis
|
||||
store *storage.S3
|
||||
cas *storage.CAS
|
||||
circuit *CircuitBreaker
|
||||
accessLog chan database.AccessLogEntry
|
||||
}
|
||||
|
||||
@@ -40,6 +43,7 @@ func NewEngine(db *database.DB, c *cache.Redis, s *storage.S3) *Engine {
|
||||
cache: c,
|
||||
store: s,
|
||||
cas: storage.NewCAS(s),
|
||||
circuit: NewCircuitBreaker(c),
|
||||
accessLog: make(chan database.AccessLogEntry, accessLogBufferSize),
|
||||
}
|
||||
go e.runAccessLogWriter()
|
||||
@@ -154,10 +158,26 @@ func (e *Engine) Fetch(ctx context.Context, remote models.Remote, path string, p
|
||||
fwdHeaders = clientHeaders[0]
|
||||
}
|
||||
|
||||
// Short-circuit upstream calls when the remote's breaker is open: serve
|
||||
// stale from the store if we have it, otherwise fail fast rather than
|
||||
// hammering a known-bad upstream.
|
||||
if e.circuit.IsOpen(ctx, remote.Name) {
|
||||
if stale, serr := e.serveFromStore(ctx, remote, path); serr == nil {
|
||||
slog.Warn("circuit open, serving stale", "remote", remote.Name, "path", path)
|
||||
stale.Source = "cache"
|
||||
e.logAccess(remote.Name, path, true, stale.Size, 0)
|
||||
return stale, nil
|
||||
}
|
||||
return nil, &ProxyError{Status: http.StatusServiceUnavailable, Message: "upstream circuit open"}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result, err := e.fetchFromUpstream(ctx, remote, path, prov, class, ttl, fwdHeaders)
|
||||
upstreamMS := int(time.Since(start).Milliseconds())
|
||||
if err != nil {
|
||||
if isNetworkError(err) {
|
||||
e.circuit.RecordFailure(ctx, remote.Name)
|
||||
}
|
||||
if remote.StaleOnError && isNetworkError(err) {
|
||||
_ = e.cache.SetTTL(ctx, remote.Name, path, ttl)
|
||||
stale, serr := e.serveFromStore(ctx, remote, path)
|
||||
@@ -171,6 +191,7 @@ func (e *Engine) Fetch(ctx context.Context, remote models.Remote, path string, p
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e.circuit.RecordSuccess(ctx, remote.Name)
|
||||
e.logAccess(remote.Name, path, false, result.Size, upstreamMS)
|
||||
return result, nil
|
||||
}
|
||||
@@ -233,7 +254,7 @@ func (e *Engine) headUpstream(ctx context.Context, remote models.Remote, path st
|
||||
}
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
resp.Body.Close()
|
||||
token, terr := fetchBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote)
|
||||
token, _, terr := fetchBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote)
|
||||
if terr == nil && token != "" {
|
||||
resp, err = doHead(http.Header{"Authorization": []string{"Bearer " + token}})
|
||||
if err != nil {
|
||||
@@ -514,6 +535,11 @@ const (
|
||||
bearerTokenTTLMargin = 10 * time.Second
|
||||
)
|
||||
|
||||
func sha256Hash(data []byte) string {
|
||||
h := sha256.Sum256(data)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// cachedBearerToken returns a bearer token for the given challenge, reusing a
|
||||
// Redis-cached token for the same remote+challenge while it is still valid.
|
||||
func (e *Engine) cachedBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, error) {
|
||||
|
||||
Reference in New Issue
Block a user