feat: add pypi remote type with URL rewriting and basic auth
- Add 'pypi' package type to config.py; simple/ paths are mutable by default - Refactor content-type detection into _get_content_type() helper; add .whl - Add _resolve_content() which rewrites files host URLs in simple index HTML to go through the proxy (pypi_files_url / pypi_files_remote config keys), and returns text/html content-type for simple index responses - Add basic auth support for non-Docker remotes (username + password/token in remote config); thread auth through _upstream_reachable and check_upstream_changed so mutable TTL checks also authenticate - Add 'pypi' remote (pypi.org simple index) and 'pypi-files' remote (files.pythonhosted.org) to remotes.yaml; add 'pypi-gitea' example for Gitea package registries where index and files share the same base URL - Add unit tests: simple index URL rewriting, HTML content-type, .whl/.tar.gz content-types, mutable index detection, and immutable pattern enforcement
This commit is contained in:
@@ -18,6 +18,9 @@ _PACKAGE_MUTABLE_PATTERNS: dict[str, list[str]] = {
|
||||
r"/manifests/(?!sha256:)[^/]+$",
|
||||
r"/tags/list$",
|
||||
],
|
||||
"pypi": [
|
||||
r"simple/", # Per-package and top-level simple index pages
|
||||
],
|
||||
"generic": [],
|
||||
}
|
||||
|
||||
|
||||
+63
-45
@@ -1,3 +1,4 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
@@ -208,8 +209,11 @@ async def cache_single_artifact(url: str, remote_name: str, path: str) -> dict:
|
||||
remote_config = config.get_remote_config(remote_name) or {}
|
||||
is_docker = remote_config.get("package") == "docker" or "/v2/" in url
|
||||
|
||||
# Prepare headers for Docker registry requests
|
||||
# Prepare headers
|
||||
headers = {}
|
||||
username = remote_config.get("username")
|
||||
password = remote_config.get("password")
|
||||
|
||||
if is_docker:
|
||||
if "/manifests/" in url:
|
||||
headers["Accept"] = (
|
||||
@@ -220,6 +224,8 @@ async def cache_single_artifact(url: str, remote_name: str, path: str) -> dict:
|
||||
)
|
||||
elif "/blobs/" in url:
|
||||
headers["Accept"] = "application/octet-stream"
|
||||
elif username and password:
|
||||
headers["Authorization"] = "Basic " + base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
@@ -254,11 +260,20 @@ async def cache_single_artifact(url: str, remote_name: str, path: str) -> dict:
|
||||
return {"url": url, "status": "error", "error": str(e)}
|
||||
|
||||
|
||||
async def _upstream_reachable(url: str) -> bool:
|
||||
def _basic_auth_header(remote_cfg: dict) -> dict[str, str]:
|
||||
username = remote_cfg.get("username")
|
||||
password = remote_cfg.get("password")
|
||||
if username and password:
|
||||
token = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
return {"Authorization": f"Basic {token}"}
|
||||
return {}
|
||||
|
||||
|
||||
async def _upstream_reachable(url: str, auth_headers: dict | None = None) -> bool:
|
||||
"""HEAD with a short timeout. Returns False only on network/timeout errors."""
|
||||
try:
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
await client.head(url, timeout=10.0)
|
||||
await client.head(url, headers=auth_headers or {}, timeout=10.0)
|
||||
return True
|
||||
except (httpx.NetworkError, httpx.TimeoutException):
|
||||
return False
|
||||
@@ -266,19 +281,19 @@ async def _upstream_reachable(url: str) -> bool:
|
||||
return True # 4xx/5xx means backend is up
|
||||
|
||||
|
||||
async def check_upstream_changed(remote_url: str, remote_name: str, path: str) -> bool:
|
||||
async def check_upstream_changed(remote_url: str, remote_name: str, path: str, auth_headers: dict | None = None) -> bool:
|
||||
"""Conditional HEAD against upstream. Returns False only on a definitive 304.
|
||||
Raises UpstreamUnreachable if the backend cannot be contacted."""
|
||||
meta = cache.get_mutable_meta(remote_name, path)
|
||||
if not meta:
|
||||
return True
|
||||
|
||||
headers = {}
|
||||
headers = dict(auth_headers or {})
|
||||
if meta.get("etag"):
|
||||
headers["If-None-Match"] = meta["etag"]
|
||||
if meta.get("last_modified"):
|
||||
headers["If-Modified-Since"] = meta["last_modified"]
|
||||
if not headers:
|
||||
if not (meta.get("etag") or meta.get("last_modified")):
|
||||
return True
|
||||
|
||||
try:
|
||||
@@ -294,12 +309,13 @@ async def handle_expired_mutable(remote_name: str, path: str, remote_url: str) -
|
||||
mutable_ttl = config.get_cache_config(remote_name).get("mutable_ttl", 3600)
|
||||
|
||||
remote_cfg = config.get_remote_config(remote_name) or {}
|
||||
auth = _basic_auth_header(remote_cfg)
|
||||
check_updates = remote_cfg.get("check_mutable_updates", False)
|
||||
user_mutable = check_updates and cache.is_mutable_file(path, config.get_user_mutable_patterns(remote_name))
|
||||
|
||||
if user_mutable:
|
||||
try:
|
||||
changed = await check_upstream_changed(remote_url, remote_name, path)
|
||||
changed = await check_upstream_changed(remote_url, remote_name, path, auth)
|
||||
except UpstreamUnreachable:
|
||||
cache.mark_index_cached(remote_name, path, mutable_ttl)
|
||||
logger.warning(f"Mutable STALE (backend unreachable): {remote_name}/{path} - TTL extended ({mutable_ttl}s)")
|
||||
@@ -310,7 +326,7 @@ async def handle_expired_mutable(remote_name: str, path: str, remote_url: str) -
|
||||
return True
|
||||
logger.info(f"Mutable file CHANGED: {remote_name}/{path} - re-downloading")
|
||||
else:
|
||||
if not await _upstream_reachable(remote_url):
|
||||
if not await _upstream_reachable(remote_url, auth):
|
||||
cache.mark_index_cached(remote_name, path, mutable_ttl)
|
||||
logger.warning(f"Mutable STALE (backend unreachable): {remote_name}/{path} - TTL extended ({mutable_ttl}s)")
|
||||
return True
|
||||
@@ -320,8 +336,44 @@ async def handle_expired_mutable(remote_name: str, path: str, remote_url: str) -
|
||||
return False
|
||||
|
||||
|
||||
def _get_content_type(filename: str) -> str:
|
||||
if filename.endswith(".tar.gz"):
|
||||
return "application/gzip"
|
||||
if filename.endswith(".zip") or filename.endswith(".whl"):
|
||||
return "application/zip"
|
||||
if filename.endswith(".exe"):
|
||||
return "application/x-msdownload"
|
||||
if filename.endswith(".rpm"):
|
||||
return "application/x-rpm"
|
||||
if filename.endswith(".xml"):
|
||||
return "application/xml"
|
||||
if filename.endswith((".xml.gz", ".xml.bz2", ".xml.xz")):
|
||||
return "application/gzip"
|
||||
return "application/octet-stream"
|
||||
|
||||
|
||||
def _resolve_content(
|
||||
data: bytes,
|
||||
path: str,
|
||||
filename: str,
|
||||
remote_config: dict,
|
||||
request: Request,
|
||||
) -> tuple[bytes, str]:
|
||||
"""Return (possibly-rewritten data, content_type) for a cached artifact."""
|
||||
if remote_config.get("package") == "pypi" and "simple/" in path:
|
||||
files_url = remote_config.get("pypi_files_url", "https://files.pythonhosted.org")
|
||||
files_remote = remote_config.get("pypi_files_remote", "pypi-files")
|
||||
proxy_base = str(request.base_url).rstrip("/")
|
||||
data = data.replace(
|
||||
files_url.rstrip("/").encode(),
|
||||
f"{proxy_base}/api/v1/remote/{files_remote}".encode(),
|
||||
)
|
||||
return data, "text/html; charset=utf-8"
|
||||
return data, _get_content_type(filename)
|
||||
|
||||
|
||||
@app.get("/api/v1/remote/{remote_name}/{path:path}")
|
||||
async def get_artifact(remote_name: str, path: str):
|
||||
async def get_artifact(request: Request, remote_name: str, path: str):
|
||||
# Check if remote is configured
|
||||
remote_config = config.get_remote_config(remote_name)
|
||||
if not remote_config:
|
||||
@@ -384,29 +436,11 @@ async def get_artifact(remote_name: str, path: str):
|
||||
try:
|
||||
artifact_data = storage.download_object(cached_key)
|
||||
filename = os.path.basename(path)
|
||||
artifact_data, content_type = _resolve_content(artifact_data, path, filename, remote_config, request)
|
||||
|
||||
# Log cache hit
|
||||
logger.info(f"Cache HIT: {remote_name}/{path} (size: {len(artifact_data)} bytes, key: {cached_key})")
|
||||
|
||||
# Determine content type based on file extension
|
||||
content_type = "application/octet-stream"
|
||||
if filename.endswith(".tar.gz"):
|
||||
content_type = "application/gzip"
|
||||
elif filename.endswith(".zip"):
|
||||
content_type = "application/zip"
|
||||
elif filename.endswith(".exe"):
|
||||
content_type = "application/x-msdownload"
|
||||
elif filename.endswith(".rpm"):
|
||||
content_type = "application/x-rpm"
|
||||
elif filename.endswith(".xml"):
|
||||
content_type = "application/xml"
|
||||
elif filename.endswith((".xml.gz", ".xml.bz2", ".xml.xz")):
|
||||
content_type = "application/gzip"
|
||||
|
||||
# Record cache hit metrics
|
||||
metrics.record_cache_hit(remote_name, len(artifact_data))
|
||||
|
||||
# Record artifact mapping in database if not already recorded
|
||||
database.record_artifact_mapping(cached_key, remote_name, path, len(artifact_data))
|
||||
|
||||
return Response(
|
||||
@@ -443,25 +477,9 @@ async def get_artifact(remote_name: str, path: str):
|
||||
cache_key = storage.get_object_key(remote_name, path)
|
||||
artifact_data = storage.download_object(cache_key)
|
||||
filename = os.path.basename(path)
|
||||
artifact_data, content_type = _resolve_content(artifact_data, path, filename, remote_config, request)
|
||||
|
||||
content_type = "application/octet-stream"
|
||||
if filename.endswith(".tar.gz"):
|
||||
content_type = "application/gzip"
|
||||
elif filename.endswith(".zip"):
|
||||
content_type = "application/zip"
|
||||
elif filename.endswith(".exe"):
|
||||
content_type = "application/x-msdownload"
|
||||
elif filename.endswith(".rpm"):
|
||||
content_type = "application/x-rpm"
|
||||
elif filename.endswith(".xml"):
|
||||
content_type = "application/xml"
|
||||
elif filename.endswith((".xml.gz", ".xml.bz2", ".xml.xz")):
|
||||
content_type = "application/gzip"
|
||||
|
||||
# Record cache miss metrics
|
||||
metrics.record_cache_miss(remote_name, len(artifact_data))
|
||||
|
||||
# Record artifact mapping in database
|
||||
cache_key = storage.get_object_key(remote_name, path)
|
||||
database.record_artifact_mapping(cache_key, remote_name, path, len(artifact_data))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user