feat: cache upstream bearer tokens (#92)

Fixes #77

## Why
Each upstream 401 re-ran the full token-endpoint request, even though a single Docker pull triggers many blob/manifest requests sharing one scope.

## Changes
- Add Redis `GetToken`/`SetToken`.
- `fetchBearerToken` now also parses `expires_in` and returns a TTL.
- New `Engine.cachedBearerToken` reuses a cached token keyed by remote + challenge (hashed), caching for `expires_in` minus a safety margin (default 60s when absent).

## Validation
- `make e2e` passes.

Reviewed-on: #92
Co-authored-by: Ben Vincent <ben@unkin.net>
Co-committed-by: Ben Vincent <ben@unkin.net>
This commit was merged in pull request #92.
This commit is contained in:
2026-07-02 21:35:46 +10:00
committed by BenVincent
parent f3680951b7
commit e7027c8ccc
2 changed files with 56 additions and 10 deletions
+12
View File
@@ -70,6 +70,18 @@ func (r *Redis) GetETag(ctx context.Context, remote, path string) (string, error
return val, err return val, err
} }
func (r *Redis) GetToken(ctx context.Context, key string) (string, error) {
val, err := r.client.Get(ctx, "token:"+key).Result()
if err == redis.Nil {
return "", nil
}
return val, err
}
func (r *Redis) SetToken(ctx context.Context, key, token string, ttl time.Duration) error {
return r.client.Set(ctx, "token:"+key, token, ttl).Err()
}
func (r *Redis) IncrCircuitFailure(ctx context.Context, remote string, cooldown time.Duration) (int64, error) { func (r *Redis) IncrCircuitFailure(ctx context.Context, remote string, cooldown time.Duration) (int64, error) {
key := fmt.Sprintf("circuit:%s", remote) key := fmt.Sprintf("circuit:%s", remote)
pipe := r.client.Pipeline() pipe := r.client.Pipeline()
+44 -10
View File
@@ -160,7 +160,7 @@ func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, pa
if resp.StatusCode == http.StatusUnauthorized { if resp.StatusCode == http.StatusUnauthorized {
resp.Body.Close() resp.Body.Close()
token, err := fetchBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote) token, err := e.cachedBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote)
if err == nil && token != "" { if err == nil && token != "" {
req2, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) req2, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
req2.Header.Set("Authorization", "Bearer "+token) req2.Header.Set("Authorization", "Bearer "+token)
@@ -345,9 +345,41 @@ func (r readerAt) ReadAt(p []byte, off int64) (n int, err error) {
return return
} }
func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, error) { // bearerTokenTTLDefault/Margin bound how long a token is cached: the default
// is used when the token endpoint omits expires_in, and the margin is
// subtracted so a cached token is refreshed slightly before it actually expires.
const (
bearerTokenTTLDefault = 60 * time.Second
bearerTokenTTLMargin = 10 * time.Second
)
// cachedBearerToken returns a bearer token for the given challenge, reusing a
// Redis-cached token for the same remote+challenge while it is still valid.
func (e *Engine) cachedBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, error) {
key := remote.Name + ":" + sha256Hash([]byte(wwwAuth))
if tok, err := e.cache.GetToken(ctx, key); err == nil && tok != "" {
return tok, nil
}
tok, ttl, err := fetchBearerToken(ctx, wwwAuth, remote)
if err != nil {
return "", err
}
if tok != "" {
if ttl <= 0 {
ttl = bearerTokenTTLDefault
}
if ttl > bearerTokenTTLMargin {
ttl -= bearerTokenTTLMargin
}
_ = e.cache.SetToken(ctx, key, tok, ttl)
}
return tok, nil
}
func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, time.Duration, error) {
if !strings.HasPrefix(wwwAuth, "Bearer ") { if !strings.HasPrefix(wwwAuth, "Bearer ") {
return "", fmt.Errorf("not a Bearer challenge") return "", 0, fmt.Errorf("not a Bearer challenge")
} }
params := map[string]string{} params := map[string]string{}
@@ -364,7 +396,7 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote)
realm := params["realm"] realm := params["realm"]
if realm == "" { if realm == "" {
return "", fmt.Errorf("no realm in Bearer challenge") return "", 0, fmt.Errorf("no realm in Bearer challenge")
} }
tokenURL := realm tokenURL := realm
@@ -379,7 +411,7 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tokenURL, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, tokenURL, nil)
if err != nil { if err != nil {
return "", err return "", 0, err
} }
if remote.Username != "" && remote.Password != "" { if remote.Username != "" && remote.Password != "" {
@@ -388,26 +420,28 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote)
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return "", err return "", 0, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("token endpoint returned %d", resp.StatusCode) return "", 0, fmt.Errorf("token endpoint returned %d", resp.StatusCode)
} }
var tokenResp struct { var tokenResp struct {
Token string `json:"token"` Token string `json:"token"`
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
} }
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", err return "", 0, err
} }
ttl := time.Duration(tokenResp.ExpiresIn) * time.Second
if tokenResp.Token != "" { if tokenResp.Token != "" {
return tokenResp.Token, nil return tokenResp.Token, ttl, nil
} }
return tokenResp.AccessToken, nil return tokenResp.AccessToken, ttl, nil
} }
type ProxyError struct { type ProxyError struct {