493b3cb906
Each upstream 401 re-ran the full token-endpoint dance. Cache the minted token keyed by remote + challenge, honouring the token's expires_in (with a safety margin, defaulting to 60s), so subsequent blobs sharing a scope reuse it. Refs #77
118 lines
3.0 KiB
Go
118 lines
3.0 KiB
Go
package cache
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
type Redis struct {
|
|
client *redis.Client
|
|
}
|
|
|
|
func NewRedis(url string) (*Redis, error) {
|
|
opts, err := redis.ParseURL(url)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse redis url: %w", err)
|
|
}
|
|
|
|
client := redis.NewClient(opts)
|
|
|
|
if err := client.Ping(context.Background()).Err(); err != nil {
|
|
return nil, fmt.Errorf("ping redis: %w", err)
|
|
}
|
|
|
|
return &Redis{client: client}, nil
|
|
}
|
|
|
|
func (r *Redis) Close() error {
|
|
return r.client.Close()
|
|
}
|
|
|
|
func (r *Redis) SetTTL(ctx context.Context, remote, path string, ttl time.Duration) error {
|
|
key := fmt.Sprintf("ttl:%s:%s", remote, path)
|
|
return r.client.Set(ctx, key, "1", ttl).Err()
|
|
}
|
|
|
|
func (r *Redis) CheckTTL(ctx context.Context, remote, path string) (bool, error) {
|
|
key := fmt.Sprintf("ttl:%s:%s", remote, path)
|
|
exists, err := r.client.Exists(ctx, key).Result()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return exists > 0, nil
|
|
}
|
|
|
|
func (r *Redis) AcquireLock(ctx context.Context, remote, path string, ttl time.Duration) (bool, error) {
|
|
key := fmt.Sprintf("lock:%s:%s", remote, path)
|
|
ok, err := r.client.SetNX(ctx, key, "1", ttl).Result()
|
|
return ok, err
|
|
}
|
|
|
|
func (r *Redis) ReleaseLock(ctx context.Context, remote, path string) error {
|
|
key := fmt.Sprintf("lock:%s:%s", remote, path)
|
|
return r.client.Del(ctx, key).Err()
|
|
}
|
|
|
|
func (r *Redis) SetETag(ctx context.Context, remote, path, etag string, ttl time.Duration) error {
|
|
key := fmt.Sprintf("etag:%s:%s", remote, path)
|
|
return r.client.Set(ctx, key, etag, ttl).Err()
|
|
}
|
|
|
|
func (r *Redis) GetETag(ctx context.Context, remote, path string) (string, error) {
|
|
key := fmt.Sprintf("etag:%s:%s", remote, path)
|
|
val, err := r.client.Get(ctx, key).Result()
|
|
if err == redis.Nil {
|
|
return "", nil
|
|
}
|
|
return val, err
|
|
}
|
|
|
|
func (r *Redis) GetToken(ctx context.Context, key string) (string, error) {
|
|
val, err := r.client.Get(ctx, "token:"+key).Result()
|
|
if err == redis.Nil {
|
|
return "", nil
|
|
}
|
|
return val, err
|
|
}
|
|
|
|
func (r *Redis) SetToken(ctx context.Context, key, token string, ttl time.Duration) error {
|
|
return r.client.Set(ctx, "token:"+key, token, ttl).Err()
|
|
}
|
|
|
|
func (r *Redis) IncrCircuitFailure(ctx context.Context, remote string, cooldown time.Duration) (int64, error) {
|
|
key := fmt.Sprintf("circuit:%s", remote)
|
|
pipe := r.client.Pipeline()
|
|
incr := pipe.Incr(ctx, key)
|
|
pipe.Expire(ctx, key, cooldown)
|
|
_, err := pipe.Exec(ctx)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return incr.Val(), nil
|
|
}
|
|
|
|
func (r *Redis) ResetCircuit(ctx context.Context, remote string) error {
|
|
key := fmt.Sprintf("circuit:%s", remote)
|
|
return r.client.Del(ctx, key).Err()
|
|
}
|
|
|
|
func (r *Redis) GetCircuitFailures(ctx context.Context, remote string) (int64, error) {
|
|
key := fmt.Sprintf("circuit:%s", remote)
|
|
val, err := r.client.Get(ctx, key).Int64()
|
|
if err == redis.Nil {
|
|
return 0, nil
|
|
}
|
|
return val, err
|
|
}
|
|
|
|
func (r *Redis) FlushRemote(ctx context.Context, remote string) error {
|
|
iter := r.client.Scan(ctx, 0, fmt.Sprintf("*:%s:*", remote), 100).Iterator()
|
|
for iter.Next(ctx) {
|
|
r.client.Del(ctx, iter.Val())
|
|
}
|
|
return iter.Err()
|
|
}
|