feat: cache upstream bearer tokens (#92)
Fixes #77 ## Why Each upstream 401 re-ran the full token-endpoint request, even though a single Docker pull triggers many blob/manifest requests sharing one scope. ## Changes - Add Redis `GetToken`/`SetToken`. - `fetchBearerToken` now also parses `expires_in` and returns a TTL. - New `Engine.cachedBearerToken` reuses a cached token keyed by remote + challenge (hashed), caching for `expires_in` minus a safety margin (default 60s when absent). ## Validation - `make e2e` passes. Reviewed-on: #92 Co-authored-by: Ben Vincent <ben@unkin.net> Co-committed-by: Ben Vincent <ben@unkin.net>
This commit was merged in pull request #92.
This commit is contained in:
Vendored
+12
@@ -70,6 +70,18 @@ func (r *Redis) GetETag(ctx context.Context, remote, path string) (string, error
|
||||
return val, err
|
||||
}
|
||||
|
||||
func (r *Redis) GetToken(ctx context.Context, key string) (string, error) {
|
||||
val, err := r.client.Get(ctx, "token:"+key).Result()
|
||||
if err == redis.Nil {
|
||||
return "", nil
|
||||
}
|
||||
return val, err
|
||||
}
|
||||
|
||||
func (r *Redis) SetToken(ctx context.Context, key, token string, ttl time.Duration) error {
|
||||
return r.client.Set(ctx, "token:"+key, token, ttl).Err()
|
||||
}
|
||||
|
||||
func (r *Redis) IncrCircuitFailure(ctx context.Context, remote string, cooldown time.Duration) (int64, error) {
|
||||
key := fmt.Sprintf("circuit:%s", remote)
|
||||
pipe := r.client.Pipeline()
|
||||
|
||||
+44
-10
@@ -160,7 +160,7 @@ func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, pa
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
resp.Body.Close()
|
||||
token, err := fetchBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote)
|
||||
token, err := e.cachedBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote)
|
||||
if err == nil && token != "" {
|
||||
req2, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
req2.Header.Set("Authorization", "Bearer "+token)
|
||||
@@ -345,9 +345,41 @@ func (r readerAt) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, error) {
|
||||
// bearerTokenTTLDefault/Margin bound how long a token is cached: the default
|
||||
// is used when the token endpoint omits expires_in, and the margin is
|
||||
// subtracted so a cached token is refreshed slightly before it actually expires.
|
||||
const (
|
||||
bearerTokenTTLDefault = 60 * time.Second
|
||||
bearerTokenTTLMargin = 10 * time.Second
|
||||
)
|
||||
|
||||
// cachedBearerToken returns a bearer token for the given challenge, reusing a
|
||||
// Redis-cached token for the same remote+challenge while it is still valid.
|
||||
func (e *Engine) cachedBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, error) {
|
||||
key := remote.Name + ":" + sha256Hash([]byte(wwwAuth))
|
||||
if tok, err := e.cache.GetToken(ctx, key); err == nil && tok != "" {
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
tok, ttl, err := fetchBearerToken(ctx, wwwAuth, remote)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if tok != "" {
|
||||
if ttl <= 0 {
|
||||
ttl = bearerTokenTTLDefault
|
||||
}
|
||||
if ttl > bearerTokenTTLMargin {
|
||||
ttl -= bearerTokenTTLMargin
|
||||
}
|
||||
_ = e.cache.SetToken(ctx, key, tok, ttl)
|
||||
}
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, time.Duration, error) {
|
||||
if !strings.HasPrefix(wwwAuth, "Bearer ") {
|
||||
return "", fmt.Errorf("not a Bearer challenge")
|
||||
return "", 0, fmt.Errorf("not a Bearer challenge")
|
||||
}
|
||||
|
||||
params := map[string]string{}
|
||||
@@ -364,7 +396,7 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote)
|
||||
|
||||
realm := params["realm"]
|
||||
if realm == "" {
|
||||
return "", fmt.Errorf("no realm in Bearer challenge")
|
||||
return "", 0, fmt.Errorf("no realm in Bearer challenge")
|
||||
}
|
||||
|
||||
tokenURL := realm
|
||||
@@ -379,7 +411,7 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tokenURL, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
if remote.Username != "" && remote.Password != "" {
|
||||
@@ -388,26 +420,28 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("token endpoint returned %d", resp.StatusCode)
|
||||
return "", 0, fmt.Errorf("token endpoint returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
Token string `json:"token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", err
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
ttl := time.Duration(tokenResp.ExpiresIn) * time.Second
|
||||
if tokenResp.Token != "" {
|
||||
return tokenResp.Token, nil
|
||||
return tokenResp.Token, ttl, nil
|
||||
}
|
||||
return tokenResp.AccessToken, nil
|
||||
return tokenResp.AccessToken, ttl, nil
|
||||
}
|
||||
|
||||
type ProxyError struct {
|
||||
|
||||
Reference in New Issue
Block a user