Files
unkinben e7027c8ccc feat: cache upstream bearer tokens (#92)
Fixes #77

## Why
Each upstream 401 re-ran the full token-endpoint request, even though a single Docker pull triggers many blob/manifest requests sharing one scope.

## Changes
- Add Redis `GetToken`/`SetToken`.
- `fetchBearerToken` now also parses `expires_in` and returns a TTL.
- New `Engine.cachedBearerToken` reuses a cached token keyed by remote + challenge (hashed), caching for `expires_in` minus a safety margin (default 60s when absent).

## Validation
- `make e2e` passes.

Reviewed-on: #92
Co-authored-by: Ben Vincent <ben@unkin.net>
Co-committed-by: Ben Vincent <ben@unkin.net>
2026-07-02 21:35:46 +10:00

118 lines
3.0 KiB
Go

package cache
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
type Redis struct {
client *redis.Client
}
func NewRedis(url string) (*Redis, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, fmt.Errorf("parse redis url: %w", err)
}
client := redis.NewClient(opts)
if err := client.Ping(context.Background()).Err(); err != nil {
return nil, fmt.Errorf("ping redis: %w", err)
}
return &Redis{client: client}, nil
}
func (r *Redis) Close() error {
return r.client.Close()
}
func (r *Redis) SetTTL(ctx context.Context, remote, path string, ttl time.Duration) error {
key := fmt.Sprintf("ttl:%s:%s", remote, path)
return r.client.Set(ctx, key, "1", ttl).Err()
}
func (r *Redis) CheckTTL(ctx context.Context, remote, path string) (bool, error) {
key := fmt.Sprintf("ttl:%s:%s", remote, path)
exists, err := r.client.Exists(ctx, key).Result()
if err != nil {
return false, err
}
return exists > 0, nil
}
func (r *Redis) AcquireLock(ctx context.Context, remote, path string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("lock:%s:%s", remote, path)
ok, err := r.client.SetNX(ctx, key, "1", ttl).Result()
return ok, err
}
func (r *Redis) ReleaseLock(ctx context.Context, remote, path string) error {
key := fmt.Sprintf("lock:%s:%s", remote, path)
return r.client.Del(ctx, key).Err()
}
func (r *Redis) SetETag(ctx context.Context, remote, path, etag string, ttl time.Duration) error {
key := fmt.Sprintf("etag:%s:%s", remote, path)
return r.client.Set(ctx, key, etag, ttl).Err()
}
func (r *Redis) GetETag(ctx context.Context, remote, path string) (string, error) {
key := fmt.Sprintf("etag:%s:%s", remote, path)
val, err := r.client.Get(ctx, key).Result()
if err == redis.Nil {
return "", nil
}
return val, err
}
func (r *Redis) GetToken(ctx context.Context, key string) (string, error) {
val, err := r.client.Get(ctx, "token:"+key).Result()
if err == redis.Nil {
return "", nil
}
return val, err
}
func (r *Redis) SetToken(ctx context.Context, key, token string, ttl time.Duration) error {
return r.client.Set(ctx, "token:"+key, token, ttl).Err()
}
func (r *Redis) IncrCircuitFailure(ctx context.Context, remote string, cooldown time.Duration) (int64, error) {
key := fmt.Sprintf("circuit:%s", remote)
pipe := r.client.Pipeline()
incr := pipe.Incr(ctx, key)
pipe.Expire(ctx, key, cooldown)
_, err := pipe.Exec(ctx)
if err != nil {
return 0, err
}
return incr.Val(), nil
}
func (r *Redis) ResetCircuit(ctx context.Context, remote string) error {
key := fmt.Sprintf("circuit:%s", remote)
return r.client.Del(ctx, key).Err()
}
func (r *Redis) GetCircuitFailures(ctx context.Context, remote string) (int64, error) {
key := fmt.Sprintf("circuit:%s", remote)
val, err := r.client.Get(ctx, key).Int64()
if err == redis.Nil {
return 0, nil
}
return val, err
}
func (r *Redis) FlushRemote(ctx context.Context, remote string) error {
iter := r.client.Scan(ctx, 0, fmt.Sprintf("*:%s:*", remote), 100).Iterator()
for iter.Next(ctx) {
r.client.Del(ctx, iter.Val())
}
return iter.Err()
}