feat: cache upstream bearer tokens in Redis
ci/woodpecker/pr/pre-commit Pipeline was successful
ci/woodpecker/pr/test Pipeline was successful
ci/woodpecker/pr/build Pipeline was successful

Each upstream 401 re-ran the full token-endpoint dance. Cache the minted
token keyed by remote + challenge, honouring the token's expires_in (with
a safety margin, defaulting to 60s), so subsequent blobs sharing a scope
reuse it.

Refs #77
This commit is contained in:
2026-07-02 00:42:48 +10:00
parent 8d9bc1c422
commit 493b3cb906
2 changed files with 56 additions and 10 deletions
+12
View File
@@ -70,6 +70,18 @@ func (r *Redis) GetETag(ctx context.Context, remote, path string) (string, error
return val, err
}
func (r *Redis) GetToken(ctx context.Context, key string) (string, error) {
val, err := r.client.Get(ctx, "token:"+key).Result()
if err == redis.Nil {
return "", nil
}
return val, err
}
func (r *Redis) SetToken(ctx context.Context, key, token string, ttl time.Duration) error {
return r.client.Set(ctx, "token:"+key, token, ttl).Err()
}
func (r *Redis) IncrCircuitFailure(ctx context.Context, remote string, cooldown time.Duration) (int64, error) {
key := fmt.Sprintf("circuit:%s", remote)
pipe := r.client.Pipeline()
+44 -10
View File
@@ -161,7 +161,7 @@ func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, pa
if resp.StatusCode == http.StatusUnauthorized {
resp.Body.Close()
token, err := fetchBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote)
token, err := e.cachedBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote)
if err == nil && token != "" {
req2, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
req2.Header.Set("Authorization", "Bearer "+token)
@@ -351,9 +351,41 @@ func (r readerAt) ReadAt(p []byte, off int64) (n int, err error) {
return
}
func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, error) {
// bearerTokenTTLDefault/Margin bound how long a token is cached: the default
// is used when the token endpoint omits expires_in, and the margin is
// subtracted so a cached token is refreshed slightly before it actually expires.
const (
bearerTokenTTLDefault = 60 * time.Second
bearerTokenTTLMargin = 10 * time.Second
)
// cachedBearerToken returns a bearer token for the given challenge, reusing a
// Redis-cached token for the same remote+challenge while it is still valid.
func (e *Engine) cachedBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, error) {
key := remote.Name + ":" + sha256Hash([]byte(wwwAuth))
if tok, err := e.cache.GetToken(ctx, key); err == nil && tok != "" {
return tok, nil
}
tok, ttl, err := fetchBearerToken(ctx, wwwAuth, remote)
if err != nil {
return "", err
}
if tok != "" {
if ttl <= 0 {
ttl = bearerTokenTTLDefault
}
if ttl > bearerTokenTTLMargin {
ttl -= bearerTokenTTLMargin
}
_ = e.cache.SetToken(ctx, key, tok, ttl)
}
return tok, nil
}
func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, time.Duration, error) {
if !strings.HasPrefix(wwwAuth, "Bearer ") {
return "", fmt.Errorf("not a Bearer challenge")
return "", 0, fmt.Errorf("not a Bearer challenge")
}
params := map[string]string{}
@@ -370,7 +402,7 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote)
realm := params["realm"]
if realm == "" {
return "", fmt.Errorf("no realm in Bearer challenge")
return "", 0, fmt.Errorf("no realm in Bearer challenge")
}
tokenURL := realm
@@ -385,7 +417,7 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tokenURL, nil)
if err != nil {
return "", err
return "", 0, err
}
if remote.Username != "" && remote.Password != "" {
@@ -394,26 +426,28 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
return "", 0, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("token endpoint returned %d", resp.StatusCode)
return "", 0, fmt.Errorf("token endpoint returned %d", resp.StatusCode)
}
var tokenResp struct {
Token string `json:"token"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", err
return "", 0, err
}
ttl := time.Duration(tokenResp.ExpiresIn) * time.Second
if tokenResp.Token != "" {
return tokenResp.Token, nil
return tokenResp.Token, ttl, nil
}
return tokenResp.AccessToken, nil
return tokenResp.AccessToken, ttl, nil
}
type ProxyError struct {