feat: cache upstream bearer tokens (#92)
Fixes #77 ## Why Each upstream 401 re-ran the full token-endpoint request, even though a single Docker pull triggers many blob/manifest requests sharing one scope. ## Changes - Add Redis `GetToken`/`SetToken`. - `fetchBearerToken` now also parses `expires_in` and returns a TTL. - New `Engine.cachedBearerToken` reuses a cached token keyed by remote + challenge (hashed), caching for `expires_in` minus a safety margin (default 60s when absent). ## Validation - `make e2e` passes. Reviewed-on: #92 Co-authored-by: Ben Vincent <ben@unkin.net> Co-committed-by: Ben Vincent <ben@unkin.net>
This commit was merged in pull request #92.
This commit is contained in:
Vendored
+12
@@ -70,6 +70,18 @@ func (r *Redis) GetETag(ctx context.Context, remote, path string) (string, error
|
|||||||
return val, err
|
return val, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Redis) GetToken(ctx context.Context, key string) (string, error) {
|
||||||
|
val, err := r.client.Get(ctx, "token:"+key).Result()
|
||||||
|
if err == redis.Nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Redis) SetToken(ctx context.Context, key, token string, ttl time.Duration) error {
|
||||||
|
return r.client.Set(ctx, "token:"+key, token, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Redis) IncrCircuitFailure(ctx context.Context, remote string, cooldown time.Duration) (int64, error) {
|
func (r *Redis) IncrCircuitFailure(ctx context.Context, remote string, cooldown time.Duration) (int64, error) {
|
||||||
key := fmt.Sprintf("circuit:%s", remote)
|
key := fmt.Sprintf("circuit:%s", remote)
|
||||||
pipe := r.client.Pipeline()
|
pipe := r.client.Pipeline()
|
||||||
|
|||||||
+44
-10
@@ -160,7 +160,7 @@ func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, pa
|
|||||||
|
|
||||||
if resp.StatusCode == http.StatusUnauthorized {
|
if resp.StatusCode == http.StatusUnauthorized {
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
token, err := fetchBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote)
|
token, err := e.cachedBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote)
|
||||||
if err == nil && token != "" {
|
if err == nil && token != "" {
|
||||||
req2, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
req2, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
req2.Header.Set("Authorization", "Bearer "+token)
|
req2.Header.Set("Authorization", "Bearer "+token)
|
||||||
@@ -345,9 +345,41 @@ func (r readerAt) ReadAt(p []byte, off int64) (n int, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, error) {
|
// bearerTokenTTLDefault/Margin bound how long a token is cached: the default
|
||||||
|
// is used when the token endpoint omits expires_in, and the margin is
|
||||||
|
// subtracted so a cached token is refreshed slightly before it actually expires.
|
||||||
|
const (
|
||||||
|
bearerTokenTTLDefault = 60 * time.Second
|
||||||
|
bearerTokenTTLMargin = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// cachedBearerToken returns a bearer token for the given challenge, reusing a
|
||||||
|
// Redis-cached token for the same remote+challenge while it is still valid.
|
||||||
|
func (e *Engine) cachedBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, error) {
|
||||||
|
key := remote.Name + ":" + sha256Hash([]byte(wwwAuth))
|
||||||
|
if tok, err := e.cache.GetToken(ctx, key); err == nil && tok != "" {
|
||||||
|
return tok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tok, ttl, err := fetchBearerToken(ctx, wwwAuth, remote)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if tok != "" {
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = bearerTokenTTLDefault
|
||||||
|
}
|
||||||
|
if ttl > bearerTokenTTLMargin {
|
||||||
|
ttl -= bearerTokenTTLMargin
|
||||||
|
}
|
||||||
|
_ = e.cache.SetToken(ctx, key, tok, ttl)
|
||||||
|
}
|
||||||
|
return tok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, time.Duration, error) {
|
||||||
if !strings.HasPrefix(wwwAuth, "Bearer ") {
|
if !strings.HasPrefix(wwwAuth, "Bearer ") {
|
||||||
return "", fmt.Errorf("not a Bearer challenge")
|
return "", 0, fmt.Errorf("not a Bearer challenge")
|
||||||
}
|
}
|
||||||
|
|
||||||
params := map[string]string{}
|
params := map[string]string{}
|
||||||
@@ -364,7 +396,7 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote)
|
|||||||
|
|
||||||
realm := params["realm"]
|
realm := params["realm"]
|
||||||
if realm == "" {
|
if realm == "" {
|
||||||
return "", fmt.Errorf("no realm in Bearer challenge")
|
return "", 0, fmt.Errorf("no realm in Bearer challenge")
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenURL := realm
|
tokenURL := realm
|
||||||
@@ -379,7 +411,7 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote)
|
|||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tokenURL, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tokenURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if remote.Username != "" && remote.Password != "" {
|
if remote.Username != "" && remote.Password != "" {
|
||||||
@@ -388,26 +420,28 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote)
|
|||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return "", fmt.Errorf("token endpoint returned %d", resp.StatusCode)
|
return "", 0, fmt.Errorf("token endpoint returned %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenResp struct {
|
var tokenResp struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||||
return "", err
|
return "", 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ttl := time.Duration(tokenResp.ExpiresIn) * time.Second
|
||||||
if tokenResp.Token != "" {
|
if tokenResp.Token != "" {
|
||||||
return tokenResp.Token, nil
|
return tokenResp.Token, ttl, nil
|
||||||
}
|
}
|
||||||
return tokenResp.AccessToken, nil
|
return tokenResp.AccessToken, ttl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProxyError struct {
|
type ProxyError struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user