diff --git a/internal/cache/redis.go b/internal/cache/redis.go index 3e92b64..529d88e 100644 --- a/internal/cache/redis.go +++ b/internal/cache/redis.go @@ -70,6 +70,18 @@ func (r *Redis) GetETag(ctx context.Context, remote, path string) (string, error return val, err } +func (r *Redis) GetToken(ctx context.Context, key string) (string, error) { + val, err := r.client.Get(ctx, "token:"+key).Result() + if err == redis.Nil { + return "", nil + } + return val, err +} + +func (r *Redis) SetToken(ctx context.Context, key, token string, ttl time.Duration) error { + return r.client.Set(ctx, "token:"+key, token, ttl).Err() +} + func (r *Redis) IncrCircuitFailure(ctx context.Context, remote string, cooldown time.Duration) (int64, error) { key := fmt.Sprintf("circuit:%s", remote) pipe := r.client.Pipeline() diff --git a/internal/proxy/engine.go b/internal/proxy/engine.go index 6dbfbd1..59e9bb5 100644 --- a/internal/proxy/engine.go +++ b/internal/proxy/engine.go @@ -160,7 +160,7 @@ func (e *Engine) fetchFromUpstream(ctx context.Context, remote models.Remote, pa if resp.StatusCode == http.StatusUnauthorized { resp.Body.Close() - token, err := fetchBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote) + token, err := e.cachedBearerToken(ctx, resp.Header.Get("Www-Authenticate"), remote) if err == nil && token != "" { req2, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) req2.Header.Set("Authorization", "Bearer "+token) @@ -345,9 +345,41 @@ func (r readerAt) ReadAt(p []byte, off int64) (n int, err error) { return } -func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, error) { +// bearerTokenTTLDefault/Margin bound how long a token is cached: the default +// is used when the token endpoint omits expires_in, and the margin is +// subtracted so a cached token is refreshed slightly before it actually expires. +const ( + bearerTokenTTLDefault = 60 * time.Second + bearerTokenTTLMargin = 10 * time.Second +) + +// cachedBearerToken returns a bearer token for the given challenge, reusing a +// Redis-cached token for the same remote+challenge while it is still valid. +func (e *Engine) cachedBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, error) { + key := remote.Name + ":" + sha256Hash([]byte(wwwAuth)) + if tok, err := e.cache.GetToken(ctx, key); err == nil && tok != "" { + return tok, nil + } + + tok, ttl, err := fetchBearerToken(ctx, wwwAuth, remote) + if err != nil { + return "", err + } + if tok != "" { + if ttl <= 0 { + ttl = bearerTokenTTLDefault + } + if ttl > bearerTokenTTLMargin { + ttl -= bearerTokenTTLMargin + } + _ = e.cache.SetToken(ctx, key, tok, ttl) + } + return tok, nil +} + +func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) (string, time.Duration, error) { if !strings.HasPrefix(wwwAuth, "Bearer ") { - return "", fmt.Errorf("not a Bearer challenge") + return "", 0, fmt.Errorf("not a Bearer challenge") } params := map[string]string{} @@ -364,7 +396,7 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) realm := params["realm"] if realm == "" { - return "", fmt.Errorf("no realm in Bearer challenge") + return "", 0, fmt.Errorf("no realm in Bearer challenge") } tokenURL := realm @@ -379,7 +411,7 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) req, err := http.NewRequestWithContext(ctx, http.MethodGet, tokenURL, nil) if err != nil { - return "", err + return "", 0, err } if remote.Username != "" && remote.Password != "" { @@ -388,26 +420,28 @@ func fetchBearerToken(ctx context.Context, wwwAuth string, remote models.Remote) resp, err := http.DefaultClient.Do(req) if err != nil { - return "", err + return "", 0, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("token endpoint returned %d", resp.StatusCode) + return "", 0, fmt.Errorf("token endpoint returned %d", resp.StatusCode) } var tokenResp struct { Token string `json:"token"` AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` } if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - return "", err + return "", 0, err } + ttl := time.Duration(tokenResp.ExpiresIn) * time.Second if tokenResp.Token != "" { - return tokenResp.Token, nil + return tokenResp.Token, ttl, nil } - return tokenResp.AccessToken, nil + return tokenResp.AccessToken, ttl, nil } type ProxyError struct {